From 517529ca4011b2eeece2d289b082eeed9ee5c191 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 25 Jan 2024 16:15:54 -0300 Subject: [PATCH] Refactor imports and fix formatting in conftest.py --- tests/conftest.py | 111 ++++++++++++++++------------------------------ 1 file changed, 37 insertions(+), 74 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 500419fe9..6538ddbe1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,60 +1,43 @@ -from contextlib import contextmanager import json -from contextlib import suppress -from pathlib import Path -from typing import AsyncGenerator, TYPE_CHECKING - -from langflow.graph.graph.base import Graph -from langflow.services.auth.utils import get_password_hash -from langflow.services.database.models.flow.flow import Flow, FlowCreate -from langflow.services.database.models.user.user import User, UserCreate -import orjson -from langflow.services.database.utils import session_getter -from langflow.services.getters import get_db_service -import pytest -from fastapi.testclient import TestClient -from httpx import AsyncClient -from sqlmodel import SQLModel, Session, create_engine -from sqlmodel.pool import StaticPool -from typer.testing import CliRunner # we need to import tmpdir import tempfile +from contextlib import contextmanager, suppress +from pathlib import Path +from typing import TYPE_CHECKING, AsyncGenerator + +import orjson +import pytest +from fastapi.testclient import TestClient +from httpx import AsyncClient +from langflow.graph.graph.base import Graph +from langflow.services.auth.utils import get_password_hash +from langflow.services.database.models.flow.model import Flow, FlowCreate +from langflow.services.database.models.user.model import User, UserCreate +from langflow.services.database.utils import session_getter +from langflow.services.deps import get_db_service +from sqlmodel import Session, SQLModel, create_engine +from sqlmodel.pool import StaticPool +from typer.testing import CliRunner if TYPE_CHECKING: - from langflow.services.database.manager import DatabaseService + from langflow.services.database.service import DatabaseService def pytest_configure(): - pytest.BASIC_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "basic_example.json" - ) - pytest.COMPLEX_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "complex_example.json" - ) - pytest.OPENAPI_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "Openapi.json" - ) - pytest.GROUPED_CHAT_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "grouped_chat.json" - ) - pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "one_group_chat.json" - ) - pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json" - ) + pytest.BASIC_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "basic_example.json" + pytest.COMPLEX_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "complex_example.json" + pytest.OPENAPI_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "Openapi.json" + pytest.GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "grouped_chat.json" + pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "one_group_chat.json" + pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json" pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY = ( Path(__file__).parent.absolute() / "data" / "BasicChatWithPromptAndHistory.json" ) pytest.CHAT_INPUT = Path(__file__).parent.absolute() / "data" / "ChatInputTest.json" - pytest.TWO_OUTPUTS = ( - Path(__file__).parent.absolute() / "data" / "TwoOutputsTest.json" - ) - pytest.VECTOR_STORE_PATH = ( - Path(__file__).parent.absolute() / "data" / "Vector_store.json" - ) + pytest.TWO_OUTPUTS = Path(__file__).parent.absolute() / "data" / "TwoOutputsTest.json" + pytest.VECTOR_STORE_PATH = Path(__file__).parent.absolute() / "data" / "Vector_store.json" pytest.CODE_WITH_SYNTAX_ERROR = """ def get_text(): retun "Hello World" @@ -65,9 +48,7 @@ def get_text(): def check_openai_api_key_in_environment_variables(): import os - assert ( - os.environ.get("OPENAI_API_KEY") is not None - ), "OPENAI_API_KEY is not set in environment variables" + assert os.environ.get("OPENAI_API_KEY") is not None, "OPENAI_API_KEY is not set in environment variables" @pytest.fixture() @@ -81,9 +62,7 @@ async def async_client() -> AsyncGenerator: @pytest.fixture(name="session") def session_fixture(): - engine = create_engine( - "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool - ) + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) SQLModel.metadata.create_all(engine) with Session(engine) as session: yield session @@ -119,9 +98,7 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env): monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false") # monkeypatch langflow.services.task.manager.USE_CELERY to True # monkeypatch.setattr(manager, "USE_CELERY", True) - monkeypatch.setattr( - celery_app, "celery_app", celery_app.make_celery("langflow", Config) - ) + monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config)) # def get_session_override(): # return session @@ -272,11 +249,7 @@ def active_user(client): is_superuser=False, ) # check if user exists - if ( - active_user := session.query(User) - .filter(User.username == user.username) - .first() - ): + if active_user := session.query(User).filter(User.username == user.username).first(): return active_user session.add(user) session.commit() @@ -296,13 +269,12 @@ def logged_in_headers(client, active_user): @pytest.fixture def flow(client, json_flow: str, active_user): - from langflow.services.database.models.flow.flow import FlowCreate + from langflow.services.database.models.flow.model import FlowCreate loaded_json = json.loads(json_flow) - flow_data = FlowCreate( - name="test_flow", data=loaded_json.get("data"), user_id=active_user.id - ) - flow = Flow(**flow_data.dict()) + flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id) + + flow = Flow.model_validate(flow_data) with session_getter(get_db_service()) as session: session.add(flow) session.commit() @@ -311,12 +283,6 @@ def flow(client, json_flow: str, active_user): return flow -@pytest.fixture -def json_flow_with_prompt_and_history(): - with open(pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY, "r") as f: - return f.read() - - @pytest.fixture def json_chat_input(): with open(pytest.CHAT_INPUT, "r") as f: @@ -330,9 +296,7 @@ def json_two_outputs(): @pytest.fixture -def added_flow_with_prompt_and_history( - client, json_flow_with_prompt_and_history, logged_in_headers -): +def added_flow_with_prompt_and_history(client, json_flow_with_prompt_and_history, logged_in_headers): flow = orjson.loads(json_flow_with_prompt_and_history) data = flow["data"] flow = FlowCreate(name="Basic Chat", description="description", data=data) @@ -364,6 +328,7 @@ def added_flow_two_outputs(client, json_two_outputs, logged_in_headers): assert response.status_code == 201 assert response.json()["name"] == flow.name assert response.json()["data"] == flow.data + return response.json() @pytest.fixture @@ -371,9 +336,7 @@ def added_vector_store(client, json_vector_store, logged_in_headers): vector_store = orjson.loads(json_vector_store) data = vector_store["data"] vector_store = FlowCreate(name="Vector Store", description="description", data=data) - response = client.post( - "api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers - ) + response = client.post("api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers) assert response.status_code == 201 assert response.json()["name"] == vector_store.name assert response.json()["data"] == vector_store.data