From c88f9bf8a0700fc7e8b166b7fb241cb8fb242504 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 22 Sep 2023 11:03:17 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20chore(conftest.py):=20refactor?= =?UTF-8?q?=20client=20fixture=20to=20use=20dependency=20overrides=20for?= =?UTF-8?q?=20session=20and=20add=20session=20fixture=20for=20creating=20a?= =?UTF-8?q?=20session=20with=20an=20in-memory=20SQLite=20database=20?= =?UTF-8?q?=F0=9F=94=A7=20chore(conftest.py):=20add=20distributed=5Fenv=20?= =?UTF-8?q?fixture=20to=20set=20up=20environment=20variables=20for=20distr?= =?UTF-8?q?ibuted=20testing=20=F0=9F=94=A7=20chore(conftest.py):=20add=20d?= =?UTF-8?q?istributed=5Fclient=20fixture=20for=20distributed=20testing=20w?= =?UTF-8?q?ith=20Celery=20=F0=9F=94=A7=20chore(conftest.py):=20remove=20un?= =?UTF-8?q?used=20imports=20and=20fixtures=20=F0=9F=94=A7=20chore(test=5Fc?= =?UTF-8?q?ache.py):=20remove=20unused=20client=20fixture=20from=20test=5F?= =?UTF-8?q?build=5Fgraph=20=F0=9F=94=A7=20chore(test=5Fcreators.py):=20rem?= =?UTF-8?q?ove=20unused=20client=20fixture=20from=20test=5Flang=5Fchain=5F?= =?UTF-8?q?type=5Fcreator=5Fto=5Fdict=20=F0=9F=94=A7=20chore(test=5Fdataba?= =?UTF-8?q?se.py):=20remove=20unused=20client=20fixture=20from=20test=5Fdo?= =?UTF-8?q?wnload=5Ffile?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/conftest.py | 91 +++++++++++++++++++++++++++++------------- tests/test_cache.py | 2 +- tests/test_creators.py | 1 + tests/test_database.py | 6 ++- 4 files changed, 70 insertions(+), 30 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index aeaef940a..e58f55ceb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,15 +51,75 @@ async def async_client() -> AsyncGenerator: yield client -# Create client fixture for FastAPI -@pytest.fixture(scope="module", autouse=True) -def client(): +@pytest.fixture(name="session") +def session_fixture(): + engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + yield session + + +@pytest.fixture(name="client") +def client_fixture(session: Session): + def get_session_override(): + return session + from langflow.main import create_app app = create_app() + app.dependency_overrides[get_session] = get_session_override with TestClient(app) as client: yield client + app.dependency_overrides.clear() + + +class Config: + broker_url = "redis://localhost:6379/0" + result_backend = "redis://localhost:6379/0" + + +@pytest.fixture(name="distributed_env") +def setup_env(monkeypatch): + monkeypatch.setenv("LANGFLOW_CACHE_TYPE", "redis") + monkeypatch.setenv("LANGFLOW_REDIS_HOST", "queue") + monkeypatch.setenv("LANGFLOW_REDIS_PORT", "6379") + monkeypatch.setenv("LANGFLOW_REDIS_DB", "0") + monkeypatch.setenv("LANGFLOW_REDIS_EXPIRE", "3600") + monkeypatch.setenv("LANGFLOW_REDIS_PASSWORD", "") + monkeypatch.setenv("FLOWER_UNAUTHENTICATED_API", "True") + monkeypatch.setenv("BROKER_URL", "redis://queue:6379/0") + monkeypatch.setenv("RESULT_BACKEND", "redis://queue:6379/0") + monkeypatch.setenv("C_FORCE_ROOT", "true") + + +@pytest.fixture(name="distributed_client") +def distributed_client_fixture(session: Session, monkeypatch, distributed_env): + # Here we load the .env from ../deploy/.env + from dotenv import load_dotenv + from langflow.services.task import manager + from langflow.core import celery_app + from langflow.services.manager import reinitialize_services, initialize_services + + # monkeypatch langflow.services.task.manager.USE_CELERY to True + monkeypatch.setattr(manager, "USE_CELERY", True) + monkeypatch.setattr( + celery_app, "celery_app", celery_app.make_celery("langflow", Config) + ) + + def get_session_override(): + return session + + from langflow.main import create_app + + app = create_app() + + app.dependency_overrides[get_session] = get_session_override + with TestClient(app) as client: + yield client + app.dependency_overrides.clear() def get_graph(_type="basic"): @@ -119,31 +179,6 @@ def json_vector_store(): return f.read() -@pytest.fixture(name="session") -def session_fixture(): - engine = create_engine( - "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool - ) - SQLModel.metadata.create_all(engine) - with Session(engine) as session: - yield session - - -@pytest.fixture(name="client") -def client_fixture(session: Session): - def get_session_override(): - return session - - from langflow.main import create_app - - app = create_app() - - app.dependency_overrides[get_session] = get_session_override - with TestClient(app) as client: - yield client - app.dependency_overrides.clear() - - # @contextmanager # def session_getter(): # try: diff --git a/tests/test_cache.py b/tests/test_cache.py index 4ceea2c2a..c2c706ee9 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -38,7 +38,7 @@ def langchain_objects_are_equal(obj1, obj2): # Test build_graph -def test_build_graph(basic_data_graph): +def test_build_graph(client, basic_data_graph): graph = Graph.from_payload(basic_data_graph) assert graph is not None assert len(graph.nodes) == len(basic_data_graph["nodes"]) diff --git a/tests/test_creators.py b/tests/test_creators.py index 2098e87cd..177dd4105 100644 --- a/tests/test_creators.py +++ b/tests/test_creators.py @@ -32,6 +32,7 @@ def sample_agent_creator() -> AgentCreator: def test_lang_chain_type_creator_to_dict( + client, sample_lang_chain_type_creator: LangChainTypeCreator, ): type_dict = sample_lang_chain_type_creator.to_dict() diff --git a/tests/test_database.py b/tests/test_database.py index e4f68ca56..058c470c5 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -179,7 +179,11 @@ def test_upload_file( def test_download_file( - client: TestClient, session: Session, json_flow, active_user, logged_in_headers + client: TestClient, + session: Session, + json_flow, + active_user, + logged_in_headers, ): flow = orjson.loads(json_flow) data = flow["data"]