From dd344ce6c688526a3e87bef68b3c1c6af355e748 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Boschi?= Date: Thu, 23 May 2024 12:34:34 +0200 Subject: [PATCH] backend: validate flow_id on api (#1927) --- src/backend/base/langflow/api/v1/chat.py | 13 +++++++------ src/backend/base/langflow/api/v1/endpoints.py | 10 +++++++--- src/backend/base/langflow/api/v1/files.py | 19 +++++++++++++------ 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index 40290afa7..6a38f4981 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -53,7 +53,7 @@ async def try_running_celery_task(vertex, user_id): @router.post("/build/{flow_id}/vertices", response_model=VerticesOrderResponse) async def retrieve_vertices_order( - flow_id: str, + flow_id: uuid.UUID, data: Optional[Annotated[Optional[FlowDataRequest], Body(embed=True)]] = None, stop_component_id: Optional[str] = None, start_component_id: Optional[str] = None, @@ -78,6 +78,7 @@ async def retrieve_vertices_order( HTTPException: If there is an error checking the build status. """ try: + flow_id = str(flow_id) # First, we need to check if the flow_id is in the cache if not data: graph = await build_and_cache_graph_from_db(flow_id=flow_id, session=session, chat_service=chat_service) @@ -119,7 +120,7 @@ async def retrieve_vertices_order( @router.post("/build/{flow_id}/vertices/{vertex_id}") async def build_vertex( - flow_id: str, + flow_id: uuid.UUID, vertex_id: str, background_tasks: BackgroundTasks, inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None, @@ -143,8 +144,8 @@ async def build_vertex( HTTPException: If there is an error building the vertex. """ + flow_id = str(flow_id) - start_time = time.perf_counter() next_runnable_vertices = [] top_level_vertices = [] try: @@ -158,8 +159,7 @@ async def build_vertex( ) else: graph = cache.get("result") - result_data_response = ResultDataResponse(results={}) - duration = "" + ResultDataResponse(results={}) vertex = graph.get_vertex(vertex_id) try: lock = chat_service._cache_locks[flow_id] @@ -240,7 +240,7 @@ async def build_vertex( @router.get("/build/{flow_id}/{vertex_id}/stream", response_class=StreamingResponse) async def build_vertex_stream( - flow_id: str, + flow_id: uuid.UUID, vertex_id: str, session_id: Optional[str] = None, chat_service: "ChatService" = Depends(get_chat_service), @@ -272,6 +272,7 @@ async def build_vertex_stream( HTTPException: If an error occurs while building the vertex. """ try: + flow_id = str(flow_id) async def stream_vertex(): try: diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index 35ef9d617..b84fff034 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -1,5 +1,6 @@ from http import HTTPStatus from typing import Annotated, List, Optional, Union +from uuid import UUID import sqlalchemy as sa from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status @@ -54,7 +55,7 @@ def get_all( @router.post("/run/{flow_id}", response_model=RunResponse, response_model_exclude_none=True) async def simplified_run_flow( db: Annotated[Session, Depends(get_session)], - flow_id: str, + flow_id: UUID, input_request: SimplifiedAPIRequest = SimplifiedAPIRequest(), stream: bool = False, api_key_user: User = Depends(api_key_security), @@ -111,6 +112,7 @@ async def simplified_run_flow( session_id = input_request.session_id try: + flow_id = str(flow_id) task_result: List[RunOutputs] = [] artifacts = {} if input_request.session_id: @@ -187,7 +189,7 @@ async def simplified_run_flow( @router.post("/run/advanced/{flow_id}", response_model=RunResponse, response_model_exclude_none=True) async def experimental_run_flow( session: Annotated[Session, Depends(get_session)], - flow_id: str, + flow_id: UUID, inputs: Optional[List[InputValueRequest]] = [InputValueRequest(components=[], input_value="")], outputs: Optional[List[str]] = [], tweaks: Annotated[Optional[Tweaks], Body(embed=True)] = None, # noqa: F821 @@ -235,6 +237,7 @@ async def experimental_run_flow( This endpoint facilitates complex flow executions with customized inputs, outputs, and configurations, catering to diverse application requirements. """ try: + flow_id = str(flow_id) if outputs is None: outputs = [] @@ -357,9 +360,10 @@ async def get_task_status(task_id: str): ) async def create_upload_file( file: UploadFile, - flow_id: str, + flow_id: UUID, ): try: + flow_id = str(flow_id) file_path = save_uploaded_file(file, folder_name=flow_id) return UploadFileResponse( diff --git a/src/backend/base/langflow/api/v1/files.py b/src/backend/base/langflow/api/v1/files.py index 435aea826..762d39da9 100644 --- a/src/backend/base/langflow/api/v1/files.py +++ b/src/backend/base/langflow/api/v1/files.py @@ -1,6 +1,7 @@ import hashlib from http import HTTPStatus from io import BytesIO +from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, UploadFile from fastapi.responses import StreamingResponse @@ -20,10 +21,11 @@ router = APIRouter(tags=["Files"], prefix="/files") # then finds it in the database and returns it while # using the current user as the owner def get_flow_id( - flow_id: str, + flow_id: UUID, current_user=Depends(get_current_active_user), session=Depends(get_session), ): + flow_id = str(flow_id) # AttributeError: 'SelectOfScalar' object has no attribute 'first' flow = session.get(Flow, flow_id) if not flow: @@ -36,10 +38,11 @@ def get_flow_id( @router.post("/upload/{flow_id}", status_code=HTTPStatus.CREATED) async def upload_file( file: UploadFile, - flow_id: str = Depends(get_flow_id), + flow_id: UUID = Depends(get_flow_id), storage_service: StorageService = Depends(get_storage_service), ): try: + flow_id = str(flow_id) file_content = await file.read() file_name = file.filename or hashlib.sha256(file_content).hexdigest() folder = flow_id @@ -50,8 +53,9 @@ async def upload_file( @router.get("/download/{flow_id}/{file_name}") -async def download_file(file_name: str, flow_id: str, storage_service: StorageService = Depends(get_storage_service)): +async def download_file(file_name: str, flow_id: UUID, storage_service: StorageService = Depends(get_storage_service)): try: + flow_id = str(flow_id) extension = file_name.split(".")[-1] if not extension: @@ -74,9 +78,10 @@ async def download_file(file_name: str, flow_id: str, storage_service: StorageSe @router.get("/images/{flow_id}/{file_name}") -async def download_image(file_name: str, flow_id: str, storage_service: StorageService = Depends(get_storage_service)): +async def download_image(file_name: str, flow_id: UUID, storage_service: StorageService = Depends(get_storage_service)): try: extension = file_name.split(".")[-1] + flow_id = str(flow_id) if not extension: raise HTTPException(status_code=500, detail=f"Extension not found for file {file_name}") @@ -96,9 +101,10 @@ async def download_image(file_name: str, flow_id: str, storage_service: StorageS @router.get("/list/{flow_id}") async def list_files( - flow_id: str = Depends(get_flow_id), storage_service: StorageService = Depends(get_storage_service) + flow_id: UUID = Depends(get_flow_id), storage_service: StorageService = Depends(get_storage_service) ): try: + flow_id = str(flow_id) files = await storage_service.list_files(flow_id=flow_id) return {"files": files} except Exception as e: @@ -107,9 +113,10 @@ async def list_files( @router.delete("/delete/{flow_id}/{file_name}") async def delete_file( - file_name: str, flow_id: str = Depends(get_flow_id), storage_service: StorageService = Depends(get_storage_service) + file_name: str, flow_id: UUID = Depends(get_flow_id), storage_service: StorageService = Depends(get_storage_service) ): try: + flow_id = str(flow_id) await storage_service.delete_file(flow_id=flow_id, file_name=file_name) return {"message": f"File {file_name} deleted successfully"} except Exception as e: