fix: Use AsyncSession in delete_vertex_builds (#4653)

Use AsyncSession in delete_vertex_builds
This commit is contained in:
Christophe Bornet 2024-11-17 12:51:24 +01:00 committed by GitHub
commit a7aa3ab03f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 39 additions and 34 deletions

View file

@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import delete
from sqlmodel import col, select
from langflow.api.utils import AsyncDbSession, DbSession
from langflow.api.utils import AsyncDbSession
from langflow.schema.message import MessageResponse
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.message.model import MessageRead, MessageTable, MessageUpdate
@ -30,9 +30,9 @@ async def get_vertex_builds(flow_id: Annotated[UUID, Query()], session: AsyncDbS
@router.delete("/builds", status_code=204)
def delete_vertex_builds(flow_id: Annotated[UUID, Query()], session: DbSession) -> None:
async def delete_vertex_builds(flow_id: Annotated[UUID, Query()], session: AsyncDbSession) -> None:
try:
delete_vertex_builds_by_flow_id(session, flow_id)
await delete_vertex_builds_by_flow_id(session, flow_id)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

View file

@ -32,6 +32,7 @@ def log_vertex_build(db: Session, vertex_build: VertexBuildBase) -> VertexBuildT
return table
def delete_vertex_builds_by_flow_id(db: Session, flow_id: UUID) -> None:
db.exec(delete(VertexBuildTable).where(VertexBuildTable.flow_id == flow_id))
db.commit()
async def delete_vertex_builds_by_flow_id(db: AsyncSession, flow_id: UUID) -> None:
stmt = delete(VertexBuildTable).where(VertexBuildTable.flow_id == flow_id)
await db.exec(stmt)
await db.commit()

View file

@ -28,7 +28,9 @@ from langflow.services.database.models.vertex_builds.crud import delete_vertex_b
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from loguru import logger
from sqlalchemy.orm import selectinload
from sqlmodel import Session, SQLModel, create_engine, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
@ -85,21 +87,21 @@ def get_text():
assert path.exists(), f"File {path} does not exist. Available files: {list(data_path.iterdir())}"
def delete_transactions_by_flow_id(db: Session, flow_id: UUID):
async def delete_transactions_by_flow_id(db: AsyncSession, flow_id: UUID):
stmt = select(TransactionTable).where(TransactionTable.flow_id == flow_id)
transactions = db.exec(stmt)
transactions = await db.exec(stmt)
for transaction in transactions:
db.delete(transaction)
db.commit()
await db.delete(transaction)
await db.commit()
def _delete_transactions_and_vertex_builds(session, user: User):
flow_ids = [flow.id for flow in user.flows]
async def _delete_transactions_and_vertex_builds(session, flows: list[Flow]):
flow_ids = [flow.id for flow in flows]
for flow_id in flow_ids:
if not flow_id:
continue
delete_vertex_builds_by_flow_id(session, flow_id)
delete_transactions_by_flow_id(session, flow_id)
await delete_vertex_builds_by_flow_id(session, flow_id)
await delete_transactions_by_flow_id(session, flow_id)
@pytest.fixture
@ -361,31 +363,32 @@ async def test_user(client):
@pytest.fixture
def active_user(client): # noqa: ARG001
async def active_user(client): # noqa: ARG001
db_manager = get_db_service()
with db_manager.with_session() as session:
async with db_manager.with_async_session() as session:
user = User(
username="activeuser",
password=get_password_hash("testpassword"),
is_active=True,
is_superuser=False,
)
if active_user := session.exec(select(User).where(User.username == user.username)).first():
stmt = select(User).where(User.username == user.username)
if active_user := (await session.exec(stmt)).first():
user = active_user
else:
session.add(user)
session.commit()
session.refresh(user)
await session.commit()
await session.refresh(user)
user = UserRead.model_validate(user, from_attributes=True)
yield user
# Clean up
# Now cleanup transactions, vertex_build
with db_manager.with_session() as session:
user = session.get(User, user.id)
_delete_transactions_and_vertex_builds(session, user)
session.delete(user)
async with db_manager.with_async_session() as session:
user = await session.get(User, user.id, options=[selectinload(User.flows)])
await _delete_transactions_and_vertex_builds(session, user.flows)
await session.delete(user)
session.commit()
await session.commit()
@pytest.fixture
@ -399,31 +402,32 @@ async def logged_in_headers(client, active_user):
@pytest.fixture
def active_super_user(client): # noqa: ARG001
async def active_super_user(client): # noqa: ARG001
db_manager = get_db_service()
with db_manager.with_session() as session:
async with db_manager.with_async_session() as session:
user = User(
username="activeuser",
password=get_password_hash("testpassword"),
is_active=True,
is_superuser=True,
)
if active_user := session.exec(select(User).where(User.username == user.username)).first():
stmt = select(User).where(User.username == user.username)
if active_user := (await session.exec(stmt)).first():
user = active_user
else:
session.add(user)
session.commit()
session.refresh(user)
await session.commit()
await session.refresh(user)
user = UserRead.model_validate(user, from_attributes=True)
yield user
# Clean up
# Now cleanup transactions, vertex_build
with db_manager.with_session() as session:
user = session.get(User, user.id)
_delete_transactions_and_vertex_builds(session, user)
session.delete(user)
async with db_manager.with_async_session() as session:
user = await session.get(User, user.id, options=[selectinload(User.flows)])
await _delete_transactions_and_vertex_builds(session, user.flows)
await session.delete(user)
session.commit()
await session.commit()
@pytest.fixture