diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 988130ba7..048b4baae 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -33,6 +33,7 @@ from langflow.services.utils import get_session from langflow.worker import process_graph_cached_task from sqlmodel import Session +from langflow.services.task.manager import TaskManager # build router router = APIRouter(tags=["Base"]) @@ -97,7 +98,9 @@ async def process_flow( tweaks: Optional[dict] = None, clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821 session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821 + task_manager: "TaskManager" = Depends(get_task_manager), api_key_user: User = Depends(api_key_security), + sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821 ): """ Endpoint to process an input with a given flow_id. @@ -128,18 +131,30 @@ async def process_flow( graph_data = process_tweaks(graph_data, tweaks) except Exception as exc: logger.error(f"Error processing tweaks: {exc}") - task_manager = get_task_manager() - task_id = task_manager.launch_task( - process_graph_cached_task - if task_manager.use_celery - else process_graph_cached, - graph_data, - inputs, - clear_cache, - session_id, - ) - task = task_manager.get_task(task_id) - return ProcessResponse(result=task.status, id=task_id) + if sync: + task_id, result = await task_manager.launch_and_await_task( + process_graph_cached_task + if task_manager.use_celery + else process_graph_cached, + graph_data, + inputs, + clear_cache, + session_id, + ) + task_result = result.result + session_id = result.session_id + else: + task_id, task = await task_manager.launch_task( + process_graph_cached_task + if task_manager.use_celery + else process_graph_cached, + graph_data, + inputs, + clear_cache, + session_id, + ) + task_result = task.status + return ProcessResponse(result=task_result, id=task_id, session_id=session_id) except sa.exc.StatementError as exc: # StatementError('(builtins.ValueError) badly formed hexadecimal UUID string') if "badly formed hexadecimal UUID string" in str(exc): diff --git a/src/backend/langflow/services/task/manager.py b/src/backend/langflow/services/task/manager.py index 6fc9dfa16..c3e5e6444 100644 --- a/src/backend/langflow/services/task/manager.py +++ b/src/backend/langflow/services/task/manager.py @@ -1,12 +1,12 @@ -import asyncio from typing import Any, Callable, Union import logging from langflow.services.base import Service -from langflow.services.task.utils import AsyncIOTaskResult, get_celery_worker_status +from langflow.services.task.backends.anyio import AnyIOBackend +from langflow.services.task.backends.base import TaskBackend +from langflow.services.task.utils import get_celery_worker_status try: - from celery.result import AsyncResult from langflow.worker import celery_app try: @@ -20,51 +20,36 @@ except ImportError: class TaskManager(Service): - STATUS_PENDING = "PENDING" - STATUS_FINISHED = "FINISHED" - STATUS_UNKNOWN = "UNKNOWN" name = "task_manager" def __init__(self): - self.tasks = {} # For storing asyncio tasks - self.celery_results = {} # For storing Celery AsyncResult instances - if USE_CELERY: - from langflow.worker import celery_app - - self.celery_app = celery_app - - else: - self.celery_app = None # To store the celery app if available + self.backend = self.get_backend() self.use_celery = USE_CELERY - def launch_task( + def get_backend(self) -> TaskBackend: + if USE_CELERY: + from langflow.services.task.backends.celery import CeleryBackend + + return CeleryBackend() + return AnyIOBackend() + + # In your TaskManager class + async def launch_and_await_task( self, task_func: Callable[..., Any], *args: Any, **kwargs: Any, - ) -> Union[int, str]: - if USE_CELERY: - task = task_func.apply_async(args=args, kwargs=kwargs) - self.celery_results[task.id] = task - return task.id - else: - task = asyncio.create_task(task_func(*args, **kwargs)) - task_id = str(id(task)) - self.tasks[task_id] = AsyncIOTaskResult(task) + ) -> Any: + if not self.use_celery: + return None, await task_func(*args, **kwargs) + task = task_func.apply(args=args, kwargs=kwargs) + result = task.get() + return task.id, result - def set_result(future): - try: - self.tasks[task_id] = AsyncIOTaskResult(future) - except Exception as e: - logging.error(f"An error occurred: {e}") + async def launch_task( + self, task_func: Callable[..., Any], *args: Any, **kwargs: Any + ) -> Union[str, str]: + return await self.backend.launch_task(task_func, *args, **kwargs) - task.add_done_callback(set_result) - return task_id - - # Update the get_task_status function in TaskManager class - def get_task( - self, task_id: Union[int, str] - ) -> Union[AsyncResult, AsyncIOTaskResult]: - if self.use_celery: - return AsyncResult(task_id, app=self.celery_app) - return self.tasks.get(task_id) + def get_task(self, task_id: Union[int, str]) -> Any: + return self.backend.get_task(task_id) diff --git a/src/backend/langflow/services/task/utils.py b/src/backend/langflow/services/task/utils.py index 7736c256a..412b33ae2 100644 --- a/src/backend/langflow/services/task/utils.py +++ b/src/backend/langflow/services/task/utils.py @@ -1,24 +1,3 @@ -from asyncio import Task - - -class AsyncIOTaskResult: - def __init__(self, task: Task): - self._task = task - - @property - def status(self) -> str: - if self._task.done(): - return "FAILURE" if self._task.exception() is not None else "SUCCESS" - return "PENDING" - - @property - def result(self) -> any: - return self._task.result() if self._task.done() else None - - def ready(self) -> bool: - return self._task.done() - - def get_celery_worker_status(app): i = app.control.inspect() availability = i.ping()