tests: update env.py and conftest so tests use unique databases (#3654)

* Refactor Alembic environment script to simplify database connection logic and remove unused imports

* Refactor test client fixture to use TemporaryDirectory and UUID for database path

* Add `reload_engine` method to reinitialize the database engine

* Add cleanup steps to test fixtures to ensure proper resource management

- Added cleanup steps to various pytest fixtures to ensure temporary directories, database entries, and other resources are properly cleaned up after tests.
- Modified fixtures to use `yield` for better resource management.
- Ensured database connections are closed and tables are dropped after tests.
- Improved temporary directory handling with context managers.

* Add unit test to verify return type of list_flows method in custom component

* Refactor tests in `test_custom_component.py` to remove unused imports and fixtures, and update existing fixtures for consistency.

* Add debug checks for user existence in `test_user_waiting_for_approval` test

* Fix import order and add flow_id validation in transaction handling

- Corrected the import order in `utils.py`.
- Added validation for `flow_id` in the transaction handling logic to ensure it is set correctly.

* Add function to delete transactions by flow ID in CRUD module

* Add cleanup for transactions and vertex builds in test teardown

- Introduced `_delete_transactions_and_vertex_builds` function to remove transactions and vertex builds associated with a user.
- Updated `user` fixture to call the new cleanup function before deleting the user.

* Refactor flow_id assignment logic in `utils.py` to improve readability and correctness

* [autofix.ci] apply automated fixes

* Refactor test to use pytest's tmp_path fixture for temporary directory creation

* Convert `test_user_waiting_for_approval` to an async test function

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-10-10 08:45:11 -03:00 committed by GitHub
commit 2adda780c9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 179 additions and 149 deletions

View file

@ -1,9 +1,6 @@
import os
import warnings
from logging.config import fileConfig
from alembic import context
from loguru import logger
from sqlalchemy import engine_from_config, pool
from langflow.services.database.models import *
@ -42,8 +39,7 @@ def run_migrations_offline() -> None:
script output.
"""
url = os.getenv("LANGFLOW_DATABASE_URL")
url = url or config.get_main_option("sqlalchemy.url")
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
@ -63,32 +59,17 @@ def run_migrations_online() -> None:
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
try:
from langflow.services.database.factory import DatabaseServiceFactory
from langflow.services.deps import get_db_service
from langflow.services.manager import initialize_settings_service, service_manager
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata, render_as_batch=True)
initialize_settings_service()
service_manager.register_factory(DatabaseServiceFactory())
connectable = get_db_service().engine
except Exception:
logger.exception("Error getting database engine")
url = os.getenv("LANGFLOW_DATABASE_URL")
url = url or config.get_main_option("sqlalchemy.url")
if url:
config.set_main_option("sqlalchemy.url", url)
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata, render_as_batch=True)
with context.begin_transaction():
context.run_migrations()
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():

View file

@ -141,6 +141,11 @@ async def log_transaction(
try:
if not get_settings_service().settings.transactions_storage_enabled:
return
if not flow_id:
if source.graph.flow_id:
flow_id = source.graph.flow_id
else:
return
inputs = _vertex_to_primitive_dict(source)
transaction = TransactionBase(
vertex_id=source.id,

View file

@ -45,6 +45,9 @@ class DatabaseService(Service):
self.alembic_cfg_path = langflow_dir / "alembic.ini"
self.engine = self._create_engine()
def reload_engine(self):
self.engine = self._create_engine()
def _create_engine(self) -> Engine:
"""Create the engine for the database."""
if self.settings_service.settings.database_url and self.settings_service.settings.database_url.startswith(

View file

@ -112,6 +112,8 @@ class ThreadSafeSingletonMetaUsingWeakref(type):
class OpenTelemetry(metaclass=ThreadSafeSingletonMetaUsingWeakref):
_metrics_registry: dict[str, Metric] = {}
_metrics: dict[str, Counter | ObservableGaugeWrapper | Histogram | UpDownCounter] = {}
_meter_provider: MeterProvider | None = None
def _add_metric(self, name: str, description: str, unit: str, metric_type: MetricType, labels: dict[str, bool]):
metric = Metric(name=name, description=description, metric_type=metric_type, unit=unit, labels=labels)
@ -140,33 +142,37 @@ class OpenTelemetry(metaclass=ThreadSafeSingletonMetaUsingWeakref):
labels={"flow_id": mandatory_label},
)
_metrics: dict[str, Counter | ObservableGaugeWrapper | Histogram | UpDownCounter] = {}
def __init__(self, prometheus_enabled: bool = True):
self._register_metric()
if not self._metrics_registry:
self._register_metric()
resource = Resource.create({"service.name": "langflow"})
metric_readers = []
if self._meter_provider is None:
resource = Resource.create({"service.name": "langflow"})
metric_readers = []
# configure prometheus exporter
self.prometheus_enabled = prometheus_enabled
if prometheus_enabled:
metric_readers.append(PrometheusMetricReader())
# configure prometheus exporter
self.prometheus_enabled = prometheus_enabled
if prometheus_enabled:
metric_readers.append(PrometheusMetricReader())
meter_provider = MeterProvider(resource=resource, metric_readers=metric_readers)
metrics.set_meter_provider(meter_provider)
self.meter = meter_provider.get_meter(langflow_meter_name)
self._meter_provider = MeterProvider(resource=resource, metric_readers=metric_readers)
metrics.set_meter_provider(self._meter_provider)
self.meter = self._meter_provider.get_meter(langflow_meter_name)
for name, metric in self._metrics_registry.items():
if name != metric.name:
msg = f"Key '{name}' does not match metric name '{metric.name}'"
raise ValueError(msg)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self._metrics[metric.name] = self._create_metric(metric)
if name not in self._metrics:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self._metrics[metric.name] = self._create_metric(metric)
def _create_metric(self, metric):
if metric.name in self._metrics:
return self._metrics[metric.name]
if metric.type == MetricType.COUNTER:
return self.meter.create_counter(
name=metric.name,

View file

@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator
from contextlib import contextmanager, suppress
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import UUID
import orjson
import pytest
@ -29,7 +30,9 @@ from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.models.flow.model import Flow, FlowCreate
from langflow.services.database.models.folder.model import Folder
from langflow.services.database.models.user.model import User, UserCreate
from langflow.services.database.models.transactions.model import TransactionTable
from langflow.services.database.models.user.model import User, UserCreate, UserRead
from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
@ -82,6 +85,23 @@ def get_text():
assert path.exists(), f"File {path} does not exist. Available files: {list(data_path.iterdir())}"
def delete_transactions_by_flow_id(db: Session, flow_id: UUID):
stmt = select(TransactionTable).where(TransactionTable.flow_id == flow_id)
transactions = db.exec(stmt)
for transaction in transactions:
db.delete(transaction)
db.commit()
def _delete_transactions_and_vertex_builds(session, user: User):
flow_ids = [flow.id for flow in user.flows]
for flow_id in flow_ids:
if not flow_id:
continue
delete_vertex_builds_by_flow_id(session, flow_id)
delete_transactions_by_flow_id(session, flow_id)
@pytest.fixture
def caplog(caplog: LogCaptureFixture):
handler_id = logger.add(
@ -110,6 +130,7 @@ def session_fixture():
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
SQLModel.metadata.drop_all(engine) # Add this line to clean up tables
class Config:
@ -119,8 +140,8 @@ class Config:
@pytest.fixture(name="load_flows_dir")
def load_flows_dir():
tempdir = tempfile.TemporaryDirectory()
yield tempdir.name
with tempfile.TemporaryDirectory() as tempdir:
yield tempdir
@pytest.fixture(name="distributed_env")
@ -143,23 +164,26 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
from langflow.core import celery_app
db_dir = tempfile.mkdtemp()
db_path = Path(db_dir) / "test.db"
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
# monkeypatch langflow.services.task.manager.USE_CELERY to True
# monkeypatch.setattr(manager, "USE_CELERY", True)
monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config))
try:
db_path = Path(db_dir) / "test.db"
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
# monkeypatch langflow.services.task.manager.USE_CELERY to True
# monkeypatch.setattr(manager, "USE_CELERY", True)
monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config))
# def get_session_override():
# return session
# def get_session_override():
# return session
from langflow.main import create_app
from langflow.main import create_app
app = create_app()
app = create_app()
# app.dependency_overrides[get_session] = get_session_override
with TestClient(app) as client:
yield client
# app.dependency_overrides[get_session] = get_session_override
with TestClient(app) as client:
yield client
finally:
shutil.rmtree(db_dir) # Clean up the temporary directory
app.dependency_overrides.clear()
monkeypatch.undo()
@ -279,7 +303,9 @@ async def client_fixture(session: Session, monkeypatch, request, load_flows_dir)
from langflow.main import create_app
app = create_app()
db_service = get_db_service()
db_service.database_url = f"sqlite:///{db_path}"
db_service.reload_engine()
# app.dependency_overrides[get_session] = get_session_override
async with LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager:
async with AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client:
@ -304,7 +330,7 @@ def session_getter_fixture(client):
@pytest.fixture
def runner():
return CliRunner()
yield CliRunner()
@pytest.fixture
@ -315,26 +341,38 @@ async def test_user(client):
)
response = await client.post("api/v1/users/", json=user_data.model_dump())
assert response.status_code == 201
return response.json()
user = response.json()
yield user
# Clean up
await client.delete(f"/api/v1/users/{user['id']}")
@pytest.fixture(scope="function")
def active_user(client):
db_manager = get_db_service()
with session_getter(db_manager) as session:
with db_manager.with_session() as session:
user = User(
username="activeuser",
password=get_password_hash("testpassword"),
is_active=True,
is_superuser=False,
)
# check if user exists
if active_user := session.exec(select(User).where(User.username == user.username)).first():
return active_user
session.add(user)
user = active_user
else:
session.add(user)
session.commit()
session.refresh(user)
user = UserRead.model_validate(user, from_attributes=True)
yield user
# Clean up
# Now cleanup transactions, vertex_build
with db_manager.with_session() as session:
user = session.get(User, user.id)
_delete_transactions_and_vertex_builds(session, user)
session.delete(user)
session.commit()
session.refresh(user)
return user
@pytest.fixture
@ -344,7 +382,7 @@ async def logged_in_headers(client, active_user):
assert response.status_code == 200
tokens = response.json()
a_token = tokens["access_token"]
return {"Authorization": f"Bearer {a_token}"}
yield {"Authorization": f"Bearer {a_token}"}
@pytest.fixture
@ -359,20 +397,22 @@ def flow(client, json_flow: str, active_user):
session.add(flow)
session.commit()
session.refresh(flow)
return flow
yield flow
# Clean up
session.delete(flow)
session.commit()
@pytest.fixture
def json_chat_input():
with open(pytest.CHAT_INPUT) as f:
return f.read()
yield f.read()
@pytest.fixture
def json_two_outputs():
with open(pytest.TWO_OUTPUTS) as f:
return f.read()
yield f.read()
@pytest.fixture
@ -385,7 +425,7 @@ async def added_flow_webhook_test(client, json_webhook_test, logged_in_headers):
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
yield response.json()
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
@pytest.fixture
@ -398,7 +438,7 @@ async def added_flow_chat_input(client, json_chat_input, logged_in_headers):
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
yield response.json()
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
@pytest.fixture
@ -411,7 +451,7 @@ async def added_flow_two_outputs(client, json_two_outputs, logged_in_headers):
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
yield response.json()
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
@pytest.fixture
@ -424,7 +464,7 @@ async def added_vector_store(client, json_vector_store, logged_in_headers):
assert response.json()["name"] == vector_store.name
assert response.json()["data"] == vector_store.data
yield response.json()
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
@pytest.fixture
@ -439,7 +479,7 @@ async def added_webhook_test(client, json_webhook_test, logged_in_headers):
assert response.json()["name"] == webhook_test.name
assert response.json()["data"] == webhook_test.data
yield response.json()
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
@pytest.fixture
@ -451,7 +491,7 @@ async def flow_component(client: TestClient, logged_in_headers):
response = await client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
assert response.status_code == 201
yield response.json()
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
@pytest.fixture
@ -466,11 +506,15 @@ def created_api_key(active_user):
db_manager = get_db_service()
with session_getter(db_manager) as session:
if existing_api_key := session.exec(select(ApiKey).where(ApiKey.api_key == api_key.api_key)).first():
return existing_api_key
yield existing_api_key
return
session.add(api_key)
session.commit()
session.refresh(api_key)
return api_key
yield api_key
# Clean up
session.delete(api_key)
session.commit()
@pytest.fixture(name="simple_api_test")
@ -483,7 +527,7 @@ async def get_simple_api_test(client, logged_in_headers, json_simple_api_test):
response = await client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
assert response.status_code == 201
yield response.json()
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
@pytest.fixture(name="starter_project")
@ -511,4 +555,7 @@ def get_starter_project(active_user):
session.commit()
session.refresh(new_flow)
new_flow_dict = new_flow.model_dump()
return new_flow_dict
yield new_flow_dict
# Clean up
session.delete(new_flow)
session.commit()

View file

@ -1,6 +1,3 @@
from pathlib import Path
from tempfile import tempdir
import pytest
from langflow.__main__ import app
@ -15,13 +12,9 @@ def default_settings():
]
def test_components_path(runner, client, default_settings):
# Create a foldr in the tmp directory
temp_dir = Path(tempdir)
def test_components_path(runner, client, default_settings, tmp_path):
# create a "components" folder
temp_dir = temp_dir / "components"
temp_dir.mkdir(exist_ok=True)
temp_dir = tmp_path / "components"
result = runner.invoke(
app,

View file

@ -1,7 +1,6 @@
import ast
import types
from textwrap import dedent
from uuid import uuid4
import pytest
from langchain_core.documents import Document
@ -10,8 +9,6 @@ from langflow.custom import Component, CustomComponent
from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
from langflow.custom.custom_component.base_component import BaseComponent, ComponentCodeNullError
from langflow.custom.utils import build_custom_component_template
from langflow.services.database.models.flow import FlowCreate
from langflow.services.settings.feature_flags import FEATURE_FLAGS
@pytest.fixture
@ -460,9 +457,8 @@ def test_build_config_no_code():
@pytest.fixture
def component(client, active_user):
return CustomComponent(
user_id=active_user.id,
def component():
yield CustomComponent(
field_config={
"fields": {
"llm": {"type": "str"},
@ -473,41 +469,6 @@ def component(client, active_user):
)
@pytest.fixture(scope="session")
def test_flow(db):
flow_data = {
"nodes": [{"id": "1"}, {"id": "2"}],
"edges": [{"source": "1", "target": "2"}],
}
# Create flow
flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data)
# Add to database
db.add(flow)
db.commit()
yield flow
# Clean up
db.delete(flow)
db.commit()
@pytest.fixture(scope="session")
def db(app):
# Setup database for tests
yield app.db
# Teardown
app.db.drop_all()
def test_list_flows_return_type(component):
flows = component.list_flows()
assert isinstance(flows, list)
def test_build_config_return_type(component):
config = component.build_config()
assert isinstance(config, dict)
@ -539,19 +500,11 @@ def test_build_config_field_value_keys(component):
assert all("type" in value for value in field_values)
def test_custom_component_multiple_outputs(code_component_with_multiple_outputs, active_user):
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
def test_custom_component_multiple_outputs(code_component_with_multiple_outputs):
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs)
assert frontnd_node_dict["outputs"][0]["types"] == ["Text"]
def test_feature_flags_add_toolkit_output(active_user, code_component_with_multiple_outputs):
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
len_outputs = len(frontnd_node_dict["outputs"])
FEATURE_FLAGS.add_toolkit_output = True
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
assert len(frontnd_node_dict["outputs"]) == len_outputs + 1
def test_custom_component_subclass_from_lctoolcomponent():
# Import LCToolComponent and create a subclass
code = dedent("""

View file

@ -1,7 +1,17 @@
import pytest
from langflow.custom import Component
from langflow.custom.custom_component.custom_component import CustomComponent
from langflow.custom.utils import build_custom_component_template
from langflow.field_typing.constants import Data
from langflow.services.settings.feature_flags import FEATURE_FLAGS
@pytest.fixture
def code_component_with_multiple_outputs():
with open("src/backend/tests/data/component_multiple_outputs.py") as f:
code = f.read()
return Component(_code=code)
@pytest.fixture
@ -23,3 +33,16 @@ def test_list_flows_flow_objects(component):
are_flows = [isinstance(flow, Data) for flow in flows]
flow_types = [type(flow) for flow in flows]
assert all(are_flows), f"Expected all flows to be Data objects, got {flow_types}"
def test_list_flows_return_type(component):
flows = component.list_flows()
assert isinstance(flows, list)
def test_feature_flags_add_toolkit_output(active_user, code_component_with_multiple_outputs):
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
len_outputs = len(frontnd_node_dict["outputs"])
FEATURE_FLAGS.add_toolkit_output = True
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
assert len(frontnd_node_dict["outputs"]) == len_outputs + 1

View file

@ -2,6 +2,7 @@ from datetime import datetime
import pytest
from httpx import AsyncClient
from sqlmodel import select
from langflow.services.auth.utils import create_super_user, get_password_hash
from langflow.services.database.models.user import UserUpdate
@ -53,13 +54,23 @@ def deactivated_user():
return user
@pytest.mark.api_key_required
async def test_user_waiting_for_approval(client: AsyncClient):
async def test_user_waiting_for_approval(client):
username = "waitingforapproval"
password = "testpassword"
# Debug: Check if the user already exists
with session_getter(get_db_service()) as session:
existing_user = session.exec(select(User).where(User.username == username)).first()
if existing_user:
pytest.fail(
f"User {username} already exists before the test. Database URL: {get_db_service().database_url}"
)
# Create a user that is not active and has never logged in
with session_getter(get_db_service()) as session:
user = User(
username="waitingforapproval",
password=get_password_hash("testpassword"),
username=username,
password=get_password_hash(password),
is_active=False,
last_login_at=None,
)
@ -71,6 +82,14 @@ async def test_user_waiting_for_approval(client: AsyncClient):
assert response.status_code == 400
assert response.json()["detail"] == "Waiting for approval"
# Debug: Check if the user still exists after the test
with session_getter(get_db_service()) as session:
existing_user = session.exec(select(User).where(User.username == username)).first()
if existing_user:
print(f"User {username} still exists after the test. This is expected.")
else:
pytest.fail(f"User {username} does not exist after the test. This is unexpected.")
@pytest.mark.api_key_required
async def test_deactivated_user_cannot_login(client: AsyncClient, deactivated_user):