ref: Use AsyncSession in some tests (#5151)

This commit is contained in:
Christophe Bornet 2024-12-08 02:09:43 +01:00 committed by GitHub
commit 624a2dde5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 53 additions and 47 deletions

View file

@ -74,7 +74,7 @@ def session_getter(db_service: DatabaseService):
@asynccontextmanager
async def async_session_getter(db_service: DatabaseService):
try:
session = AsyncSession(db_service.async_engine)
session = AsyncSession(db_service.async_engine, expire_on_commit=False)
yield session
except Exception:
logger.exception("Session rollback because of exception")

View file

@ -26,7 +26,7 @@ from langflow.services.database.models.folder.model import Folder
from langflow.services.database.models.transactions.model import TransactionTable
from langflow.services.database.models.user.model import User, UserCreate, UserRead
from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id
from langflow.services.database.utils import session_getter
from langflow.services.database.utils import async_session_getter
from langflow.services.deps import get_db_service
from loguru import logger
from sqlalchemy.ext.asyncio import create_async_engine
@ -157,7 +157,7 @@ async def async_session():
engine = create_async_engine("sqlite+aiosqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async with AsyncSession(engine) as session:
async with AsyncSession(engine, expire_on_commit=False) as session:
yield session
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
@ -469,7 +469,7 @@ async def logged_in_headers_super_user(client, active_super_user):
@pytest.fixture
def flow(
async def flow(
client, # noqa: ARG001
json_flow: str,
active_user,
@ -480,14 +480,14 @@ def flow(
flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id)
flow = Flow.model_validate(flow_data)
with session_getter(get_db_service()) as session:
async with async_session_getter(get_db_service()) as session:
session.add(flow)
session.commit()
session.refresh(flow)
await session.commit()
await session.refresh(flow)
yield flow
# Clean up
session.delete(flow)
session.commit()
await session.delete(flow)
await session.commit()
@pytest.fixture
@ -582,7 +582,7 @@ async def flow_component(client: AsyncClient, logged_in_headers):
@pytest.fixture
def created_api_key(active_user):
async def created_api_key(active_user):
hashed = get_password_hash("random_key")
api_key = ApiKey(
name="test_api_key",
@ -591,17 +591,18 @@ def created_api_key(active_user):
hashed_api_key=hashed,
)
db_manager = get_db_service()
with session_getter(db_manager) as session:
if existing_api_key := session.exec(select(ApiKey).where(ApiKey.api_key == api_key.api_key)).first():
async with async_session_getter(db_manager) as session:
stmt = select(ApiKey).where(ApiKey.api_key == api_key.api_key)
if existing_api_key := (await session.exec(stmt)).first():
yield existing_api_key
return
session.add(api_key)
session.commit()
session.refresh(api_key)
await session.commit()
await session.refresh(api_key)
yield api_key
# Clean up
session.delete(api_key)
session.commit()
await session.delete(api_key)
await session.commit()
@pytest.fixture(name="simple_api_test")
@ -618,14 +619,15 @@ async def get_simple_api_test(client, logged_in_headers, json_simple_api_test):
@pytest.fixture(name="starter_project")
def get_starter_project(active_user):
async def get_starter_project(active_user):
# once the client is created, we can get the starter project
with session_getter(get_db_service()) as session:
flow = session.exec(
async with async_session_getter(get_db_service()) as session:
stmt = (
select(Flow)
.where(Flow.folder.has(Folder.name == STARTER_FOLDER_NAME))
.where(Flow.name == "Basic Prompting (Hello, World)")
).first()
)
flow = (await session.exec(stmt)).first()
if not flow:
msg = "No starter project found"
raise ValueError(msg)
@ -640,10 +642,10 @@ def get_starter_project(active_user):
)
new_flow = Flow.model_validate(new_flow_create, from_attributes=True)
session.add(new_flow)
session.commit()
session.refresh(new_flow)
await session.commit()
await session.refresh(new_flow)
new_flow_dict = new_flow.model_dump()
yield new_flow_dict
# Clean up
session.delete(new_flow)
session.commit()
await session.delete(new_flow)
await session.commit()

View file

@ -24,7 +24,7 @@ async def session():
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async with AsyncSession(engine) as session:
async with AsyncSession(engine, expire_on_commit=False) as session:
yield session

View file

@ -12,7 +12,7 @@ from langflow.initial_setup.setup import load_starter_projects
from langflow.services.database.models.base import orjson_dumps
from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
from langflow.services.database.models.folder.model import FolderCreate
from langflow.services.database.utils import session_getter
from langflow.services.database.utils import async_session_getter
from langflow.services.deps import get_db_service
@ -530,14 +530,14 @@ async def test_download_file(
]
)
db_manager = get_db_service()
with session_getter(db_manager) as _session:
async with async_session_getter(db_manager) as _session:
saved_flows = []
for flow in flow_list.flows:
flow.user_id = active_user.id
db_flow = Flow.model_validate(flow, from_attributes=True)
_session.add(db_flow)
saved_flows.append(db_flow)
_session.commit()
await _session.commit()
# Make request to endpoint inside the session context
flow_ids = [str(db_flow.id) for db_flow in saved_flows] # Convert UUIDs to strings
flow_ids_json = json.dumps(flow_ids)

View file

@ -12,7 +12,8 @@ from langflow.initial_setup.setup import (
)
from langflow.interface.types import aget_all_types_dict
from langflow.services.database.models.folder.model import Folder
from langflow.services.deps import session_scope
from langflow.services.deps import async_session_scope
from sqlalchemy.orm import selectinload
from sqlmodel import select
@ -52,12 +53,13 @@ def test_get_project_data():
@pytest.mark.usefixtures("client")
async def test_create_or_update_starter_projects():
with session_scope() as session:
async with async_session_scope() as session:
# Get the number of projects returned by load_starter_projects
num_projects = len(await asyncio.to_thread(load_starter_projects))
# Get the number of projects in the database
folder = session.exec(select(Folder).where(Folder.name == STARTER_FOLDER_NAME)).first()
stmt = select(Folder).options(selectinload(Folder.flows)).where(Folder.name == STARTER_FOLDER_NAME)
folder = (await session.exec(stmt)).first()
assert folder is not None
num_db_projects = len(folder.flows)

View file

@ -1,7 +1,7 @@
import pytest
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.user import User
from langflow.services.deps import session_scope
from langflow.services.deps import async_session_scope
from sqlalchemy.exc import IntegrityError
@ -18,9 +18,9 @@ def test_user():
async def test_login_successful(client, test_user):
# Adding the test user to the database
try:
with session_scope() as session:
async with async_session_scope() as session:
session.add(test_user)
session.commit()
await session.commit()
except IntegrityError:
pass

View file

@ -5,7 +5,7 @@ from httpx import AsyncClient
from langflow.services.auth.utils import create_super_user, get_password_hash
from langflow.services.database.models.user import UserUpdate
from langflow.services.database.models.user.model import User
from langflow.services.database.utils import async_session_getter, session_getter
from langflow.services.database.utils import async_session_getter
from langflow.services.deps import get_db_service, get_settings_service
from sqlmodel import select
@ -41,8 +41,8 @@ async def super_user_headers(
@pytest.fixture
def deactivated_user(client): # noqa: ARG001
with session_getter(get_db_service()) as session:
async def deactivated_user(client): # noqa: ARG001
async with async_session_getter(get_db_service()) as session:
user = User(
username="deactivateduser",
password=get_password_hash("testpassword"),
@ -51,8 +51,8 @@ def deactivated_user(client): # noqa: ARG001
last_login_at=datetime.now(tz=timezone.utc),
)
session.add(user)
session.commit()
session.refresh(user)
await session.commit()
await session.refresh(user)
return user
@ -61,15 +61,16 @@ async def test_user_waiting_for_approval(client):
password = "testpassword" # noqa: S105
# Debug: Check if the user already exists
with session_getter(get_db_service()) as session:
existing_user = session.exec(select(User).where(User.username == username)).first()
async with async_session_getter(get_db_service()) as session:
stmt = select(User).where(User.username == username)
existing_user = (await session.exec(stmt)).first()
if existing_user:
pytest.fail(
f"User {username} already exists before the test. Database URL: {get_db_service().database_url}"
)
# Create a user that is not active and has never logged in
with session_getter(get_db_service()) as session:
async with async_session_getter(get_db_service()) as session:
user = User(
username=username,
password=get_password_hash(password),
@ -77,7 +78,7 @@ async def test_user_waiting_for_approval(client):
last_login_at=None,
)
session.add(user)
session.commit()
await session.commit()
login_data = {"username": "waitingforapproval", "password": "testpassword"}
response = await client.post("api/v1/login", data=login_data)
@ -85,8 +86,9 @@ async def test_user_waiting_for_approval(client):
assert response.json()["detail"] == "Waiting for approval"
# Debug: Check if the user still exists after the test
with session_getter(get_db_service()) as session:
existing_user = session.exec(select(User).where(User.username == username)).first()
async with async_session_getter(get_db_service()) as session:
stmt = select(User).where(User.username == username)
existing_user = (await session.exec(stmt)).first()
if existing_user:
pass
else:
@ -138,7 +140,7 @@ async def test_data_consistency_after_delete(client: AsyncClient, test_user, sup
@pytest.mark.api_key_required
async def test_inactive_user(client: AsyncClient):
# Create a user that is not active and has a last_login_at value
with session_getter(get_db_service()) as session:
async with async_session_getter(get_db_service()) as session:
user = User(
username="inactiveuser",
password=get_password_hash("testpassword"),
@ -146,7 +148,7 @@ async def test_inactive_user(client: AsyncClient):
last_login_at=datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
)
session.add(user)
session.commit()
await session.commit()
login_data = {"username": "inactiveuser", "password": "testpassword"}
response = await client.post("api/v1/login", data=login_data)