backend: validate flow_id on api (#1927)

This commit is contained in:
Nicolò Boschi 2024-05-23 12:34:34 +02:00 committed by GitHub
commit dd344ce6c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 27 additions and 15 deletions

View file

@ -53,7 +53,7 @@ async def try_running_celery_task(vertex, user_id):
@router.post("/build/{flow_id}/vertices", response_model=VerticesOrderResponse)
async def retrieve_vertices_order(
flow_id: str,
flow_id: uuid.UUID,
data: Optional[Annotated[Optional[FlowDataRequest], Body(embed=True)]] = None,
stop_component_id: Optional[str] = None,
start_component_id: Optional[str] = None,
@ -78,6 +78,7 @@ async def retrieve_vertices_order(
HTTPException: If there is an error checking the build status.
"""
try:
flow_id = str(flow_id)
# First, we need to check if the flow_id is in the cache
if not data:
graph = await build_and_cache_graph_from_db(flow_id=flow_id, session=session, chat_service=chat_service)
@ -119,7 +120,7 @@ async def retrieve_vertices_order(
@router.post("/build/{flow_id}/vertices/{vertex_id}")
async def build_vertex(
flow_id: str,
flow_id: uuid.UUID,
vertex_id: str,
background_tasks: BackgroundTasks,
inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None,
@ -143,8 +144,8 @@ async def build_vertex(
HTTPException: If there is an error building the vertex.
"""
flow_id = str(flow_id)
start_time = time.perf_counter()
next_runnable_vertices = []
top_level_vertices = []
try:
@ -158,8 +159,7 @@ async def build_vertex(
)
else:
graph = cache.get("result")
result_data_response = ResultDataResponse(results={})
duration = ""
ResultDataResponse(results={})
vertex = graph.get_vertex(vertex_id)
try:
lock = chat_service._cache_locks[flow_id]
@ -240,7 +240,7 @@ async def build_vertex(
@router.get("/build/{flow_id}/{vertex_id}/stream", response_class=StreamingResponse)
async def build_vertex_stream(
flow_id: str,
flow_id: uuid.UUID,
vertex_id: str,
session_id: Optional[str] = None,
chat_service: "ChatService" = Depends(get_chat_service),
@ -272,6 +272,7 @@ async def build_vertex_stream(
HTTPException: If an error occurs while building the vertex.
"""
try:
flow_id = str(flow_id)
async def stream_vertex():
try:

View file

@ -1,5 +1,6 @@
from http import HTTPStatus
from typing import Annotated, List, Optional, Union
from uuid import UUID
import sqlalchemy as sa
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status
@ -54,7 +55,7 @@ def get_all(
@router.post("/run/{flow_id}", response_model=RunResponse, response_model_exclude_none=True)
async def simplified_run_flow(
db: Annotated[Session, Depends(get_session)],
flow_id: str,
flow_id: UUID,
input_request: SimplifiedAPIRequest = SimplifiedAPIRequest(),
stream: bool = False,
api_key_user: User = Depends(api_key_security),
@ -111,6 +112,7 @@ async def simplified_run_flow(
session_id = input_request.session_id
try:
flow_id = str(flow_id)
task_result: List[RunOutputs] = []
artifacts = {}
if input_request.session_id:
@ -187,7 +189,7 @@ async def simplified_run_flow(
@router.post("/run/advanced/{flow_id}", response_model=RunResponse, response_model_exclude_none=True)
async def experimental_run_flow(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
flow_id: UUID,
inputs: Optional[List[InputValueRequest]] = [InputValueRequest(components=[], input_value="")],
outputs: Optional[List[str]] = [],
tweaks: Annotated[Optional[Tweaks], Body(embed=True)] = None, # noqa: F821
@ -235,6 +237,7 @@ async def experimental_run_flow(
This endpoint facilitates complex flow executions with customized inputs, outputs, and configurations, catering to diverse application requirements.
"""
try:
flow_id = str(flow_id)
if outputs is None:
outputs = []
@ -357,9 +360,10 @@ async def get_task_status(task_id: str):
)
async def create_upload_file(
file: UploadFile,
flow_id: str,
flow_id: UUID,
):
try:
flow_id = str(flow_id)
file_path = save_uploaded_file(file, folder_name=flow_id)
return UploadFileResponse(

View file

@ -1,6 +1,7 @@
import hashlib
from http import HTTPStatus
from io import BytesIO
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
@ -20,10 +21,11 @@ router = APIRouter(tags=["Files"], prefix="/files")
# then finds it in the database and returns it while
# using the current user as the owner
def get_flow_id(
flow_id: str,
flow_id: UUID,
current_user=Depends(get_current_active_user),
session=Depends(get_session),
):
flow_id = str(flow_id)
# AttributeError: 'SelectOfScalar' object has no attribute 'first'
flow = session.get(Flow, flow_id)
if not flow:
@ -36,10 +38,11 @@ def get_flow_id(
@router.post("/upload/{flow_id}", status_code=HTTPStatus.CREATED)
async def upload_file(
file: UploadFile,
flow_id: str = Depends(get_flow_id),
flow_id: UUID = Depends(get_flow_id),
storage_service: StorageService = Depends(get_storage_service),
):
try:
flow_id = str(flow_id)
file_content = await file.read()
file_name = file.filename or hashlib.sha256(file_content).hexdigest()
folder = flow_id
@ -50,8 +53,9 @@ async def upload_file(
@router.get("/download/{flow_id}/{file_name}")
async def download_file(file_name: str, flow_id: str, storage_service: StorageService = Depends(get_storage_service)):
async def download_file(file_name: str, flow_id: UUID, storage_service: StorageService = Depends(get_storage_service)):
try:
flow_id = str(flow_id)
extension = file_name.split(".")[-1]
if not extension:
@ -74,9 +78,10 @@ async def download_file(file_name: str, flow_id: str, storage_service: StorageSe
@router.get("/images/{flow_id}/{file_name}")
async def download_image(file_name: str, flow_id: str, storage_service: StorageService = Depends(get_storage_service)):
async def download_image(file_name: str, flow_id: UUID, storage_service: StorageService = Depends(get_storage_service)):
try:
extension = file_name.split(".")[-1]
flow_id = str(flow_id)
if not extension:
raise HTTPException(status_code=500, detail=f"Extension not found for file {file_name}")
@ -96,9 +101,10 @@ async def download_image(file_name: str, flow_id: str, storage_service: StorageS
@router.get("/list/{flow_id}")
async def list_files(
flow_id: str = Depends(get_flow_id), storage_service: StorageService = Depends(get_storage_service)
flow_id: UUID = Depends(get_flow_id), storage_service: StorageService = Depends(get_storage_service)
):
try:
flow_id = str(flow_id)
files = await storage_service.list_files(flow_id=flow_id)
return {"files": files}
except Exception as e:
@ -107,9 +113,10 @@ async def list_files(
@router.delete("/delete/{flow_id}/{file_name}")
async def delete_file(
file_name: str, flow_id: str = Depends(get_flow_id), storage_service: StorageService = Depends(get_storage_service)
file_name: str, flow_id: UUID = Depends(get_flow_id), storage_service: StorageService = Depends(get_storage_service)
):
try:
flow_id = str(flow_id)
await storage_service.delete_file(flow_id=flow_id, file_name=file_name)
return {"message": f"File {file_name} deleted successfully"}
except Exception as e: