🔧 chore(conftest.py): refactor client fixture to use dependency overrides for session and add session fixture for creating a session with an in-memory SQLite database
🔧 chore(conftest.py): add distributed_env fixture to set up environment variables for distributed testing 🔧 chore(conftest.py): add distributed_client fixture for distributed testing with Celery 🔧 chore(conftest.py): remove unused imports and fixtures 🔧 chore(test_cache.py): remove unused client fixture from test_build_graph 🔧 chore(test_creators.py): remove unused client fixture from test_lang_chain_type_creator_to_dict 🔧 chore(test_database.py): remove unused client fixture from test_download_file
This commit is contained in:
parent
f8b38ee162
commit
c88f9bf8a0
4 changed files with 70 additions and 30 deletions
|
|
@ -51,15 +51,75 @@ async def async_client() -> AsyncGenerator:
|
|||
yield client
|
||||
|
||||
|
||||
# Create client fixture for FastAPI
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def client():
|
||||
@pytest.fixture(name="session")
|
||||
def session_fixture():
|
||||
engine = create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture(name="client")
|
||||
def client_fixture(session: Session):
|
||||
def get_session_override():
|
||||
return session
|
||||
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
app.dependency_overrides[get_session] = get_session_override
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class Config:
|
||||
broker_url = "redis://localhost:6379/0"
|
||||
result_backend = "redis://localhost:6379/0"
|
||||
|
||||
|
||||
@pytest.fixture(name="distributed_env")
|
||||
def setup_env(monkeypatch):
|
||||
monkeypatch.setenv("LANGFLOW_CACHE_TYPE", "redis")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_HOST", "queue")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_DB", "0")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_EXPIRE", "3600")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_PASSWORD", "")
|
||||
monkeypatch.setenv("FLOWER_UNAUTHENTICATED_API", "True")
|
||||
monkeypatch.setenv("BROKER_URL", "redis://queue:6379/0")
|
||||
monkeypatch.setenv("RESULT_BACKEND", "redis://queue:6379/0")
|
||||
monkeypatch.setenv("C_FORCE_ROOT", "true")
|
||||
|
||||
|
||||
@pytest.fixture(name="distributed_client")
|
||||
def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
|
||||
# Here we load the .env from ../deploy/.env
|
||||
from dotenv import load_dotenv
|
||||
from langflow.services.task import manager
|
||||
from langflow.core import celery_app
|
||||
from langflow.services.manager import reinitialize_services, initialize_services
|
||||
|
||||
# monkeypatch langflow.services.task.manager.USE_CELERY to True
|
||||
monkeypatch.setattr(manager, "USE_CELERY", True)
|
||||
monkeypatch.setattr(
|
||||
celery_app, "celery_app", celery_app.make_celery("langflow", Config)
|
||||
)
|
||||
|
||||
def get_session_override():
|
||||
return session
|
||||
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
app.dependency_overrides[get_session] = get_session_override
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
|
|
@ -119,31 +179,6 @@ def json_vector_store():
|
|||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture(name="session")
|
||||
def session_fixture():
|
||||
engine = create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture(name="client")
|
||||
def client_fixture(session: Session):
|
||||
def get_session_override():
|
||||
return session
|
||||
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
app.dependency_overrides[get_session] = get_session_override
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# @contextmanager
|
||||
# def session_getter():
|
||||
# try:
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def langchain_objects_are_equal(obj1, obj2):
|
|||
|
||||
|
||||
# Test build_graph
|
||||
def test_build_graph(basic_data_graph):
|
||||
def test_build_graph(client, basic_data_graph):
|
||||
graph = Graph.from_payload(basic_data_graph)
|
||||
assert graph is not None
|
||||
assert len(graph.nodes) == len(basic_data_graph["nodes"])
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ def sample_agent_creator() -> AgentCreator:
|
|||
|
||||
|
||||
def test_lang_chain_type_creator_to_dict(
|
||||
client,
|
||||
sample_lang_chain_type_creator: LangChainTypeCreator,
|
||||
):
|
||||
type_dict = sample_lang_chain_type_creator.to_dict()
|
||||
|
|
|
|||
|
|
@ -179,7 +179,11 @@ def test_upload_file(
|
|||
|
||||
|
||||
def test_download_file(
|
||||
client: TestClient, session: Session, json_flow, active_user, logged_in_headers
|
||||
client: TestClient,
|
||||
session: Session,
|
||||
json_flow,
|
||||
active_user,
|
||||
logged_in_headers,
|
||||
):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue