fix: Use AsyncSession in some API endpoints (#4650)
* Use AsyncSession in some API endpoints * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
44a3e7643a
commit
ba9dea5547
4 changed files with 17 additions and 16 deletions
|
|
@ -5,7 +5,7 @@ from loguru import logger
|
|||
from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.api.utils import DbSession
|
||||
from langflow.api.utils import AsyncDbSession
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.deps import get_chat_service
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ async def health():
|
|||
# It's a reliable health check for a langflow instance
|
||||
@health_check_router.get("/health_check")
|
||||
async def health_check(
|
||||
session: DbSession,
|
||||
session: AsyncDbSession,
|
||||
) -> HealthResponse:
|
||||
response = HealthResponse()
|
||||
# use a fixed valid UUId that UUID collision is very unlikely
|
||||
|
|
@ -46,7 +46,7 @@ async def health_check(
|
|||
try:
|
||||
# Check database to query a bogus flow
|
||||
stmt = select(Flow).where(Flow.id == uuid.uuid4())
|
||||
session.exec(stmt).first()
|
||||
(await session.exec(stmt)).first()
|
||||
response.db = "ok"
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error checking database")
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from fastapi import (
|
|||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.api.utils import CurrentActiveUser, DbSession, parse_value
|
||||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, parse_value
|
||||
from langflow.api.v1.schemas import (
|
||||
ConfigResponse,
|
||||
CustomComponentRequest,
|
||||
|
|
@ -379,7 +379,7 @@ async def webhook_run_flow(
|
|||
)
|
||||
async def experimental_run_flow(
|
||||
*,
|
||||
session: DbSession,
|
||||
session: AsyncDbSession,
|
||||
flow_id: UUID,
|
||||
inputs: list[InputValueRequest] | None = None,
|
||||
outputs: list[str] | None = None,
|
||||
|
|
@ -454,9 +454,8 @@ async def experimental_run_flow(
|
|||
try:
|
||||
# Get the flow that matches the flow_id and belongs to the user
|
||||
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
|
||||
flow = session.exec(
|
||||
select(Flow).where(Flow.id == flow_id_str).where(Flow.user_id == api_key_user.id)
|
||||
).first()
|
||||
stmt = select(Flow).where(Flow.id == flow_id_str).where(Flow.user_id == api_key_user.id)
|
||||
flow = (await session.exec(stmt)).first()
|
||||
except sa.exc.StatementError as exc:
|
||||
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
|
||||
if "badly formed hexadecimal UUID string" in str(exc):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlmodel import select
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession
|
||||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
|
||||
from langflow.api.v1.schemas import UsersResponse
|
||||
from langflow.services.auth.utils import (
|
||||
get_current_active_superuser,
|
||||
|
|
@ -127,7 +127,7 @@ async def reset_password(
|
|||
async def delete_user(
|
||||
user_id: UUID,
|
||||
current_user: Annotated[User, Depends(get_current_active_superuser)],
|
||||
session: DbSession,
|
||||
session: AsyncDbSession,
|
||||
) -> dict:
|
||||
"""Delete a user from the database."""
|
||||
if current_user.id == user_id:
|
||||
|
|
@ -135,11 +135,12 @@ async def delete_user(
|
|||
if not current_user.is_superuser:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
|
||||
user_db = session.exec(select(User).where(User.id == user_id)).first()
|
||||
stmt = select(User).where(User.id == user_id)
|
||||
user_db = (await session.exec(stmt)).first()
|
||||
if not user_db:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
session.delete(user_db)
|
||||
session.commit()
|
||||
await session.delete(user_db)
|
||||
await session.commit()
|
||||
|
||||
return {"detail": "User deleted"}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from langflow.graph.graph.base import Graph
|
|||
from langflow.graph.utils import log_vertex_build
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.services.database.models.flow.model import Flow
|
||||
from langflow.services.deps import get_session
|
||||
from langflow.services.deps import get_async_session
|
||||
|
||||
|
||||
def set_socketio_server(socketio_server) -> None:
|
||||
|
|
@ -23,8 +23,9 @@ def set_socketio_server(socketio_server) -> None:
|
|||
|
||||
async def get_vertices(sio, sid, flow_id, chat_service) -> None:
|
||||
try:
|
||||
session = next(get_session())
|
||||
flow: Flow = session.exec(select(Flow).where(Flow.id == flow_id)).first()
|
||||
session = await anext(get_async_session())
|
||||
stmt = select(Flow).where(Flow.id == flow_id)
|
||||
flow: Flow = (await session.exec(stmt)).first()
|
||||
if not flow or not flow.data:
|
||||
await sio.emit("error", data="Invalid flow ID", to=sid)
|
||||
return
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue