From a562c1f98e9f9cbc2d02dd4e788f0f15264c7d51 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sun, 14 Jan 2024 15:21:43 -0300 Subject: [PATCH] Refactor API endpoints and add preload functionality --- src/backend/langflow/api/v1/endpoints.py | 142 ++++++++++++++++++----- 1 file changed, 115 insertions(+), 27 deletions(-) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 0247be7e7..c8ef86906 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -3,12 +3,10 @@ from typing import Annotated, Any, List, Optional, Union import sqlalchemy as sa from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status -from loguru import logger -from sqlmodel import select - from langflow.api.utils import update_frontend_node_with_template_values from langflow.api.v1.schemas import ( CustomComponentCode, + PreloadResponse, ProcessResponse, TaskResponse, TaskStatusResponse, @@ -17,12 +15,15 @@ from langflow.api.v1.schemas import ( from langflow.interface.custom.custom_component import CustomComponent from langflow.interface.custom.directory_reader import DirectoryReader from langflow.interface.custom.utils import build_custom_component_template -from langflow.processing.process import process_graph_cached, process_tweaks +from langflow.processing.process import build_graph_and_generate_result, process_graph_cached, process_tweaks from langflow.services.auth.utils import api_key_security, get_current_active_user from langflow.services.cache.utils import save_uploaded_file from langflow.services.database.models.flow import Flow from langflow.services.database.models.user.model import User from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service +from langflow.services.session.service import SessionService +from loguru import logger +from sqlmodel import select try: from langflow.worker import process_graph_cached_task @@ -32,9 +33,8 @@ except ImportError: raise NotImplementedError("Celery is not installed") -from sqlmodel import Session - from langflow.services.task.service import TaskService +from sqlmodel import Session # build router router = APIRouter(tags=["Base"]) @@ -148,6 +148,55 @@ async def process_json( raise HTTPException(status_code=500, detail=str(exc)) from exc +# Endpoint to preload a graph +@router.post("/process/preload/{flow_id}", response_model=PreloadResponse) +async def preload_flow( + session: Annotated[Session, Depends(get_session)], + flow_id: str, + session_id: Optional[str] = None, + session_service: SessionService = Depends(get_session_service), + api_key_user: User = Depends(api_key_security), + clear_session: Annotated[bool, Body(embed=True)] = False, # noqa: F821 +): + try: + # Get the flow that matches the flow_id and belongs to the user + # flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first() + if clear_session: + session_service.clear_session(session_id) + # Check if the session exists + session_data = await session_service.load_session(session_id) + # Session data is a tuple of (graph, artifacts) + # or (None, None) if the session is empty + if isinstance(session_data, tuple): + graph, artifacts = session_data + is_clear = graph is None and artifacts is None + else: + is_clear = session_data is None + return PreloadResponse(session_id=session_id, is_clear=is_clear) + else: + if session_id is None: + session_id = flow_id + flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first() + if flow is None: + raise ValueError(f"Flow {flow_id} not found") + + if flow.data is None: + raise ValueError(f"Flow {flow_id} has no data") + graph_data = flow.data + session_service.clear_session(session_id) + # Load the graph using SessionService + session_data = await session_service.load_session(session_id, graph_data) + graph, artifacts = session_data if session_data else (None, None) + if not graph: + raise ValueError("Graph not found in the session") + _ = await graph.build() + session_service.update_session(session_id, (graph, artifacts)) + return PreloadResponse(session_id=session_id) + except Exception as exc: + logger.exception(exc) + raise HTTPException(status_code=500, detail=str(exc)) from exc + + @router.post( "/predict/{flow_id}", response_model=ProcessResponse, @@ -167,36 +216,75 @@ async def process( task_service: "TaskService" = Depends(get_task_service), api_key_user: User = Depends(api_key_security), sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821 + session_service: SessionService = Depends(get_session_service), ): """ Endpoint to process an input with a given flow_id. """ try: - if api_key_user is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API Key", + if session_id: + session_data = await session_service.load_session(session_id) + graph, artifacts = session_data if session_data else (None, None) + task_result: Any = None + task_status = None + task_id = None + if not graph: + raise ValueError("Graph not found in the session") + result = await build_graph_and_generate_result( + graph=graph, + inputs=inputs, + artifacts=artifacts, + session_id=session_id, + session_service=session_service, + ) + task_id = str(id(result)) + if isinstance(result, dict) and "result" in result: + task_result = result["result"] + session_id = result["session_id"] + elif hasattr(result, "result") and hasattr(result, "session_id"): + task_result = result.result + + session_id = result.session_id + else: + task_result = result + if task_id: + task_response = TaskResponse(id=task_id, href=f"api/v1/task/{task_id}") + else: + task_response = None + return ProcessResponse( + result=task_result, + status=task_status, + task=task_response, + session_id=session_id, + backend=task_service.backend_name, ) - # Get the flow that matches the flow_id and belongs to the user - # flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first() - flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first() - if flow is None: - raise ValueError(f"Flow {flow_id} not found") + else: + if api_key_user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API Key", + ) - if flow.data is None: - raise ValueError(f"Flow {flow_id} has no data") - graph_data = flow.data - return await process_graph_data( - graph_data=graph_data, - inputs=inputs, - tweaks=tweaks, - clear_cache=clear_cache, - session_id=session_id, - task_service=task_service, - sync=sync, - ) + # Get the flow that matches the flow_id and belongs to the user + # flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first() + flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first() + if flow is None: + raise ValueError(f"Flow {flow_id} not found") + + if flow.data is None: + raise ValueError(f"Flow {flow_id} has no data") + graph_data = flow.data + return await process_graph_data( + graph_data=graph_data, + inputs=inputs, + tweaks=tweaks, + clear_cache=clear_cache, + session_id=session_id, + task_service=task_service, + sync=sync, + ) except sa.exc.StatementError as exc: # StatementError('(builtins.ValueError) badly formed hexadecimal UUID string') if "badly formed hexadecimal UUID string" in str(exc):