ref: Remove sync get_session and DbSession (#5146)

Remove sync get_session and DbSession
This commit is contained in:
Christophe Bornet 2024-12-08 11:27:19 +01:00 committed by GitHub
commit 9270f1a0f8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 71 additions and 99 deletions

View file

@ -5,7 +5,7 @@ from loguru import logger
from pydantic import BaseModel
from sqlmodel import select
from langflow.api.utils import AsyncDbSession
from langflow.api.utils import DbSession
from langflow.services.database.models.flow import Flow
from langflow.services.deps import get_chat_service
@ -38,7 +38,7 @@ async def health():
# It's a reliable health check for a langflow instance
@health_check_router.get("/health_check")
async def health_check(
session: AsyncDbSession,
session: DbSession,
) -> HealthResponse:
response = HealthResponse()
# use a fixed valid UUId that UUID collision is very unlikely

View file

@ -8,7 +8,6 @@ from fastapi import Depends, HTTPException, Query
from fastapi_pagination import Params
from loguru import logger
from sqlalchemy import delete
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.graph.graph.base import Graph
@ -17,7 +16,7 @@ from langflow.services.database.models import User
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.transactions.model import TransactionTable
from langflow.services.database.models.vertex_builds.model import VertexBuildTable
from langflow.services.deps import async_session_scope, get_async_session, get_session
from langflow.services.deps import async_session_scope, get_session
from langflow.services.store.utils import get_lf_version_from_pypi
if TYPE_CHECKING:
@ -31,8 +30,7 @@ MAX_PAGE_SIZE = 50
MIN_PAGE_SIZE = 1
CurrentActiveUser = Annotated[User, Depends(get_current_active_user)]
DbSession = Annotated[Session, Depends(get_session)]
AsyncDbSession = Annotated[AsyncSession, Depends(get_async_session)]
DbSession = Annotated[AsyncSession, Depends(get_session)]
def has_api_terms(word: str):

View file

@ -2,7 +2,7 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Response
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.api.v1.schemas import ApiKeyCreateRequest, ApiKeysResponse
from langflow.services.auth import utils as auth_utils
@ -16,7 +16,7 @@ router = APIRouter(tags=["APIKey"], prefix="/api_key")
@router.get("/")
async def get_api_keys_route(
db: AsyncDbSession,
db: DbSession,
current_user: CurrentActiveUser,
) -> ApiKeysResponse:
try:
@ -32,7 +32,7 @@ async def get_api_keys_route(
async def create_api_key_route(
req: ApiKeyCreate,
current_user: CurrentActiveUser,
db: AsyncDbSession,
db: DbSession,
) -> UnmaskedApiKeyRead:
try:
user_id = current_user.id
@ -44,7 +44,7 @@ async def create_api_key_route(
@router.delete("/{api_key_id}", dependencies=[Depends(auth_utils.get_current_active_user)])
async def delete_api_key_route(
api_key_id: UUID,
db: AsyncDbSession,
db: DbSession,
):
try:
await delete_api_key(db, api_key_id)
@ -58,7 +58,7 @@ async def save_store_api_key(
api_key_request: ApiKeyCreateRequest,
response: Response,
current_user: CurrentActiveUser,
db: AsyncDbSession,
db: DbSession,
):
settings_service = get_settings_service()
auth_settings = settings_service.auth_settings

View file

@ -17,8 +17,8 @@ from starlette.responses import ContentStream
from starlette.types import Receive
from langflow.api.utils import (
AsyncDbSession,
CurrentActiveUser,
DbSession,
build_and_cache_graph_from_data,
build_graph_from_data,
build_graph_from_db,
@ -44,7 +44,7 @@ from langflow.schema.schema import OutputValue
from langflow.services.cache.utils import CacheMiss
from langflow.services.chat.service import ChatService
from langflow.services.database.models.flow.model import Flow
from langflow.services.deps import async_session_scope, get_async_session, get_chat_service, get_telemetry_service
from langflow.services.deps import async_session_scope, get_chat_service, get_session, get_telemetry_service
from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload
if TYPE_CHECKING:
@ -77,7 +77,7 @@ async def retrieve_vertices_order(
data: Annotated[FlowDataRequest | None, Body(embed=True)] | None = None,
stop_component_id: str | None = None,
start_component_id: str | None = None,
session: AsyncDbSession,
session: DbSession,
) -> VerticesOrderResponse:
"""Retrieve the vertices order for a given flow.
@ -153,7 +153,7 @@ async def build_flow(
start_component_id: str | None = None,
log_builds: bool | None = True,
current_user: CurrentActiveUser,
session: AsyncDbSession,
session: DbSession,
):
chat_service = get_chat_service()
telemetry_service = get_telemetry_service()
@ -512,7 +512,7 @@ async def build_vertex(
# If there's no cache
logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}")
graph: Graph = await build_graph_from_db(
flow_id=flow_id_str, session=await anext(get_async_session()), chat_service=chat_service
flow_id=flow_id_str, session=await anext(get_session()), chat_service=chat_service
)
else:
graph = cache.get("result")

View file

@ -20,7 +20,7 @@ from fastapi import (
from loguru import logger
from sqlmodel import select
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, parse_value
from langflow.api.utils import CurrentActiveUser, DbSession, parse_value
from langflow.api.v1.schemas import (
ConfigResponse,
CustomComponentRequest,
@ -379,7 +379,7 @@ async def webhook_run_flow(
)
async def experimental_run_flow(
*,
session: AsyncDbSession,
session: DbSession,
flow_id: UUID,
inputs: list[InputValueRequest] | None = None,
outputs: list[str] | None = None,

View file

@ -9,7 +9,7 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.api.v1.schemas import UploadFileResponse
from langflow.services.database.models.flow import Flow
from langflow.services.deps import get_storage_service
@ -25,7 +25,7 @@ router = APIRouter(tags=["Files"], prefix="/files")
async def get_flow_id(
flow_id: UUID,
current_user: CurrentActiveUser,
session: AsyncDbSession,
session: DbSession,
):
flow_id_str = str(flow_id)
# AttributeError: 'SelectOfScalar' object has no attribute 'first'
@ -43,7 +43,7 @@ async def upload_file(
file: UploadFile,
flow_id: Annotated[UUID, Depends(get_flow_id)],
current_user: CurrentActiveUser,
session: AsyncDbSession,
session: DbSession,
storage_service: Annotated[StorageService, Depends(get_storage_service)],
) -> UploadFileResponse:
try:

View file

@ -18,8 +18,8 @@ from sqlmodel import and_, col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.api.utils import (
AsyncDbSession,
CurrentActiveUser,
DbSession,
cascade_delete_flow,
remove_api_keys,
validate_is_component,
@ -124,7 +124,7 @@ async def _new_flow(
@router.post("/", response_model=FlowRead, status_code=201)
async def create_flow(
*,
session: AsyncDbSession,
session: DbSession,
flow: FlowCreate,
current_user: CurrentActiveUser,
):
@ -154,7 +154,7 @@ async def create_flow(
async def read_flows(
*,
current_user: CurrentActiveUser,
session: AsyncDbSession,
session: DbSession,
remove_example_flows: bool = False,
components_only: bool = False,
get_all: bool = True,
@ -261,7 +261,7 @@ async def _read_flow(
@router.get("/{flow_id}", response_model=FlowRead, status_code=200)
async def read_flow(
*,
session: AsyncDbSession,
session: DbSession,
flow_id: UUID,
current_user: CurrentActiveUser,
):
@ -274,7 +274,7 @@ async def read_flow(
@router.patch("/{flow_id}", response_model=FlowRead, status_code=200)
async def update_flow(
*,
session: AsyncDbSession,
session: DbSession,
flow_id: UUID,
flow: FlowUpdate,
current_user: CurrentActiveUser,
@ -334,7 +334,7 @@ async def update_flow(
@router.delete("/{flow_id}", status_code=200)
async def delete_flow(
*,
session: AsyncDbSession,
session: DbSession,
flow_id: UUID,
current_user: CurrentActiveUser,
):
@ -355,7 +355,7 @@ async def delete_flow(
@router.post("/batch/", response_model=list[FlowRead], status_code=201)
async def create_flows(
*,
session: AsyncDbSession,
session: DbSession,
flow_list: FlowListCreate,
current_user: CurrentActiveUser,
):
@ -375,7 +375,7 @@ async def create_flows(
@router.post("/upload/", response_model=list[FlowRead], status_code=201)
async def upload_file(
*,
session: AsyncDbSession,
session: DbSession,
file: Annotated[UploadFile, File(...)],
current_user: CurrentActiveUser,
folder_id: UUID | None = None,
@ -420,7 +420,7 @@ async def upload_file(
async def delete_multiple_flows(
flow_ids: list[UUID],
user: CurrentActiveUser,
db: AsyncDbSession,
db: DbSession,
):
"""Delete multiple flows by their IDs.
@ -458,7 +458,7 @@ async def delete_multiple_flows(
async def download_multiple_file(
flow_ids: list[UUID],
user: CurrentActiveUser,
db: AsyncDbSession,
db: DbSession,
):
"""Download all flows as a zip file."""
flows = (await db.exec(select(Flow).where(and_(Flow.user_id == user.id, Flow.id.in_(flow_ids))))).all() # type: ignore[attr-defined]
@ -499,7 +499,7 @@ async def download_multiple_file(
@router.get("/basic_examples/", response_model=list[FlowRead], status_code=200)
async def read_basic_examples(
*,
session: AsyncDbSession,
session: DbSession,
):
"""Retrieve a list of basic example flows.

View file

@ -14,7 +14,7 @@ from sqlalchemy import or_, update
from sqlalchemy.orm import selectinload
from sqlmodel import select
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, cascade_delete_flow, custom_params, remove_api_keys
from langflow.api.utils import CurrentActiveUser, DbSession, cascade_delete_flow, custom_params, remove_api_keys
from langflow.api.v1.flows import create_flows
from langflow.api.v1.schemas import FlowListCreate
from langflow.helpers.flow import generate_unique_flow_name
@ -37,7 +37,7 @@ router = APIRouter(prefix="/folders", tags=["Folders"])
@router.post("/", response_model=FolderRead, status_code=201)
async def create_folder(
*,
session: AsyncDbSession,
session: DbSession,
folder: FolderCreate,
current_user: CurrentActiveUser,
):
@ -93,7 +93,7 @@ async def create_folder(
@router.get("/", response_model=list[FolderRead], status_code=200)
async def read_folders(
*,
session: AsyncDbSession,
session: DbSession,
current_user: CurrentActiveUser,
):
try:
@ -113,7 +113,7 @@ async def read_folders(
@router.get("/{folder_id}", response_model=FolderWithPaginatedFlows | FolderReadWithFlows, status_code=200)
async def read_folder(
*,
session: AsyncDbSession,
session: DbSession,
folder_id: str,
current_user: CurrentActiveUser,
params: Annotated[Params | None, Depends(custom_params)],
@ -164,7 +164,7 @@ async def read_folder(
@router.patch("/{folder_id}", response_model=FolderRead, status_code=200)
async def update_folder(
*,
session: AsyncDbSession,
session: DbSession,
folder_id: str,
folder: FolderUpdate, # Assuming FolderUpdate is a Pydantic model defining updatable fields
current_user: CurrentActiveUser,
@ -225,7 +225,7 @@ async def update_folder(
@router.delete("/{folder_id}", status_code=204)
async def delete_folder(
*,
session: AsyncDbSession,
session: DbSession,
folder_id: str,
current_user: CurrentActiveUser,
):
@ -257,7 +257,7 @@ async def delete_folder(
@router.get("/download/{folder_id}", status_code=200)
async def download_file(
*,
session: AsyncDbSession,
session: DbSession,
folder_id: str,
current_user: CurrentActiveUser,
):
@ -305,7 +305,7 @@ async def download_file(
@router.post("/upload/", response_model=list[FlowRead], status_code=201)
async def upload_file(
*,
session: AsyncDbSession,
session: DbSession,
file: Annotated[UploadFile, File(...)],
current_user: CurrentActiveUser,
):

View file

@ -5,7 +5,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from langflow.api.utils import AsyncDbSession
from langflow.api.utils import DbSession
from langflow.api.v1.schemas import Token
from langflow.services.auth.utils import (
authenticate_user,
@ -24,7 +24,7 @@ router = APIRouter(tags=["Login"])
async def login_to_get_access_token(
response: Response,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: AsyncDbSession,
db: DbSession,
):
auth_settings = get_settings_service().auth_settings
try:
@ -78,7 +78,7 @@ async def login_to_get_access_token(
@router.get("/auto_login")
async def auto_login(response: Response, db: AsyncDbSession):
async def auto_login(response: Response, db: DbSession):
auth_settings = get_settings_service().auth_settings
if auth_settings.AUTO_LOGIN:
@ -124,7 +124,7 @@ async def auto_login(response: Response, db: AsyncDbSession):
async def refresh_token(
request: Request,
response: Response,
db: AsyncDbSession,
db: DbSession,
):
auth_settings = get_settings_service().auth_settings

View file

@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import delete
from sqlmodel import col, select
from langflow.api.utils import AsyncDbSession
from langflow.api.utils import DbSession
from langflow.schema.message import MessageResponse
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.message.model import MessageRead, MessageTable, MessageUpdate
@ -21,7 +21,7 @@ router = APIRouter(prefix="/monitor", tags=["Monitor"])
@router.get("/builds")
async def get_vertex_builds(flow_id: Annotated[UUID, Query()], session: AsyncDbSession) -> VertexBuildMapModel:
async def get_vertex_builds(flow_id: Annotated[UUID, Query()], session: DbSession) -> VertexBuildMapModel:
try:
vertex_builds = await get_vertex_builds_by_flow_id(session, flow_id)
return VertexBuildMapModel.from_list_of_dicts(vertex_builds)
@ -30,7 +30,7 @@ async def get_vertex_builds(flow_id: Annotated[UUID, Query()], session: AsyncDbS
@router.delete("/builds", status_code=204)
async def delete_vertex_builds(flow_id: Annotated[UUID, Query()], session: AsyncDbSession) -> None:
async def delete_vertex_builds(flow_id: Annotated[UUID, Query()], session: DbSession) -> None:
try:
await delete_vertex_builds_by_flow_id(session, flow_id)
except Exception as e:
@ -39,7 +39,7 @@ async def delete_vertex_builds(flow_id: Annotated[UUID, Query()], session: Async
@router.get("/messages")
async def get_messages(
session: AsyncDbSession,
session: DbSession,
flow_id: Annotated[str | None, Query()] = None,
session_id: Annotated[str | None, Query()] = None,
sender: Annotated[str | None, Query()] = None,
@ -66,7 +66,7 @@ async def get_messages(
@router.delete("/messages", status_code=204, dependencies=[Depends(get_current_active_user)])
async def delete_messages(message_ids: list[UUID], session: AsyncDbSession) -> None:
async def delete_messages(message_ids: list[UUID], session: DbSession) -> None:
try:
await session.exec(delete(MessageTable).where(MessageTable.id.in_(message_ids))) # type: ignore[attr-defined]
await session.commit()
@ -78,7 +78,7 @@ async def delete_messages(message_ids: list[UUID], session: AsyncDbSession) -> N
async def update_message(
message_id: UUID,
message: MessageUpdate,
session: AsyncDbSession,
session: DbSession,
):
try:
db_message = await session.get(MessageTable, message_id)
@ -108,7 +108,7 @@ async def update_message(
async def update_session_id(
old_session_id: str,
new_session_id: Annotated[str, Query(..., description="The new session ID to update to")],
session: AsyncDbSession,
session: DbSession,
) -> list[MessageResponse]:
try:
# Get all messages with the old session ID
@ -141,7 +141,7 @@ async def update_session_id(
@router.delete("/messages/session/{session_id}", status_code=204)
async def delete_messages_session(
session_id: str,
session: AsyncDbSession,
session: DbSession,
):
try:
await session.exec(
@ -159,7 +159,7 @@ async def delete_messages_session(
@router.get("/transactions")
async def get_transactions(
flow_id: Annotated[UUID, Query()],
session: AsyncDbSession,
session: DbSession,
) -> list[TransactionReadResponse]:
try:
transactions = await get_transactions_by_flow_id(session, flow_id)

View file

@ -7,7 +7,7 @@ from sqlalchemy.exc import IntegrityError
from sqlmodel import select
from sqlmodel.sql.expression import SelectOfScalar
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.api.v1.schemas import UsersResponse
from langflow.services.auth.utils import (
get_current_active_superuser,
@ -25,7 +25,7 @@ router = APIRouter(tags=["Users"], prefix="/users")
@router.post("/", response_model=UserRead, status_code=201)
async def add_user(
user: UserCreate,
session: AsyncDbSession,
session: DbSession,
) -> User:
"""Add a new user to the database."""
new_user = User.model_validate(user, from_attributes=True)
@ -58,7 +58,7 @@ async def read_all_users(
*,
skip: int = 0,
limit: int = 10,
session: AsyncDbSession,
session: DbSession,
) -> UsersResponse:
"""Retrieve a list of users from the database with pagination."""
query: SelectOfScalar = select(User).offset(skip).limit(limit)
@ -78,7 +78,7 @@ async def patch_user(
user_id: UUID,
user_update: UserUpdate,
user: CurrentActiveUser,
session: AsyncDbSession,
session: DbSession,
) -> User:
"""Update an existing user's data."""
update_password = bool(user_update.password)
@ -105,7 +105,7 @@ async def reset_password(
user_id: UUID,
user_update: UserUpdate,
user: CurrentActiveUser,
session: AsyncDbSession,
session: DbSession,
) -> User:
"""Reset a user's password."""
if user_id != user.id:
@ -127,7 +127,7 @@ async def reset_password(
async def delete_user(
user_id: UUID,
current_user: Annotated[User, Depends(get_current_active_superuser)],
session: AsyncDbSession,
session: DbSession,
) -> dict:
"""Delete a user from the database."""
if current_user.id == user_id:

View file

@ -3,7 +3,7 @@ from uuid import UUID
from fastapi import APIRouter, HTTPException
from sqlalchemy.exc import NoResultFound
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate
from langflow.services.deps import get_variable_service
from langflow.services.variable.constants import GENERIC_TYPE
@ -15,7 +15,7 @@ router = APIRouter(prefix="/variables", tags=["Variables"])
@router.post("/", response_model=VariableRead, status_code=201)
async def create_variable(
*,
session: AsyncDbSession,
session: DbSession,
variable: VariableCreate,
current_user: CurrentActiveUser,
):
@ -50,7 +50,7 @@ async def create_variable(
@router.get("/", response_model=list[VariableRead], status_code=200)
async def read_variables(
*,
session: AsyncDbSession,
session: DbSession,
current_user: CurrentActiveUser,
):
"""Read all variables."""
@ -67,7 +67,7 @@ async def read_variables(
@router.patch("/{variable_id}", response_model=VariableRead, status_code=200)
async def update_variable(
*,
session: AsyncDbSession,
session: DbSession,
variable_id: UUID,
variable: VariableUpdate,
current_user: CurrentActiveUser,
@ -94,7 +94,7 @@ async def update_variable(
@router.delete("/{variable_id}", status_code=204)
async def delete_variable(
*,
session: AsyncDbSession,
session: DbSession,
variable_id: UUID,
current_user: CurrentActiveUser,
) -> None:

View file

@ -17,7 +17,7 @@ from starlette.websockets import WebSocket
from langflow.services.database.models.api_key.crud import check_key
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import get_async_session, get_db_service, get_settings_service
from langflow.services.deps import get_db_service, get_session, get_settings_service
from langflow.services.settings.service import SettingsService
if TYPE_CHECKING:
@ -79,7 +79,7 @@ async def get_current_user(
token: Annotated[str, Security(oauth2_login)],
query_param: Annotated[str, Security(api_key_query)],
header_param: Annotated[str, Security(api_key_header)],
db: Annotated[AsyncSession, Depends(get_async_session)],
db: Annotated[AsyncSession, Depends(get_session)],
) -> User:
if token:
return await get_current_user_by_jwt(token, db)
@ -156,7 +156,7 @@ async def get_current_user_by_jwt(
async def get_current_user_for_websocket(
websocket: WebSocket,
db: Annotated[AsyncSession, Depends(get_async_session)],
db: Annotated[AsyncSession, Depends(get_session)],
query_param: Annotated[str, Security(api_key_query)],
) -> User | None:
token = websocket.query_params.get("token")

View file

@ -1,23 +1,8 @@
from typing import Annotated
from fastapi import Depends
from sqlmodel import Session
from langflow.services.deps import get_session
from langflow.utils.version import get_version_info
from .model import Flow
def get_flow_by_id(session: Annotated[Session, Depends(get_session)], flow_id: str | None = None) -> Flow | None:
"""Get flow by id."""
if flow_id is None:
msg = "Flow id is required."
raise ValueError(msg)
return session.get(Flow, flow_id)
def get_webhook_component_in_flow(flow_data: dict):
"""Get webhook component in flow data."""
for node in flow_data.get("nodes", []):

View file

@ -142,22 +142,11 @@ def get_db_service() -> DatabaseService:
return get_service(ServiceType.DATABASE_SERVICE, DatabaseServiceFactory())
def get_session() -> Generator[Session, None, None]:
"""Retrieves a session from the database service.
Yields:
Session: A session object.
"""
with get_db_service().with_session() as session:
yield session
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""Retrieves an async session from the database service.
Yields:
Session: An async session object.
AsyncSession: An async session object.
"""
async with get_db_service().with_async_session() as session:
@ -173,7 +162,7 @@ def session_scope() -> Generator[Session, None, None]:
and rolled back if an exception is raised.
Yields:
session: The session object.
Session: The session object.
Raises:
Exception: If an error occurs during the session scope.
@ -199,7 +188,7 @@ async def async_session_scope() -> AsyncGenerator[AsyncSession, None]:
and rolled back if an exception is raised.
Yields:
session: The async session object.
AsyncSession: The async session object.
Raises:
Exception: If an error occurs during the session scope.

View file

@ -11,7 +11,7 @@ from langflow.graph.graph.base import Graph
from langflow.graph.utils import log_vertex_build
from langflow.graph.vertex.base import Vertex
from langflow.services.database.models.flow.model import Flow
from langflow.services.deps import get_async_session
from langflow.services.deps import get_session
def set_socketio_server(socketio_server) -> None:
@ -23,7 +23,7 @@ def set_socketio_server(socketio_server) -> None:
async def get_vertices(sio, sid, flow_id, chat_service) -> None:
try:
session = await anext(get_async_session())
session = await anext(get_session())
stmt = select(Flow).where(Flow.id == flow_id)
flow: Flow = (await session.exec(stmt)).first()
if not flow or not flow.data: