From ae0cc86a76684721bce6058406e93804db0758ea Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Wed, 31 May 2023 11:13:22 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(endpoints.py):=20fix=20predi?= =?UTF-8?q?ct=20endpoint=20to=20load=20flow=20from=20database=20instead=20?= =?UTF-8?q?of=20request=20=E2=9C=A8=20feat(endpoints.py):=20add=20support?= =?UTF-8?q?=20for=20returning=20intermediate=20steps=20in=20predict=20resp?= =?UTF-8?q?onse=20The=20predict=20endpoint=20was=20fixed=20to=20load=20the?= =?UTF-8?q?=20flow=20from=20the=20database=20instead=20of=20the=20request.?= =?UTF-8?q?=20This=20ensures=20that=20the=20correct=20flow=20is=20used=20f?= =?UTF-8?q?or=20prediction.=20Additionally,=20support=20for=20returning=20?= =?UTF-8?q?intermediate=20steps=20in=20the=20predict=20response=20was=20ad?= =?UTF-8?q?ded.=20This=20allows=20for=20better=20debugging=20and=20underst?= =?UTF-8?q?anding=20of=20the=20prediction=20process.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/endpoints.py | 22 ++++++++++++++-------- src/backend/langflow/api/schemas.py | 2 +- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/backend/langflow/api/endpoints.py b/src/backend/langflow/api/endpoints.py index dacdad64b..19d3d5f46 100644 --- a/src/backend/langflow/api/endpoints.py +++ b/src/backend/langflow/api/endpoints.py @@ -1,10 +1,11 @@ +import json +from langflow.database.models.flow import Flow from langflow.utils.logger import logger from importlib.metadata import version -from fastapi import APIRouter, File, HTTPException, UploadFile +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile from langflow.api.schemas import ( - ExportedFlow, GraphData, PredictRequest, PredictResponse, @@ -12,7 +13,9 @@ from langflow.api.schemas import ( from langflow.interface.run import process_graph_cached from langflow.interface.types import build_langchain_types_dict from langflow.cache import cache_manager - +from langflow.database.base import get_session +from sqlmodel import Session +from sqlmodel import select # build router router = APIRouter(tags=["Base"]) @@ -22,14 +25,17 @@ def get_all(): return build_langchain_types_dict() -@router.post("/predict", response_model=PredictResponse) -async def get_load(predict_request: PredictRequest): +@router.post("/predict/{flow_id}", status_code=200, response_model=PredictResponse) +async def get_load(predict_request: PredictRequest, flow_id: str, session: Session= Depends(get_session)): + try: - exported_flow: ExportedFlow = predict_request.exported_flow - graph_data: GraphData = exported_flow.data + flow_obj = session.get(Flow, flow_id) + if flow_obj is None: + raise ValueError(f"Flow {flow_id} not found") + graph_data: GraphData = json.loads(flow_obj.flow) data = graph_data.dict() response = process_graph_cached(data, predict_request.message) - return PredictResponse(result=response.get("result", "")) + return PredictResponse(result=response.get("result", ""), intermediate_steps=response.get("thought", "")) except Exception as e: # Log stack trace logger.exception(e) diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/schemas.py index a9cb4dcb6..7d5de2b63 100644 --- a/src/backend/langflow/api/schemas.py +++ b/src/backend/langflow/api/schemas.py @@ -24,13 +24,13 @@ class PredictRequest(BaseModel): """Predict request schema.""" message: str - exported_flow: ExportedFlow class PredictResponse(BaseModel): """Predict response schema.""" result: str + intermediate_steps: str = "" class ChatMessage(BaseModel):