diff --git a/src/backend/langflow/services/database/base.py b/src/backend/langflow/services/database/base.py index cfc434f25..9f92c6c25 100644 --- a/src/backend/langflow/services/database/base.py +++ b/src/backend/langflow/services/database/base.py @@ -70,7 +70,7 @@ class DatabaseManager(Service): @contextmanager def session_getter(db_manager: DatabaseManager): try: - session = Session(DatabaseManager.engine) + session = Session(db_manager.engine) yield session except Exception as e: print("Session rollback because of exception:", e) diff --git a/tests/conftest.py b/tests/conftest.py index e6cc2a855..a97270c7c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ from contextlib import contextmanager import json from pathlib import Path -from typing import AsyncGenerator +from typing import AsyncGenerator, TYPE_CHECKING from langflow.api.v1.flows import get_session from langflow.graph.graph.base import Graph @@ -11,6 +11,9 @@ from httpx import AsyncClient from sqlmodel import SQLModel, Session, create_engine from sqlmodel.pool import StaticPool +if TYPE_CHECKING: + from langflow.services.database.base import DatabaseManager + def pytest_configure(): pytest.BASIC_EXAMPLE_PATH = ( @@ -134,15 +137,15 @@ def client_fixture(session: Session): # create a fixture for session_getter above @pytest.fixture(name="session_getter") -def session_getter_fixture(): +def session_getter_fixture(client): engine = create_engine( "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool ) SQLModel.metadata.create_all(engine) @contextmanager - def blank_session_getter(): - with Session(engine) as session: + def blank_session_getter(db_manager: "DatabaseManager"): + with Session(db_manager.engine) as session: yield session yield blank_session_getter