refactor: release session after select (#3555)
This commit is contained in:
parent
60c586a52b
commit
c20e02e0e2
3 changed files with 34 additions and 29 deletions
|
|
@ -34,8 +34,9 @@ from langflow.schema.graph import Tweaks
|
|||
from langflow.services.auth.utils import api_key_security, get_current_active_user
|
||||
from langflow.services.cache.utils import save_uploaded_file
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.models.flow.model import FlowRead
|
||||
from langflow.services.database.models.flow.utils import get_all_webhook_components_in_flow
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.database.models.user.model import User, UserRead
|
||||
from langflow.services.deps import (
|
||||
get_cache_service,
|
||||
get_session,
|
||||
|
|
@ -175,10 +176,10 @@ async def simple_run_flow_task(
|
|||
@router.post("/run/{flow_id_or_name}", response_model=RunResponse, response_model_exclude_none=True)
|
||||
async def simplified_run_flow(
|
||||
background_tasks: BackgroundTasks,
|
||||
flow: Annotated[Flow, Depends(get_flow_by_id_or_endpoint_name)],
|
||||
flow: Annotated[FlowRead | None, Depends(get_flow_by_id_or_endpoint_name)],
|
||||
input_request: SimplifiedAPIRequest = SimplifiedAPIRequest(),
|
||||
stream: bool = False,
|
||||
api_key_user: User = Depends(api_key_security),
|
||||
api_key_user: UserRead = Depends(api_key_security),
|
||||
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
|
||||
):
|
||||
"""
|
||||
|
|
@ -229,6 +230,8 @@ async def simplified_run_flow(
|
|||
|
||||
This endpoint provides a powerful interface for executing flows with enhanced flexibility and efficiency, supporting a wide range of applications by allowing for dynamic input and output configuration along with performance optimizations through session management and caching.
|
||||
"""
|
||||
if flow is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Flow not found")
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
result = await simple_run_flow(
|
||||
|
|
@ -364,7 +367,7 @@ async def experimental_run_flow(
|
|||
tweaks: Annotated[Optional[Tweaks], Body(embed=True)] = None, # noqa: F821
|
||||
stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
|
||||
api_key_user: User = Depends(api_key_security),
|
||||
api_key_user: UserRead = Depends(api_key_security),
|
||||
session_service: SessionService = Depends(get_session_service),
|
||||
):
|
||||
"""
|
||||
|
|
@ -478,7 +481,7 @@ async def process(
|
|||
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
|
||||
task_service: "TaskService" = Depends(get_task_service),
|
||||
api_key_user: User = Depends(api_key_security),
|
||||
api_key_user: UserRead = Depends(api_key_security),
|
||||
sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821
|
||||
session_service: SessionService = Depends(get_session_service),
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,15 +1,16 @@
|
|||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple, Type, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi import HTTPException
|
||||
from pydantic.v1 import BaseModel, Field, create_model
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.graph.schema import RunOutputs
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.schema import INPUT_FIELD_NAME
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.deps import get_session, get_settings_service, session_scope
|
||||
from langflow.services.database.models.flow.model import FlowRead
|
||||
from langflow.services.deps import get_settings_service, session_scope
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.graph.base import Graph
|
||||
|
|
@ -257,23 +258,23 @@ def get_arg_names(inputs: List["Vertex"]) -> List[dict[str, str]]:
|
|||
]
|
||||
|
||||
|
||||
def get_flow_by_id_or_endpoint_name(
|
||||
flow_id_or_name: str, db: Session = Depends(get_session), user_id: Optional[UUID] = None
|
||||
) -> Flow:
|
||||
endpoint_name = None
|
||||
try:
|
||||
flow_id = UUID(flow_id_or_name)
|
||||
flow = db.get(Flow, flow_id)
|
||||
except ValueError:
|
||||
endpoint_name = flow_id_or_name
|
||||
stmt = select(Flow).where(Flow.endpoint_name == endpoint_name)
|
||||
if user_id:
|
||||
stmt = stmt.where(Flow.user_id == user_id)
|
||||
flow = db.exec(stmt).first()
|
||||
if flow is None:
|
||||
raise HTTPException(status_code=404, detail=f"Flow identifier {flow_id_or_name} not found")
|
||||
|
||||
return flow
|
||||
def get_flow_by_id_or_endpoint_name(flow_id_or_name: str, user_id: Optional[UUID] = None) -> FlowRead | None:
|
||||
flow_read = None
|
||||
with session_scope() as session:
|
||||
endpoint_name = None
|
||||
try:
|
||||
flow_id = UUID(flow_id_or_name)
|
||||
flow = session.get(Flow, flow_id)
|
||||
except ValueError:
|
||||
endpoint_name = flow_id_or_name
|
||||
stmt = select(Flow).where(Flow.endpoint_name == endpoint_name)
|
||||
if user_id:
|
||||
stmt = stmt.where(Flow.user_id == user_id)
|
||||
flow = session.exec(stmt).first()
|
||||
if flow is None:
|
||||
raise HTTPException(status_code=404, detail=f"Flow identifier {flow_id_or_name} not found")
|
||||
flow_read = FlowRead.model_validate(flow, from_attributes=True)
|
||||
return flow_read
|
||||
|
||||
|
||||
def generate_unique_flow_name(flow_name, user_id, session):
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from starlette.websockets import WebSocket
|
|||
from langflow.services.database.models.api_key.crud import check_key
|
||||
from langflow.services.database.models.api_key.model import ApiKey
|
||||
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.database.models.user.model import User, UserRead
|
||||
from langflow.services.deps import get_session, get_settings_service
|
||||
|
||||
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
|
||||
|
|
@ -32,7 +32,7 @@ async def api_key_security(
|
|||
query_param: str = Security(api_key_query),
|
||||
header_param: str = Security(api_key_header),
|
||||
db: Session = Depends(get_session),
|
||||
) -> Optional[User]:
|
||||
) -> Optional[UserRead]:
|
||||
settings_service = get_settings_service()
|
||||
result: Optional[Union[ApiKey, User]] = None
|
||||
if settings_service.auth_settings.AUTO_LOGIN:
|
||||
|
|
@ -63,9 +63,10 @@ async def api_key_security(
|
|||
detail="Invalid or missing API key",
|
||||
)
|
||||
if isinstance(result, ApiKey):
|
||||
return result.user
|
||||
return UserRead.model_validate(result.user, from_attributes=True)
|
||||
elif isinstance(result, User):
|
||||
return result
|
||||
return UserRead.model_validate(result, from_attributes=True)
|
||||
raise ValueError("Invalid result type")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue