🔧 fix(endpoints.py): import TaskManager from the correct module to fix import error
🔧 fix(endpoints.py): add missing import statement for TaskManager 🔧 fix(endpoints.py): add missing import statement for sync parameter
This commit is contained in:
parent
f656b71173
commit
e303881155
3 changed files with 52 additions and 73 deletions
|
|
@ -33,6 +33,7 @@ from langflow.services.utils import get_session
|
|||
from langflow.worker import process_graph_cached_task
|
||||
from sqlmodel import Session
|
||||
|
||||
from langflow.services.task.manager import TaskManager
|
||||
|
||||
# build router
|
||||
router = APIRouter(tags=["Base"])
|
||||
|
|
@ -97,7 +98,9 @@ async def process_flow(
|
|||
tweaks: Optional[dict] = None,
|
||||
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
|
||||
task_manager: "TaskManager" = Depends(get_task_manager),
|
||||
api_key_user: User = Depends(api_key_security),
|
||||
sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821
|
||||
):
|
||||
"""
|
||||
Endpoint to process an input with a given flow_id.
|
||||
|
|
@ -128,18 +131,30 @@ async def process_flow(
|
|||
graph_data = process_tweaks(graph_data, tweaks)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error processing tweaks: {exc}")
|
||||
task_manager = get_task_manager()
|
||||
task_id = task_manager.launch_task(
|
||||
process_graph_cached_task
|
||||
if task_manager.use_celery
|
||||
else process_graph_cached,
|
||||
graph_data,
|
||||
inputs,
|
||||
clear_cache,
|
||||
session_id,
|
||||
)
|
||||
task = task_manager.get_task(task_id)
|
||||
return ProcessResponse(result=task.status, id=task_id)
|
||||
if sync:
|
||||
task_id, result = await task_manager.launch_and_await_task(
|
||||
process_graph_cached_task
|
||||
if task_manager.use_celery
|
||||
else process_graph_cached,
|
||||
graph_data,
|
||||
inputs,
|
||||
clear_cache,
|
||||
session_id,
|
||||
)
|
||||
task_result = result.result
|
||||
session_id = result.session_id
|
||||
else:
|
||||
task_id, task = await task_manager.launch_task(
|
||||
process_graph_cached_task
|
||||
if task_manager.use_celery
|
||||
else process_graph_cached,
|
||||
graph_data,
|
||||
inputs,
|
||||
clear_cache,
|
||||
session_id,
|
||||
)
|
||||
task_result = task.status
|
||||
return ProcessResponse(result=task_result, id=task_id, session_id=session_id)
|
||||
except sa.exc.StatementError as exc:
|
||||
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
|
||||
if "badly formed hexadecimal UUID string" in str(exc):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import asyncio
|
||||
from typing import Any, Callable, Union
|
||||
import logging
|
||||
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.task.utils import AsyncIOTaskResult, get_celery_worker_status
|
||||
from langflow.services.task.backends.anyio import AnyIOBackend
|
||||
from langflow.services.task.backends.base import TaskBackend
|
||||
from langflow.services.task.utils import get_celery_worker_status
|
||||
|
||||
try:
|
||||
from celery.result import AsyncResult
|
||||
from langflow.worker import celery_app
|
||||
|
||||
try:
|
||||
|
|
@ -20,51 +20,36 @@ except ImportError:
|
|||
|
||||
|
||||
class TaskManager(Service):
|
||||
STATUS_PENDING = "PENDING"
|
||||
STATUS_FINISHED = "FINISHED"
|
||||
STATUS_UNKNOWN = "UNKNOWN"
|
||||
name = "task_manager"
|
||||
|
||||
def __init__(self):
|
||||
self.tasks = {} # For storing asyncio tasks
|
||||
self.celery_results = {} # For storing Celery AsyncResult instances
|
||||
if USE_CELERY:
|
||||
from langflow.worker import celery_app
|
||||
|
||||
self.celery_app = celery_app
|
||||
|
||||
else:
|
||||
self.celery_app = None # To store the celery app if available
|
||||
self.backend = self.get_backend()
|
||||
self.use_celery = USE_CELERY
|
||||
|
||||
def launch_task(
|
||||
def get_backend(self) -> TaskBackend:
|
||||
if USE_CELERY:
|
||||
from langflow.services.task.backends.celery import CeleryBackend
|
||||
|
||||
return CeleryBackend()
|
||||
return AnyIOBackend()
|
||||
|
||||
# In your TaskManager class
|
||||
async def launch_and_await_task(
|
||||
self,
|
||||
task_func: Callable[..., Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Union[int, str]:
|
||||
if USE_CELERY:
|
||||
task = task_func.apply_async(args=args, kwargs=kwargs)
|
||||
self.celery_results[task.id] = task
|
||||
return task.id
|
||||
else:
|
||||
task = asyncio.create_task(task_func(*args, **kwargs))
|
||||
task_id = str(id(task))
|
||||
self.tasks[task_id] = AsyncIOTaskResult(task)
|
||||
) -> Any:
|
||||
if not self.use_celery:
|
||||
return None, await task_func(*args, **kwargs)
|
||||
task = task_func.apply(args=args, kwargs=kwargs)
|
||||
result = task.get()
|
||||
return task.id, result
|
||||
|
||||
def set_result(future):
|
||||
try:
|
||||
self.tasks[task_id] = AsyncIOTaskResult(future)
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
async def launch_task(
|
||||
self, task_func: Callable[..., Any], *args: Any, **kwargs: Any
|
||||
) -> Union[str, str]:
|
||||
return await self.backend.launch_task(task_func, *args, **kwargs)
|
||||
|
||||
task.add_done_callback(set_result)
|
||||
return task_id
|
||||
|
||||
# Update the get_task_status function in TaskManager class
|
||||
def get_task(
|
||||
self, task_id: Union[int, str]
|
||||
) -> Union[AsyncResult, AsyncIOTaskResult]:
|
||||
if self.use_celery:
|
||||
return AsyncResult(task_id, app=self.celery_app)
|
||||
return self.tasks.get(task_id)
|
||||
def get_task(self, task_id: Union[int, str]) -> Any:
|
||||
return self.backend.get_task(task_id)
|
||||
|
|
|
|||
|
|
@ -1,24 +1,3 @@
|
|||
from asyncio import Task
|
||||
|
||||
|
||||
class AsyncIOTaskResult:
|
||||
def __init__(self, task: Task):
|
||||
self._task = task
|
||||
|
||||
@property
|
||||
def status(self) -> str:
|
||||
if self._task.done():
|
||||
return "FAILURE" if self._task.exception() is not None else "SUCCESS"
|
||||
return "PENDING"
|
||||
|
||||
@property
|
||||
def result(self) -> any:
|
||||
return self._task.result() if self._task.done() else None
|
||||
|
||||
def ready(self) -> bool:
|
||||
return self._task.done()
|
||||
|
||||
|
||||
def get_celery_worker_status(app):
|
||||
i = app.control.inspect()
|
||||
availability = i.ping()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue