From 2adda780c9c286cc4b22745012c5a1188f1dd481 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 10 Oct 2024 08:45:11 -0300 Subject: [PATCH] tests: update env.py and conftest so tests use unique databases (#3654) * Refactor Alembic environment script to simplify database connection logic and remove unused imports * Refactor test client fixture to use TemporaryDirectory and UUID for database path * Add `reload_engine` method to reinitialize the database engine * Add cleanup steps to test fixtures to ensure proper resource management - Added cleanup steps to various pytest fixtures to ensure temporary directories, database entries, and other resources are properly cleaned up after tests. - Modified fixtures to use `yield` for better resource management. - Ensured database connections are closed and tables are dropped after tests. - Improved temporary directory handling with context managers. * Add unit test to verify return type of list_flows method in custom component * Refactor tests in `test_custom_component.py` to remove unused imports and fixtures, and update existing fixtures for consistency. * Add debug checks for user existence in `test_user_waiting_for_approval` test * Fix import order and add flow_id validation in transaction handling - Corrected the import order in `utils.py`. - Added validation for `flow_id` in the transaction handling logic to ensure it is set correctly. * Add function to delete transactions by flow ID in CRUD module * Add cleanup for transactions and vertex builds in test teardown - Introduced `_delete_transactions_and_vertex_builds` function to remove transactions and vertex builds associated with a user. - Updated `user` fixture to call the new cleanup function before deleting the user. * Refactor flow_id assignment logic in `utils.py` to improve readability and correctness * [autofix.ci] apply automated fixes * Refactor test to use pytest's tmp_path fixture for temporary directory creation * Convert `test_user_waiting_for_approval` to an async test function --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- src/backend/base/langflow/alembic/env.py | 39 ++---- src/backend/base/langflow/graph/utils.py | 5 + .../langflow/services/database/service.py | 3 + .../services/telemetry/opentelemetry.py | 38 +++--- src/backend/tests/conftest.py | 127 ++++++++++++------ src/backend/tests/unit/test_cli.py | 11 +- .../tests/unit/test_custom_component.py | 55 +------- .../unit/test_custom_component_with_client.py | 23 ++++ src/backend/tests/unit/test_user.py | 27 +++- 9 files changed, 179 insertions(+), 149 deletions(-) diff --git a/src/backend/base/langflow/alembic/env.py b/src/backend/base/langflow/alembic/env.py index 6e4a87561..d4e130649 100644 --- a/src/backend/base/langflow/alembic/env.py +++ b/src/backend/base/langflow/alembic/env.py @@ -1,9 +1,6 @@ -import os -import warnings from logging.config import fileConfig from alembic import context -from loguru import logger from sqlalchemy import engine_from_config, pool from langflow.services.database.models import * @@ -42,8 +39,7 @@ def run_migrations_offline() -> None: script output. """ - url = os.getenv("LANGFLOW_DATABASE_URL") - url = url or config.get_main_option("sqlalchemy.url") + url = config.get_main_option("sqlalchemy.url") context.configure( url=url, target_metadata=target_metadata, @@ -63,32 +59,17 @@ def run_migrations_online() -> None: and associate a connection with the context. """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) - try: - from langflow.services.database.factory import DatabaseServiceFactory - from langflow.services.deps import get_db_service - from langflow.services.manager import initialize_settings_service, service_manager + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata, render_as_batch=True) - initialize_settings_service() - service_manager.register_factory(DatabaseServiceFactory()) - connectable = get_db_service().engine - except Exception: - logger.exception("Error getting database engine") - url = os.getenv("LANGFLOW_DATABASE_URL") - url = url or config.get_main_option("sqlalchemy.url") - if url: - config.set_main_option("sqlalchemy.url", url) - connectable = engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata, render_as_batch=True) - with context.begin_transaction(): - context.run_migrations() + with context.begin_transaction(): + context.run_migrations() if context.is_offline_mode(): diff --git a/src/backend/base/langflow/graph/utils.py b/src/backend/base/langflow/graph/utils.py index 1eb3c8876..a994e2f1d 100644 --- a/src/backend/base/langflow/graph/utils.py +++ b/src/backend/base/langflow/graph/utils.py @@ -141,6 +141,11 @@ async def log_transaction( try: if not get_settings_service().settings.transactions_storage_enabled: return + if not flow_id: + if source.graph.flow_id: + flow_id = source.graph.flow_id + else: + return inputs = _vertex_to_primitive_dict(source) transaction = TransactionBase( vertex_id=source.id, diff --git a/src/backend/base/langflow/services/database/service.py b/src/backend/base/langflow/services/database/service.py index 16ba8e147..d48a2f2dd 100644 --- a/src/backend/base/langflow/services/database/service.py +++ b/src/backend/base/langflow/services/database/service.py @@ -45,6 +45,9 @@ class DatabaseService(Service): self.alembic_cfg_path = langflow_dir / "alembic.ini" self.engine = self._create_engine() + def reload_engine(self): + self.engine = self._create_engine() + def _create_engine(self) -> Engine: """Create the engine for the database.""" if self.settings_service.settings.database_url and self.settings_service.settings.database_url.startswith( diff --git a/src/backend/base/langflow/services/telemetry/opentelemetry.py b/src/backend/base/langflow/services/telemetry/opentelemetry.py index 232561c96..886e54c51 100644 --- a/src/backend/base/langflow/services/telemetry/opentelemetry.py +++ b/src/backend/base/langflow/services/telemetry/opentelemetry.py @@ -112,6 +112,8 @@ class ThreadSafeSingletonMetaUsingWeakref(type): class OpenTelemetry(metaclass=ThreadSafeSingletonMetaUsingWeakref): _metrics_registry: dict[str, Metric] = {} + _metrics: dict[str, Counter | ObservableGaugeWrapper | Histogram | UpDownCounter] = {} + _meter_provider: MeterProvider | None = None def _add_metric(self, name: str, description: str, unit: str, metric_type: MetricType, labels: dict[str, bool]): metric = Metric(name=name, description=description, metric_type=metric_type, unit=unit, labels=labels) @@ -140,33 +142,37 @@ class OpenTelemetry(metaclass=ThreadSafeSingletonMetaUsingWeakref): labels={"flow_id": mandatory_label}, ) - _metrics: dict[str, Counter | ObservableGaugeWrapper | Histogram | UpDownCounter] = {} - def __init__(self, prometheus_enabled: bool = True): - self._register_metric() + if not self._metrics_registry: + self._register_metric() - resource = Resource.create({"service.name": "langflow"}) - metric_readers = [] + if self._meter_provider is None: + resource = Resource.create({"service.name": "langflow"}) + metric_readers = [] - # configure prometheus exporter - self.prometheus_enabled = prometheus_enabled - if prometheus_enabled: - metric_readers.append(PrometheusMetricReader()) + # configure prometheus exporter + self.prometheus_enabled = prometheus_enabled + if prometheus_enabled: + metric_readers.append(PrometheusMetricReader()) - meter_provider = MeterProvider(resource=resource, metric_readers=metric_readers) - metrics.set_meter_provider(meter_provider) - self.meter = meter_provider.get_meter(langflow_meter_name) + self._meter_provider = MeterProvider(resource=resource, metric_readers=metric_readers) + metrics.set_meter_provider(self._meter_provider) + + self.meter = self._meter_provider.get_meter(langflow_meter_name) for name, metric in self._metrics_registry.items(): if name != metric.name: msg = f"Key '{name}' does not match metric name '{metric.name}'" raise ValueError(msg) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - self._metrics[metric.name] = self._create_metric(metric) + if name not in self._metrics: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self._metrics[metric.name] = self._create_metric(metric) def _create_metric(self, metric): + if metric.name in self._metrics: + return self._metrics[metric.name] + if metric.type == MetricType.COUNTER: return self.meter.create_counter( name=metric.name, diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index f6a159e3f..c776064e5 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator from contextlib import contextmanager, suppress from pathlib import Path from typing import TYPE_CHECKING +from uuid import UUID import orjson import pytest @@ -29,7 +30,9 @@ from langflow.services.auth.utils import get_password_hash from langflow.services.database.models.api_key.model import ApiKey from langflow.services.database.models.flow.model import Flow, FlowCreate from langflow.services.database.models.folder.model import Folder -from langflow.services.database.models.user.model import User, UserCreate +from langflow.services.database.models.transactions.model import TransactionTable +from langflow.services.database.models.user.model import User, UserCreate, UserRead +from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service @@ -82,6 +85,23 @@ def get_text(): assert path.exists(), f"File {path} does not exist. Available files: {list(data_path.iterdir())}" +def delete_transactions_by_flow_id(db: Session, flow_id: UUID): + stmt = select(TransactionTable).where(TransactionTable.flow_id == flow_id) + transactions = db.exec(stmt) + for transaction in transactions: + db.delete(transaction) + db.commit() + + +def _delete_transactions_and_vertex_builds(session, user: User): + flow_ids = [flow.id for flow in user.flows] + for flow_id in flow_ids: + if not flow_id: + continue + delete_vertex_builds_by_flow_id(session, flow_id) + delete_transactions_by_flow_id(session, flow_id) + + @pytest.fixture def caplog(caplog: LogCaptureFixture): handler_id = logger.add( @@ -110,6 +130,7 @@ def session_fixture(): SQLModel.metadata.create_all(engine) with Session(engine) as session: yield session + SQLModel.metadata.drop_all(engine) # Add this line to clean up tables class Config: @@ -119,8 +140,8 @@ class Config: @pytest.fixture(name="load_flows_dir") def load_flows_dir(): - tempdir = tempfile.TemporaryDirectory() - yield tempdir.name + with tempfile.TemporaryDirectory() as tempdir: + yield tempdir @pytest.fixture(name="distributed_env") @@ -143,23 +164,26 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env): from langflow.core import celery_app db_dir = tempfile.mkdtemp() - db_path = Path(db_dir) / "test.db" - monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}") - monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false") - # monkeypatch langflow.services.task.manager.USE_CELERY to True - # monkeypatch.setattr(manager, "USE_CELERY", True) - monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config)) + try: + db_path = Path(db_dir) / "test.db" + monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}") + monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false") + # monkeypatch langflow.services.task.manager.USE_CELERY to True + # monkeypatch.setattr(manager, "USE_CELERY", True) + monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config)) - # def get_session_override(): - # return session + # def get_session_override(): + # return session - from langflow.main import create_app + from langflow.main import create_app - app = create_app() + app = create_app() - # app.dependency_overrides[get_session] = get_session_override - with TestClient(app) as client: - yield client + # app.dependency_overrides[get_session] = get_session_override + with TestClient(app) as client: + yield client + finally: + shutil.rmtree(db_dir) # Clean up the temporary directory app.dependency_overrides.clear() monkeypatch.undo() @@ -279,7 +303,9 @@ async def client_fixture(session: Session, monkeypatch, request, load_flows_dir) from langflow.main import create_app app = create_app() - + db_service = get_db_service() + db_service.database_url = f"sqlite:///{db_path}" + db_service.reload_engine() # app.dependency_overrides[get_session] = get_session_override async with LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager: async with AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client: @@ -304,7 +330,7 @@ def session_getter_fixture(client): @pytest.fixture def runner(): - return CliRunner() + yield CliRunner() @pytest.fixture @@ -315,26 +341,38 @@ async def test_user(client): ) response = await client.post("api/v1/users/", json=user_data.model_dump()) assert response.status_code == 201 - return response.json() + user = response.json() + yield user + # Clean up + await client.delete(f"/api/v1/users/{user['id']}") @pytest.fixture(scope="function") def active_user(client): db_manager = get_db_service() - with session_getter(db_manager) as session: + with db_manager.with_session() as session: user = User( username="activeuser", password=get_password_hash("testpassword"), is_active=True, is_superuser=False, ) - # check if user exists if active_user := session.exec(select(User).where(User.username == user.username)).first(): - return active_user - session.add(user) + user = active_user + else: + session.add(user) + session.commit() + session.refresh(user) + user = UserRead.model_validate(user, from_attributes=True) + yield user + # Clean up + # Now cleanup transactions, vertex_build + with db_manager.with_session() as session: + user = session.get(User, user.id) + _delete_transactions_and_vertex_builds(session, user) + session.delete(user) + session.commit() - session.refresh(user) - return user @pytest.fixture @@ -344,7 +382,7 @@ async def logged_in_headers(client, active_user): assert response.status_code == 200 tokens = response.json() a_token = tokens["access_token"] - return {"Authorization": f"Bearer {a_token}"} + yield {"Authorization": f"Bearer {a_token}"} @pytest.fixture @@ -359,20 +397,22 @@ def flow(client, json_flow: str, active_user): session.add(flow) session.commit() session.refresh(flow) - - return flow + yield flow + # Clean up + session.delete(flow) + session.commit() @pytest.fixture def json_chat_input(): with open(pytest.CHAT_INPUT) as f: - return f.read() + yield f.read() @pytest.fixture def json_two_outputs(): with open(pytest.TWO_OUTPUTS) as f: - return f.read() + yield f.read() @pytest.fixture @@ -385,7 +425,7 @@ async def added_flow_webhook_test(client, json_webhook_test, logged_in_headers): assert response.json()["name"] == flow.name assert response.json()["data"] == flow.data yield response.json() - client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) + await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) @pytest.fixture @@ -398,7 +438,7 @@ async def added_flow_chat_input(client, json_chat_input, logged_in_headers): assert response.json()["name"] == flow.name assert response.json()["data"] == flow.data yield response.json() - client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) + await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) @pytest.fixture @@ -411,7 +451,7 @@ async def added_flow_two_outputs(client, json_two_outputs, logged_in_headers): assert response.json()["name"] == flow.name assert response.json()["data"] == flow.data yield response.json() - client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) + await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) @pytest.fixture @@ -424,7 +464,7 @@ async def added_vector_store(client, json_vector_store, logged_in_headers): assert response.json()["name"] == vector_store.name assert response.json()["data"] == vector_store.data yield response.json() - client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) + await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) @pytest.fixture @@ -439,7 +479,7 @@ async def added_webhook_test(client, json_webhook_test, logged_in_headers): assert response.json()["name"] == webhook_test.name assert response.json()["data"] == webhook_test.data yield response.json() - client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) + await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) @pytest.fixture @@ -451,7 +491,7 @@ async def flow_component(client: TestClient, logged_in_headers): response = await client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers) assert response.status_code == 201 yield response.json() - client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) + await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) @pytest.fixture @@ -466,11 +506,15 @@ def created_api_key(active_user): db_manager = get_db_service() with session_getter(db_manager) as session: if existing_api_key := session.exec(select(ApiKey).where(ApiKey.api_key == api_key.api_key)).first(): - return existing_api_key + yield existing_api_key + return session.add(api_key) session.commit() session.refresh(api_key) - return api_key + yield api_key + # Clean up + session.delete(api_key) + session.commit() @pytest.fixture(name="simple_api_test") @@ -483,7 +527,7 @@ async def get_simple_api_test(client, logged_in_headers, json_simple_api_test): response = await client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers) assert response.status_code == 201 yield response.json() - client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) + await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) @pytest.fixture(name="starter_project") @@ -511,4 +555,7 @@ def get_starter_project(active_user): session.commit() session.refresh(new_flow) new_flow_dict = new_flow.model_dump() - return new_flow_dict + yield new_flow_dict + # Clean up + session.delete(new_flow) + session.commit() diff --git a/src/backend/tests/unit/test_cli.py b/src/backend/tests/unit/test_cli.py index 275fee279..a59b16905 100644 --- a/src/backend/tests/unit/test_cli.py +++ b/src/backend/tests/unit/test_cli.py @@ -1,6 +1,3 @@ -from pathlib import Path -from tempfile import tempdir - import pytest from langflow.__main__ import app @@ -15,13 +12,9 @@ def default_settings(): ] -def test_components_path(runner, client, default_settings): - # Create a foldr in the tmp directory - - temp_dir = Path(tempdir) +def test_components_path(runner, client, default_settings, tmp_path): # create a "components" folder - temp_dir = temp_dir / "components" - temp_dir.mkdir(exist_ok=True) + temp_dir = tmp_path / "components" result = runner.invoke( app, diff --git a/src/backend/tests/unit/test_custom_component.py b/src/backend/tests/unit/test_custom_component.py index d96f2acee..63103a379 100644 --- a/src/backend/tests/unit/test_custom_component.py +++ b/src/backend/tests/unit/test_custom_component.py @@ -1,7 +1,6 @@ import ast import types from textwrap import dedent -from uuid import uuid4 import pytest from langchain_core.documents import Document @@ -10,8 +9,6 @@ from langflow.custom import Component, CustomComponent from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError from langflow.custom.custom_component.base_component import BaseComponent, ComponentCodeNullError from langflow.custom.utils import build_custom_component_template -from langflow.services.database.models.flow import FlowCreate -from langflow.services.settings.feature_flags import FEATURE_FLAGS @pytest.fixture @@ -460,9 +457,8 @@ def test_build_config_no_code(): @pytest.fixture -def component(client, active_user): - return CustomComponent( - user_id=active_user.id, +def component(): + yield CustomComponent( field_config={ "fields": { "llm": {"type": "str"}, @@ -473,41 +469,6 @@ def component(client, active_user): ) -@pytest.fixture(scope="session") -def test_flow(db): - flow_data = { - "nodes": [{"id": "1"}, {"id": "2"}], - "edges": [{"source": "1", "target": "2"}], - } - - # Create flow - flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data) - - # Add to database - db.add(flow) - db.commit() - - yield flow - - # Clean up - db.delete(flow) - db.commit() - - -@pytest.fixture(scope="session") -def db(app): - # Setup database for tests - yield app.db - - # Teardown - app.db.drop_all() - - -def test_list_flows_return_type(component): - flows = component.list_flows() - assert isinstance(flows, list) - - def test_build_config_return_type(component): config = component.build_config() assert isinstance(config, dict) @@ -539,19 +500,11 @@ def test_build_config_field_value_keys(component): assert all("type" in value for value in field_values) -def test_custom_component_multiple_outputs(code_component_with_multiple_outputs, active_user): - frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id) +def test_custom_component_multiple_outputs(code_component_with_multiple_outputs): + frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs) assert frontnd_node_dict["outputs"][0]["types"] == ["Text"] -def test_feature_flags_add_toolkit_output(active_user, code_component_with_multiple_outputs): - frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id) - len_outputs = len(frontnd_node_dict["outputs"]) - FEATURE_FLAGS.add_toolkit_output = True - frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id) - assert len(frontnd_node_dict["outputs"]) == len_outputs + 1 - - def test_custom_component_subclass_from_lctoolcomponent(): # Import LCToolComponent and create a subclass code = dedent(""" diff --git a/src/backend/tests/unit/test_custom_component_with_client.py b/src/backend/tests/unit/test_custom_component_with_client.py index ab8e7859e..65b1b873c 100644 --- a/src/backend/tests/unit/test_custom_component_with_client.py +++ b/src/backend/tests/unit/test_custom_component_with_client.py @@ -1,7 +1,17 @@ import pytest +from langflow.custom import Component from langflow.custom.custom_component.custom_component import CustomComponent +from langflow.custom.utils import build_custom_component_template from langflow.field_typing.constants import Data +from langflow.services.settings.feature_flags import FEATURE_FLAGS + + +@pytest.fixture +def code_component_with_multiple_outputs(): + with open("src/backend/tests/data/component_multiple_outputs.py") as f: + code = f.read() + return Component(_code=code) @pytest.fixture @@ -23,3 +33,16 @@ def test_list_flows_flow_objects(component): are_flows = [isinstance(flow, Data) for flow in flows] flow_types = [type(flow) for flow in flows] assert all(are_flows), f"Expected all flows to be Data objects, got {flow_types}" + + +def test_list_flows_return_type(component): + flows = component.list_flows() + assert isinstance(flows, list) + + +def test_feature_flags_add_toolkit_output(active_user, code_component_with_multiple_outputs): + frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id) + len_outputs = len(frontnd_node_dict["outputs"]) + FEATURE_FLAGS.add_toolkit_output = True + frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id) + assert len(frontnd_node_dict["outputs"]) == len_outputs + 1 diff --git a/src/backend/tests/unit/test_user.py b/src/backend/tests/unit/test_user.py index 461ebd7ea..974884541 100644 --- a/src/backend/tests/unit/test_user.py +++ b/src/backend/tests/unit/test_user.py @@ -2,6 +2,7 @@ from datetime import datetime import pytest from httpx import AsyncClient +from sqlmodel import select from langflow.services.auth.utils import create_super_user, get_password_hash from langflow.services.database.models.user import UserUpdate @@ -53,13 +54,23 @@ def deactivated_user(): return user -@pytest.mark.api_key_required -async def test_user_waiting_for_approval(client: AsyncClient): +async def test_user_waiting_for_approval(client): + username = "waitingforapproval" + password = "testpassword" + + # Debug: Check if the user already exists + with session_getter(get_db_service()) as session: + existing_user = session.exec(select(User).where(User.username == username)).first() + if existing_user: + pytest.fail( + f"User {username} already exists before the test. Database URL: {get_db_service().database_url}" + ) + # Create a user that is not active and has never logged in with session_getter(get_db_service()) as session: user = User( - username="waitingforapproval", - password=get_password_hash("testpassword"), + username=username, + password=get_password_hash(password), is_active=False, last_login_at=None, ) @@ -71,6 +82,14 @@ async def test_user_waiting_for_approval(client: AsyncClient): assert response.status_code == 400 assert response.json()["detail"] == "Waiting for approval" + # Debug: Check if the user still exists after the test + with session_getter(get_db_service()) as session: + existing_user = session.exec(select(User).where(User.username == username)).first() + if existing_user: + print(f"User {username} still exists after the test. This is expected.") + else: + pytest.fail(f"User {username} does not exist after the test. This is unexpected.") + @pytest.mark.api_key_required async def test_deactivated_user_cannot_login(client: AsyncClient, deactivated_user):