🔧 fix(endpoints.py): import TaskManager from the correct module to fix import error

🔧 fix(endpoints.py): add missing import statement for TaskManager
🔧 fix(endpoints.py): add missing import statement for sync parameter
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-09-04 09:37:18 -03:00
commit e303881155
3 changed files with 52 additions and 73 deletions

View file

@ -33,6 +33,7 @@ from langflow.services.utils import get_session
from langflow.worker import process_graph_cached_task
from sqlmodel import Session
from langflow.services.task.manager import TaskManager
# build router
router = APIRouter(tags=["Base"])
@ -97,7 +98,9 @@ async def process_flow(
tweaks: Optional[dict] = None,
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
task_manager: "TaskManager" = Depends(get_task_manager),
api_key_user: User = Depends(api_key_security),
sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821
):
"""
Endpoint to process an input with a given flow_id.
@ -128,18 +131,30 @@ async def process_flow(
graph_data = process_tweaks(graph_data, tweaks)
except Exception as exc:
logger.error(f"Error processing tweaks: {exc}")
task_manager = get_task_manager()
task_id = task_manager.launch_task(
process_graph_cached_task
if task_manager.use_celery
else process_graph_cached,
graph_data,
inputs,
clear_cache,
session_id,
)
task = task_manager.get_task(task_id)
return ProcessResponse(result=task.status, id=task_id)
if sync:
task_id, result = await task_manager.launch_and_await_task(
process_graph_cached_task
if task_manager.use_celery
else process_graph_cached,
graph_data,
inputs,
clear_cache,
session_id,
)
task_result = result.result
session_id = result.session_id
else:
task_id, task = await task_manager.launch_task(
process_graph_cached_task
if task_manager.use_celery
else process_graph_cached,
graph_data,
inputs,
clear_cache,
session_id,
)
task_result = task.status
return ProcessResponse(result=task_result, id=task_id, session_id=session_id)
except sa.exc.StatementError as exc:
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):

View file

@ -1,12 +1,12 @@
import asyncio
from typing import Any, Callable, Union
import logging
from langflow.services.base import Service
from langflow.services.task.utils import AsyncIOTaskResult, get_celery_worker_status
from langflow.services.task.backends.anyio import AnyIOBackend
from langflow.services.task.backends.base import TaskBackend
from langflow.services.task.utils import get_celery_worker_status
try:
from celery.result import AsyncResult
from langflow.worker import celery_app
try:
@ -20,51 +20,36 @@ except ImportError:
class TaskManager(Service):
STATUS_PENDING = "PENDING"
STATUS_FINISHED = "FINISHED"
STATUS_UNKNOWN = "UNKNOWN"
name = "task_manager"
def __init__(self):
self.tasks = {} # For storing asyncio tasks
self.celery_results = {} # For storing Celery AsyncResult instances
if USE_CELERY:
from langflow.worker import celery_app
self.celery_app = celery_app
else:
self.celery_app = None # To store the celery app if available
self.backend = self.get_backend()
self.use_celery = USE_CELERY
def launch_task(
def get_backend(self) -> TaskBackend:
if USE_CELERY:
from langflow.services.task.backends.celery import CeleryBackend
return CeleryBackend()
return AnyIOBackend()
# In your TaskManager class
async def launch_and_await_task(
self,
task_func: Callable[..., Any],
*args: Any,
**kwargs: Any,
) -> Union[int, str]:
if USE_CELERY:
task = task_func.apply_async(args=args, kwargs=kwargs)
self.celery_results[task.id] = task
return task.id
else:
task = asyncio.create_task(task_func(*args, **kwargs))
task_id = str(id(task))
self.tasks[task_id] = AsyncIOTaskResult(task)
) -> Any:
if not self.use_celery:
return None, await task_func(*args, **kwargs)
task = task_func.apply(args=args, kwargs=kwargs)
result = task.get()
return task.id, result
def set_result(future):
try:
self.tasks[task_id] = AsyncIOTaskResult(future)
except Exception as e:
logging.error(f"An error occurred: {e}")
async def launch_task(
self, task_func: Callable[..., Any], *args: Any, **kwargs: Any
) -> Union[str, str]:
return await self.backend.launch_task(task_func, *args, **kwargs)
task.add_done_callback(set_result)
return task_id
# Update the get_task_status function in TaskManager class
def get_task(
self, task_id: Union[int, str]
) -> Union[AsyncResult, AsyncIOTaskResult]:
if self.use_celery:
return AsyncResult(task_id, app=self.celery_app)
return self.tasks.get(task_id)
def get_task(self, task_id: Union[int, str]) -> Any:
return self.backend.get_task(task_id)

View file

@ -1,24 +1,3 @@
from asyncio import Task
class AsyncIOTaskResult:
def __init__(self, task: Task):
self._task = task
@property
def status(self) -> str:
if self._task.done():
return "FAILURE" if self._task.exception() is not None else "SUCCESS"
return "PENDING"
@property
def result(self) -> any:
return self._task.result() if self._task.done() else None
def ready(self) -> bool:
return self._task.done()
def get_celery_worker_status(app):
i = app.control.inspect()
availability = i.ping()