langflow/tests/conftest.py
Cristhian Zanforlin Lousa 452eb224d7
Enhanced Component Testing & General Test Fixes (#1585)
This pull request introduces comprehensive new tests for our components,
significantly improving our project's test coverage. In addition to the
new component tests, this PR addresses and resolves several issues that
were affecting our general test suite, enhancing its reliability and
efficiency.
2024-04-08 14:08:15 -03:00

403 lines
13 KiB
Python

import json
# we need to import tmpdir
import tempfile
from contextlib import contextmanager, suppress
from pathlib import Path
from typing import TYPE_CHECKING, AsyncGenerator
import orjson
import pytest
from fastapi.testclient import TestClient
from httpx import AsyncClient
from sqlmodel import Session, SQLModel, create_engine, select
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
from langflow.graph.graph.base import Graph
from langflow.initial_setup.setup import STARTER_FOLDER_NAME
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.models.flow.model import Flow, FlowCreate
from langflow.services.database.models.user.model import User, UserCreate
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from sqlmodel import Session, SQLModel, create_engine, select
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService
def pytest_configure():
pytest.BASIC_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "basic_example.json"
pytest.COMPLEX_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "complex_example.json"
pytest.COMPLEX_DEPS_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "complex_deps_example.json"
pytest.OPENAPI_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "Openapi.json"
pytest.GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "grouped_chat.json"
pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "one_group_chat.json"
pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json"
pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY = (
Path(__file__).parent.absolute() / "data" / "BasicChatWithPromptAndHistory.json"
)
pytest.CHAT_INPUT = Path(__file__).parent.absolute() / "data" / "ChatInputTest.json"
pytest.TWO_OUTPUTS = Path(__file__).parent.absolute() / "data" / "TwoOutputsTest.json"
pytest.VECTOR_STORE_PATH = Path(__file__).parent.absolute() / "data" / "Vector_store.json"
pytest.CODE_WITH_SYNTAX_ERROR = """
def get_text():
retun "Hello World"
"""
@pytest.fixture(autouse=True)
def check_openai_api_key_in_environment_variables():
import os
assert os.environ.get("OPENAI_API_KEY") is not None, "OPENAI_API_KEY is not set in environment variables"
@pytest.fixture()
async def async_client() -> AsyncGenerator:
from langflow.main import create_app
app = create_app()
async with AsyncClient(app=app, base_url="http://testserver") as client:
yield client
@pytest.fixture(name="session")
def session_fixture():
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
class Config:
broker_url = "redis://localhost:6379/0"
result_backend = "redis://localhost:6379/0"
@pytest.fixture(name="distributed_env")
def setup_env(monkeypatch):
monkeypatch.setenv("LANGFLOW_CACHE_TYPE", "redis")
monkeypatch.setenv("LANGFLOW_REDIS_HOST", "result_backend")
monkeypatch.setenv("LANGFLOW_REDIS_PORT", "6379")
monkeypatch.setenv("LANGFLOW_REDIS_DB", "0")
monkeypatch.setenv("LANGFLOW_REDIS_EXPIRE", "3600")
monkeypatch.setenv("LANGFLOW_REDIS_PASSWORD", "")
monkeypatch.setenv("FLOWER_UNAUTHENTICATED_API", "True")
monkeypatch.setenv("BROKER_URL", "redis://result_backend:6379/0")
monkeypatch.setenv("RESULT_BACKEND", "redis://result_backend:6379/0")
monkeypatch.setenv("C_FORCE_ROOT", "true")
@pytest.fixture(name="distributed_client")
def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
# Here we load the .env from ../deploy/.env
from langflow.core import celery_app
db_dir = tempfile.mkdtemp()
db_path = Path(db_dir) / "test.db"
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
# monkeypatch langflow.services.task.manager.USE_CELERY to True
# monkeypatch.setattr(manager, "USE_CELERY", True)
monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config))
# def get_session_override():
# return session
from langflow.main import create_app
app = create_app()
# app.dependency_overrides[get_session] = get_session_override
with TestClient(app) as client:
yield client
app.dependency_overrides.clear()
monkeypatch.undo()
def get_graph(_type="basic"):
"""Get a graph from a json file"""
if _type == "basic":
path = pytest.BASIC_EXAMPLE_PATH
elif _type == "complex":
path = pytest.COMPLEX_EXAMPLE_PATH
elif _type == "openapi":
path = pytest.OPENAPI_EXAMPLE_PATH
with open(path, "r") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
edges = data_graph["edges"]
return Graph(nodes, edges)
@pytest.fixture
def basic_graph_data():
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
return json.load(f)
@pytest.fixture
def basic_graph():
return get_graph()
@pytest.fixture
def complex_graph():
return get_graph("complex")
@pytest.fixture
def openapi_graph():
return get_graph("openapi")
@pytest.fixture
def json_flow():
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
return f.read()
@pytest.fixture
def grouped_chat_json_flow():
with open(pytest.GROUPED_CHAT_EXAMPLE_PATH, "r") as f:
return f.read()
@pytest.fixture
def one_grouped_chat_json_flow():
with open(pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH, "r") as f:
return f.read()
@pytest.fixture
def vector_store_grouped_json_flow():
with open(pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH, "r") as f:
return f.read()
@pytest.fixture
def json_flow_with_prompt_and_history():
with open(pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY, "r") as f:
return f.read()
@pytest.fixture
def json_vector_store():
with open(pytest.VECTOR_STORE_PATH, "r") as f:
return f.read()
@pytest.fixture
def complex_graph_with_groups():
with open(pytest.COMPLEX_DEPS_EXAMPLE_PATH, "r") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
edges = data_graph["edges"]
return Graph(nodes, edges)
@pytest.fixture(name="client", autouse=True)
def client_fixture(session: Session, monkeypatch):
# Set the database url to a test database
db_dir = tempfile.mkdtemp()
db_path = Path(db_dir) / "test.db"
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
from langflow.main import create_app
app = create_app()
# app.dependency_overrides[get_session] = get_session_override
with TestClient(app) as client:
yield client
# app.dependency_overrides.clear()
monkeypatch.undo()
# clear the temp db
with suppress(FileNotFoundError):
db_path.unlink()
# create a fixture for session_getter above
@pytest.fixture(name="session_getter")
def session_getter_fixture(client):
@contextmanager
def blank_session_getter(db_service: "DatabaseService"):
with Session(db_service.engine) as session:
yield session
yield blank_session_getter
@pytest.fixture
def runner():
return CliRunner()
@pytest.fixture
def test_user(client):
user_data = UserCreate(
username="testuser",
password="testpassword",
)
response = client.post("/api/v1/users", json=user_data.dict())
assert response.status_code == 201
return response.json()
@pytest.fixture(scope="function")
def active_user(client):
db_manager = get_db_service()
with session_getter(db_manager) as session:
user = User(
username="activeuser",
password=get_password_hash("testpassword"),
is_active=True,
is_superuser=False,
)
# check if user exists
if active_user := session.exec(select(User).where(User.username == user.username)).first():
return active_user
session.add(user)
session.commit()
session.refresh(user)
return user
@pytest.fixture
def logged_in_headers(client, active_user):
login_data = {"username": active_user.username, "password": "testpassword"}
response = client.post("/api/v1/login", data=login_data)
assert response.status_code == 200
tokens = response.json()
a_token = tokens["access_token"]
return {"Authorization": f"Bearer {a_token}"}
@pytest.fixture
def flow(client, json_flow: str, active_user):
from langflow.services.database.models.flow.model import FlowCreate
loaded_json = json.loads(json_flow)
flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id)
flow = Flow.model_validate(flow_data)
with session_getter(get_db_service()) as session:
session.add(flow)
session.commit()
session.refresh(flow)
return flow
@pytest.fixture
def json_chat_input():
with open(pytest.CHAT_INPUT, "r") as f:
return f.read()
@pytest.fixture
def json_two_outputs():
with open(pytest.TWO_OUTPUTS, "r") as f:
return f.read()
@pytest.fixture
def added_flow_with_prompt_and_history(client, json_flow_with_prompt_and_history, logged_in_headers):
flow = orjson.loads(json_flow_with_prompt_and_history)
data = flow["data"]
flow = FlowCreate(name="Basic Chat", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
return response.json()
@pytest.fixture
def added_flow_chat_input(client, json_chat_input, logged_in_headers):
flow = orjson.loads(json_chat_input)
data = flow["data"]
flow = FlowCreate(name="Chat Input", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
return response.json()
@pytest.fixture
def added_flow_two_outputs(client, json_two_outputs, logged_in_headers):
flow = orjson.loads(json_two_outputs)
data = flow["data"]
flow = FlowCreate(name="Two Outputs", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
return response.json()
@pytest.fixture
def added_vector_store(client, json_vector_store, logged_in_headers):
vector_store = orjson.loads(json_vector_store)
data = vector_store["data"]
vector_store = FlowCreate(name="Vector Store", description="description", data=data)
response = client.post("api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == vector_store.name
assert response.json()["data"] == vector_store.data
return response.json()
@pytest.fixture
def created_api_key(active_user):
hashed = get_password_hash("random_key")
api_key = ApiKey(
name="test_api_key",
user_id=active_user.id,
api_key="random_key",
hashed_api_key=hashed,
)
db_manager = get_db_service()
with session_getter(db_manager) as session:
if existing_api_key := session.exec(select(ApiKey).where(ApiKey.api_key == api_key.api_key)).first():
return existing_api_key
session.add(api_key)
session.commit()
session.refresh(api_key)
return api_key
@pytest.fixture(name="starter_project")
def get_starter_project(active_user):
# once the client is created, we can get the starter project
with session_getter(get_db_service()) as session:
flow = session.exec(
select(Flow).where(Flow.folder == STARTER_FOLDER_NAME).where(Flow.name == "Basic Prompting (Hello, World)")
).first()
if not flow:
raise ValueError("No starter project found")
new_flow_create = FlowCreate(
name=flow.name,
description=flow.description,
data=flow.data,
user_id=active_user.id,
)
new_flow = Flow.model_validate(new_flow_create, from_attributes=True)
session.add(new_flow)
session.commit()
session.refresh(new_flow)
new_flow_dict = new_flow.model_dump()
return new_flow_dict