From a9db2da6bfb71225e835122414076278fccc7a90 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sun, 6 Aug 2023 12:15:29 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(base.py):=20use=20db=5Fmanag?= =?UTF-8?q?er.engine=20instead=20of=20DatabaseManager.engine=20to=20access?= =?UTF-8?q?=20the=20database=20engine=20=F0=9F=90=9B=20fix(conftest.py):?= =?UTF-8?q?=20add=20TYPE=5FCHECKING=20import=20to=20fix=20type=20hinting?= =?UTF-8?q?=20error=20=F0=9F=90=9B=20fix(conftest.py):=20pass=20db=5Fmanag?= =?UTF-8?q?er=20to=20blank=5Fsession=5Fgetter=20fixture=20to=20fix=20sessi?= =?UTF-8?q?on=20creation=20error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/services/database/base.py | 2 +- tests/conftest.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/backend/langflow/services/database/base.py b/src/backend/langflow/services/database/base.py index cfc434f25..9f92c6c25 100644 --- a/src/backend/langflow/services/database/base.py +++ b/src/backend/langflow/services/database/base.py @@ -70,7 +70,7 @@ class DatabaseManager(Service): @contextmanager def session_getter(db_manager: DatabaseManager): try: - session = Session(DatabaseManager.engine) + session = Session(db_manager.engine) yield session except Exception as e: print("Session rollback because of exception:", e) diff --git a/tests/conftest.py b/tests/conftest.py index e6cc2a855..a97270c7c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ from contextlib import contextmanager import json from pathlib import Path -from typing import AsyncGenerator +from typing import AsyncGenerator, TYPE_CHECKING from langflow.api.v1.flows import get_session from langflow.graph.graph.base import Graph @@ -11,6 +11,9 @@ from httpx import AsyncClient from sqlmodel import SQLModel, Session, create_engine from sqlmodel.pool import StaticPool +if TYPE_CHECKING: + from langflow.services.database.base import DatabaseManager + def pytest_configure(): pytest.BASIC_EXAMPLE_PATH = ( @@ -134,15 +137,15 @@ def client_fixture(session: Session): # create a fixture for session_getter above @pytest.fixture(name="session_getter") -def session_getter_fixture(): +def session_getter_fixture(client): engine = create_engine( "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool ) SQLModel.metadata.create_all(engine) @contextmanager - def blank_session_getter(): - with Session(engine) as session: + def blank_session_getter(db_manager: "DatabaseManager"): + with Session(db_manager.engine) as session: yield session yield blank_session_getter