diff --git a/src/backend/base/langflow/api/health_check_router.py b/src/backend/base/langflow/api/health_check_router.py index 8479fc736..57c0c0897 100644 --- a/src/backend/base/langflow/api/health_check_router.py +++ b/src/backend/base/langflow/api/health_check_router.py @@ -5,7 +5,7 @@ from loguru import logger from pydantic import BaseModel from sqlmodel import select -from langflow.api.utils import DbSession +from langflow.api.utils import AsyncDbSession from langflow.services.database.models.flow import Flow from langflow.services.deps import get_chat_service @@ -38,7 +38,7 @@ async def health(): # It's a reliable health check for a langflow instance @health_check_router.get("/health_check") async def health_check( - session: DbSession, + session: AsyncDbSession, ) -> HealthResponse: response = HealthResponse() # use a fixed valid UUId that UUID collision is very unlikely @@ -46,7 +46,7 @@ async def health_check( try: # Check database to query a bogus flow stmt = select(Flow).where(Flow.id == uuid.uuid4()) - session.exec(stmt).first() + (await session.exec(stmt)).first() response.db = "ok" except Exception: # noqa: BLE001 logger.exception("Error checking database") diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index 1fcaf1b36..654fd546b 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -20,7 +20,7 @@ from fastapi import ( from loguru import logger from sqlmodel import select -from langflow.api.utils import CurrentActiveUser, DbSession, parse_value +from langflow.api.utils import AsyncDbSession, CurrentActiveUser, parse_value from langflow.api.v1.schemas import ( ConfigResponse, CustomComponentRequest, @@ -379,7 +379,7 @@ async def webhook_run_flow( ) async def experimental_run_flow( *, - session: DbSession, + session: AsyncDbSession, flow_id: UUID, inputs: list[InputValueRequest] | None = None, outputs: list[str] | None = None, @@ -454,9 +454,8 @@ async def experimental_run_flow( try: # Get the flow that matches the flow_id and belongs to the user # flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first() - flow = session.exec( - select(Flow).where(Flow.id == flow_id_str).where(Flow.user_id == api_key_user.id) - ).first() + stmt = select(Flow).where(Flow.id == flow_id_str).where(Flow.user_id == api_key_user.id) + flow = (await session.exec(stmt)).first() except sa.exc.StatementError as exc: # StatementError('(builtins.ValueError) badly formed hexadecimal UUID string') if "badly formed hexadecimal UUID string" in str(exc): diff --git a/src/backend/base/langflow/api/v1/users.py b/src/backend/base/langflow/api/v1/users.py index 8c43985b3..599908769 100644 --- a/src/backend/base/langflow/api/v1/users.py +++ b/src/backend/base/langflow/api/v1/users.py @@ -7,7 +7,7 @@ from sqlalchemy.exc import IntegrityError from sqlmodel import select from sqlmodel.sql.expression import SelectOfScalar -from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession +from langflow.api.utils import AsyncDbSession, CurrentActiveUser from langflow.api.v1.schemas import UsersResponse from langflow.services.auth.utils import ( get_current_active_superuser, @@ -127,7 +127,7 @@ async def reset_password( async def delete_user( user_id: UUID, current_user: Annotated[User, Depends(get_current_active_superuser)], - session: DbSession, + session: AsyncDbSession, ) -> dict: """Delete a user from the database.""" if current_user.id == user_id: @@ -135,11 +135,12 @@ async def delete_user( if not current_user.is_superuser: raise HTTPException(status_code=403, detail="Permission denied") - user_db = session.exec(select(User).where(User.id == user_id)).first() + stmt = select(User).where(User.id == user_id) + user_db = (await session.exec(stmt)).first() if not user_db: raise HTTPException(status_code=404, detail="User not found") - session.delete(user_db) - session.commit() + await session.delete(user_db) + await session.commit() return {"detail": "User deleted"} diff --git a/src/backend/base/langflow/services/socket/utils.py b/src/backend/base/langflow/services/socket/utils.py index 723acd48d..95d3a90ba 100644 --- a/src/backend/base/langflow/services/socket/utils.py +++ b/src/backend/base/langflow/services/socket/utils.py @@ -11,7 +11,7 @@ from langflow.graph.graph.base import Graph from langflow.graph.utils import log_vertex_build from langflow.graph.vertex.base import Vertex from langflow.services.database.models.flow.model import Flow -from langflow.services.deps import get_session +from langflow.services.deps import get_async_session def set_socketio_server(socketio_server) -> None: @@ -23,8 +23,9 @@ def set_socketio_server(socketio_server) -> None: async def get_vertices(sio, sid, flow_id, chat_service) -> None: try: - session = next(get_session()) - flow: Flow = session.exec(select(Flow).where(Flow.id == flow_id)).first() + session = await anext(get_async_session()) + stmt = select(Flow).where(Flow.id == flow_id) + flow: Flow = (await session.exec(stmt)).first() if not flow or not flow.data: await sio.emit("error", data="Invalid flow ID", to=sid) return