refactor: release session after select (#3555)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-08-26 18:53:53 -03:00 committed by GitHub
commit c20e02e0e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 34 additions and 29 deletions

View file

@ -34,8 +34,9 @@ from langflow.schema.graph import Tweaks
from langflow.services.auth.utils import api_key_security, get_current_active_user
from langflow.services.cache.utils import save_uploaded_file
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.flow.model import FlowRead
from langflow.services.database.models.flow.utils import get_all_webhook_components_in_flow
from langflow.services.database.models.user.model import User
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import (
get_cache_service,
get_session,
@ -175,10 +176,10 @@ async def simple_run_flow_task(
@router.post("/run/{flow_id_or_name}", response_model=RunResponse, response_model_exclude_none=True)
async def simplified_run_flow(
background_tasks: BackgroundTasks,
flow: Annotated[Flow, Depends(get_flow_by_id_or_endpoint_name)],
flow: Annotated[FlowRead | None, Depends(get_flow_by_id_or_endpoint_name)],
input_request: SimplifiedAPIRequest = SimplifiedAPIRequest(),
stream: bool = False,
api_key_user: User = Depends(api_key_security),
api_key_user: UserRead = Depends(api_key_security),
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
):
"""
@ -229,6 +230,8 @@ async def simplified_run_flow(
This endpoint provides a powerful interface for executing flows with enhanced flexibility and efficiency, supporting a wide range of applications by allowing for dynamic input and output configuration along with performance optimizations through session management and caching.
"""
if flow is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Flow not found")
start_time = time.perf_counter()
try:
result = await simple_run_flow(
@ -364,7 +367,7 @@ async def experimental_run_flow(
tweaks: Annotated[Optional[Tweaks], Body(embed=True)] = None, # noqa: F821
stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
api_key_user: User = Depends(api_key_security),
api_key_user: UserRead = Depends(api_key_security),
session_service: SessionService = Depends(get_session_service),
):
"""
@ -478,7 +481,7 @@ async def process(
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
task_service: "TaskService" = Depends(get_task_service),
api_key_user: User = Depends(api_key_security),
api_key_user: UserRead = Depends(api_key_security),
sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821
session_service: SessionService = Depends(get_session_service),
):

View file

@ -1,15 +1,16 @@
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple, Type, Union, cast
from uuid import UUID
from fastapi import Depends, HTTPException
from fastapi import HTTPException
from pydantic.v1 import BaseModel, Field, create_model
from sqlmodel import Session, select
from sqlmodel import select
from langflow.graph.schema import RunOutputs
from langflow.schema import Data
from langflow.schema.schema import INPUT_FIELD_NAME
from langflow.services.database.models.flow import Flow
from langflow.services.deps import get_session, get_settings_service, session_scope
from langflow.services.database.models.flow.model import FlowRead
from langflow.services.deps import get_settings_service, session_scope
if TYPE_CHECKING:
from langflow.graph.graph.base import Graph
@ -257,23 +258,23 @@ def get_arg_names(inputs: List["Vertex"]) -> List[dict[str, str]]:
]
def get_flow_by_id_or_endpoint_name(
flow_id_or_name: str, db: Session = Depends(get_session), user_id: Optional[UUID] = None
) -> Flow:
endpoint_name = None
try:
flow_id = UUID(flow_id_or_name)
flow = db.get(Flow, flow_id)
except ValueError:
endpoint_name = flow_id_or_name
stmt = select(Flow).where(Flow.endpoint_name == endpoint_name)
if user_id:
stmt = stmt.where(Flow.user_id == user_id)
flow = db.exec(stmt).first()
if flow is None:
raise HTTPException(status_code=404, detail=f"Flow identifier {flow_id_or_name} not found")
return flow
def get_flow_by_id_or_endpoint_name(flow_id_or_name: str, user_id: Optional[UUID] = None) -> FlowRead | None:
flow_read = None
with session_scope() as session:
endpoint_name = None
try:
flow_id = UUID(flow_id_or_name)
flow = session.get(Flow, flow_id)
except ValueError:
endpoint_name = flow_id_or_name
stmt = select(Flow).where(Flow.endpoint_name == endpoint_name)
if user_id:
stmt = stmt.where(Flow.user_id == user_id)
flow = session.exec(stmt).first()
if flow is None:
raise HTTPException(status_code=404, detail=f"Flow identifier {flow_id_or_name} not found")
flow_read = FlowRead.model_validate(flow, from_attributes=True)
return flow_read
def generate_unique_flow_name(flow_name, user_id, session):

View file

@ -16,7 +16,7 @@ from starlette.websockets import WebSocket
from langflow.services.database.models.api_key.crud import check_key
from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at
from langflow.services.database.models.user.model import User
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import get_session, get_settings_service
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
@ -32,7 +32,7 @@ async def api_key_security(
query_param: str = Security(api_key_query),
header_param: str = Security(api_key_header),
db: Session = Depends(get_session),
) -> Optional[User]:
) -> Optional[UserRead]:
settings_service = get_settings_service()
result: Optional[Union[ApiKey, User]] = None
if settings_service.auth_settings.AUTO_LOGIN:
@ -63,9 +63,10 @@ async def api_key_security(
detail="Invalid or missing API key",
)
if isinstance(result, ApiKey):
return result.user
return UserRead.model_validate(result.user, from_attributes=True)
elif isinstance(result, User):
return result
return UserRead.model_validate(result, from_attributes=True)
raise ValueError("Invalid result type")
async def get_current_user(