The function name was changed to predict_flow to better reflect the functionality of the endpoint.
69 lines
2 KiB
Python
69 lines
2 KiB
Python
from langflow.database.models.flow import Flow
|
|
from langflow.processing.process import process_graph_cached, process_tweaks
|
|
from langflow.utils.logger import logger
|
|
from importlib.metadata import version
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from fastapi.security import HTTPBearer
|
|
|
|
from langflow.api.v1.schemas import (
|
|
PredictRequest,
|
|
PredictResponse,
|
|
)
|
|
|
|
from langflow.interface.types import build_langchain_types_dict
|
|
from langflow.database.base import get_session
|
|
from sqlmodel import Session
|
|
|
|
# build router
|
|
router = APIRouter(tags=["Base"])
|
|
|
|
security = HTTPBearer()
|
|
|
|
|
|
def get_flow_from_token(
|
|
bearer: HTTPBearer = Depends(security), session: Session = Depends(get_session)
|
|
) -> str:
|
|
# Extract the token, which is the flow_id in this case
|
|
flow_id = bearer.credentials
|
|
# Check if the flow_id exists in the database
|
|
flow = session.get(Flow, flow_id)
|
|
if flow is None:
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
return flow
|
|
|
|
|
|
@router.get("/all")
|
|
def get_all():
|
|
return build_langchain_types_dict()
|
|
|
|
|
|
@router.post("/predict", response_model=PredictResponse)
|
|
async def predict_flow(
|
|
predict_request: PredictRequest,
|
|
flow: Flow = Depends(get_flow_from_token),
|
|
):
|
|
"""
|
|
Endpoint to process a message using the flow passed in the bearer token.
|
|
"""
|
|
|
|
try:
|
|
graph_data = flow.data
|
|
if predict_request.tweaks:
|
|
graph_data = process_tweaks(graph_data, predict_request.tweaks)
|
|
|
|
response = process_graph_cached(graph_data, predict_request.message)
|
|
return PredictResponse(
|
|
result=response.get("result", ""),
|
|
intermediate_steps=response.get("thought", ""),
|
|
)
|
|
except Exception as e:
|
|
# Log stack trace
|
|
logger.exception(e)
|
|
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
|
|
|
|
# get endpoint to return version of langflow
|
|
@router.get("/version")
|
|
def get_version():
|
|
return {"version": version("langflow")}
|