From c26175aede4b94dcf61cea20c8e0ff3cae2d4946 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 27 Feb 2024 13:46:00 -0300 Subject: [PATCH] Add new run_flow_with_caching endpoint --- src/backend/langflow/api/v1/endpoints.py | 139 ++++++++++++++++++++--- 1 file changed, 124 insertions(+), 15 deletions(-) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 8a3f67ddf..2dc79e85a 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -3,11 +3,15 @@ from typing import Annotated, Any, List, Optional, Union import sqlalchemy as sa from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status +from loguru import logger +from sqlmodel import select + from langflow.api.utils import update_frontend_node_with_template_values from langflow.api.v1.schemas import ( CustomComponentCode, PreloadResponse, ProcessResponse, + RunResponse, TaskResponse, TaskStatusResponse, UploadFileResponse, @@ -15,15 +19,23 @@ from langflow.api.v1.schemas import ( from langflow.interface.custom.custom_component import CustomComponent from langflow.interface.custom.directory_reader import DirectoryReader from langflow.interface.custom.utils import build_custom_component_template -from langflow.processing.process import build_graph_and_generate_result, process_graph_cached, process_tweaks +from langflow.processing.process import ( + build_graph_and_generate_result, + process_graph_cached, + process_tweaks, + run_graph, +) from langflow.services.auth.utils import api_key_security, get_current_active_user from langflow.services.cache.utils import save_uploaded_file from langflow.services.database.models.flow import Flow from langflow.services.database.models.user.model import User -from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service +from langflow.services.deps import ( + get_session, + get_session_service, + get_settings_service, + get_task_service, +) from langflow.services.session.service import SessionService -from loguru import logger -from sqlmodel import select try: from langflow.worker import process_graph_cached_task @@ -33,9 +45,10 @@ except ImportError: raise NotImplementedError("Celery is not installed") -from langflow.services.task.service import TaskService from sqlmodel import Session +from langflow.services.task.service import TaskService + # build router router = APIRouter(tags=["Base"]) @@ -80,9 +93,15 @@ async def process_graph_data( ) if session_id is None: # Generate a session ID - session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data) + session_id = get_session_service().generate_key( + session_id=session_id, data_graph=graph_data + ) task_id, task = await task_service.launch_task( - process_graph_cached_task if task_service.use_celery else process_graph_cached, + ( + process_graph_cached_task + if task_service.use_celery + else process_graph_cached + ), graph_data, inputs, clear_cache, @@ -176,7 +195,11 @@ async def preload_flow( else: if session_id is None: session_id = flow_id - flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first() + flow = session.exec( + select(Flow) + .where(Flow.id == flow_id) + .where(Flow.user_id == api_key_user.id) + ).first() if flow is None: raise ValueError(f"Flow {flow_id} not found") @@ -197,6 +220,76 @@ async def preload_flow( raise HTTPException(status_code=500, detail=str(exc)) from exc +@router.post("/run/{flow_id}", response_model=ProcessResponse) +async def run_flow_with_caching( + session: Annotated[Session, Depends(get_session)], + flow_id: str, + inputs: Optional[Union[List[dict], dict]] = None, + tweaks: Optional[dict] = None, + session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821 + api_key_user: User = Depends(api_key_security), + session_service: SessionService = Depends(get_session_service), +): + try: + if session_id: + session_data = await session_service.load_session(session_id) + graph, artifacts = session_data if session_data else (None, None) + task_result: Any = None + task_status = None + if not graph: + raise ValueError("Graph not found in the session") + task_result = await run_graph( + graph, + session_id, + inputs, + artifacts=artifacts, + session_service=session_service, + ) + + else: + # Get the flow that matches the flow_id and belongs to the user + # flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first() + flow = session.exec( + select(Flow) + .where(Flow.id == flow_id) + .where(Flow.user_id == api_key_user.id) + ).first() + if flow is None: + raise ValueError(f"Flow {flow_id} not found") + + if flow.data is None: + raise ValueError(f"Flow {flow_id} has no data") + graph_data = flow.data + graph_data = process_tweaks(graph_data, tweaks) + task_result = await run_graph( + graph_data, + inputs, + tweaks, + session_id, + session_service=session_service, + ) + + return RunResponse( + outputs=task_result, session_id=session_id, status=task_status + ) + except sa.exc.StatementError as exc: + # StatementError('(builtins.ValueError) badly formed hexadecimal UUID string') + if "badly formed hexadecimal UUID string" in str(exc): + # This means the Flow ID is not a valid UUID which means it can't find the flow + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc + except ValueError as exc: + if f"Flow {flow_id} not found" in str(exc): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc + else: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc) + ) from exc + + @router.post( "/predict/{flow_id}", response_model=ProcessResponse, @@ -269,7 +362,11 @@ async def process( # Get the flow that matches the flow_id and belongs to the user # flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first() - flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first() + flow = session.exec( + select(Flow) + .where(Flow.id == flow_id) + .where(Flow.user_id == api_key_user.id) + ).first() if flow is None: raise ValueError(f"Flow {flow_id} not found") @@ -289,12 +386,18 @@ async def process( # StatementError('(builtins.ValueError) badly formed hexadecimal UUID string') if "badly formed hexadecimal UUID string" in str(exc): # This means the Flow ID is not a valid UUID which means it can't find the flow - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc except ValueError as exc: if f"Flow {flow_id} not found" in str(exc): - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc else: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc) + ) from exc except Exception as e: # Log stack trace logger.exception(e) @@ -364,12 +467,16 @@ async def custom_component( built_frontend_node = build_custom_component_template(component, user_id=user.id) - built_frontend_node = update_frontend_node_with_template_values(built_frontend_node, raw_code.frontend_node) + built_frontend_node = update_frontend_node_with_template_values( + built_frontend_node, raw_code.frontend_node + ) return built_frontend_node @router.post("/custom_component/reload", status_code=HTTPStatus.OK) -async def reload_custom_component(path: str, user: User = Depends(get_current_active_user)): +async def reload_custom_component( + path: str, user: User = Depends(get_current_active_user) +): from langflow.interface.custom.utils import build_custom_component_template try: @@ -391,6 +498,8 @@ async def custom_component_update( ): component = CustomComponent(code=raw_code.code) - component_node = build_custom_component_template(component, user_id=user.id, update_field=raw_code.field) + component_node = build_custom_component_template( + component, user_id=user.id, update_field=raw_code.field + ) # Update the field return component_node