From 46a66a57c133ac562acc321b07e732c2379de631 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 4 Sep 2024 20:00:47 -0300 Subject: [PATCH] fix: add user dependency to webhook endpoint for flow execution (#3685) * Add helper function to retrieve user by flow ID or endpoint name in user.py - Introduced `get_user_by_flow_id_or_endpoint_name` function to fetch user details based on flow ID or endpoint name. - Added error handling for cases where the flow or user is not found. - Utilized `get_db_service` for database session management. * Add user dependency to webhook endpoint for flow execution - Import `get_user_by_flow_id_or_endpoint_name` helper function. - Add `user` parameter to `webhook_run_flow` endpoint. - Pass `user` to `simple_run_flow_task` for API key association. --- src/backend/base/langflow/api/v1/endpoints.py | 4 +++ src/backend/base/langflow/helpers/user.py | 29 +++++++++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 src/backend/base/langflow/helpers/user.py diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index 0b2f4833b..78960f456 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -28,6 +28,7 @@ from langflow.exceptions.api import APIException, InvalidChatInputException from langflow.graph.graph.base import Graph from langflow.graph.schema import RunOutputs from langflow.helpers.flow import get_flow_by_id_or_endpoint_name +from langflow.helpers.user import get_user_by_flow_id_or_endpoint_name from langflow.interface.initialize.loading import update_params_with_load_from_db_fields from langflow.processing.process import process_tweaks, run_graph_internal from langflow.schema.graph import Tweaks @@ -286,6 +287,7 @@ async def simplified_run_flow( @router.post("/webhook/{flow_id_or_name}", response_model=dict, status_code=HTTPStatus.ACCEPTED) async def webhook_run_flow( flow: Annotated[Flow, Depends(get_flow_by_id_or_endpoint_name)], + user: Annotated[User, Depends(get_user_by_flow_id_or_endpoint_name)], request: Request, background_tasks: BackgroundTasks, telemetry_service: "TelemetryService" = Depends(get_telemetry_service), @@ -329,11 +331,13 @@ async def webhook_run_flow( tweaks=tweaks, session_id=None, ) + logger.debug("Starting background task") background_tasks.add_task( # type: ignore simple_run_flow_task, flow=flow, input_request=input_request, + api_key_user=user, ) background_tasks.add_task( telemetry_service.log_package_run, diff --git a/src/backend/base/langflow/helpers/user.py b/src/backend/base/langflow/helpers/user.py new file mode 100644 index 000000000..d2cd6a8d9 --- /dev/null +++ b/src/backend/base/langflow/helpers/user.py @@ -0,0 +1,29 @@ +from uuid import UUID + +from fastapi import HTTPException +from sqlmodel import select + +from langflow.services.database.models.flow.model import Flow +from langflow.services.database.models.user.model import User, UserRead +from langflow.services.deps import get_db_service + + +def get_user_by_flow_id_or_endpoint_name(flow_id_or_name: str) -> UserRead | None: + user_read = None + with get_db_service().with_session() as session: + try: + flow_id = UUID(flow_id_or_name) + flow = session.get(Flow, flow_id) + except ValueError: + stmt = select(Flow).where(Flow.endpoint_name == flow_id_or_name) + flow = session.exec(stmt).first() + + if flow is None: + raise HTTPException(status_code=404, detail=f"Flow identifier {flow_id_or_name} not found") + + user = session.get(User, flow.user_id) + if user is None: + raise HTTPException(status_code=404, detail=f"User for flow {flow_id_or_name} not found") + + user_read = UserRead.model_validate(user, from_attributes=True) + return user_read