From ddafdf31b5eadc8e5214cfebc7077836b5263936 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 31 Jul 2023 17:31:40 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20fix(conftest.py):=20remove=20unu?= =?UTF-8?q?sed=20imports=20and=20commented=20out=20code=20=E2=9C=A8=20feat?= =?UTF-8?q?(conftest.py):=20add=20session=5Fgetter=20fixture=20to=20create?= =?UTF-8?q?=20a=20blank=20session=20for=20testing=20=F0=9F=94=A7=20fix(tes?= =?UTF-8?q?t=5Fcustom=5Fcomponent.py):=20pass=20session=5Fgetter=20fixture?= =?UTF-8?q?=20to=20list=5Fflows=20function=20for=20testing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/conftest.py | 30 ++++++++++++++++++++++++++++++ tests/test_custom_component.py | 8 ++++---- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index dfb2b56f3..45a8f8f1f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager import json from pathlib import Path from typing import AsyncGenerator @@ -116,3 +117,32 @@ def client_fixture(session: Session): yield TestClient(app) app.dependency_overrides.clear() + + +# @contextmanager +# def session_getter(): +# try: +# session = Session(engine) +# yield session +# except Exception as e: +# print("Session rollback because of exception:", e) +# session.rollback() +# raise +# finally: +# session.close() + + +# create a fixture for session_getter above +@pytest.fixture(name="session_getter") +def session_getter_fixture(): + engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) + SQLModel.metadata.create_all(engine) + + @contextmanager + def blank_session_getter(): + with Session(engine) as session: + yield session + + yield blank_session_getter diff --git a/tests/test_custom_component.py b/tests/test_custom_component.py index d57f347a1..199906dda 100644 --- a/tests/test_custom_component.py +++ b/tests/test_custom_component.py @@ -517,13 +517,13 @@ def db(app): app.db.drop_all() -def test_list_flows_return_type(component): - flows = component.list_flows() +def test_list_flows_return_type(component, session_getter): + flows = component.list_flows(get_session=session_getter) assert isinstance(flows, list) -def test_list_flows_flow_objects(component): - flows = component.list_flows() +def test_list_flows_flow_objects(component, session_getter): + flows = component.list_flows(get_session=session_getter) assert all(isinstance(flow, Flow) for flow in flows)