From b8d8eccbffb2f601ce6a94936a9c206df2cad22a Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sat, 16 Sep 2023 11:13:08 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(anyio.py):=20add=20error=20h?= =?UTF-8?q?andling=20when=20launching=20a=20task=20to=20prevent=20crashes?= =?UTF-8?q?=20and=20log=20the=20error=20=F0=9F=90=9B=20fix(celery.py):=20a?= =?UTF-8?q?dd=20type=20hinting=20to=20the=20launch=5Ftask=20method=20and?= =?UTF-8?q?=20return=20the=20AsyncResult=20object=20for=20better=20task=20?= =?UTF-8?q?tracking?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../langflow/services/task/backends/anyio.py | 29 +++++++++++++++---- .../langflow/services/task/backends/celery.py | 9 ++++-- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/src/backend/langflow/services/task/backends/anyio.py b/src/backend/langflow/services/task/backends/anyio.py index 20691212d..6c833443a 100644 --- a/src/backend/langflow/services/task/backends/anyio.py +++ b/src/backend/langflow/services/task/backends/anyio.py @@ -1,6 +1,7 @@ from typing import Any, Callable, Tuple import anyio from langflow.services.task.backends.base import TaskBackend +from loguru import logger class AnyIOTaskResult: @@ -38,13 +39,29 @@ class AnyIOBackend(TaskBackend): async def launch_task( self, task_func: Callable[..., Any], *args: Any, **kwargs: Any - ) -> Tuple[str, AnyIOTaskResult]: # sourcery skip: remove-unnecessary-cast + ) -> Tuple[str, AnyIOTaskResult]: + """ + Launch a new task in an asynchronous manner. + + Parameters: + task_func: The asynchronous function to run. + *args: Positional arguments to pass to task_func. + **kwargs: Keyword arguments to pass to task_func. + + Returns: + A tuple containing a unique task ID and the task result object. + """ async with anyio.create_task_group() as tg: - task_result = AnyIOTaskResult(tg) - tg.start_soon(task_result.run, task_func, *args, **kwargs) - task_id = str(id(task_result)) - self.tasks[task_id] = task_result - return task_id, task_result + try: + task_result = AnyIOTaskResult(tg) + tg.start_soon(task_result.run, task_func, *args, **kwargs) + task_id = str(id(task_result)) + self.tasks[task_id] = task_result + logger.info(f"Task {task_id} started.") + return task_id, task_result + except Exception as e: + logger.error(f"An error occurred while launching the task: {e}") + return None, None def get_task(self, task_id: str) -> Any: return self.tasks.get(task_id) diff --git a/src/backend/langflow/services/task/backends/celery.py b/src/backend/langflow/services/task/backends/celery.py index 3a6e1f450..f4f81b9cc 100644 --- a/src/backend/langflow/services/task/backends/celery.py +++ b/src/backend/langflow/services/task/backends/celery.py @@ -1,5 +1,5 @@ from typing import Any, Callable -from celery.result import AsyncResult # type: ignore +from celery.result import AsyncResult from langflow.services.task.backends.base import TaskBackend from langflow.worker import celery_app @@ -11,10 +11,13 @@ class CeleryBackend(TaskBackend): def launch_task( self, task_func: Callable[..., Any], *args: Any, **kwargs: Any ) -> str: + # I need to type the delay method to make it easier + from celery import Task + if not hasattr(task_func, "delay"): raise ValueError(f"Task function {task_func} does not have a delay method") - task = task_func.delay(*args, **kwargs) - return task.id + task: Task = task_func.delay(*args, **kwargs) + return task.id, AsyncResult(task.id, app=self.celery_app) def get_task(self, task_id: str) -> Any: return AsyncResult(task_id, app=self.celery_app)