From 7cc14e83b83972411c0fcd9bb8a432f9edb988f8 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 13 Jun 2023 12:48:52 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80=20feat(endpoints.py):=20add=20auth?= =?UTF-8?q?entication=20to=20predict=20endpoint=20using=20HTTPBearer=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(endpoints.py):=20change=20predict=20endpoint?= =?UTF-8?q?=20to=20use=20Flow=20object=20instead=20of=20flow=5Fid=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(endpoints.py):=20add=20support=20for=20proce?= =?UTF-8?q?ssing=20tweaks=20in=20predict=20endpoint=20The=20predict=20endp?= =?UTF-8?q?oint=20now=20requires=20authentication=20using=20HTTPBearer.=20?= =?UTF-8?q?The=20flow=5Fid=20is=20now=20extracted=20from=20the=20bearer=20?= =?UTF-8?q?token=20instead=20of=20being=20passed=20as=20a=20parameter.=20T?= =?UTF-8?q?his=20improves=20security=20as=20the=20flow=5Fid=20is=20not=20e?= =?UTF-8?q?xposed=20in=20the=20URL.=20The=20predict=20endpoint=20now=20use?= =?UTF-8?q?s=20the=20Flow=20object=20instead=20of=20the=20flow=5Fid=20to?= =?UTF-8?q?=20retrieve=20the=20graph=20data.=20This=20improves=20code=20re?= =?UTF-8?q?adability=20and=20reduces=20the=20number=20of=20database=20quer?= =?UTF-8?q?ies.=20The=20predict=20endpoint=20now=20supports=20processing?= =?UTF-8?q?=20tweaks,=20which=20allows=20for=20more=20flexibility=20in=20t?= =?UTF-8?q?he=20processing=20of=20messages.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/v1/endpoints.py | 34 ++++++++++++++++++------ 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index e38906fbf..e99fb723f 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -1,9 +1,10 @@ from langflow.database.models.flow import Flow -from langflow.processing.process import process_graph_cached +from langflow.processing.process import process_graph_cached, process_tweaks from langflow.utils.logger import logger from importlib.metadata import version from fastapi import APIRouter, Depends, HTTPException +from fastapi.security import HTTPBearer from langflow.api.v1.schemas import ( PredictRequest, @@ -17,23 +18,40 @@ from sqlmodel import Session # build router router = APIRouter(tags=["Base"]) +security = HTTPBearer() + + +def get_flow_from_token( + bearer: HTTPBearer = Depends(security), session: Session = Depends(get_session) +) -> str: + # Extract the token, which is the flow_id in this case + flow_id = bearer.credentials + # Check if the flow_id exists in the database + flow = session.get(Flow, flow_id) + if flow is None: + raise HTTPException(status_code=401, detail="Invalid token") + return flow + @router.get("/all") def get_all(): return build_langchain_types_dict() -@router.post("/predict/{flow_id}", status_code=200, response_model=PredictResponse) +@router.post("/predict", response_model=PredictResponse) async def get_load( predict_request: PredictRequest, - flow_id: str, - session: Session = Depends(get_session), + flow: Flow = Depends(get_flow_from_token), ): + """ + Endpoint to process a message using the flow passed in the bearer token. + """ + try: - flow_obj = session.get(Flow, flow_id) - if flow_obj is None: - raise ValueError(f"Flow {flow_id} not found") - graph_data = flow_obj.data + graph_data = flow.data + if predict_request.tweaks: + graph_data = process_tweaks(graph_data, predict_request.tweaks) + response = process_graph_cached(graph_data, predict_request.message) return PredictResponse( result=response.get("result", ""),