🐛 fix(endpoints.py): import missing get_task_manager function to fix NameError

 feat(endpoints.py): add support for launching tasks using task manager to improve task management
🐛 fix(endpoints.py): change process_graph_cached_task function call to use task_manager.launch_task to fix AttributeError
🐛 fix(endpoints.py): change process_flow function to use task_manager.get_task to fix AttributeError
🐛 fix(endpoints.py): change get_task_status function to use task_manager.get_task to fix AttributeError
🐛 fix(process.py): change process_graph_cached function to be async to fix TypeError
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-09-01 17:22:30 -03:00
commit e26d61fcfe
2 changed files with 18 additions and 9 deletions

View file

@ -5,9 +5,9 @@ from langflow.services.auth.utils import api_key_security, get_current_active_us
from langflow.services.cache.utils import save_uploaded_file
from langflow.services.database.models.flow import Flow
from langflow.processing.process import process_tweaks
from langflow.processing.process import process_graph_cached, process_tweaks
from langflow.services.database.models.user.user import User
from langflow.services.utils import get_settings_manager
from langflow.services.utils import get_settings_manager, get_task_manager
from langflow.utils.logger import logger
from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body, status
import sqlalchemy as sa
@ -33,7 +33,6 @@ from langflow.services.utils import get_session
from langflow.worker import process_graph_cached_task
from sqlmodel import Session
from celery.result import AsyncResult
# build router
router = APIRouter(tags=["Base"])
@ -129,11 +128,18 @@ async def process_flow(
graph_data = process_tweaks(graph_data, tweaks)
except Exception as exc:
logger.error(f"Error processing tweaks: {exc}")
task: "AsyncResult" = process_graph_cached_task.delay(
graph_data, inputs, clear_cache, session_id
task_manager = get_task_manager()
task_id = task_manager.launch_task(
process_graph_cached_task
if task_manager.use_celery
else process_graph_cached,
graph_data,
inputs,
clear_cache,
session_id,
)
return ProcessResponse(result=task.state, id=task.id)
task = task_manager.get_task(task_id)
return ProcessResponse(result=task.status, id=task_id)
except sa.exc.StatementError as exc:
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):
@ -158,7 +164,10 @@ async def process_flow(
@router.get("/task/{task_id}/status", response_model=TaskStatusResponse)
async def get_task_status(task_id: str):
task = AsyncResult(task_id)
task_manager = get_task_manager()
task = task_manager.get_task(task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
return TaskStatusResponse(
status=task.status, result=task.result if task.ready() else None
)

View file

@ -173,7 +173,7 @@ def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
# return result, session_id
def process_graph_cached(
async def process_graph_cached(
data_graph: Dict[str, Any],
inputs: Optional[dict] = None,
clear_cache=False,