fix: Use AsyncSession in some API endpoints (#4650)

* Use AsyncSession in some API endpoints

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Christophe Bornet 2024-12-04 16:14:01 +01:00 committed by GitHub
commit ba9dea5547
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 17 additions and 16 deletions

View file

@ -5,7 +5,7 @@ from loguru import logger
from pydantic import BaseModel
from sqlmodel import select
from langflow.api.utils import DbSession
from langflow.api.utils import AsyncDbSession
from langflow.services.database.models.flow import Flow
from langflow.services.deps import get_chat_service
@ -38,7 +38,7 @@ async def health():
# It's a reliable health check for a langflow instance
@health_check_router.get("/health_check")
async def health_check(
session: DbSession,
session: AsyncDbSession,
) -> HealthResponse:
response = HealthResponse()
# use a fixed valid UUId that UUID collision is very unlikely
@ -46,7 +46,7 @@ async def health_check(
try:
# Check database to query a bogus flow
stmt = select(Flow).where(Flow.id == uuid.uuid4())
session.exec(stmt).first()
(await session.exec(stmt)).first()
response.db = "ok"
except Exception: # noqa: BLE001
logger.exception("Error checking database")

View file

@ -20,7 +20,7 @@ from fastapi import (
from loguru import logger
from sqlmodel import select
from langflow.api.utils import CurrentActiveUser, DbSession, parse_value
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, parse_value
from langflow.api.v1.schemas import (
ConfigResponse,
CustomComponentRequest,
@ -379,7 +379,7 @@ async def webhook_run_flow(
)
async def experimental_run_flow(
*,
session: DbSession,
session: AsyncDbSession,
flow_id: UUID,
inputs: list[InputValueRequest] | None = None,
outputs: list[str] | None = None,
@ -454,9 +454,8 @@ async def experimental_run_flow(
try:
# Get the flow that matches the flow_id and belongs to the user
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
flow = session.exec(
select(Flow).where(Flow.id == flow_id_str).where(Flow.user_id == api_key_user.id)
).first()
stmt = select(Flow).where(Flow.id == flow_id_str).where(Flow.user_id == api_key_user.id)
flow = (await session.exec(stmt)).first()
except sa.exc.StatementError as exc:
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):

View file

@ -7,7 +7,7 @@ from sqlalchemy.exc import IntegrityError
from sqlmodel import select
from sqlmodel.sql.expression import SelectOfScalar
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
from langflow.api.v1.schemas import UsersResponse
from langflow.services.auth.utils import (
get_current_active_superuser,
@ -127,7 +127,7 @@ async def reset_password(
async def delete_user(
user_id: UUID,
current_user: Annotated[User, Depends(get_current_active_superuser)],
session: DbSession,
session: AsyncDbSession,
) -> dict:
"""Delete a user from the database."""
if current_user.id == user_id:
@ -135,11 +135,12 @@ async def delete_user(
if not current_user.is_superuser:
raise HTTPException(status_code=403, detail="Permission denied")
user_db = session.exec(select(User).where(User.id == user_id)).first()
stmt = select(User).where(User.id == user_id)
user_db = (await session.exec(stmt)).first()
if not user_db:
raise HTTPException(status_code=404, detail="User not found")
session.delete(user_db)
session.commit()
await session.delete(user_db)
await session.commit()
return {"detail": "User deleted"}

View file

@ -11,7 +11,7 @@ from langflow.graph.graph.base import Graph
from langflow.graph.utils import log_vertex_build
from langflow.graph.vertex.base import Vertex
from langflow.services.database.models.flow.model import Flow
from langflow.services.deps import get_session
from langflow.services.deps import get_async_session
def set_socketio_server(socketio_server) -> None:
@ -23,8 +23,9 @@ def set_socketio_server(socketio_server) -> None:
async def get_vertices(sio, sid, flow_id, chat_service) -> None:
try:
session = next(get_session())
flow: Flow = session.exec(select(Flow).where(Flow.id == flow_id)).first()
session = await anext(get_async_session())
stmt = select(Flow).where(Flow.id == flow_id)
flow: Flow = (await session.exec(stmt)).first()
if not flow or not flow.data:
await sio.emit("error", data="Invalid flow ID", to=sid)
return