🐛 fix(anyio.py): add error handling when launching a task to prevent crashes and log the error
🐛 fix(celery.py): add type hinting to the launch_task method and return the AsyncResult object for better task tracking
This commit is contained in:
parent
abc4e8a3e0
commit
b8d8eccbff
2 changed files with 29 additions and 9 deletions
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Any, Callable, Tuple
|
||||
import anyio
|
||||
from langflow.services.task.backends.base import TaskBackend
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class AnyIOTaskResult:
|
||||
|
|
@ -38,13 +39,29 @@ class AnyIOBackend(TaskBackend):
|
|||
|
||||
async def launch_task(
|
||||
self, task_func: Callable[..., Any], *args: Any, **kwargs: Any
|
||||
) -> Tuple[str, AnyIOTaskResult]: # sourcery skip: remove-unnecessary-cast
|
||||
) -> Tuple[str, AnyIOTaskResult]:
|
||||
"""
|
||||
Launch a new task in an asynchronous manner.
|
||||
|
||||
Parameters:
|
||||
task_func: The asynchronous function to run.
|
||||
*args: Positional arguments to pass to task_func.
|
||||
**kwargs: Keyword arguments to pass to task_func.
|
||||
|
||||
Returns:
|
||||
A tuple containing a unique task ID and the task result object.
|
||||
"""
|
||||
async with anyio.create_task_group() as tg:
|
||||
task_result = AnyIOTaskResult(tg)
|
||||
tg.start_soon(task_result.run, task_func, *args, **kwargs)
|
||||
task_id = str(id(task_result))
|
||||
self.tasks[task_id] = task_result
|
||||
return task_id, task_result
|
||||
try:
|
||||
task_result = AnyIOTaskResult(tg)
|
||||
tg.start_soon(task_result.run, task_func, *args, **kwargs)
|
||||
task_id = str(id(task_result))
|
||||
self.tasks[task_id] = task_result
|
||||
logger.info(f"Task {task_id} started.")
|
||||
return task_id, task_result
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while launching the task: {e}")
|
||||
return None, None
|
||||
|
||||
def get_task(self, task_id: str) -> Any:
|
||||
return self.tasks.get(task_id)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, Callable
|
||||
from celery.result import AsyncResult # type: ignore
|
||||
from celery.result import AsyncResult
|
||||
from langflow.services.task.backends.base import TaskBackend
|
||||
from langflow.worker import celery_app
|
||||
|
||||
|
|
@ -11,10 +11,13 @@ class CeleryBackend(TaskBackend):
|
|||
def launch_task(
|
||||
self, task_func: Callable[..., Any], *args: Any, **kwargs: Any
|
||||
) -> str:
|
||||
# I need to type the delay method to make it easier
|
||||
from celery import Task
|
||||
|
||||
if not hasattr(task_func, "delay"):
|
||||
raise ValueError(f"Task function {task_func} does not have a delay method")
|
||||
task = task_func.delay(*args, **kwargs)
|
||||
return task.id
|
||||
task: Task = task_func.delay(*args, **kwargs)
|
||||
return task.id, AsyncResult(task.id, app=self.celery_app)
|
||||
|
||||
def get_task(self, task_id: str) -> Any:
|
||||
return AsyncResult(task_id, app=self.celery_app)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue