diff --git a/tests/conftest.py b/tests/conftest.py index dfb2b56f3..45a8f8f1f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager import json from pathlib import Path from typing import AsyncGenerator @@ -116,3 +117,32 @@ def client_fixture(session: Session): yield TestClient(app) app.dependency_overrides.clear() + + +# @contextmanager +# def session_getter(): +# try: +# session = Session(engine) +# yield session +# except Exception as e: +# print("Session rollback because of exception:", e) +# session.rollback() +# raise +# finally: +# session.close() + + +# create a fixture for session_getter above +@pytest.fixture(name="session_getter") +def session_getter_fixture(): + engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) + SQLModel.metadata.create_all(engine) + + @contextmanager + def blank_session_getter(): + with Session(engine) as session: + yield session + + yield blank_session_getter diff --git a/tests/test_custom_component.py b/tests/test_custom_component.py index d57f347a1..199906dda 100644 --- a/tests/test_custom_component.py +++ b/tests/test_custom_component.py @@ -517,13 +517,13 @@ def db(app): app.db.drop_all() -def test_list_flows_return_type(component): - flows = component.list_flows() +def test_list_flows_return_type(component, session_getter): + flows = component.list_flows(get_session=session_getter) assert isinstance(flows, list) -def test_list_flows_flow_objects(component): - flows = component.list_flows() +def test_list_flows_flow_objects(component, session_getter): + flows = component.list_flows(get_session=session_getter) assert all(isinstance(flow, Flow) for flow in flows)