🔧 chore(conftest.py): refactor client fixture to use dependency overrides for session and add session fixture for creating a session with an in-memory SQLite database

🔧 chore(conftest.py): add distributed_env fixture to set up environment variables for distributed testing
🔧 chore(conftest.py): add distributed_client fixture for distributed testing with Celery
🔧 chore(conftest.py): remove unused imports and fixtures
🔧 chore(test_cache.py): remove unused client fixture from test_build_graph
🔧 chore(test_creators.py): remove unused client fixture from test_lang_chain_type_creator_to_dict
🔧 chore(test_database.py): remove unused client fixture from test_download_file
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-09-22 11:03:17 -03:00
commit c88f9bf8a0
4 changed files with 70 additions and 30 deletions

View file

@ -51,15 +51,75 @@ async def async_client() -> AsyncGenerator:
yield client
# Create client fixture for FastAPI
@pytest.fixture(scope="module", autouse=True)
def client():
@pytest.fixture(name="session")
def session_fixture():
engine = create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
@pytest.fixture(name="client")
def client_fixture(session: Session):
def get_session_override():
return session
from langflow.main import create_app
app = create_app()
app.dependency_overrides[get_session] = get_session_override
with TestClient(app) as client:
yield client
app.dependency_overrides.clear()
class Config:
broker_url = "redis://localhost:6379/0"
result_backend = "redis://localhost:6379/0"
@pytest.fixture(name="distributed_env")
def setup_env(monkeypatch):
monkeypatch.setenv("LANGFLOW_CACHE_TYPE", "redis")
monkeypatch.setenv("LANGFLOW_REDIS_HOST", "queue")
monkeypatch.setenv("LANGFLOW_REDIS_PORT", "6379")
monkeypatch.setenv("LANGFLOW_REDIS_DB", "0")
monkeypatch.setenv("LANGFLOW_REDIS_EXPIRE", "3600")
monkeypatch.setenv("LANGFLOW_REDIS_PASSWORD", "")
monkeypatch.setenv("FLOWER_UNAUTHENTICATED_API", "True")
monkeypatch.setenv("BROKER_URL", "redis://queue:6379/0")
monkeypatch.setenv("RESULT_BACKEND", "redis://queue:6379/0")
monkeypatch.setenv("C_FORCE_ROOT", "true")
@pytest.fixture(name="distributed_client")
def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
# Here we load the .env from ../deploy/.env
from dotenv import load_dotenv
from langflow.services.task import manager
from langflow.core import celery_app
from langflow.services.manager import reinitialize_services, initialize_services
# monkeypatch langflow.services.task.manager.USE_CELERY to True
monkeypatch.setattr(manager, "USE_CELERY", True)
monkeypatch.setattr(
celery_app, "celery_app", celery_app.make_celery("langflow", Config)
)
def get_session_override():
return session
from langflow.main import create_app
app = create_app()
app.dependency_overrides[get_session] = get_session_override
with TestClient(app) as client:
yield client
app.dependency_overrides.clear()
def get_graph(_type="basic"):
@ -119,31 +179,6 @@ def json_vector_store():
return f.read()
@pytest.fixture(name="session")
def session_fixture():
engine = create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
@pytest.fixture(name="client")
def client_fixture(session: Session):
def get_session_override():
return session
from langflow.main import create_app
app = create_app()
app.dependency_overrides[get_session] = get_session_override
with TestClient(app) as client:
yield client
app.dependency_overrides.clear()
# @contextmanager
# def session_getter():
# try:

View file

@ -38,7 +38,7 @@ def langchain_objects_are_equal(obj1, obj2):
# Test build_graph
def test_build_graph(basic_data_graph):
def test_build_graph(client, basic_data_graph):
graph = Graph.from_payload(basic_data_graph)
assert graph is not None
assert len(graph.nodes) == len(basic_data_graph["nodes"])

View file

@ -32,6 +32,7 @@ def sample_agent_creator() -> AgentCreator:
def test_lang_chain_type_creator_to_dict(
client,
sample_lang_chain_type_creator: LangChainTypeCreator,
):
type_dict = sample_lang_chain_type_creator.to_dict()

View file

@ -179,7 +179,11 @@ def test_upload_file(
def test_download_file(
client: TestClient, session: Session, json_flow, active_user, logged_in_headers
client: TestClient,
session: Session,
json_flow,
active_user,
logged_in_headers,
):
flow = orjson.loads(json_flow)
data = flow["data"]