tests: update env.py and conftest so tests use unique databases (#3654)
* Refactor Alembic environment script to simplify database connection logic and remove unused imports * Refactor test client fixture to use TemporaryDirectory and UUID for database path * Add `reload_engine` method to reinitialize the database engine * Add cleanup steps to test fixtures to ensure proper resource management - Added cleanup steps to various pytest fixtures to ensure temporary directories, database entries, and other resources are properly cleaned up after tests. - Modified fixtures to use `yield` for better resource management. - Ensured database connections are closed and tables are dropped after tests. - Improved temporary directory handling with context managers. * Add unit test to verify return type of list_flows method in custom component * Refactor tests in `test_custom_component.py` to remove unused imports and fixtures, and update existing fixtures for consistency. * Add debug checks for user existence in `test_user_waiting_for_approval` test * Fix import order and add flow_id validation in transaction handling - Corrected the import order in `utils.py`. - Added validation for `flow_id` in the transaction handling logic to ensure it is set correctly. * Add function to delete transactions by flow ID in CRUD module * Add cleanup for transactions and vertex builds in test teardown - Introduced `_delete_transactions_and_vertex_builds` function to remove transactions and vertex builds associated with a user. - Updated `user` fixture to call the new cleanup function before deleting the user. * Refactor flow_id assignment logic in `utils.py` to improve readability and correctness * [autofix.ci] apply automated fixes * Refactor test to use pytest's tmp_path fixture for temporary directory creation * Convert `test_user_waiting_for_approval` to an async test function --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
4e36dcc2ad
commit
2adda780c9
9 changed files with 179 additions and 149 deletions
|
|
@ -1,9 +1,6 @@
|
|||
import os
|
||||
import warnings
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from loguru import logger
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from langflow.services.database.models import *
|
||||
|
|
@ -42,8 +39,7 @@ def run_migrations_offline() -> None:
|
|||
script output.
|
||||
|
||||
"""
|
||||
url = os.getenv("LANGFLOW_DATABASE_URL")
|
||||
url = url or config.get_main_option("sqlalchemy.url")
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
|
|
@ -63,32 +59,17 @@ def run_migrations_online() -> None:
|
|||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
try:
|
||||
from langflow.services.database.factory import DatabaseServiceFactory
|
||||
from langflow.services.deps import get_db_service
|
||||
from langflow.services.manager import initialize_settings_service, service_manager
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata, render_as_batch=True)
|
||||
|
||||
initialize_settings_service()
|
||||
service_manager.register_factory(DatabaseServiceFactory())
|
||||
connectable = get_db_service().engine
|
||||
except Exception:
|
||||
logger.exception("Error getting database engine")
|
||||
url = os.getenv("LANGFLOW_DATABASE_URL")
|
||||
url = url or config.get_main_option("sqlalchemy.url")
|
||||
if url:
|
||||
config.set_main_option("sqlalchemy.url", url)
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata, render_as_batch=True)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
|
|
|
|||
|
|
@ -141,6 +141,11 @@ async def log_transaction(
|
|||
try:
|
||||
if not get_settings_service().settings.transactions_storage_enabled:
|
||||
return
|
||||
if not flow_id:
|
||||
if source.graph.flow_id:
|
||||
flow_id = source.graph.flow_id
|
||||
else:
|
||||
return
|
||||
inputs = _vertex_to_primitive_dict(source)
|
||||
transaction = TransactionBase(
|
||||
vertex_id=source.id,
|
||||
|
|
|
|||
|
|
@ -45,6 +45,9 @@ class DatabaseService(Service):
|
|||
self.alembic_cfg_path = langflow_dir / "alembic.ini"
|
||||
self.engine = self._create_engine()
|
||||
|
||||
def reload_engine(self):
|
||||
self.engine = self._create_engine()
|
||||
|
||||
def _create_engine(self) -> Engine:
|
||||
"""Create the engine for the database."""
|
||||
if self.settings_service.settings.database_url and self.settings_service.settings.database_url.startswith(
|
||||
|
|
|
|||
|
|
@ -112,6 +112,8 @@ class ThreadSafeSingletonMetaUsingWeakref(type):
|
|||
|
||||
class OpenTelemetry(metaclass=ThreadSafeSingletonMetaUsingWeakref):
|
||||
_metrics_registry: dict[str, Metric] = {}
|
||||
_metrics: dict[str, Counter | ObservableGaugeWrapper | Histogram | UpDownCounter] = {}
|
||||
_meter_provider: MeterProvider | None = None
|
||||
|
||||
def _add_metric(self, name: str, description: str, unit: str, metric_type: MetricType, labels: dict[str, bool]):
|
||||
metric = Metric(name=name, description=description, metric_type=metric_type, unit=unit, labels=labels)
|
||||
|
|
@ -140,33 +142,37 @@ class OpenTelemetry(metaclass=ThreadSafeSingletonMetaUsingWeakref):
|
|||
labels={"flow_id": mandatory_label},
|
||||
)
|
||||
|
||||
_metrics: dict[str, Counter | ObservableGaugeWrapper | Histogram | UpDownCounter] = {}
|
||||
|
||||
def __init__(self, prometheus_enabled: bool = True):
|
||||
self._register_metric()
|
||||
if not self._metrics_registry:
|
||||
self._register_metric()
|
||||
|
||||
resource = Resource.create({"service.name": "langflow"})
|
||||
metric_readers = []
|
||||
if self._meter_provider is None:
|
||||
resource = Resource.create({"service.name": "langflow"})
|
||||
metric_readers = []
|
||||
|
||||
# configure prometheus exporter
|
||||
self.prometheus_enabled = prometheus_enabled
|
||||
if prometheus_enabled:
|
||||
metric_readers.append(PrometheusMetricReader())
|
||||
# configure prometheus exporter
|
||||
self.prometheus_enabled = prometheus_enabled
|
||||
if prometheus_enabled:
|
||||
metric_readers.append(PrometheusMetricReader())
|
||||
|
||||
meter_provider = MeterProvider(resource=resource, metric_readers=metric_readers)
|
||||
metrics.set_meter_provider(meter_provider)
|
||||
self.meter = meter_provider.get_meter(langflow_meter_name)
|
||||
self._meter_provider = MeterProvider(resource=resource, metric_readers=metric_readers)
|
||||
metrics.set_meter_provider(self._meter_provider)
|
||||
|
||||
self.meter = self._meter_provider.get_meter(langflow_meter_name)
|
||||
|
||||
for name, metric in self._metrics_registry.items():
|
||||
if name != metric.name:
|
||||
msg = f"Key '{name}' does not match metric name '{metric.name}'"
|
||||
raise ValueError(msg)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
self._metrics[metric.name] = self._create_metric(metric)
|
||||
if name not in self._metrics:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
self._metrics[metric.name] = self._create_metric(metric)
|
||||
|
||||
def _create_metric(self, metric):
|
||||
if metric.name in self._metrics:
|
||||
return self._metrics[metric.name]
|
||||
|
||||
if metric.type == MetricType.COUNTER:
|
||||
return self.meter.create_counter(
|
||||
name=metric.name,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator
|
|||
from contextlib import contextmanager, suppress
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
|
|
@ -29,7 +30,9 @@ from langflow.services.auth.utils import get_password_hash
|
|||
from langflow.services.database.models.api_key.model import ApiKey
|
||||
from langflow.services.database.models.flow.model import Flow, FlowCreate
|
||||
from langflow.services.database.models.folder.model import Folder
|
||||
from langflow.services.database.models.user.model import User, UserCreate
|
||||
from langflow.services.database.models.transactions.model import TransactionTable
|
||||
from langflow.services.database.models.user.model import User, UserCreate, UserRead
|
||||
from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
|
||||
|
|
@ -82,6 +85,23 @@ def get_text():
|
|||
assert path.exists(), f"File {path} does not exist. Available files: {list(data_path.iterdir())}"
|
||||
|
||||
|
||||
def delete_transactions_by_flow_id(db: Session, flow_id: UUID):
|
||||
stmt = select(TransactionTable).where(TransactionTable.flow_id == flow_id)
|
||||
transactions = db.exec(stmt)
|
||||
for transaction in transactions:
|
||||
db.delete(transaction)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _delete_transactions_and_vertex_builds(session, user: User):
|
||||
flow_ids = [flow.id for flow in user.flows]
|
||||
for flow_id in flow_ids:
|
||||
if not flow_id:
|
||||
continue
|
||||
delete_vertex_builds_by_flow_id(session, flow_id)
|
||||
delete_transactions_by_flow_id(session, flow_id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def caplog(caplog: LogCaptureFixture):
|
||||
handler_id = logger.add(
|
||||
|
|
@ -110,6 +130,7 @@ def session_fixture():
|
|||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
SQLModel.metadata.drop_all(engine) # Add this line to clean up tables
|
||||
|
||||
|
||||
class Config:
|
||||
|
|
@ -119,8 +140,8 @@ class Config:
|
|||
|
||||
@pytest.fixture(name="load_flows_dir")
|
||||
def load_flows_dir():
|
||||
tempdir = tempfile.TemporaryDirectory()
|
||||
yield tempdir.name
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
yield tempdir
|
||||
|
||||
|
||||
@pytest.fixture(name="distributed_env")
|
||||
|
|
@ -143,23 +164,26 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
|
|||
from langflow.core import celery_app
|
||||
|
||||
db_dir = tempfile.mkdtemp()
|
||||
db_path = Path(db_dir) / "test.db"
|
||||
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
|
||||
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
|
||||
# monkeypatch langflow.services.task.manager.USE_CELERY to True
|
||||
# monkeypatch.setattr(manager, "USE_CELERY", True)
|
||||
monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config))
|
||||
try:
|
||||
db_path = Path(db_dir) / "test.db"
|
||||
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
|
||||
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
|
||||
# monkeypatch langflow.services.task.manager.USE_CELERY to True
|
||||
# monkeypatch.setattr(manager, "USE_CELERY", True)
|
||||
monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config))
|
||||
|
||||
# def get_session_override():
|
||||
# return session
|
||||
# def get_session_override():
|
||||
# return session
|
||||
|
||||
from langflow.main import create_app
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
app = create_app()
|
||||
|
||||
# app.dependency_overrides[get_session] = get_session_override
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
# app.dependency_overrides[get_session] = get_session_override
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
finally:
|
||||
shutil.rmtree(db_dir) # Clean up the temporary directory
|
||||
app.dependency_overrides.clear()
|
||||
monkeypatch.undo()
|
||||
|
||||
|
|
@ -279,7 +303,9 @@ async def client_fixture(session: Session, monkeypatch, request, load_flows_dir)
|
|||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
db_service = get_db_service()
|
||||
db_service.database_url = f"sqlite:///{db_path}"
|
||||
db_service.reload_engine()
|
||||
# app.dependency_overrides[get_session] = get_session_override
|
||||
async with LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager:
|
||||
async with AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client:
|
||||
|
|
@ -304,7 +330,7 @@ def session_getter_fixture(client):
|
|||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
yield CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -315,26 +341,38 @@ async def test_user(client):
|
|||
)
|
||||
response = await client.post("api/v1/users/", json=user_data.model_dump())
|
||||
assert response.status_code == 201
|
||||
return response.json()
|
||||
user = response.json()
|
||||
yield user
|
||||
# Clean up
|
||||
await client.delete(f"/api/v1/users/{user['id']}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def active_user(client):
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
with db_manager.with_session() as session:
|
||||
user = User(
|
||||
username="activeuser",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
# check if user exists
|
||||
if active_user := session.exec(select(User).where(User.username == user.username)).first():
|
||||
return active_user
|
||||
session.add(user)
|
||||
user = active_user
|
||||
else:
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user = UserRead.model_validate(user, from_attributes=True)
|
||||
yield user
|
||||
# Clean up
|
||||
# Now cleanup transactions, vertex_build
|
||||
with db_manager.with_session() as session:
|
||||
user = session.get(User, user.id)
|
||||
_delete_transactions_and_vertex_builds(session, user)
|
||||
session.delete(user)
|
||||
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -344,7 +382,7 @@ async def logged_in_headers(client, active_user):
|
|||
assert response.status_code == 200
|
||||
tokens = response.json()
|
||||
a_token = tokens["access_token"]
|
||||
return {"Authorization": f"Bearer {a_token}"}
|
||||
yield {"Authorization": f"Bearer {a_token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -359,20 +397,22 @@ def flow(client, json_flow: str, active_user):
|
|||
session.add(flow)
|
||||
session.commit()
|
||||
session.refresh(flow)
|
||||
|
||||
return flow
|
||||
yield flow
|
||||
# Clean up
|
||||
session.delete(flow)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_chat_input():
|
||||
with open(pytest.CHAT_INPUT) as f:
|
||||
return f.read()
|
||||
yield f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_two_outputs():
|
||||
with open(pytest.TWO_OUTPUTS) as f:
|
||||
return f.read()
|
||||
yield f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -385,7 +425,7 @@ async def added_flow_webhook_test(client, json_webhook_test, logged_in_headers):
|
|||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
yield response.json()
|
||||
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -398,7 +438,7 @@ async def added_flow_chat_input(client, json_chat_input, logged_in_headers):
|
|||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
yield response.json()
|
||||
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -411,7 +451,7 @@ async def added_flow_two_outputs(client, json_two_outputs, logged_in_headers):
|
|||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
yield response.json()
|
||||
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -424,7 +464,7 @@ async def added_vector_store(client, json_vector_store, logged_in_headers):
|
|||
assert response.json()["name"] == vector_store.name
|
||||
assert response.json()["data"] == vector_store.data
|
||||
yield response.json()
|
||||
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -439,7 +479,7 @@ async def added_webhook_test(client, json_webhook_test, logged_in_headers):
|
|||
assert response.json()["name"] == webhook_test.name
|
||||
assert response.json()["data"] == webhook_test.data
|
||||
yield response.json()
|
||||
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -451,7 +491,7 @@ async def flow_component(client: TestClient, logged_in_headers):
|
|||
response = await client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
yield response.json()
|
||||
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -466,11 +506,15 @@ def created_api_key(active_user):
|
|||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
if existing_api_key := session.exec(select(ApiKey).where(ApiKey.api_key == api_key.api_key)).first():
|
||||
return existing_api_key
|
||||
yield existing_api_key
|
||||
return
|
||||
session.add(api_key)
|
||||
session.commit()
|
||||
session.refresh(api_key)
|
||||
return api_key
|
||||
yield api_key
|
||||
# Clean up
|
||||
session.delete(api_key)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture(name="simple_api_test")
|
||||
|
|
@ -483,7 +527,7 @@ async def get_simple_api_test(client, logged_in_headers, json_simple_api_test):
|
|||
response = await client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
yield response.json()
|
||||
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
await client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.fixture(name="starter_project")
|
||||
|
|
@ -511,4 +555,7 @@ def get_starter_project(active_user):
|
|||
session.commit()
|
||||
session.refresh(new_flow)
|
||||
new_flow_dict = new_flow.model_dump()
|
||||
return new_flow_dict
|
||||
yield new_flow_dict
|
||||
# Clean up
|
||||
session.delete(new_flow)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,3 @@
|
|||
from pathlib import Path
|
||||
from tempfile import tempdir
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.__main__ import app
|
||||
|
|
@ -15,13 +12,9 @@ def default_settings():
|
|||
]
|
||||
|
||||
|
||||
def test_components_path(runner, client, default_settings):
|
||||
# Create a foldr in the tmp directory
|
||||
|
||||
temp_dir = Path(tempdir)
|
||||
def test_components_path(runner, client, default_settings, tmp_path):
|
||||
# create a "components" folder
|
||||
temp_dir = temp_dir / "components"
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
temp_dir = tmp_path / "components"
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import ast
|
||||
import types
|
||||
from textwrap import dedent
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
|
@ -10,8 +9,6 @@ from langflow.custom import Component, CustomComponent
|
|||
from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
|
||||
from langflow.custom.custom_component.base_component import BaseComponent, ComponentCodeNullError
|
||||
from langflow.custom.utils import build_custom_component_template
|
||||
from langflow.services.database.models.flow import FlowCreate
|
||||
from langflow.services.settings.feature_flags import FEATURE_FLAGS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -460,9 +457,8 @@ def test_build_config_no_code():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def component(client, active_user):
|
||||
return CustomComponent(
|
||||
user_id=active_user.id,
|
||||
def component():
|
||||
yield CustomComponent(
|
||||
field_config={
|
||||
"fields": {
|
||||
"llm": {"type": "str"},
|
||||
|
|
@ -473,41 +469,6 @@ def component(client, active_user):
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_flow(db):
|
||||
flow_data = {
|
||||
"nodes": [{"id": "1"}, {"id": "2"}],
|
||||
"edges": [{"source": "1", "target": "2"}],
|
||||
}
|
||||
|
||||
# Create flow
|
||||
flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data)
|
||||
|
||||
# Add to database
|
||||
db.add(flow)
|
||||
db.commit()
|
||||
|
||||
yield flow
|
||||
|
||||
# Clean up
|
||||
db.delete(flow)
|
||||
db.commit()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def db(app):
|
||||
# Setup database for tests
|
||||
yield app.db
|
||||
|
||||
# Teardown
|
||||
app.db.drop_all()
|
||||
|
||||
|
||||
def test_list_flows_return_type(component):
|
||||
flows = component.list_flows()
|
||||
assert isinstance(flows, list)
|
||||
|
||||
|
||||
def test_build_config_return_type(component):
|
||||
config = component.build_config()
|
||||
assert isinstance(config, dict)
|
||||
|
|
@ -539,19 +500,11 @@ def test_build_config_field_value_keys(component):
|
|||
assert all("type" in value for value in field_values)
|
||||
|
||||
|
||||
def test_custom_component_multiple_outputs(code_component_with_multiple_outputs, active_user):
|
||||
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
|
||||
def test_custom_component_multiple_outputs(code_component_with_multiple_outputs):
|
||||
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs)
|
||||
assert frontnd_node_dict["outputs"][0]["types"] == ["Text"]
|
||||
|
||||
|
||||
def test_feature_flags_add_toolkit_output(active_user, code_component_with_multiple_outputs):
|
||||
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
|
||||
len_outputs = len(frontnd_node_dict["outputs"])
|
||||
FEATURE_FLAGS.add_toolkit_output = True
|
||||
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
|
||||
assert len(frontnd_node_dict["outputs"]) == len_outputs + 1
|
||||
|
||||
|
||||
def test_custom_component_subclass_from_lctoolcomponent():
|
||||
# Import LCToolComponent and create a subclass
|
||||
code = dedent("""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,17 @@
|
|||
import pytest
|
||||
|
||||
from langflow.custom import Component
|
||||
from langflow.custom.custom_component.custom_component import CustomComponent
|
||||
from langflow.custom.utils import build_custom_component_template
|
||||
from langflow.field_typing.constants import Data
|
||||
from langflow.services.settings.feature_flags import FEATURE_FLAGS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def code_component_with_multiple_outputs():
|
||||
with open("src/backend/tests/data/component_multiple_outputs.py") as f:
|
||||
code = f.read()
|
||||
return Component(_code=code)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -23,3 +33,16 @@ def test_list_flows_flow_objects(component):
|
|||
are_flows = [isinstance(flow, Data) for flow in flows]
|
||||
flow_types = [type(flow) for flow in flows]
|
||||
assert all(are_flows), f"Expected all flows to be Data objects, got {flow_types}"
|
||||
|
||||
|
||||
def test_list_flows_return_type(component):
|
||||
flows = component.list_flows()
|
||||
assert isinstance(flows, list)
|
||||
|
||||
|
||||
def test_feature_flags_add_toolkit_output(active_user, code_component_with_multiple_outputs):
|
||||
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
|
||||
len_outputs = len(frontnd_node_dict["outputs"])
|
||||
FEATURE_FLAGS.add_toolkit_output = True
|
||||
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
|
||||
assert len(frontnd_node_dict["outputs"]) == len_outputs + 1
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from datetime import datetime
|
|||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.services.auth.utils import create_super_user, get_password_hash
|
||||
from langflow.services.database.models.user import UserUpdate
|
||||
|
|
@ -53,13 +54,23 @@ def deactivated_user():
|
|||
return user
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
async def test_user_waiting_for_approval(client: AsyncClient):
|
||||
async def test_user_waiting_for_approval(client):
|
||||
username = "waitingforapproval"
|
||||
password = "testpassword"
|
||||
|
||||
# Debug: Check if the user already exists
|
||||
with session_getter(get_db_service()) as session:
|
||||
existing_user = session.exec(select(User).where(User.username == username)).first()
|
||||
if existing_user:
|
||||
pytest.fail(
|
||||
f"User {username} already exists before the test. Database URL: {get_db_service().database_url}"
|
||||
)
|
||||
|
||||
# Create a user that is not active and has never logged in
|
||||
with session_getter(get_db_service()) as session:
|
||||
user = User(
|
||||
username="waitingforapproval",
|
||||
password=get_password_hash("testpassword"),
|
||||
username=username,
|
||||
password=get_password_hash(password),
|
||||
is_active=False,
|
||||
last_login_at=None,
|
||||
)
|
||||
|
|
@ -71,6 +82,14 @@ async def test_user_waiting_for_approval(client: AsyncClient):
|
|||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Waiting for approval"
|
||||
|
||||
# Debug: Check if the user still exists after the test
|
||||
with session_getter(get_db_service()) as session:
|
||||
existing_user = session.exec(select(User).where(User.username == username)).first()
|
||||
if existing_user:
|
||||
print(f"User {username} still exists after the test. This is expected.")
|
||||
else:
|
||||
pytest.fail(f"User {username} does not exist after the test. This is unexpected.")
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
async def test_deactivated_user_cannot_login(client: AsyncClient, deactivated_user):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue