diff --git a/tests/conftest.py b/tests/conftest.py index b64851aba..2c8b9016e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,12 +2,13 @@ from contextlib import contextmanager import json from pathlib import Path from typing import AsyncGenerator, TYPE_CHECKING -from langflow.api.v1.flows import get_session from langflow.graph.graph.base import Graph from langflow.services.auth.utils import get_password_hash from langflow.services.database.models.flow.flow import Flow from langflow.services.database.models.user.user import User, UserCreate +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_manager import pytest from fastapi.testclient import TestClient from httpx import AsyncClient @@ -120,6 +121,7 @@ def client_fixture(session: Session, monkeypatch): db_dir = tempfile.mkdtemp() db_path = Path(db_dir) / "test.db" monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}") + # monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", 1) def get_session_override(): return session @@ -128,10 +130,10 @@ def client_fixture(session: Session, monkeypatch): app = create_app() - app.dependency_overrides[get_session] = get_session_override + # app.dependency_overrides[get_session] = get_session_override with TestClient(app) as client: yield client - app.dependency_overrides.clear() + # app.dependency_overrides.clear() monkeypatch.undo() # clear the temp db db_path.unlink() @@ -153,11 +155,6 @@ def client_fixture(session: Session, monkeypatch): # create a fixture for session_getter above @pytest.fixture(name="session_getter") def session_getter_fixture(client): - engine = create_engine( - "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool - ) - SQLModel.metadata.create_all(engine) - @contextmanager def blank_session_getter(db_manager: "DatabaseManager"): with Session(db_manager.engine) as session: @@ -183,17 +180,18 @@ def test_user(client): @pytest.fixture(scope="function") -def active_user(client, session): - user = User( - username="activeuser", - password=get_password_hash( - "testpassword" - ), # Assuming password needs to be hashed - is_active=True, - is_superuser=False, - ) - session.add(user) - session.commit() +def active_user(client): + db_manager = get_db_manager() + with session_getter(db_manager) as session: + user = User( + username="activeuser", + password=get_password_hash("testpassword"), + is_active=True, + is_superuser=False, + ) + session.add(user) + session.commit() + session.refresh(user) return user @@ -208,7 +206,7 @@ def logged_in_headers(client, active_user): @pytest.fixture -def flow(client, json_flow: str, session, active_user): +def flow(client, json_flow: str, active_user): from langflow.services.database.models.flow.flow import FlowCreate loaded_json = json.loads(json_flow) @@ -216,7 +214,9 @@ def flow(client, json_flow: str, session, active_user): name="test_flow", data=loaded_json.get("data"), user_id=active_user.id ) flow = Flow(**flow_data.dict()) - session.add(flow) - session.commit() + with session_getter(get_db_manager()) as session: + session.add(flow) + session.commit() + session.refresh(flow) return flow