From 624a2dde5d15f9ac021567ead77f2d602bec3048 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Sun, 8 Dec 2024 02:09:43 +0100 Subject: [PATCH] ref: Use AsyncSession in some tests (#5151) --- .../base/langflow/services/database/utils.py | 2 +- src/backend/tests/conftest.py | 48 ++++++++++--------- .../unit/services/variable/test_service.py | 2 +- src/backend/tests/unit/test_database.py | 6 +-- src/backend/tests/unit/test_initial_setup.py | 8 ++-- src/backend/tests/unit/test_login.py | 6 +-- src/backend/tests/unit/test_user.py | 28 ++++++----- 7 files changed, 53 insertions(+), 47 deletions(-) diff --git a/src/backend/base/langflow/services/database/utils.py b/src/backend/base/langflow/services/database/utils.py index 576ec2140..97c800f08 100644 --- a/src/backend/base/langflow/services/database/utils.py +++ b/src/backend/base/langflow/services/database/utils.py @@ -74,7 +74,7 @@ def session_getter(db_service: DatabaseService): @asynccontextmanager async def async_session_getter(db_service: DatabaseService): try: - session = AsyncSession(db_service.async_engine) + session = AsyncSession(db_service.async_engine, expire_on_commit=False) yield session except Exception: logger.exception("Session rollback because of exception") diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index fa149895a..60c6d06ac 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -26,7 +26,7 @@ from langflow.services.database.models.folder.model import Folder from langflow.services.database.models.transactions.model import TransactionTable from langflow.services.database.models.user.model import User, UserCreate, UserRead from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id -from langflow.services.database.utils import session_getter +from langflow.services.database.utils import async_session_getter from langflow.services.deps import get_db_service from loguru import logger from sqlalchemy.ext.asyncio import create_async_engine @@ -157,7 +157,7 @@ async def async_session(): engine = create_async_engine("sqlite+aiosqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) async with engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) - async with AsyncSession(engine) as session: + async with AsyncSession(engine, expire_on_commit=False) as session: yield session async with engine.begin() as conn: await conn.run_sync(SQLModel.metadata.drop_all) @@ -469,7 +469,7 @@ async def logged_in_headers_super_user(client, active_super_user): @pytest.fixture -def flow( +async def flow( client, # noqa: ARG001 json_flow: str, active_user, @@ -480,14 +480,14 @@ def flow( flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id) flow = Flow.model_validate(flow_data) - with session_getter(get_db_service()) as session: + async with async_session_getter(get_db_service()) as session: session.add(flow) - session.commit() - session.refresh(flow) + await session.commit() + await session.refresh(flow) yield flow # Clean up - session.delete(flow) - session.commit() + await session.delete(flow) + await session.commit() @pytest.fixture @@ -582,7 +582,7 @@ async def flow_component(client: AsyncClient, logged_in_headers): @pytest.fixture -def created_api_key(active_user): +async def created_api_key(active_user): hashed = get_password_hash("random_key") api_key = ApiKey( name="test_api_key", @@ -591,17 +591,18 @@ def created_api_key(active_user): hashed_api_key=hashed, ) db_manager = get_db_service() - with session_getter(db_manager) as session: - if existing_api_key := session.exec(select(ApiKey).where(ApiKey.api_key == api_key.api_key)).first(): + async with async_session_getter(db_manager) as session: + stmt = select(ApiKey).where(ApiKey.api_key == api_key.api_key) + if existing_api_key := (await session.exec(stmt)).first(): yield existing_api_key return session.add(api_key) - session.commit() - session.refresh(api_key) + await session.commit() + await session.refresh(api_key) yield api_key # Clean up - session.delete(api_key) - session.commit() + await session.delete(api_key) + await session.commit() @pytest.fixture(name="simple_api_test") @@ -618,14 +619,15 @@ async def get_simple_api_test(client, logged_in_headers, json_simple_api_test): @pytest.fixture(name="starter_project") -def get_starter_project(active_user): +async def get_starter_project(active_user): # once the client is created, we can get the starter project - with session_getter(get_db_service()) as session: - flow = session.exec( + async with async_session_getter(get_db_service()) as session: + stmt = ( select(Flow) .where(Flow.folder.has(Folder.name == STARTER_FOLDER_NAME)) .where(Flow.name == "Basic Prompting (Hello, World)") - ).first() + ) + flow = (await session.exec(stmt)).first() if not flow: msg = "No starter project found" raise ValueError(msg) @@ -640,10 +642,10 @@ def get_starter_project(active_user): ) new_flow = Flow.model_validate(new_flow_create, from_attributes=True) session.add(new_flow) - session.commit() - session.refresh(new_flow) + await session.commit() + await session.refresh(new_flow) new_flow_dict = new_flow.model_dump() yield new_flow_dict # Clean up - session.delete(new_flow) - session.commit() + await session.delete(new_flow) + await session.commit() diff --git a/src/backend/tests/unit/services/variable/test_service.py b/src/backend/tests/unit/services/variable/test_service.py index f66da3a04..cb4e253db 100644 --- a/src/backend/tests/unit/services/variable/test_service.py +++ b/src/backend/tests/unit/services/variable/test_service.py @@ -24,7 +24,7 @@ async def session(): engine = create_async_engine("sqlite+aiosqlite:///:memory:") async with engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) - async with AsyncSession(engine) as session: + async with AsyncSession(engine, expire_on_commit=False) as session: yield session diff --git a/src/backend/tests/unit/test_database.py b/src/backend/tests/unit/test_database.py index 5d2e2bc61..f32cf2edb 100644 --- a/src/backend/tests/unit/test_database.py +++ b/src/backend/tests/unit/test_database.py @@ -12,7 +12,7 @@ from langflow.initial_setup.setup import load_starter_projects from langflow.services.database.models.base import orjson_dumps from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate from langflow.services.database.models.folder.model import FolderCreate -from langflow.services.database.utils import session_getter +from langflow.services.database.utils import async_session_getter from langflow.services.deps import get_db_service @@ -530,14 +530,14 @@ async def test_download_file( ] ) db_manager = get_db_service() - with session_getter(db_manager) as _session: + async with async_session_getter(db_manager) as _session: saved_flows = [] for flow in flow_list.flows: flow.user_id = active_user.id db_flow = Flow.model_validate(flow, from_attributes=True) _session.add(db_flow) saved_flows.append(db_flow) - _session.commit() + await _session.commit() # Make request to endpoint inside the session context flow_ids = [str(db_flow.id) for db_flow in saved_flows] # Convert UUIDs to strings flow_ids_json = json.dumps(flow_ids) diff --git a/src/backend/tests/unit/test_initial_setup.py b/src/backend/tests/unit/test_initial_setup.py index 3f2650fcf..ec1fb4a61 100644 --- a/src/backend/tests/unit/test_initial_setup.py +++ b/src/backend/tests/unit/test_initial_setup.py @@ -12,7 +12,8 @@ from langflow.initial_setup.setup import ( ) from langflow.interface.types import aget_all_types_dict from langflow.services.database.models.folder.model import Folder -from langflow.services.deps import session_scope +from langflow.services.deps import async_session_scope +from sqlalchemy.orm import selectinload from sqlmodel import select @@ -52,12 +53,13 @@ def test_get_project_data(): @pytest.mark.usefixtures("client") async def test_create_or_update_starter_projects(): - with session_scope() as session: + async with async_session_scope() as session: # Get the number of projects returned by load_starter_projects num_projects = len(await asyncio.to_thread(load_starter_projects)) # Get the number of projects in the database - folder = session.exec(select(Folder).where(Folder.name == STARTER_FOLDER_NAME)).first() + stmt = select(Folder).options(selectinload(Folder.flows)).where(Folder.name == STARTER_FOLDER_NAME) + folder = (await session.exec(stmt)).first() assert folder is not None num_db_projects = len(folder.flows) diff --git a/src/backend/tests/unit/test_login.py b/src/backend/tests/unit/test_login.py index 16864f3ca..7b01a5f0c 100644 --- a/src/backend/tests/unit/test_login.py +++ b/src/backend/tests/unit/test_login.py @@ -1,7 +1,7 @@ import pytest from langflow.services.auth.utils import get_password_hash from langflow.services.database.models.user import User -from langflow.services.deps import session_scope +from langflow.services.deps import async_session_scope from sqlalchemy.exc import IntegrityError @@ -18,9 +18,9 @@ def test_user(): async def test_login_successful(client, test_user): # Adding the test user to the database try: - with session_scope() as session: + async with async_session_scope() as session: session.add(test_user) - session.commit() + await session.commit() except IntegrityError: pass diff --git a/src/backend/tests/unit/test_user.py b/src/backend/tests/unit/test_user.py index 6caec3180..d2122109c 100644 --- a/src/backend/tests/unit/test_user.py +++ b/src/backend/tests/unit/test_user.py @@ -5,7 +5,7 @@ from httpx import AsyncClient from langflow.services.auth.utils import create_super_user, get_password_hash from langflow.services.database.models.user import UserUpdate from langflow.services.database.models.user.model import User -from langflow.services.database.utils import async_session_getter, session_getter +from langflow.services.database.utils import async_session_getter from langflow.services.deps import get_db_service, get_settings_service from sqlmodel import select @@ -41,8 +41,8 @@ async def super_user_headers( @pytest.fixture -def deactivated_user(client): # noqa: ARG001 - with session_getter(get_db_service()) as session: +async def deactivated_user(client): # noqa: ARG001 + async with async_session_getter(get_db_service()) as session: user = User( username="deactivateduser", password=get_password_hash("testpassword"), @@ -51,8 +51,8 @@ def deactivated_user(client): # noqa: ARG001 last_login_at=datetime.now(tz=timezone.utc), ) session.add(user) - session.commit() - session.refresh(user) + await session.commit() + await session.refresh(user) return user @@ -61,15 +61,16 @@ async def test_user_waiting_for_approval(client): password = "testpassword" # noqa: S105 # Debug: Check if the user already exists - with session_getter(get_db_service()) as session: - existing_user = session.exec(select(User).where(User.username == username)).first() + async with async_session_getter(get_db_service()) as session: + stmt = select(User).where(User.username == username) + existing_user = (await session.exec(stmt)).first() if existing_user: pytest.fail( f"User {username} already exists before the test. Database URL: {get_db_service().database_url}" ) # Create a user that is not active and has never logged in - with session_getter(get_db_service()) as session: + async with async_session_getter(get_db_service()) as session: user = User( username=username, password=get_password_hash(password), @@ -77,7 +78,7 @@ async def test_user_waiting_for_approval(client): last_login_at=None, ) session.add(user) - session.commit() + await session.commit() login_data = {"username": "waitingforapproval", "password": "testpassword"} response = await client.post("api/v1/login", data=login_data) @@ -85,8 +86,9 @@ async def test_user_waiting_for_approval(client): assert response.json()["detail"] == "Waiting for approval" # Debug: Check if the user still exists after the test - with session_getter(get_db_service()) as session: - existing_user = session.exec(select(User).where(User.username == username)).first() + async with async_session_getter(get_db_service()) as session: + stmt = select(User).where(User.username == username) + existing_user = (await session.exec(stmt)).first() if existing_user: pass else: @@ -138,7 +140,7 @@ async def test_data_consistency_after_delete(client: AsyncClient, test_user, sup @pytest.mark.api_key_required async def test_inactive_user(client: AsyncClient): # Create a user that is not active and has a last_login_at value - with session_getter(get_db_service()) as session: + async with async_session_getter(get_db_service()) as session: user = User( username="inactiveuser", password=get_password_hash("testpassword"), @@ -146,7 +148,7 @@ async def test_inactive_user(client: AsyncClient): last_login_at=datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ) session.add(user) - session.commit() + await session.commit() login_data = {"username": "inactiveuser", "password": "testpassword"} response = await client.post("api/v1/login", data=login_data)