diff --git a/src/backend/langflow/api/endpoints.py b/src/backend/langflow/api/endpoints.py index dacdad64b..19d3d5f46 100644 --- a/src/backend/langflow/api/endpoints.py +++ b/src/backend/langflow/api/endpoints.py @@ -1,10 +1,11 @@ +import json +from langflow.database.models.flow import Flow from langflow.utils.logger import logger from importlib.metadata import version -from fastapi import APIRouter, File, HTTPException, UploadFile +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile from langflow.api.schemas import ( - ExportedFlow, GraphData, PredictRequest, PredictResponse, @@ -12,7 +13,9 @@ from langflow.api.schemas import ( from langflow.interface.run import process_graph_cached from langflow.interface.types import build_langchain_types_dict from langflow.cache import cache_manager - +from langflow.database.base import get_session +from sqlmodel import Session +from sqlmodel import select # build router router = APIRouter(tags=["Base"]) @@ -22,14 +25,17 @@ def get_all(): return build_langchain_types_dict() -@router.post("/predict", response_model=PredictResponse) -async def get_load(predict_request: PredictRequest): +@router.post("/predict/{flow_id}", status_code=200, response_model=PredictResponse) +async def get_load(predict_request: PredictRequest, flow_id: str, session: Session= Depends(get_session)): + try: - exported_flow: ExportedFlow = predict_request.exported_flow - graph_data: GraphData = exported_flow.data + flow_obj = session.get(Flow, flow_id) + if flow_obj is None: + raise ValueError(f"Flow {flow_id} not found") + graph_data: GraphData = json.loads(flow_obj.flow) data = graph_data.dict() response = process_graph_cached(data, predict_request.message) - return PredictResponse(result=response.get("result", "")) + return PredictResponse(result=response.get("result", ""), intermediate_steps=response.get("thought", "")) except Exception as e: # Log stack trace logger.exception(e) diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/schemas.py index a9cb4dcb6..7d5de2b63 100644 --- a/src/backend/langflow/api/schemas.py +++ b/src/backend/langflow/api/schemas.py @@ -24,13 +24,13 @@ class PredictRequest(BaseModel): """Predict request schema.""" message: str - exported_flow: ExportedFlow class PredictResponse(BaseModel): """Predict response schema.""" result: str + intermediate_steps: str = "" class ChatMessage(BaseModel):