🐛 fix(endpoints.py): fix predict endpoint to load flow from database instead of request

 feat(endpoints.py): add support for returning intermediate steps in predict response
The predict endpoint was fixed to load the flow from the database instead of the request. This ensures that the correct flow is used for prediction. Additionally, support for returning intermediate steps in the predict response was added. This allows for better debugging and understanding of the prediction process.
This commit is contained in:
Gabriel Almeida 2023-05-31 11:13:22 -03:00
commit ae0cc86a76
2 changed files with 15 additions and 9 deletions

View file

@ -1,10 +1,11 @@
import json
from langflow.database.models.flow import Flow
from langflow.utils.logger import logger
from importlib.metadata import version
from fastapi import APIRouter, File, HTTPException, UploadFile
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from langflow.api.schemas import (
ExportedFlow,
GraphData,
PredictRequest,
PredictResponse,
@ -12,7 +13,9 @@ from langflow.api.schemas import (
from langflow.interface.run import process_graph_cached
from langflow.interface.types import build_langchain_types_dict
from langflow.cache import cache_manager
from langflow.database.base import get_session
from sqlmodel import Session
from sqlmodel import select
# build router
router = APIRouter(tags=["Base"])
@ -22,14 +25,17 @@ def get_all():
return build_langchain_types_dict()
@router.post("/predict", response_model=PredictResponse)
async def get_load(predict_request: PredictRequest):
@router.post("/predict/{flow_id}", status_code=200, response_model=PredictResponse)
async def get_load(predict_request: PredictRequest, flow_id: str, session: Session= Depends(get_session)):
try:
exported_flow: ExportedFlow = predict_request.exported_flow
graph_data: GraphData = exported_flow.data
flow_obj = session.get(Flow, flow_id)
if flow_obj is None:
raise ValueError(f"Flow {flow_id} not found")
graph_data: GraphData = json.loads(flow_obj.flow)
data = graph_data.dict()
response = process_graph_cached(data, predict_request.message)
return PredictResponse(result=response.get("result", ""))
return PredictResponse(result=response.get("result", ""), intermediate_steps=response.get("thought", ""))
except Exception as e:
# Log stack trace
logger.exception(e)

View file

@ -24,13 +24,13 @@ class PredictRequest(BaseModel):
"""Predict request schema."""
message: str
exported_flow: ExportedFlow
class PredictResponse(BaseModel):
"""Predict response schema."""
result: str
intermediate_steps: str = ""
class ChatMessage(BaseModel):