From 34e116ddf27b3afe7d4c74b9c9c4144cd157ccf1 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 18 Dec 2023 10:52:33 -0300 Subject: [PATCH] Add process_json endpoint for processing JSON data --- src/backend/langflow/api/v1/endpoints.py | 179 +++++++++++++++-------- 1 file changed, 117 insertions(+), 62 deletions(-) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index ccba9674c..aef197403 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -3,24 +3,26 @@ from typing import Annotated, Optional, Union import sqlalchemy as sa from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status +from loguru import logger +from sqlmodel import select + from langflow.api.utils import update_frontend_node_with_template_values -from langflow.api.v1.schemas import (CustomComponentCode, ProcessResponse, - TaskResponse, TaskStatusResponse, - UploadFileResponse) +from langflow.api.v1.schemas import ( + CustomComponentCode, + ProcessResponse, + TaskResponse, + TaskStatusResponse, + UploadFileResponse, +) from langflow.interface.custom.custom_component import CustomComponent from langflow.interface.custom.directory_reader import DirectoryReader -from langflow.interface.custom.utils import (build_custom_component_template, - create_and_validate_component) +from langflow.interface.custom.utils import build_custom_component_template, create_and_validate_component from langflow.processing.process import process_graph_cached, process_tweaks -from langflow.services.auth.utils import (api_key_security, - get_current_active_user) +from langflow.services.auth.utils import api_key_security, get_current_active_user from langflow.services.cache.utils import save_uploaded_file from langflow.services.database.models.flow import Flow from langflow.services.database.models.user.model import User -from langflow.services.deps import (get_session, get_session_service, - get_settings_service, get_task_service) -from loguru import logger -from sqlmodel import select +from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service try: from langflow.worker import process_graph_cached_task @@ -30,13 +32,81 @@ except ImportError: raise NotImplementedError("Celery is not installed") -from langflow.services.task.service import TaskService from sqlmodel import Session +from langflow.services.task.service import TaskService + # build router router = APIRouter(tags=["Base"]) +async def process_graph_data( + graph_data: dict, + inputs: Optional[dict] = None, + tweaks: Optional[dict] = None, + clear_cache: bool = False, + session_id: Optional[str] = None, + task_service: "TaskService" = Depends(get_task_service), + sync: bool = True, +): + task_result = None + task_status = None + if tweaks: + try: + graph_data = process_tweaks(graph_data, tweaks) + except Exception as exc: + logger.error(f"Error processing tweaks: {exc}") + if sync: + task_id, result = await task_service.launch_and_await_task( + process_graph_cached_task if task_service.use_celery else process_graph_cached, + graph_data, + inputs, + clear_cache, + session_id, + ) + if isinstance(result, dict) and "result" in result: + task_result = result["result"] + session_id = result["session_id"] + elif hasattr(result, "result") and hasattr(result, "session_id"): + task_result = result.result + + session_id = result.session_id + else: + logger.warning( + "This is an experimental feature and may not work as expected." + "Please report any issues to our GitHub repository." + ) + if session_id is None: + # Generate a session ID + session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data) + task_id, task = await task_service.launch_task( + process_graph_cached_task if task_service.use_celery else process_graph_cached, + graph_data, + inputs, + clear_cache, + session_id, + ) + task_status = task.status + if task.status == "FAILURE": + logger.error(f"Task {task_id} failed: {task.traceback}") + task_result = str(task._exception) + else: + task_result = task.result + + if task_id: + task_response = TaskResponse(id=task_id, href=f"api/v1/task/{task_id}") + else: + task_response = None + + return ProcessResponse( + result=task_result, + status=task_status, + task=task_response, + session_id=session_id, + backend=task_service.backend_name, + ) + + @router.get("/all", dependencies=[Depends(get_current_active_user)]) def get_all( settings_service=Depends(get_settings_service), @@ -50,7 +120,32 @@ def get_all( raise HTTPException(status_code=500, detail=str(exc)) from exc -# For backwards compatibility we will keep the old endpoint +@router.post("/process/json", response_model=ProcessResponse) +async def process_json( + session: Annotated[Session, Depends(get_session)], + data: dict, + inputs: Optional[dict] = None, + tweaks: Optional[dict] = None, + clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821 + session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821 + task_service: "TaskService" = Depends(get_task_service), + sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821 +): + try: + return await process_graph_data( + graph_data=data, + inputs=inputs, + tweaks=tweaks, + clear_cache=clear_cache, + session_id=session_id, + task_service=task_service, + sync=sync, + ) + except Exception as exc: + logger.exception(exc) + raise HTTPException(status_code=500, detail=str(exc)) from exc + + @router.post( "/predict/{flow_id}", response_model=ProcessResponse, @@ -91,54 +186,14 @@ async def process( if flow.data is None: raise ValueError(f"Flow {flow_id} has no data") graph_data = flow.data - task_result = None - if tweaks: - try: - graph_data = process_tweaks(graph_data, tweaks) - except Exception as exc: - logger.error(f"Error processing tweaks: {exc}") - if sync: - task_id, result = await task_service.launch_and_await_task( - process_graph_cached_task if task_service.use_celery else process_graph_cached, - graph_data, - inputs, - clear_cache, - session_id, - ) - if isinstance(result, dict) and "result" in result: - task_result = result["result"] - session_id = result["session_id"] - elif hasattr(result, "result") and hasattr(result, "session_id"): - task_result = result.result - - session_id = result.session_id - else: - logger.warning( - "This is an experimental feature and may not work as expected." - "Please report any issues to our GitHub repository." - ) - if session_id is None: - # Generate a session ID - session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data) - task_id, task = await task_service.launch_task( - process_graph_cached_task if task_service.use_celery else process_graph_cached, - graph_data, - inputs, - clear_cache, - session_id, - ) - task_result = task.status - - if task_id: - task_response = TaskResponse(id=task_id, href=f"api/v1/task/{task_id}") - else: - task_response = None - - return ProcessResponse( - result=task_result, - task=task_response, + return await process_graph_data( + graph_data=graph_data, + inputs=inputs, + tweaks=tweaks, + clear_cache=clear_cache, session_id=session_id, - backend=task_service.backend_name, + task_service=task_service, + sync=sync, ) except sa.exc.StatementError as exc: # StatementError('(builtins.ValueError) badly formed hexadecimal UUID string') @@ -161,6 +216,8 @@ async def get_task_status(task_id: str): task_service = get_task_service() task = task_service.get_task(task_id) result = None + if task is None: + raise HTTPException(status_code=404, detail="Task not found") if task.ready(): result = task.result # If result isinstance of Exception, can we get the traceback? @@ -172,8 +229,6 @@ async def get_task_status(task_id: str): elif hasattr(result, "result"): result = result.result - if task is None: - raise HTTPException(status_code=404, detail="Task not found") if task.status == "FAILURE": result = str(task.result) logger.error(f"Task {task_id} failed: {task.traceback}")