diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index e38906fbf..35a14b822 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -1,9 +1,10 @@ from langflow.database.models.flow import Flow -from langflow.processing.process import process_graph_cached +from langflow.processing.process import process_graph_cached, process_tweaks from langflow.utils.logger import logger from importlib.metadata import version from fastapi import APIRouter, Depends, HTTPException +from fastapi.security import HTTPBearer from langflow.api.v1.schemas import ( PredictRequest, @@ -17,23 +18,40 @@ from sqlmodel import Session # build router router = APIRouter(tags=["Base"]) +security = HTTPBearer() + + +def get_flow_from_token( + bearer: HTTPBearer = Depends(security), session: Session = Depends(get_session) +) -> str: + # Extract the token, which is the flow_id in this case + flow_id = bearer.credentials + # Check if the flow_id exists in the database + flow = session.get(Flow, flow_id) + if flow is None: + raise HTTPException(status_code=401, detail="Invalid token") + return flow + @router.get("/all") def get_all(): return build_langchain_types_dict() -@router.post("/predict/{flow_id}", status_code=200, response_model=PredictResponse) -async def get_load( +@router.post("/predict", response_model=PredictResponse) +async def predict_flow( predict_request: PredictRequest, - flow_id: str, - session: Session = Depends(get_session), + flow: Flow = Depends(get_flow_from_token), ): + """ + Endpoint to process a message using the flow passed in the bearer token. + """ + try: - flow_obj = session.get(Flow, flow_id) - if flow_obj is None: - raise ValueError(f"Flow {flow_id} not found") - graph_data = flow_obj.data + graph_data = flow.data + if predict_request.tweaks: + graph_data = process_tweaks(graph_data, predict_request.tweaks) + response = process_graph_cached(graph_data, predict_request.message) return PredictResponse( result=response.get("result", ""), diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index d573a2ae2..aae4a1df3 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from langflow.database.models.flow import FlowCreate, FlowRead -from pydantic import BaseModel, validator +from pydantic import BaseModel, Field, validator class GraphData(BaseModel): @@ -23,6 +23,23 @@ class PredictRequest(BaseModel): """Predict request schema.""" message: str + tweaks: Optional[Dict[str, Dict[str, str]]] = Field(default_factory=dict) + + class Config: + schema_extra = { + "example": { + "message": "Hello, how are you?", + "tweaks": { + "dndnode_986363f0-4677-4035-9f38-74b94af5dd78": { + "name": "A tool name", + "description": "A tool description", + }, + "dndnode_986363f0-4677-4035-9f38-74b94af57378": { + "template": "A {template}", + }, + }, + } + } class PredictResponse(BaseModel): diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 3b8852e00..ae0bf28b1 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -170,3 +170,25 @@ def load_flow_from_json(path: str, build=True): fix_memory_inputs(langchain_object) return langchain_object return graph + + +def process_tweaks(graph_data: dict, tweaks: dict): + """This function is used to tweak the graph data using the node id and the tweaks dict""" + # the tweaks dict is a dict of dicts + # the key is the node id and the value is a dict of the tweaks + # the dict of tweaks contains the name of a certain parameter and the value to be tweaked + + # We need to process the graph data to add the tweaks + nodes = graph_data["data"]["nodes"] + for node in nodes: + node_id = node["id"] + if node_id in tweaks: + node_tweaks = tweaks[node_id] + template_data = node["data"]["node"]["template"] + for tweak_name, tweake_value in node_tweaks.items(): + if tweak_name in template_data: + template_data[tweak_name]["value"] = tweake_value + print( + f"Something changed in node {node_id} with tweak {tweak_name} and value {tweake_value}" + ) + return graph_data diff --git a/src/frontend/src/constants.tsx b/src/frontend/src/constants.tsx index 18b34618f..da000a862 100644 --- a/src/frontend/src/constants.tsx +++ b/src/frontend/src/constants.tsx @@ -52,15 +52,28 @@ export const TEXT_DIALOG_SUBTITLE = "Edit you text."; export const getPythonApiCode = (flowId: string): string => { return `import requests - FLOW_ID = "${flowId}" - API_URL = f"${window.location.protocol}//${window.location.host}/predict/{FLOW_ID}" +FLOW_ID = "${flowId}" +API_URL = f"${window.location.protocol}//${window.location.host}/predict" - def predict(message): - payload = {'message': message} - response = requests.post(API_URL, json=payload) - return response.json() +def run_flow(message, tweaks=None): - print(predict("Your message"))`; + if tweaks: + payload = {'message': message, 'tweaks': tweaks} + else: + payload = {'message': message} + + headers = {'Authorization': + f'Bearer {FLOW_ID}', + 'Content-Type': 'application/json' + } + + response = requests.post(API_URL, json=payload) + return response.json() + +# Setup any tweaks you want to apply to the flow +tweaks = {} # {"nodeId": {"key": "value"}, "nodeId2": {"key": "value"}} + +print(run_flow("Your message", tweaks=tweaks))`; }; /** @@ -71,8 +84,9 @@ export const getPythonApiCode = (flowId: string): string => { export const getCurlCode = (flowId: string): string => { return `curl -X POST \\ -H "Content-Type: application/json" \\ + -H "Authorization: Bearer ${flowId}" \\ -d '{"message": "Your message"}' \\ - ${window.location.protocol}//${window.location.host}/predict/${flowId}`; + ${window.location.protocol}//${window.location.host}/predict`; }; /**