diff --git a/src/backend/langflow/api/endpoints.py b/src/backend/langflow/api/endpoints.py index b8290e691..d4e592901 100644 --- a/src/backend/langflow/api/endpoints.py +++ b/src/backend/langflow/api/endpoints.py @@ -5,6 +5,12 @@ from fastapi import APIRouter, HTTPException from langflow.interface.run import process_graph_cached from langflow.interface.types import build_langchain_types_dict +from langflow.api.schemas import ( + ExportedFlow, + GraphData, + PredictRequest, + PredictResponse, +) # build router router = APIRouter() @@ -16,10 +22,14 @@ def get_all(): return build_langchain_types_dict() -@router.post("/predict") -def get_load(data: Dict[str, Any]): +@router.post("/predict", response_model=PredictResponse) +async def get_load(predict_request: PredictRequest): try: - return process_graph_cached(data) + exported_flow: ExportedFlow = predict_request.exported_flow + graph_data: GraphData = exported_flow.data + data = graph_data.dict() + response = process_graph_cached(data, predict_request.message) + return PredictResponse(result=response.get("result", "")) except Exception as e: # Log stack trace logger.exception(e) diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/schemas.py index dd157d85f..2d14bad50 100644 --- a/src/backend/langflow/api/schemas.py +++ b/src/backend/langflow/api/schemas.py @@ -1,8 +1,37 @@ -from typing import Any, Union +from typing import Any, Union, Dict, List from pydantic import BaseModel, validator +class GraphData(BaseModel): + """Data inside the exported flow.""" + + nodes: List[Dict[str, Any]] + edges: List[Dict[str, Any]] + + +class ExportedFlow(BaseModel): + """Exported flow from LangFlow.""" + + description: str + name: str + id: str + data: GraphData + + +class PredictRequest(BaseModel): + """Predict request schema.""" + + message: str + exported_flow: ExportedFlow + + +class PredictResponse(BaseModel): + """Predict response schema.""" + + result: str + + class ChatMessage(BaseModel): """Chat message schema.""" diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 68639785c..4f2143e2e 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -100,13 +100,12 @@ def process_graph(data_graph: Dict[str, Any]): return {"result": str(result), "thought": thought.strip()} -def process_graph_cached(data_graph: Dict[str, Any]): +def process_graph_cached(data_graph: Dict[str, Any], message: str): """ Process graph by extracting input variables and replacing ZeroShotPrompt with PromptTemplate,then run the graph and return the result and thought. """ # Load langchain object - message = data_graph.pop("message", "") is_first_message = len(data_graph.get("chatHistory", [])) == 0 langchain_object = load_or_build_langchain_object(data_graph, is_first_message) logger.debug("Loaded langchain object") @@ -119,7 +118,7 @@ def process_graph_cached(data_graph: Dict[str, Any]): # Generate result and thought logger.debug("Generating result and thought") - result, thought = get_result_and_steps(langchain_object, message) + result, thought = get_result_and_thought(langchain_object, message) logger.debug("Generated result and thought") return {"result": str(result), "thought": thought.strip()} @@ -241,7 +240,7 @@ def get_result_and_steps(langchain_object, message: str): return result, thought -def async_get_result_and_steps(langchain_object, message: str): +def get_result_and_thought(langchain_object, message: str): """Get result and thought from extracted json""" try: if hasattr(langchain_object, "verbose"): @@ -296,34 +295,6 @@ def async_get_result_and_steps(langchain_object, message: str): return result, thought -def get_result_and_thought(extracted_json: Dict[str, Any], message: str): - """Get result and thought from extracted json""" - try: - langchain_object = loading.load_langchain_type_from_config( - config=extracted_json - ) - with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): - output = langchain_object(message) - intermediate_steps = ( - output.get("intermediate_steps", []) if isinstance(output, dict) else [] - ) - result = ( - output.get(langchain_object.output_keys[0]) - if isinstance(output, dict) - else output - ) - - if intermediate_steps: - thought = format_intermediate_steps(intermediate_steps) - else: - thought = output_buffer.getvalue() - - except Exception as e: - result = f"Error: {str(e)}" - thought = "" - return result, thought - - def format_intermediate_steps(intermediate_steps): formatted_chain = "> Entering new AgentExecutor chain...\n" for step in intermediate_steps: