🐛 fix(endpoints.py): fix predict endpoint to load flow from database instead of request
✨ feat(endpoints.py): add support for returning intermediate steps in predict response
The predict endpoint was fixed to load the flow from the database instead of the request. This ensures that the correct flow is used for prediction. Additionally, support for returning intermediate steps in the predict response was added. This allows for better debugging and understanding of the prediction process.
This commit is contained in:
parent
bcd77a641d
commit
ae0cc86a76
2 changed files with 15 additions and 9 deletions
|
|
@ -1,10 +1,11 @@
|
|||
import json
|
||||
from langflow.database.models.flow import Flow
|
||||
from langflow.utils.logger import logger
|
||||
from importlib.metadata import version
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
|
||||
from langflow.api.schemas import (
|
||||
ExportedFlow,
|
||||
GraphData,
|
||||
PredictRequest,
|
||||
PredictResponse,
|
||||
|
|
@ -12,7 +13,9 @@ from langflow.api.schemas import (
|
|||
from langflow.interface.run import process_graph_cached
|
||||
from langflow.interface.types import build_langchain_types_dict
|
||||
from langflow.cache import cache_manager
|
||||
|
||||
from langflow.database.base import get_session
|
||||
from sqlmodel import Session
|
||||
from sqlmodel import select
|
||||
# build router
|
||||
router = APIRouter(tags=["Base"])
|
||||
|
||||
|
|
@ -22,14 +25,17 @@ def get_all():
|
|||
return build_langchain_types_dict()
|
||||
|
||||
|
||||
@router.post("/predict", response_model=PredictResponse)
|
||||
async def get_load(predict_request: PredictRequest):
|
||||
@router.post("/predict/{flow_id}", status_code=200, response_model=PredictResponse)
|
||||
async def get_load(predict_request: PredictRequest, flow_id: str, session: Session= Depends(get_session)):
|
||||
|
||||
try:
|
||||
exported_flow: ExportedFlow = predict_request.exported_flow
|
||||
graph_data: GraphData = exported_flow.data
|
||||
flow_obj = session.get(Flow, flow_id)
|
||||
if flow_obj is None:
|
||||
raise ValueError(f"Flow {flow_id} not found")
|
||||
graph_data: GraphData = json.loads(flow_obj.flow)
|
||||
data = graph_data.dict()
|
||||
response = process_graph_cached(data, predict_request.message)
|
||||
return PredictResponse(result=response.get("result", ""))
|
||||
return PredictResponse(result=response.get("result", ""), intermediate_steps=response.get("thought", ""))
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
|
|
|
|||
|
|
@ -24,13 +24,13 @@ class PredictRequest(BaseModel):
|
|||
"""Predict request schema."""
|
||||
|
||||
message: str
|
||||
exported_flow: ExportedFlow
|
||||
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
"""Predict response schema."""
|
||||
|
||||
result: str
|
||||
intermediate_steps: str = ""
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue