diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index e38906fbf..e99fb723f 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -1,9 +1,10 @@ from langflow.database.models.flow import Flow -from langflow.processing.process import process_graph_cached +from langflow.processing.process import process_graph_cached, process_tweaks from langflow.utils.logger import logger from importlib.metadata import version from fastapi import APIRouter, Depends, HTTPException +from fastapi.security import HTTPBearer from langflow.api.v1.schemas import ( PredictRequest, @@ -17,23 +18,40 @@ from sqlmodel import Session # build router router = APIRouter(tags=["Base"]) +security = HTTPBearer() + + +def get_flow_from_token( + bearer: HTTPBearer = Depends(security), session: Session = Depends(get_session) +) -> str: + # Extract the token, which is the flow_id in this case + flow_id = bearer.credentials + # Check if the flow_id exists in the database + flow = session.get(Flow, flow_id) + if flow is None: + raise HTTPException(status_code=401, detail="Invalid token") + return flow + @router.get("/all") def get_all(): return build_langchain_types_dict() -@router.post("/predict/{flow_id}", status_code=200, response_model=PredictResponse) +@router.post("/predict", response_model=PredictResponse) async def get_load( predict_request: PredictRequest, - flow_id: str, - session: Session = Depends(get_session), + flow: Flow = Depends(get_flow_from_token), ): + """ + Endpoint to process a message using the flow passed in the bearer token. + """ + try: - flow_obj = session.get(Flow, flow_id) - if flow_obj is None: - raise ValueError(f"Flow {flow_id} not found") - graph_data = flow_obj.data + graph_data = flow.data + if predict_request.tweaks: + graph_data = process_tweaks(graph_data, predict_request.tweaks) + response = process_graph_cached(graph_data, predict_request.message) return PredictResponse( result=response.get("result", ""),