🚀 feat(endpoints.py): add authentication to predict endpoint using HTTPBearer

🐛 fix(endpoints.py): change predict endpoint to use Flow object instead of flow_id
🐛 fix(endpoints.py): add support for processing tweaks in predict endpoint
The predict endpoint now requires authentication using HTTPBearer. The flow_id is now extracted from the bearer token instead of being passed as a parameter. This improves security as the flow_id is not exposed in the URL. The predict endpoint now uses the Flow object instead of the flow_id to retrieve the graph data. This improves code readability and reduces the number of database queries. The predict endpoint now supports processing tweaks, which allows for more flexibility in the processing of messages.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-13 12:48:52 -03:00
commit 7cc14e83b8

View file

@ -1,9 +1,10 @@
from langflow.database.models.flow import Flow
from langflow.processing.process import process_graph_cached
from langflow.processing.process import process_graph_cached, process_tweaks
from langflow.utils.logger import logger
from importlib.metadata import version
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import HTTPBearer
from langflow.api.v1.schemas import (
PredictRequest,
@ -17,23 +18,40 @@ from sqlmodel import Session
# build router
router = APIRouter(tags=["Base"])
security = HTTPBearer()
def get_flow_from_token(
bearer: HTTPBearer = Depends(security), session: Session = Depends(get_session)
) -> str:
# Extract the token, which is the flow_id in this case
flow_id = bearer.credentials
# Check if the flow_id exists in the database
flow = session.get(Flow, flow_id)
if flow is None:
raise HTTPException(status_code=401, detail="Invalid token")
return flow
@router.get("/all")
def get_all():
return build_langchain_types_dict()
@router.post("/predict/{flow_id}", status_code=200, response_model=PredictResponse)
@router.post("/predict", response_model=PredictResponse)
async def get_load(
predict_request: PredictRequest,
flow_id: str,
session: Session = Depends(get_session),
flow: Flow = Depends(get_flow_from_token),
):
"""
Endpoint to process a message using the flow passed in the bearer token.
"""
try:
flow_obj = session.get(Flow, flow_id)
if flow_obj is None:
raise ValueError(f"Flow {flow_id} not found")
graph_data = flow_obj.data
graph_data = flow.data
if predict_request.tweaks:
graph_data = process_tweaks(graph_data, predict_request.tweaks)
response = process_graph_cached(graph_data, predict_request.message)
return PredictResponse(
result=response.get("result", ""),