🐛 fix(endpoints.py): import missing get_task_manager function to fix NameError
✨ feat(endpoints.py): add support for launching tasks using task manager to improve task management 🐛 fix(endpoints.py): change process_graph_cached_task function call to use task_manager.launch_task to fix AttributeError 🐛 fix(endpoints.py): change process_flow function to use task_manager.get_task to fix AttributeError 🐛 fix(endpoints.py): change get_task_status function to use task_manager.get_task to fix AttributeError 🐛 fix(process.py): change process_graph_cached function to be async to fix TypeError
This commit is contained in:
parent
bd6655b0db
commit
e26d61fcfe
2 changed files with 18 additions and 9 deletions
|
|
@ -5,9 +5,9 @@ from langflow.services.auth.utils import api_key_security, get_current_active_us
|
|||
|
||||
from langflow.services.cache.utils import save_uploaded_file
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.processing.process import process_tweaks
|
||||
from langflow.processing.process import process_graph_cached, process_tweaks
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.utils import get_settings_manager
|
||||
from langflow.services.utils import get_settings_manager, get_task_manager
|
||||
from langflow.utils.logger import logger
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body, status
|
||||
import sqlalchemy as sa
|
||||
|
|
@ -33,7 +33,6 @@ from langflow.services.utils import get_session
|
|||
from langflow.worker import process_graph_cached_task
|
||||
from sqlmodel import Session
|
||||
|
||||
from celery.result import AsyncResult
|
||||
|
||||
# build router
|
||||
router = APIRouter(tags=["Base"])
|
||||
|
|
@ -129,11 +128,18 @@ async def process_flow(
|
|||
graph_data = process_tweaks(graph_data, tweaks)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error processing tweaks: {exc}")
|
||||
|
||||
task: "AsyncResult" = process_graph_cached_task.delay(
|
||||
graph_data, inputs, clear_cache, session_id
|
||||
task_manager = get_task_manager()
|
||||
task_id = task_manager.launch_task(
|
||||
process_graph_cached_task
|
||||
if task_manager.use_celery
|
||||
else process_graph_cached,
|
||||
graph_data,
|
||||
inputs,
|
||||
clear_cache,
|
||||
session_id,
|
||||
)
|
||||
return ProcessResponse(result=task.state, id=task.id)
|
||||
task = task_manager.get_task(task_id)
|
||||
return ProcessResponse(result=task.status, id=task_id)
|
||||
except sa.exc.StatementError as exc:
|
||||
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
|
||||
if "badly formed hexadecimal UUID string" in str(exc):
|
||||
|
|
@ -158,7 +164,10 @@ async def process_flow(
|
|||
|
||||
@router.get("/task/{task_id}/status", response_model=TaskStatusResponse)
|
||||
async def get_task_status(task_id: str):
|
||||
task = AsyncResult(task_id)
|
||||
task_manager = get_task_manager()
|
||||
task = task_manager.get_task(task_id)
|
||||
if task is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return TaskStatusResponse(
|
||||
status=task.status, result=task.result if task.ready() else None
|
||||
)
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
|
|||
# return result, session_id
|
||||
|
||||
|
||||
def process_graph_cached(
|
||||
async def process_graph_cached(
|
||||
data_graph: Dict[str, Any],
|
||||
inputs: Optional[dict] = None,
|
||||
clear_cache=False,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue