From c20e02e0e28caef8730961300c561e5ec185d520 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 26 Aug 2024 18:53:53 -0300 Subject: [PATCH] refactor: release session after select (#3555) --- src/backend/base/langflow/api/v1/endpoints.py | 13 +++--- src/backend/base/langflow/helpers/flow.py | 41 ++++++++++--------- .../base/langflow/services/auth/utils.py | 9 ++-- 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index 10c1f12e1..926c7ecb6 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -34,8 +34,9 @@ from langflow.schema.graph import Tweaks from langflow.services.auth.utils import api_key_security, get_current_active_user from langflow.services.cache.utils import save_uploaded_file from langflow.services.database.models.flow import Flow +from langflow.services.database.models.flow.model import FlowRead from langflow.services.database.models.flow.utils import get_all_webhook_components_in_flow -from langflow.services.database.models.user.model import User +from langflow.services.database.models.user.model import User, UserRead from langflow.services.deps import ( get_cache_service, get_session, @@ -175,10 +176,10 @@ async def simple_run_flow_task( @router.post("/run/{flow_id_or_name}", response_model=RunResponse, response_model_exclude_none=True) async def simplified_run_flow( background_tasks: BackgroundTasks, - flow: Annotated[Flow, Depends(get_flow_by_id_or_endpoint_name)], + flow: Annotated[FlowRead | None, Depends(get_flow_by_id_or_endpoint_name)], input_request: SimplifiedAPIRequest = SimplifiedAPIRequest(), stream: bool = False, - api_key_user: User = Depends(api_key_security), + api_key_user: UserRead = Depends(api_key_security), telemetry_service: "TelemetryService" = Depends(get_telemetry_service), ): """ @@ -229,6 +230,8 @@ async def simplified_run_flow( This endpoint provides a powerful interface for executing flows with enhanced flexibility and efficiency, supporting a wide range of applications by allowing for dynamic input and output configuration along with performance optimizations through session management and caching. """ + if flow is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Flow not found") start_time = time.perf_counter() try: result = await simple_run_flow( @@ -364,7 +367,7 @@ async def experimental_run_flow( tweaks: Annotated[Optional[Tweaks], Body(embed=True)] = None, # noqa: F821 stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821 session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821 - api_key_user: User = Depends(api_key_security), + api_key_user: UserRead = Depends(api_key_security), session_service: SessionService = Depends(get_session_service), ): """ @@ -478,7 +481,7 @@ async def process( clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821 session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821 task_service: "TaskService" = Depends(get_task_service), - api_key_user: User = Depends(api_key_security), + api_key_user: UserRead = Depends(api_key_security), sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821 session_service: SessionService = Depends(get_session_service), ): diff --git a/src/backend/base/langflow/helpers/flow.py b/src/backend/base/langflow/helpers/flow.py index 72075253f..b753620c5 100644 --- a/src/backend/base/langflow/helpers/flow.py +++ b/src/backend/base/langflow/helpers/flow.py @@ -1,15 +1,16 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple, Type, Union, cast from uuid import UUID -from fastapi import Depends, HTTPException +from fastapi import HTTPException from pydantic.v1 import BaseModel, Field, create_model -from sqlmodel import Session, select +from sqlmodel import select from langflow.graph.schema import RunOutputs from langflow.schema import Data from langflow.schema.schema import INPUT_FIELD_NAME from langflow.services.database.models.flow import Flow -from langflow.services.deps import get_session, get_settings_service, session_scope +from langflow.services.database.models.flow.model import FlowRead +from langflow.services.deps import get_settings_service, session_scope if TYPE_CHECKING: from langflow.graph.graph.base import Graph @@ -257,23 +258,23 @@ def get_arg_names(inputs: List["Vertex"]) -> List[dict[str, str]]: ] -def get_flow_by_id_or_endpoint_name( - flow_id_or_name: str, db: Session = Depends(get_session), user_id: Optional[UUID] = None -) -> Flow: - endpoint_name = None - try: - flow_id = UUID(flow_id_or_name) - flow = db.get(Flow, flow_id) - except ValueError: - endpoint_name = flow_id_or_name - stmt = select(Flow).where(Flow.endpoint_name == endpoint_name) - if user_id: - stmt = stmt.where(Flow.user_id == user_id) - flow = db.exec(stmt).first() - if flow is None: - raise HTTPException(status_code=404, detail=f"Flow identifier {flow_id_or_name} not found") - - return flow +def get_flow_by_id_or_endpoint_name(flow_id_or_name: str, user_id: Optional[UUID] = None) -> FlowRead | None: + flow_read = None + with session_scope() as session: + endpoint_name = None + try: + flow_id = UUID(flow_id_or_name) + flow = session.get(Flow, flow_id) + except ValueError: + endpoint_name = flow_id_or_name + stmt = select(Flow).where(Flow.endpoint_name == endpoint_name) + if user_id: + stmt = stmt.where(Flow.user_id == user_id) + flow = session.exec(stmt).first() + if flow is None: + raise HTTPException(status_code=404, detail=f"Flow identifier {flow_id_or_name} not found") + flow_read = FlowRead.model_validate(flow, from_attributes=True) + return flow_read def generate_unique_flow_name(flow_name, user_id, session): diff --git a/src/backend/base/langflow/services/auth/utils.py b/src/backend/base/langflow/services/auth/utils.py index bf12c9a91..61af0d1f2 100644 --- a/src/backend/base/langflow/services/auth/utils.py +++ b/src/backend/base/langflow/services/auth/utils.py @@ -16,7 +16,7 @@ from starlette.websockets import WebSocket from langflow.services.database.models.api_key.crud import check_key from langflow.services.database.models.api_key.model import ApiKey from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at -from langflow.services.database.models.user.model import User +from langflow.services.database.models.user.model import User, UserRead from langflow.services.deps import get_session, get_settings_service oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False) @@ -32,7 +32,7 @@ async def api_key_security( query_param: str = Security(api_key_query), header_param: str = Security(api_key_header), db: Session = Depends(get_session), -) -> Optional[User]: +) -> Optional[UserRead]: settings_service = get_settings_service() result: Optional[Union[ApiKey, User]] = None if settings_service.auth_settings.AUTO_LOGIN: @@ -63,9 +63,10 @@ async def api_key_security( detail="Invalid or missing API key", ) if isinstance(result, ApiKey): - return result.user + return UserRead.model_validate(result.user, from_attributes=True) elif isinstance(result, User): - return result + return UserRead.model_validate(result, from_attributes=True) + raise ValueError("Invalid result type") async def get_current_user(