fix: add user dependency to webhook endpoint for flow execution (#3685)

* Add helper function to retrieve user by flow ID or endpoint name in user.py

- Introduced `get_user_by_flow_id_or_endpoint_name` function to fetch user details based on flow ID or endpoint name.
- Added error handling for cases where the flow or user is not found.
- Utilized `get_db_service` for database session management.

* Add user dependency to webhook endpoint for flow execution

- Import `get_user_by_flow_id_or_endpoint_name` helper function.
- Add `user` parameter to `webhook_run_flow` endpoint.
- Pass `user` to `simple_run_flow_task` for API key association.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-09-04 20:00:47 -03:00 committed by GitHub
commit 46a66a57c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 33 additions and 0 deletions

View file

@ -28,6 +28,7 @@ from langflow.exceptions.api import APIException, InvalidChatInputException
from langflow.graph.graph.base import Graph
from langflow.graph.schema import RunOutputs
from langflow.helpers.flow import get_flow_by_id_or_endpoint_name
from langflow.helpers.user import get_user_by_flow_id_or_endpoint_name
from langflow.interface.initialize.loading import update_params_with_load_from_db_fields
from langflow.processing.process import process_tweaks, run_graph_internal
from langflow.schema.graph import Tweaks
@ -286,6 +287,7 @@ async def simplified_run_flow(
@router.post("/webhook/{flow_id_or_name}", response_model=dict, status_code=HTTPStatus.ACCEPTED)
async def webhook_run_flow(
flow: Annotated[Flow, Depends(get_flow_by_id_or_endpoint_name)],
user: Annotated[User, Depends(get_user_by_flow_id_or_endpoint_name)],
request: Request,
background_tasks: BackgroundTasks,
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
@ -329,11 +331,13 @@ async def webhook_run_flow(
tweaks=tweaks,
session_id=None,
)
logger.debug("Starting background task")
background_tasks.add_task( # type: ignore
simple_run_flow_task,
flow=flow,
input_request=input_request,
api_key_user=user,
)
background_tasks.add_task(
telemetry_service.log_package_run,

View file

@ -0,0 +1,29 @@
from uuid import UUID
from fastapi import HTTPException
from sqlmodel import select
from langflow.services.database.models.flow.model import Flow
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import get_db_service
def get_user_by_flow_id_or_endpoint_name(flow_id_or_name: str) -> UserRead | None:
user_read = None
with get_db_service().with_session() as session:
try:
flow_id = UUID(flow_id_or_name)
flow = session.get(Flow, flow_id)
except ValueError:
stmt = select(Flow).where(Flow.endpoint_name == flow_id_or_name)
flow = session.exec(stmt).first()
if flow is None:
raise HTTPException(status_code=404, detail=f"Flow identifier {flow_id_or_name} not found")
user = session.get(User, flow.user_id)
if user is None:
raise HTTPException(status_code=404, detail=f"User for flow {flow_id_or_name} not found")
user_read = UserRead.model_validate(user, from_attributes=True)
return user_read