test: add astra integration test (#2189)
* add first astra integ test framework * use fixtures * remove old tests from merge * Add correct sender type * chore: Update unit test command in GitHub workflow --------- Co-authored-by: ogabrielluiz <gabriel@langflow.org>
This commit is contained in:
parent
5a04adfa1f
commit
ca660cf8df
31 changed files with 211 additions and 12 deletions
0
tests/unit/__init__.py
Normal file
0
tests/unit/__init__.py
Normal file
428
tests/unit/conftest.py
Normal file
428
tests/unit/conftest.py
Normal file
|
|
@ -0,0 +1,428 @@
|
|||
import json
|
||||
import os.path
|
||||
import shutil
|
||||
|
||||
# we need to import tmpdir
|
||||
import tempfile
|
||||
from contextlib import contextmanager, suppress
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, AsyncGenerator
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.initial_setup.setup import STARTER_FOLDER_NAME
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.api_key.model import ApiKey
|
||||
from langflow.services.database.models.flow.model import Flow, FlowCreate
|
||||
from langflow.services.database.models.folder.model import Folder
|
||||
from langflow.services.database.models.user.model import User, UserCreate
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
from sqlmodel.pool import StaticPool
|
||||
from typer.testing import CliRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.service import DatabaseService
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "noclient: don't create a client for this test")
|
||||
data_path = Path(__file__).parent.absolute() / "data"
|
||||
|
||||
pytest.BASIC_EXAMPLE_PATH = data_path / "basic_example.json"
|
||||
pytest.COMPLEX_EXAMPLE_PATH = data_path / "complex_example.json"
|
||||
pytest.OPENAPI_EXAMPLE_PATH = data_path / "Openapi.json"
|
||||
pytest.GROUPED_CHAT_EXAMPLE_PATH = data_path / "grouped_chat.json"
|
||||
pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = data_path / "one_group_chat.json"
|
||||
pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = data_path / "vector_store_grouped.json"
|
||||
|
||||
pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY = data_path / "BasicChatwithPromptandHistory.json"
|
||||
pytest.CHAT_INPUT = data_path / "ChatInputTest.json"
|
||||
pytest.TWO_OUTPUTS = data_path / "TwoOutputsTest.json"
|
||||
pytest.VECTOR_STORE_PATH = data_path / "Vector_store.json"
|
||||
pytest.CODE_WITH_SYNTAX_ERROR = """
|
||||
def get_text():
|
||||
retun "Hello World"
|
||||
"""
|
||||
|
||||
# validate that all the paths are correct and the files exist
|
||||
for path in [
|
||||
pytest.BASIC_EXAMPLE_PATH,
|
||||
pytest.COMPLEX_EXAMPLE_PATH,
|
||||
pytest.OPENAPI_EXAMPLE_PATH,
|
||||
pytest.GROUPED_CHAT_EXAMPLE_PATH,
|
||||
pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH,
|
||||
pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH,
|
||||
pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY,
|
||||
pytest.CHAT_INPUT,
|
||||
pytest.TWO_OUTPUTS,
|
||||
pytest.VECTOR_STORE_PATH,
|
||||
]:
|
||||
assert path.exists(), f"File {path} does not exist. Available files: {list(data_path.iterdir())}"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def check_openai_api_key_in_environment_variables():
|
||||
import os
|
||||
|
||||
assert os.environ.get("OPENAI_API_KEY") is not None, "OPENAI_API_KEY is not set in environment variables"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def async_client() -> AsyncGenerator:
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
async with AsyncClient(app=app, base_url="http://testserver") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(name="session")
|
||||
def session_fixture():
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
class Config:
|
||||
broker_url = "redis://localhost:6379/0"
|
||||
result_backend = "redis://localhost:6379/0"
|
||||
|
||||
|
||||
@pytest.fixture(name="load_flows_dir")
|
||||
def load_flows_dir():
|
||||
tempdir = tempfile.TemporaryDirectory()
|
||||
yield tempdir.name
|
||||
|
||||
|
||||
@pytest.fixture(name="distributed_env")
|
||||
def setup_env(monkeypatch):
|
||||
monkeypatch.setenv("LANGFLOW_CACHE_TYPE", "redis")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_HOST", "result_backend")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_DB", "0")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_EXPIRE", "3600")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_PASSWORD", "")
|
||||
monkeypatch.setenv("FLOWER_UNAUTHENTICATED_API", "True")
|
||||
monkeypatch.setenv("BROKER_URL", "redis://result_backend:6379/0")
|
||||
monkeypatch.setenv("RESULT_BACKEND", "redis://result_backend:6379/0")
|
||||
monkeypatch.setenv("C_FORCE_ROOT", "true")
|
||||
|
||||
|
||||
@pytest.fixture(name="distributed_client")
|
||||
def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
|
||||
# Here we load the .env from ../deploy/.env
|
||||
from langflow.core import celery_app
|
||||
|
||||
db_dir = tempfile.mkdtemp()
|
||||
db_path = Path(db_dir) / "test.db"
|
||||
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
|
||||
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
|
||||
# monkeypatch langflow.services.task.manager.USE_CELERY to True
|
||||
# monkeypatch.setattr(manager, "USE_CELERY", True)
|
||||
monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config))
|
||||
|
||||
# def get_session_override():
|
||||
# return session
|
||||
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
# app.dependency_overrides[get_session] = get_session_override
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
app.dependency_overrides.clear()
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
|
||||
if _type == "basic":
|
||||
path = pytest.BASIC_EXAMPLE_PATH
|
||||
elif _type == "complex":
|
||||
path = pytest.COMPLEX_EXAMPLE_PATH
|
||||
elif _type == "openapi":
|
||||
path = pytest.OPENAPI_EXAMPLE_PATH
|
||||
|
||||
with open(path, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
return Graph(nodes, edges)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_graph_data():
|
||||
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_graph():
|
||||
return get_graph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_graph():
|
||||
return get_graph("complex")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_graph():
|
||||
return get_graph("openapi")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_flow():
|
||||
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def grouped_chat_json_flow():
|
||||
with open(pytest.GROUPED_CHAT_EXAMPLE_PATH, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def one_grouped_chat_json_flow():
|
||||
with open(pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store_grouped_json_flow():
|
||||
with open(pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_flow_with_prompt_and_history():
|
||||
with open(pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_vector_store():
|
||||
with open(pytest.VECTOR_STORE_PATH, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture(name="client", autouse=True)
|
||||
def client_fixture(session: Session, monkeypatch, request, load_flows_dir):
|
||||
# Set the database url to a test database
|
||||
if "noclient" in request.keywords:
|
||||
yield
|
||||
else:
|
||||
db_dir = tempfile.mkdtemp()
|
||||
db_path = Path(db_dir) / "test.db"
|
||||
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
|
||||
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
|
||||
if "load_flows" in request.keywords:
|
||||
shutil.copyfile(
|
||||
pytest.BASIC_EXAMPLE_PATH, os.path.join(load_flows_dir, "c54f9130-f2fa-4a3e-b22a-3856d946351b.json")
|
||||
)
|
||||
monkeypatch.setenv("LANGFLOW_LOAD_FLOWS_PATH", load_flows_dir)
|
||||
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "true")
|
||||
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
# app.dependency_overrides[get_session] = get_session_override
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
# app.dependency_overrides.clear()
|
||||
monkeypatch.undo()
|
||||
# clear the temp db
|
||||
with suppress(FileNotFoundError):
|
||||
db_path.unlink()
|
||||
|
||||
|
||||
# create a fixture for session_getter above
|
||||
@pytest.fixture(name="session_getter")
|
||||
def session_getter_fixture(client):
|
||||
@contextmanager
|
||||
def blank_session_getter(db_service: "DatabaseService"):
|
||||
with Session(db_service.engine) as session:
|
||||
yield session
|
||||
|
||||
yield blank_session_getter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(client):
|
||||
user_data = UserCreate(
|
||||
username="testuser",
|
||||
password="testpassword",
|
||||
)
|
||||
response = client.post("/api/v1/users", json=user_data.dict())
|
||||
assert response.status_code == 201
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def active_user(client):
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
user = User(
|
||||
username="activeuser",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
# check if user exists
|
||||
if active_user := session.exec(select(User).where(User.username == user.username)).first():
|
||||
return active_user
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logged_in_headers(client, active_user):
|
||||
login_data = {"username": active_user.username, "password": "testpassword"}
|
||||
response = client.post("/api/v1/login", data=login_data)
|
||||
assert response.status_code == 200
|
||||
tokens = response.json()
|
||||
a_token = tokens["access_token"]
|
||||
return {"Authorization": f"Bearer {a_token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flow(client, json_flow: str, active_user):
|
||||
from langflow.services.database.models.flow.model import FlowCreate
|
||||
|
||||
loaded_json = json.loads(json_flow)
|
||||
flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id)
|
||||
|
||||
flow = Flow.model_validate(flow_data)
|
||||
with session_getter(get_db_service()) as session:
|
||||
session.add(flow)
|
||||
session.commit()
|
||||
session.refresh(flow)
|
||||
|
||||
return flow
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_chat_input():
|
||||
with open(pytest.CHAT_INPUT, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_two_outputs():
|
||||
with open(pytest.TWO_OUTPUTS, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def added_flow_with_prompt_and_history(client, json_flow_with_prompt_and_history, logged_in_headers):
|
||||
flow = orjson.loads(json_flow_with_prompt_and_history)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Basic Chat", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def added_flow_chat_input(client, json_chat_input, logged_in_headers):
|
||||
flow = orjson.loads(json_chat_input)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Chat Input", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def added_flow_two_outputs(client, json_two_outputs, logged_in_headers):
|
||||
flow = orjson.loads(json_two_outputs)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Two Outputs", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def added_vector_store(client, json_vector_store, logged_in_headers):
|
||||
vector_store = orjson.loads(json_vector_store)
|
||||
data = vector_store["data"]
|
||||
vector_store = FlowCreate(name="Vector Store", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == vector_store.name
|
||||
assert response.json()["data"] == vector_store.data
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def created_api_key(active_user):
|
||||
hashed = get_password_hash("random_key")
|
||||
api_key = ApiKey(
|
||||
name="test_api_key",
|
||||
user_id=active_user.id,
|
||||
api_key="random_key",
|
||||
hashed_api_key=hashed,
|
||||
)
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
if existing_api_key := session.exec(select(ApiKey).where(ApiKey.api_key == api_key.api_key)).first():
|
||||
return existing_api_key
|
||||
session.add(api_key)
|
||||
session.commit()
|
||||
session.refresh(api_key)
|
||||
return api_key
|
||||
|
||||
|
||||
@pytest.fixture(name="starter_project")
|
||||
def get_starter_project(active_user):
|
||||
# once the client is created, we can get the starter project
|
||||
with session_getter(get_db_service()) as session:
|
||||
flow = session.exec(
|
||||
select(Flow)
|
||||
.where(Flow.folder.has(Folder.name == STARTER_FOLDER_NAME))
|
||||
.where(Flow.name == "Basic Prompting (Hello, World)")
|
||||
).first()
|
||||
if not flow:
|
||||
raise ValueError("No starter project found")
|
||||
|
||||
new_flow_create = FlowCreate(
|
||||
name=flow.name,
|
||||
description=flow.description,
|
||||
data=flow.data,
|
||||
user_id=active_user.id,
|
||||
)
|
||||
new_flow = Flow.model_validate(new_flow_create, from_attributes=True)
|
||||
session.add(new_flow)
|
||||
session.commit()
|
||||
session.refresh(new_flow)
|
||||
new_flow_dict = new_flow.model_dump()
|
||||
return new_flow_dict
|
||||
45
tests/unit/test_api_key.py
Normal file
45
tests/unit/test_api_key.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import pytest
|
||||
from langflow.services.database.models.api_key import ApiKeyCreate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key(client, logged_in_headers, active_user):
|
||||
api_key = ApiKeyCreate(name="test-api-key")
|
||||
|
||||
response = client.post("api/v1/api_key", data=api_key.model_dump_json(), headers=logged_in_headers)
|
||||
assert response.status_code == 200, response.text
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_get_api_keys(client, logged_in_headers, api_key):
|
||||
response = client.get("api/v1/api_key", headers=logged_in_headers)
|
||||
assert response.status_code == 200, response.text
|
||||
data = response.json()
|
||||
assert "total_count" in data
|
||||
assert "user_id" in data
|
||||
assert "api_keys" in data
|
||||
assert any("test-api-key" in api_key["name"] for api_key in data["api_keys"])
|
||||
# assert all api keys in data["api_keys"] are masked
|
||||
assert all("**" in api_key["api_key"] for api_key in data["api_keys"])
|
||||
|
||||
|
||||
def test_create_api_key(client, logged_in_headers):
|
||||
api_key_name = "test-api-key"
|
||||
response = client.post("api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "name" in data and data["name"] == api_key_name
|
||||
assert "api_key" in data
|
||||
# When creating the API key is returned which is
|
||||
# the only time the API key is unmasked
|
||||
assert "**" not in data["api_key"]
|
||||
|
||||
|
||||
def test_delete_api_key(client, logged_in_headers, active_user, api_key):
|
||||
# Assuming a function to create a test API key, returning the key ID
|
||||
api_key_id = api_key["id"]
|
||||
response = client.delete(f"api/v1/api_key/{api_key_id}", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["detail"] == "API Key deleted"
|
||||
# Optionally, add a follow-up check to ensure that the key is actually removed from the database
|
||||
47
tests/unit/test_cache.py
Normal file
47
tests/unit/test_cache.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.graph import Graph
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
if _type == "basic":
|
||||
path = pytest.BASIC_EXAMPLE_PATH
|
||||
elif _type == "complex":
|
||||
path = pytest.COMPLEX_EXAMPLE_PATH
|
||||
elif _type == "openapi":
|
||||
path = pytest.OPENAPI_EXAMPLE_PATH
|
||||
|
||||
with open(path, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
return flow_graph["data"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_data_graph():
|
||||
return get_graph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_data_graph():
|
||||
return get_graph("complex")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_data_graph():
|
||||
return get_graph("openapi")
|
||||
|
||||
|
||||
def langchain_objects_are_equal(obj1, obj2):
|
||||
return str(obj1) == str(obj2)
|
||||
|
||||
|
||||
# Test build_graph
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_graph(client, basic_data_graph):
|
||||
graph = Graph.from_payload(basic_data_graph)
|
||||
assert graph is not None
|
||||
assert len(graph.vertices) == len(basic_data_graph["nodes"])
|
||||
assert len(graph.edges) == len(basic_data_graph["edges"])
|
||||
37
tests/unit/test_cli.py
Normal file
37
tests/unit/test_cli.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from pathlib import Path
|
||||
from tempfile import tempdir
|
||||
|
||||
import pytest
|
||||
from langflow.__main__ import app
|
||||
from langflow.services import deps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_settings():
|
||||
return [
|
||||
"--backend-only",
|
||||
"--no-open-browser",
|
||||
]
|
||||
|
||||
|
||||
def test_components_path(runner, client, default_settings):
|
||||
# Create a foldr in the tmp directory
|
||||
|
||||
temp_dir = Path(tempdir)
|
||||
# create a "components" folder
|
||||
temp_dir = temp_dir / "components"
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["run", "--components-path", str(temp_dir), *default_settings],
|
||||
)
|
||||
assert result.exit_code == 0, result.stdout
|
||||
settings_service = deps.get_settings_service()
|
||||
assert str(temp_dir) in settings_service.settings.components_path
|
||||
|
||||
|
||||
def test_superuser(runner, client, session):
|
||||
result = runner.invoke(app, ["superuser"], input="admin\nadmin\n")
|
||||
assert result.exit_code == 0, result.stdout
|
||||
assert "Superuser created successfully." in result.stdout
|
||||
519
tests/unit/test_custom_component.py
Normal file
519
tests/unit/test_custom_component.py
Normal file
|
|
@ -0,0 +1,519 @@
|
|||
import ast
|
||||
import types
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
|
||||
from langflow.custom.custom_component.component import Component, ComponentCodeNullError
|
||||
from langflow.services.database.models.flow import Flow, FlowCreate
|
||||
|
||||
code_default = """
|
||||
from langflow.field_typing import Prompt
|
||||
from langflow.custom import CustomComponent
|
||||
|
||||
from langflow.field_typing import BaseLanguageModel
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.documents import Document
|
||||
|
||||
import requests
|
||||
|
||||
class YourComponent(CustomComponent):
|
||||
display_name: str = "Your Component"
|
||||
description: str = "Your description"
|
||||
field_config = { "url": { "multiline": True, "required": True } }
|
||||
|
||||
def build(self, url: str, llm: BaseLanguageModel, template: Prompt) -> Document:
|
||||
response = requests.get(url)
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
result = chain.run(response.text[:300])
|
||||
return Document(page_content=str(result))
|
||||
"""
|
||||
|
||||
|
||||
def test_code_parser_init():
|
||||
"""
|
||||
Test the initialization of the CodeParser class.
|
||||
"""
|
||||
parser = CodeParser(code_default)
|
||||
assert parser.code == code_default
|
||||
|
||||
|
||||
def test_code_parser_get_tree():
|
||||
"""
|
||||
Test the __get_tree method of the CodeParser class.
|
||||
"""
|
||||
parser = CodeParser(code_default)
|
||||
tree = parser.get_tree()
|
||||
assert isinstance(tree, ast.AST)
|
||||
|
||||
|
||||
def test_code_parser_syntax_error():
|
||||
"""
|
||||
Test the __get_tree method raises the
|
||||
CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
code_syntax_error = "zzz import os"
|
||||
|
||||
parser = CodeParser(code_syntax_error)
|
||||
with pytest.raises(CodeSyntaxError):
|
||||
parser.get_tree()
|
||||
|
||||
|
||||
def test_component_init():
|
||||
"""
|
||||
Test the initialization of the Component class.
|
||||
"""
|
||||
component = Component(code=code_default, function_entrypoint_name="build")
|
||||
assert component.code == code_default
|
||||
assert component.function_entrypoint_name == "build"
|
||||
|
||||
|
||||
def test_component_get_code_tree():
|
||||
"""
|
||||
Test the get_code_tree method of the Component class.
|
||||
"""
|
||||
component = Component(code=code_default, function_entrypoint_name="build")
|
||||
tree = component.get_code_tree(component.code)
|
||||
assert "imports" in tree
|
||||
|
||||
|
||||
def test_component_code_null_error():
|
||||
"""
|
||||
Test the get_function method raises the
|
||||
ComponentCodeNullError when the code is empty.
|
||||
"""
|
||||
component = Component(code="", function_entrypoint_name="")
|
||||
with pytest.raises(ComponentCodeNullError):
|
||||
component.get_function()
|
||||
|
||||
|
||||
def test_custom_component_init():
|
||||
"""
|
||||
Test the initialization of the CustomComponent class.
|
||||
"""
|
||||
function_entrypoint_name = "build"
|
||||
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name=function_entrypoint_name)
|
||||
assert custom_component.code == code_default
|
||||
assert custom_component.function_entrypoint_name == function_entrypoint_name
|
||||
|
||||
|
||||
def test_custom_component_build_template_config():
|
||||
"""
|
||||
Test the build_template_config property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
|
||||
config = custom_component.build_template_config()
|
||||
assert isinstance(config, dict)
|
||||
|
||||
|
||||
def test_custom_component_get_function():
|
||||
"""
|
||||
Test the get_function property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
|
||||
my_function = custom_component.get_function()
|
||||
assert isinstance(my_function, types.FunctionType)
|
||||
|
||||
|
||||
def test_code_parser_parse_imports_import():
|
||||
"""
|
||||
Test the parse_imports method of the CodeParser
|
||||
class with an import statement.
|
||||
"""
|
||||
parser = CodeParser(code_default)
|
||||
tree = parser.get_tree()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
parser.parse_imports(node)
|
||||
assert "requests" in parser.data["imports"]
|
||||
|
||||
|
||||
def test_code_parser_parse_imports_importfrom():
|
||||
"""
|
||||
Test the parse_imports method of the CodeParser
|
||||
class with an import from statement.
|
||||
"""
|
||||
parser = CodeParser("from os import path")
|
||||
tree = parser.get_tree()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
parser.parse_imports(node)
|
||||
assert ("os", "path") in parser.data["imports"]
|
||||
|
||||
|
||||
def test_code_parser_parse_functions():
|
||||
"""
|
||||
Test the parse_functions method of the CodeParser class.
|
||||
"""
|
||||
parser = CodeParser("def test(): pass")
|
||||
tree = parser.get_tree()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
parser.parse_functions(node)
|
||||
assert len(parser.data["functions"]) == 1
|
||||
assert parser.data["functions"][0]["name"] == "test"
|
||||
|
||||
|
||||
def test_code_parser_parse_classes():
|
||||
"""
|
||||
Test the parse_classes method of the CodeParser class.
|
||||
"""
|
||||
parser = CodeParser("class Test: pass")
|
||||
tree = parser.get_tree()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
parser.parse_classes(node)
|
||||
assert len(parser.data["classes"]) == 1
|
||||
assert parser.data["classes"][0]["name"] == "Test"
|
||||
|
||||
|
||||
def test_code_parser_parse_global_vars():
|
||||
"""
|
||||
Test the parse_global_vars method of the CodeParser class.
|
||||
"""
|
||||
parser = CodeParser("x = 1")
|
||||
tree = parser.get_tree()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Assign):
|
||||
parser.parse_global_vars(node)
|
||||
assert len(parser.data["global_vars"]) == 1
|
||||
assert parser.data["global_vars"][0]["targets"] == ["x"]
|
||||
|
||||
|
||||
def test_component_get_function_valid():
|
||||
"""
|
||||
Test the get_function method of the Component
|
||||
class with valid code and function_entrypoint_name.
|
||||
"""
|
||||
component = Component(code="def build(): pass", function_entrypoint_name="build")
|
||||
my_function = component.get_function()
|
||||
assert callable(my_function)
|
||||
|
||||
|
||||
def test_custom_component_get_function_entrypoint_args():
|
||||
"""
|
||||
Test the get_function_entrypoint_args
|
||||
property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
|
||||
args = custom_component.get_function_entrypoint_args
|
||||
assert len(args) == 4
|
||||
assert args[0]["name"] == "self"
|
||||
assert args[1]["name"] == "url"
|
||||
assert args[2]["name"] == "llm"
|
||||
|
||||
|
||||
def test_custom_component_get_function_entrypoint_return_type():
|
||||
"""
|
||||
Test the get_function_entrypoint_return_type
|
||||
property of the CustomComponent class.
|
||||
"""
|
||||
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
|
||||
return_type = custom_component.get_function_entrypoint_return_type
|
||||
assert return_type == [Document]
|
||||
|
||||
|
||||
def test_custom_component_get_main_class_name():
|
||||
"""
|
||||
Test the get_main_class_name property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
|
||||
class_name = custom_component.get_main_class_name
|
||||
assert class_name == "YourComponent"
|
||||
|
||||
|
||||
def test_custom_component_get_function_valid():
|
||||
"""
|
||||
Test the get_function property of the CustomComponent
|
||||
class with valid code and function_entrypoint_name.
|
||||
"""
|
||||
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
|
||||
my_function = custom_component.get_function
|
||||
assert callable(my_function)
|
||||
|
||||
|
||||
def test_code_parser_parse_arg_no_annotation():
|
||||
"""
|
||||
Test the parse_arg method of the CodeParser class without an annotation.
|
||||
"""
|
||||
parser = CodeParser("")
|
||||
arg = ast.arg(arg="x", annotation=None)
|
||||
result = parser.parse_arg(arg, None)
|
||||
assert result["name"] == "x"
|
||||
assert "type" not in result
|
||||
|
||||
|
||||
def test_code_parser_parse_arg_with_annotation():
|
||||
"""
|
||||
Test the parse_arg method of the CodeParser class with an annotation.
|
||||
"""
|
||||
parser = CodeParser("")
|
||||
arg = ast.arg(arg="x", annotation=ast.Name(id="int", ctx=ast.Load()))
|
||||
result = parser.parse_arg(arg, None)
|
||||
assert result["name"] == "x"
|
||||
assert result["type"] == "int"
|
||||
|
||||
|
||||
def test_code_parser_parse_callable_details_no_args():
|
||||
"""
|
||||
Test the parse_callable_details method of the
|
||||
CodeParser class with a function with no arguments.
|
||||
"""
|
||||
parser = CodeParser("")
|
||||
node = ast.FunctionDef(
|
||||
name="test",
|
||||
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
|
||||
body=[],
|
||||
decorator_list=[],
|
||||
returns=None,
|
||||
)
|
||||
result = parser.parse_callable_details(node)
|
||||
assert result["name"] == "test"
|
||||
assert len(result["args"]) == 0
|
||||
|
||||
|
||||
def test_code_parser_parse_assign():
|
||||
"""
|
||||
Test the parse_assign method of the CodeParser class.
|
||||
"""
|
||||
parser = CodeParser("")
|
||||
stmt = ast.Assign(targets=[ast.Name(id="x", ctx=ast.Store())], value=ast.Num(n=1))
|
||||
result = parser.parse_assign(stmt)
|
||||
assert result["name"] == "x"
|
||||
assert result["value"] == "1"
|
||||
|
||||
|
||||
def test_code_parser_parse_ann_assign():
|
||||
"""
|
||||
Test the parse_ann_assign method of the CodeParser class.
|
||||
"""
|
||||
parser = CodeParser("")
|
||||
stmt = ast.AnnAssign(
|
||||
target=ast.Name(id="x", ctx=ast.Store()),
|
||||
annotation=ast.Name(id="int", ctx=ast.Load()),
|
||||
value=ast.Num(n=1),
|
||||
simple=1,
|
||||
)
|
||||
result = parser.parse_ann_assign(stmt)
|
||||
assert result["name"] == "x"
|
||||
assert result["value"] == "1"
|
||||
assert result["annotation"] == "int"
|
||||
|
||||
|
||||
def test_code_parser_parse_function_def_not_init():
|
||||
"""
|
||||
Test the parse_function_def method of the
|
||||
CodeParser class with a function that is not __init__.
|
||||
"""
|
||||
parser = CodeParser("")
|
||||
stmt = ast.FunctionDef(
|
||||
name="test",
|
||||
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
|
||||
body=[],
|
||||
decorator_list=[],
|
||||
returns=None,
|
||||
)
|
||||
result, is_init = parser.parse_function_def(stmt)
|
||||
assert result["name"] == "test"
|
||||
assert not is_init
|
||||
|
||||
|
||||
def test_code_parser_parse_function_def_init():
|
||||
"""
|
||||
Test the parse_function_def method of the
|
||||
CodeParser class with an __init__ function.
|
||||
"""
|
||||
parser = CodeParser("")
|
||||
stmt = ast.FunctionDef(
|
||||
name="__init__",
|
||||
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
|
||||
body=[],
|
||||
decorator_list=[],
|
||||
returns=None,
|
||||
)
|
||||
result, is_init = parser.parse_function_def(stmt)
|
||||
assert result["name"] == "__init__"
|
||||
assert is_init
|
||||
|
||||
|
||||
def test_component_get_code_tree_syntax_error():
|
||||
"""
|
||||
Test the get_code_tree method of the Component class
|
||||
raises the CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
component = Component(code="import os as", function_entrypoint_name="build")
|
||||
with pytest.raises(CodeSyntaxError):
|
||||
component.get_code_tree(component.code)
|
||||
|
||||
|
||||
def test_custom_component_class_template_validation_no_code():
|
||||
"""
|
||||
Test the _class_template_validation method of the CustomComponent class
|
||||
raises the HTTPException when the code is None.
|
||||
"""
|
||||
custom_component = CustomComponent(code=None, function_entrypoint_name="build")
|
||||
with pytest.raises(TypeError):
|
||||
custom_component.get_function()
|
||||
|
||||
|
||||
def test_custom_component_get_code_tree_syntax_error():
|
||||
"""
|
||||
Test the get_code_tree method of the CustomComponent class
|
||||
raises the CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
custom_component = CustomComponent(code="import os as", function_entrypoint_name="build")
|
||||
with pytest.raises(CodeSyntaxError):
|
||||
custom_component.get_code_tree(custom_component.code)
|
||||
|
||||
|
||||
def test_custom_component_get_function_entrypoint_args_no_args():
|
||||
"""
|
||||
Test the get_function_entrypoint_args property of
|
||||
the CustomComponent class with a build method with no arguments.
|
||||
"""
|
||||
my_code = """
|
||||
class MyMainClass(CustomComponent):
|
||||
def build():
|
||||
pass"""
|
||||
|
||||
custom_component = CustomComponent(code=my_code, function_entrypoint_name="build")
|
||||
args = custom_component.get_function_entrypoint_args
|
||||
assert len(args) == 0
|
||||
|
||||
|
||||
def test_custom_component_get_function_entrypoint_return_type_no_return_type():
|
||||
"""
|
||||
Test the get_function_entrypoint_return_type property of the
|
||||
CustomComponent class with a build method with no return type.
|
||||
"""
|
||||
my_code = """
|
||||
class MyClass(CustomComponent):
|
||||
def build():
|
||||
pass"""
|
||||
|
||||
custom_component = CustomComponent(code=my_code, function_entrypoint_name="build")
|
||||
return_type = custom_component.get_function_entrypoint_return_type
|
||||
assert return_type == []
|
||||
|
||||
|
||||
def test_custom_component_get_main_class_name_no_main_class():
|
||||
"""
|
||||
Test the get_main_class_name property of the
|
||||
CustomComponent class when there is no main class.
|
||||
"""
|
||||
my_code = """
|
||||
def build():
|
||||
pass"""
|
||||
|
||||
custom_component = CustomComponent(code=my_code, function_entrypoint_name="build")
|
||||
class_name = custom_component.get_main_class_name
|
||||
assert class_name == ""
|
||||
|
||||
|
||||
def test_custom_component_build_not_implemented():
|
||||
"""
|
||||
Test the build method of the CustomComponent
|
||||
class raises the NotImplementedError.
|
||||
"""
|
||||
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
|
||||
with pytest.raises(NotImplementedError):
|
||||
custom_component.build()
|
||||
|
||||
|
||||
def test_build_config_no_code():
|
||||
component = CustomComponent(code=None)
|
||||
|
||||
assert component.get_function_entrypoint_args == []
|
||||
assert component.get_function_entrypoint_return_type == []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def component(client, active_user):
|
||||
return CustomComponent(
|
||||
user_id=active_user.id,
|
||||
field_config={
|
||||
"fields": {
|
||||
"llm": {"type": "str"},
|
||||
"url": {"type": "str"},
|
||||
"year": {"type": "int"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_flow(db):
|
||||
flow_data = {
|
||||
"nodes": [{"id": "1"}, {"id": "2"}],
|
||||
"edges": [{"source": "1", "target": "2"}],
|
||||
}
|
||||
|
||||
# Create flow
|
||||
flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data)
|
||||
|
||||
# Add to database
|
||||
db.add(flow)
|
||||
db.commit()
|
||||
|
||||
yield flow
|
||||
|
||||
# Clean up
|
||||
db.delete(flow)
|
||||
db.commit()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def db(app):
|
||||
# Setup database for tests
|
||||
yield app.db
|
||||
|
||||
# Teardown
|
||||
app.db.drop_all()
|
||||
|
||||
|
||||
def test_list_flows_return_type(component):
|
||||
flows = component.list_flows()
|
||||
assert isinstance(flows, list)
|
||||
|
||||
|
||||
def test_list_flows_flow_objects(component):
|
||||
flows = component.list_flows()
|
||||
assert all(isinstance(flow, Flow) for flow in flows)
|
||||
|
||||
|
||||
def test_build_config_return_type(component):
|
||||
config = component.build_config()
|
||||
assert isinstance(config, dict)
|
||||
|
||||
|
||||
def test_build_config_has_fields(component):
|
||||
config = component.build_config()
|
||||
assert "fields" in config
|
||||
|
||||
|
||||
def test_build_config_fields_dict(component):
|
||||
config = component.build_config()
|
||||
assert isinstance(config["fields"], dict)
|
||||
|
||||
|
||||
def test_build_config_field_keys(component):
|
||||
config = component.build_config()
|
||||
assert all(isinstance(key, str) for key in config["fields"])
|
||||
|
||||
|
||||
def test_build_config_field_values_dict(component):
|
||||
config = component.build_config()
|
||||
assert all(isinstance(value, dict) for value in config["fields"].values())
|
||||
|
||||
|
||||
def test_build_config_field_value_keys(component):
|
||||
config = component.build_config()
|
||||
field_values = config["fields"].values()
|
||||
assert all("type" in value for value in field_values)
|
||||
186
tests/unit/test_data_components.py
Normal file
186
tests/unit/test_data_components.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
from dictdiffer import diff
|
||||
from httpx import Response
|
||||
|
||||
from langflow.components import data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_request():
|
||||
# This fixture provides an instance of APIRequest for each test case
|
||||
return data.APIRequest()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_successful_get_request(api_request):
|
||||
# Mocking a successful GET request
|
||||
url = "https://example.com/api/test"
|
||||
method = "GET"
|
||||
mock_response = {"success": True}
|
||||
respx.get(url).mock(return_value=Response(200, json=mock_response))
|
||||
|
||||
# Making the request
|
||||
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url)
|
||||
|
||||
# Assertions
|
||||
assert result.data["status_code"] == 200
|
||||
assert result.data["result"] == mock_response
|
||||
|
||||
|
||||
def test_parse_curl(api_request):
|
||||
# Arrange
|
||||
field_value = (
|
||||
"curl -X GET https://example.com/api/test -H 'Content-Type: application/json' -d '{\"key\": \"value\"}'"
|
||||
)
|
||||
build_config = {
|
||||
"method": {"value": ""},
|
||||
"urls": {"value": []},
|
||||
"headers": {},
|
||||
"body": {},
|
||||
}
|
||||
# Act
|
||||
new_build_config = api_request.parse_curl(field_value, build_config.copy())
|
||||
|
||||
# Assert
|
||||
assert new_build_config["method"]["value"] == "GET"
|
||||
assert new_build_config["urls"]["value"] == ["https://example.com/api/test"]
|
||||
assert new_build_config["headers"]["value"] == {"Content-Type": "application/json"}
|
||||
assert new_build_config["body"]["value"] == {"key": "value"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_failed_request(api_request):
|
||||
# Mocking a failed GET request
|
||||
url = "https://example.com/api/test"
|
||||
method = "GET"
|
||||
respx.get(url).mock(return_value=Response(404))
|
||||
|
||||
# Making the request
|
||||
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url)
|
||||
|
||||
# Assertions
|
||||
assert result.data["status_code"] == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_timeout(api_request):
|
||||
# Mocking a timeout
|
||||
url = "https://example.com/api/timeout"
|
||||
method = "GET"
|
||||
respx.get(url).mock(side_effect=httpx.TimeoutException(message="Timeout", request=None))
|
||||
|
||||
# Making the request
|
||||
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url, timeout=1)
|
||||
|
||||
# Assertions
|
||||
assert result.data["status_code"] == 408
|
||||
assert result.data["error"] == "Request timed out"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_build_with_multiple_urls(api_request):
|
||||
# This test depends on having a working internet connection and accessible URLs
|
||||
# It's better to mock these requests using respx or a similar library
|
||||
|
||||
# Setup for multiple URLs
|
||||
method = "GET"
|
||||
urls = ["https://example.com/api/one", "https://example.com/api/two"]
|
||||
# You would mock these requests similarly to the single request tests
|
||||
for url in urls:
|
||||
respx.get(url).mock(return_value=Response(200, json={"success": True}))
|
||||
|
||||
# Do I have to mock the async client?
|
||||
#
|
||||
|
||||
# Execute the build method
|
||||
results = await api_request.build(method=method, urls=urls)
|
||||
|
||||
# Assertions
|
||||
assert len(results) == len(urls)
|
||||
|
||||
|
||||
@patch("langflow.components.data.Directory.parallel_load_records")
|
||||
@patch("langflow.components.data.Directory.retrieve_file_paths")
|
||||
@patch("langflow.components.data.DirectoryComponent.resolve_path")
|
||||
def test_directory_component_build_with_multithreading(
|
||||
mock_resolve_path, mock_retrieve_file_paths, mock_parallel_load_records
|
||||
):
|
||||
# Arrange
|
||||
directory_component = data.DirectoryComponent()
|
||||
path = os.path.dirname(os.path.abspath(__file__))
|
||||
depth = 1
|
||||
max_concurrency = 2
|
||||
load_hidden = False
|
||||
recursive = True
|
||||
silent_errors = False
|
||||
use_multithreading = True
|
||||
|
||||
mock_resolve_path.return_value = path
|
||||
mock_retrieve_file_paths.return_value = [
|
||||
os.path.join(path, file) for file in os.listdir(path) if file.endswith(".py")
|
||||
]
|
||||
mock_parallel_load_records.return_value = [Mock()]
|
||||
|
||||
# Act
|
||||
directory_component.build(
|
||||
path,
|
||||
depth,
|
||||
max_concurrency,
|
||||
load_hidden,
|
||||
recursive,
|
||||
silent_errors,
|
||||
use_multithreading,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_resolve_path.assert_called_once_with(path)
|
||||
mock_retrieve_file_paths.assert_called_once_with(path, load_hidden, recursive, depth)
|
||||
mock_parallel_load_records.assert_called_once_with(
|
||||
mock_retrieve_file_paths.return_value, silent_errors, max_concurrency
|
||||
)
|
||||
|
||||
|
||||
def test_directory_without_mocks():
|
||||
directory_component = data.DirectoryComponent()
|
||||
from langflow.initial_setup import setup
|
||||
from langflow.initial_setup.setup import load_starter_projects
|
||||
|
||||
_, projects = zip(*load_starter_projects())
|
||||
# the setup module has a folder where the projects are stored
|
||||
# the contents of that folder are in the projects variable
|
||||
# the directory component can be used to load the projects
|
||||
# and we can validate if the contents are the same as the projects variable
|
||||
setup_path = Path(setup.__file__).parent / "starter_projects"
|
||||
results = directory_component.build(str(setup_path), use_multithreading=False)
|
||||
assert len(results) == len(projects)
|
||||
# each result is a Record that contains the content attribute
|
||||
# each are dict that are exactly the same as one of the projects
|
||||
for i, result in enumerate(results):
|
||||
assert result.text in projects, list(diff(result.text, projects[i]))
|
||||
|
||||
# in ../docs/docs/components there are many mdx files
|
||||
# check if the directory component can load them
|
||||
# just check if the number of results is the same as the number of files
|
||||
docs_path = Path(__file__).parent.parent / "docs" / "docs" / "components"
|
||||
results = directory_component.build(str(docs_path), use_multithreading=False)
|
||||
docs_files = list(docs_path.glob("*.mdx"))
|
||||
assert len(results) == len(docs_files)
|
||||
|
||||
|
||||
def test_url_component():
|
||||
url_component = data.URLComponent()
|
||||
# the url component can be used to load the contents of a website
|
||||
records = url_component.build(["https://langflow.org"])
|
||||
assert all(record.data for record in records)
|
||||
assert all(record.text for record in records)
|
||||
assert all(record.source for record in records)
|
||||
277
tests/unit/test_database.py
Normal file
277
tests/unit/test_database.py
Normal file
|
|
@ -0,0 +1,277 @@
|
|||
from uuid import UUID, uuid4
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlmodel import Session
|
||||
|
||||
from langflow.api.v1.schemas import FlowListCreate
|
||||
from langflow.initial_setup.setup import load_starter_projects
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def json_style():
|
||||
# class FlowStyleBase(SQLModel):
|
||||
# color: str = Field(index=True)
|
||||
# emoji: str = Field(index=False)
|
||||
# flow_id: UUID = Field(default=None, foreign_key="flow.id")
|
||||
return orjson_dumps(
|
||||
{
|
||||
"color": "red",
|
||||
"emoji": "👍",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_create_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name=str(uuid4()), description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
# flow is optional so we can create a flow without a flow
|
||||
flow = FlowCreate(name="Test Flow")
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(exclude_unset=True), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
||||
|
||||
def test_read_flows(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow_data = orjson.loads(json_flow)
|
||||
data = flow_data["data"]
|
||||
flow = FlowCreate(name=str(uuid4()), description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
||||
flow = FlowCreate(name=str(uuid4()), description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
||||
response = client.get("api/v1/flows/", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) > 0
|
||||
|
||||
|
||||
def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
flow_id = response.json()["id"] # flow_id should be a UUID but is a string
|
||||
# turn it into a UUID
|
||||
flow_id = UUID(flow_id)
|
||||
|
||||
response = client.get(f"api/v1/flows/{flow_id}", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
||||
|
||||
def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
|
||||
flow_id = response.json()["id"]
|
||||
updated_flow = FlowUpdate(
|
||||
name="Updated Flow",
|
||||
description="updated description",
|
||||
data=data,
|
||||
)
|
||||
response = client.patch(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == updated_flow.name
|
||||
assert response.json()["description"] == updated_flow.description
|
||||
# assert response.json()["data"] == updated_flow.data
|
||||
|
||||
|
||||
def test_delete_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
flow_id = response.json()["id"]
|
||||
response = client.delete(f"api/v1/flows/{flow_id}", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "Flow deleted successfully"
|
||||
|
||||
|
||||
def test_delete_flows(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
# Create ten flows
|
||||
number_of_flows = 10
|
||||
flows = [FlowCreate(name=f"Flow {i}", description="description", data={}) for i in range(number_of_flows)]
|
||||
flow_ids = []
|
||||
for flow in flows:
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
flow_ids.append(response.json()["id"])
|
||||
|
||||
response = client.request("DELETE", "api/v1/flows/", headers=logged_in_headers, json=flow_ids)
|
||||
assert response.status_code == 200, response.content
|
||||
assert response.json().get("deleted") == number_of_flows
|
||||
|
||||
|
||||
def test_create_flows(client: TestClient, session: Session, json_flow: str, logged_in_headers):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
# Create test data
|
||||
flow_list = FlowListCreate(
|
||||
flows=[
|
||||
FlowCreate(name="Flow 1", description="description", data=data),
|
||||
FlowCreate(name="Flow 2", description="description", data=data),
|
||||
]
|
||||
)
|
||||
# Make request to endpoint
|
||||
response = client.post("api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers)
|
||||
# Check response status code
|
||||
assert response.status_code == 201
|
||||
# Check response data
|
||||
response_data = response.json()
|
||||
assert len(response_data) == 2
|
||||
assert response_data[0]["name"] == "Flow 1"
|
||||
assert response_data[0]["description"] == "description"
|
||||
assert response_data[0]["data"] == data
|
||||
assert response_data[1]["name"] == "Flow 2"
|
||||
assert response_data[1]["description"] == "description"
|
||||
assert response_data[1]["data"] == data
|
||||
|
||||
|
||||
def test_upload_file(client: TestClient, session: Session, json_flow: str, logged_in_headers):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
# Create test data
|
||||
flow_list = FlowListCreate(
|
||||
flows=[
|
||||
FlowCreate(name="Flow 1", description="description", data=data),
|
||||
FlowCreate(name="Flow 2", description="description", data=data),
|
||||
]
|
||||
)
|
||||
file_contents = orjson_dumps(flow_list.dict())
|
||||
response = client.post(
|
||||
"api/v1/flows/upload/",
|
||||
files={"file": ("examples.json", file_contents, "application/json")},
|
||||
headers=logged_in_headers,
|
||||
)
|
||||
# Check response status code
|
||||
assert response.status_code == 201
|
||||
# Check response data
|
||||
response_data = response.json()
|
||||
assert len(response_data) == 2
|
||||
assert response_data[0]["name"] == "Flow 1"
|
||||
assert response_data[0]["description"] == "description"
|
||||
assert response_data[0]["data"] == data
|
||||
assert response_data[1]["name"] == "Flow 2"
|
||||
assert response_data[1]["description"] == "description"
|
||||
assert response_data[1]["data"] == data
|
||||
|
||||
|
||||
def test_download_file(
|
||||
client: TestClient,
|
||||
session: Session,
|
||||
json_flow,
|
||||
active_user,
|
||||
logged_in_headers,
|
||||
):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
# Create test data
|
||||
flow_list = FlowListCreate(
|
||||
flows=[
|
||||
FlowCreate(name="Flow 1", description="description", data=data),
|
||||
FlowCreate(name="Flow 2", description="description", data=data),
|
||||
]
|
||||
)
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
for flow in flow_list.flows:
|
||||
flow.user_id = active_user.id
|
||||
db_flow = Flow.model_validate(flow, from_attributes=True)
|
||||
session.add(db_flow)
|
||||
session.commit()
|
||||
# Make request to endpoint
|
||||
response = client.get("api/v1/flows/download/", headers=logged_in_headers)
|
||||
# Check response status code
|
||||
assert response.status_code == 200, response.json()
|
||||
# Check response data
|
||||
response_data = response.json()["flows"]
|
||||
starter_projects = load_starter_projects()
|
||||
number_of_projects = len(starter_projects) + len(flow_list.flows)
|
||||
assert len(response_data) == number_of_projects, response_data
|
||||
assert response_data[0]["name"] == "Flow 1"
|
||||
assert response_data[0]["description"] == "description"
|
||||
assert response_data[0]["data"] == data
|
||||
assert response_data[1]["name"] == "Flow 2"
|
||||
assert response_data[1]["description"] == "description"
|
||||
assert response_data[1]["data"] == data
|
||||
|
||||
|
||||
def test_create_flow_with_invalid_data(client: TestClient, active_user, logged_in_headers):
|
||||
flow = {"name": "a" * 256, "data": "Invalid flow data"}
|
||||
response = client.post("api/v1/flows/", json=flow, headers=logged_in_headers)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_nonexistent_flow(client: TestClient, active_user, logged_in_headers):
|
||||
uuid = uuid4()
|
||||
response = client.get(f"api/v1/flows/{uuid}", headers=logged_in_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_update_flow_idempotency(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow_data = orjson.loads(json_flow)
|
||||
data = flow_data["data"]
|
||||
flow_data = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers)
|
||||
flow_id = response.json()["id"]
|
||||
updated_flow = FlowCreate(name="Updated Flow", description="description", data=data)
|
||||
response1 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
|
||||
response2 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
|
||||
assert response1.json() == response2.json()
|
||||
|
||||
|
||||
def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow_data = orjson.loads(json_flow)
|
||||
data = flow_data["data"]
|
||||
uuid = uuid4()
|
||||
updated_flow = FlowCreate(
|
||||
name="Updated Flow",
|
||||
description="description",
|
||||
data=data,
|
||||
)
|
||||
response = client.patch(f"api/v1/flows/{uuid}", json=updated_flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 404, response.text
|
||||
|
||||
|
||||
def test_delete_nonexistent_flow(client: TestClient, active_user, logged_in_headers):
|
||||
uuid = uuid4()
|
||||
response = client.delete(f"api/v1/flows/{uuid}", headers=logged_in_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_read_only_starter_projects(client: TestClient, active_user, logged_in_headers):
|
||||
response = client.get("api/v1/flows/", headers=logged_in_headers)
|
||||
starter_projects = load_starter_projects()
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == len(starter_projects)
|
||||
|
||||
|
||||
@pytest.mark.load_flows
|
||||
def test_load_flows(client: TestClient, load_flows_dir):
|
||||
client.get("/api/v1/auto_login")
|
||||
response = client.get("api/v1/flows/c54f9130-f2fa-4a3e-b22a-3856d946351b")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "BasicExample"
|
||||
104
tests/unit/test_files.py
Normal file
104
tests/unit/test_files.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.services.deps import get_storage_service
|
||||
from langflow.services.storage.service import StorageService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_service():
|
||||
# Create a mock instance of StorageService
|
||||
service = MagicMock(spec=StorageService)
|
||||
# Setup mock behaviors for the service methods as needed
|
||||
service.save_file.return_value = None
|
||||
service.get_file.return_value = b"file content" # Binary content for files
|
||||
service.list_files.return_value = ["file1.txt", "file2.jpg"]
|
||||
service.delete_file.return_value = None
|
||||
return service
|
||||
|
||||
|
||||
def test_upload_file(client, mock_storage_service, created_api_key, flow):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
# Replace the actual storage service with the mock
|
||||
client.app.dependency_overrides[get_storage_service] = lambda: mock_storage_service
|
||||
|
||||
response = client.post(
|
||||
f"api/v1/files/upload/{flow.id}",
|
||||
files={"file": ("test.txt", b"test content")},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {
|
||||
"flowId": str(flow.id),
|
||||
"file_path": f"{flow.id}/test.txt",
|
||||
}
|
||||
|
||||
|
||||
def test_download_file(client, mock_storage_service, created_api_key, flow):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
client.app.dependency_overrides[get_storage_service] = lambda: mock_storage_service
|
||||
|
||||
response = client.get(f"api/v1/files/download/{flow.id}/test.txt", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"file content"
|
||||
|
||||
|
||||
def test_list_files(client, mock_storage_service, created_api_key, flow):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
client.app.dependency_overrides[get_storage_service] = lambda: mock_storage_service
|
||||
|
||||
response = client.get(f"api/v1/files/list/{flow.id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"files": ["file1.txt", "file2.jpg"]}
|
||||
|
||||
|
||||
def test_delete_file(client, mock_storage_service, created_api_key, flow):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
client.app.dependency_overrides[get_storage_service] = lambda: mock_storage_service
|
||||
|
||||
response = client.delete(f"api/v1/files/delete/{flow.id}/test.txt", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "File test.txt deleted successfully"}
|
||||
|
||||
|
||||
def test_file_operations(client, created_api_key, flow):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
flow_id = flow.id
|
||||
file_name = "test.txt"
|
||||
file_content = b"Hello, world!"
|
||||
|
||||
# Step 1: Upload the file
|
||||
response = client.post(
|
||||
f"api/v1/files/upload/{flow_id}",
|
||||
files={"file": (file_name, file_content)},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {
|
||||
"flowId": str(flow_id),
|
||||
"file_path": f"{flow_id}/{file_name}",
|
||||
}
|
||||
|
||||
# Step 2: List files in the folder
|
||||
response = client.get(f"api/v1/files/list/{flow_id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert file_name in response.json()["files"]
|
||||
|
||||
# Step 3: Download the file and verify its content
|
||||
|
||||
response = client.get(f"api/v1/files/download/{flow_id}/{file_name}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert response.content == file_content
|
||||
# the headers are application/octet-stream
|
||||
assert response.headers["content-type"] == "application/octet-stream"
|
||||
# mime_type is inside media_type
|
||||
|
||||
# Step 4: Delete the file
|
||||
response = client.delete(f"api/v1/files/delete/{flow_id}/{file_name}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": f"File {file_name} deleted successfully"}
|
||||
|
||||
# Verify that the file is indeed deleted
|
||||
response = client.get(f"api/v1/files/list/{flow_id}", headers=headers)
|
||||
assert file_name not in response.json()["files"]
|
||||
56
tests/unit/test_frontend_nodes.py
Normal file
56
tests/unit/test_frontend_nodes.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
import pytest
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.template.base import Template
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_template_field() -> TemplateField:
|
||||
return TemplateField(name="test_field", field_type="str")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_template(sample_template_field: TemplateField) -> Template:
|
||||
return Template(type_name="test_template", fields=[sample_template_field])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_frontend_node(sample_template: Template) -> FrontendNode:
|
||||
return FrontendNode(
|
||||
template=sample_template,
|
||||
description="test description",
|
||||
base_classes=["base_class1", "base_class2"],
|
||||
name="test_frontend_node",
|
||||
)
|
||||
|
||||
|
||||
def test_template_field_defaults(sample_template_field: TemplateField):
|
||||
assert sample_template_field.field_type == "str"
|
||||
assert sample_template_field.required is False
|
||||
assert sample_template_field.placeholder == ""
|
||||
assert sample_template_field.is_list is False
|
||||
assert sample_template_field.show is True
|
||||
assert sample_template_field.multiline is False
|
||||
assert sample_template_field.value == ""
|
||||
assert sample_template_field.file_types == []
|
||||
assert sample_template_field.file_path == ""
|
||||
assert sample_template_field.password is False
|
||||
assert sample_template_field.name == "test_field"
|
||||
|
||||
|
||||
def test_template_to_dict(sample_template: Template, sample_template_field: TemplateField):
|
||||
template_dict = sample_template.to_dict()
|
||||
assert template_dict["_type"] == "test_template"
|
||||
assert len(template_dict) == 2 # _type and test_field
|
||||
assert "test_field" in template_dict
|
||||
assert "type" in template_dict["test_field"]
|
||||
assert "required" in template_dict["test_field"]
|
||||
|
||||
|
||||
def test_frontend_node_to_dict(sample_frontend_node: FrontendNode):
|
||||
node_dict = sample_frontend_node.to_dict()
|
||||
assert len(node_dict) == 1
|
||||
assert "test_frontend_node" in node_dict
|
||||
assert "description" in node_dict["test_frontend_node"]
|
||||
assert "template" in node_dict["test_frontend_node"]
|
||||
assert "base_classes" in node_dict["test_frontend_node"]
|
||||
418
tests/unit/test_graph.py
Normal file
418
tests/unit/test_graph.py
Normal file
|
|
@ -0,0 +1,418 @@
|
|||
import copy
|
||||
import json
|
||||
import pickle
|
||||
from typing import Type, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.graph import Graph
|
||||
from langflow.graph.edge.base import Edge
|
||||
from langflow.graph.graph.utils import (
|
||||
find_last_node,
|
||||
process_flow,
|
||||
set_new_target_handle,
|
||||
ungroup_node,
|
||||
update_source_handle,
|
||||
update_target_handle,
|
||||
update_template,
|
||||
)
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.initial_setup.setup import load_starter_projects
|
||||
from langflow.utils.payload import get_root_vertex
|
||||
|
||||
# Test cases for the graph module
|
||||
|
||||
# now we have three types of graph:
|
||||
# BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_template():
|
||||
return {
|
||||
"field1": {"proxy": {"field": "some_field", "id": "node1"}},
|
||||
"field2": {"proxy": {"field": "other_field", "id": "node2"}},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_nodes():
|
||||
return [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {"node": {"template": {"some_field": {"show": True, "advanced": False, "name": "Name1"}}}},
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"other_field": {
|
||||
"show": False,
|
||||
"advanced": True,
|
||||
"display_name": "DisplayName2",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "node3",
|
||||
"data": {"node": {"template": {"unrelated_field": {"show": True, "advanced": True}}}},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_node_by_type(graph, node_type: Type[Vertex]) -> Union[Vertex, None]:
|
||||
"""Get a node by type"""
|
||||
return next((node for node in graph.vertices if isinstance(node, node_type)), None)
|
||||
|
||||
|
||||
def test_graph_structure(basic_graph):
|
||||
assert isinstance(basic_graph, Graph)
|
||||
assert len(basic_graph.vertices) > 0
|
||||
assert len(basic_graph.edges) > 0
|
||||
for node in basic_graph.vertices:
|
||||
assert isinstance(node, Vertex)
|
||||
for edge in basic_graph.edges:
|
||||
assert isinstance(edge, Edge)
|
||||
source_vertex = basic_graph.get_vertex(edge.source_id)
|
||||
target_vertex = basic_graph.get_vertex(edge.target_id)
|
||||
assert source_vertex in basic_graph.vertices
|
||||
assert target_vertex in basic_graph.vertices
|
||||
|
||||
|
||||
def test_circular_dependencies(basic_graph):
|
||||
assert isinstance(basic_graph, Graph)
|
||||
|
||||
def check_circular(node, visited):
|
||||
visited.add(node)
|
||||
neighbors = basic_graph.get_vertices_with_target(node)
|
||||
for neighbor in neighbors:
|
||||
if neighbor in visited:
|
||||
return True
|
||||
if check_circular(neighbor, visited.copy()):
|
||||
return True
|
||||
return False
|
||||
|
||||
for node in basic_graph.vertices:
|
||||
assert not check_circular(node, set())
|
||||
|
||||
|
||||
def test_invalid_node_types():
|
||||
graph_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"data": {
|
||||
"node": {
|
||||
"base_classes": ["BaseClass"],
|
||||
"template": {
|
||||
"_type": "InvalidNodeType",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
with pytest.raises(Exception):
|
||||
Graph(graph_data["nodes"], graph_data["edges"])
|
||||
|
||||
|
||||
def test_get_vertices_with_target(basic_graph):
|
||||
"""Test getting connected nodes"""
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_vertex(basic_graph)
|
||||
assert root is not None
|
||||
connected_nodes = basic_graph.get_vertices_with_target(root.id)
|
||||
assert connected_nodes is not None
|
||||
|
||||
|
||||
def test_get_node_neighbors_basic(basic_graph):
|
||||
"""Test getting node neighbors"""
|
||||
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_vertex(basic_graph)
|
||||
assert root is not None
|
||||
neighbors = basic_graph.get_vertex_neighbors(root)
|
||||
assert neighbors is not None
|
||||
assert isinstance(neighbors, dict)
|
||||
# Root Node is an Agent, it requires an LLMChain and tools
|
||||
# We need to check if there is a Chain in the one of the neighbors'
|
||||
# data attribute in the type key
|
||||
assert any("ConversationBufferMemory" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
|
||||
|
||||
assert any("OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
|
||||
|
||||
|
||||
def test_get_node(basic_graph):
|
||||
"""Test getting a single node"""
|
||||
node_id = basic_graph.vertices[0].id
|
||||
node = basic_graph.get_vertex(node_id)
|
||||
assert isinstance(node, Vertex)
|
||||
assert node.id == node_id
|
||||
|
||||
|
||||
def test_build_nodes(basic_graph):
|
||||
"""Test building nodes"""
|
||||
|
||||
assert len(basic_graph.vertices) == len(basic_graph._vertices)
|
||||
for node in basic_graph.vertices:
|
||||
assert isinstance(node, Vertex)
|
||||
|
||||
|
||||
def test_build_edges(basic_graph):
|
||||
"""Test building edges"""
|
||||
assert len(basic_graph.edges) == len(basic_graph._edges)
|
||||
for edge in basic_graph.edges:
|
||||
assert isinstance(edge, Edge)
|
||||
assert isinstance(edge.source_id, str)
|
||||
assert isinstance(edge.target_id, str)
|
||||
|
||||
|
||||
def test_get_root_vertex(client, basic_graph, complex_graph):
|
||||
"""Test getting root node"""
|
||||
assert isinstance(basic_graph, Graph)
|
||||
root = get_root_vertex(basic_graph)
|
||||
assert root is not None
|
||||
assert isinstance(root, Vertex)
|
||||
assert root.data["type"] == "TimeTravelGuideChain"
|
||||
# For complex example, the root node is a ZeroShotAgent too
|
||||
assert isinstance(complex_graph, Graph)
|
||||
root = get_root_vertex(complex_graph)
|
||||
assert root is not None
|
||||
assert isinstance(root, Vertex)
|
||||
assert root.data["type"] == "ZeroShotAgent"
|
||||
|
||||
|
||||
def test_validate_edges(basic_graph):
|
||||
"""Test validating edges"""
|
||||
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# all edges should be valid
|
||||
assert all(edge.valid for edge in basic_graph.edges)
|
||||
|
||||
|
||||
def test_matched_type(basic_graph):
|
||||
"""Test matched type attribute in Edge"""
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# all edges should be valid
|
||||
assert all(edge.valid for edge in basic_graph.edges)
|
||||
# all edges should have a matched_type attribute
|
||||
assert all(hasattr(edge, "matched_type") for edge in basic_graph.edges)
|
||||
# The matched_type attribute should be in the source_types attr
|
||||
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
|
||||
|
||||
|
||||
def test_build_params(basic_graph):
|
||||
"""Test building params"""
|
||||
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# all edges should be valid
|
||||
assert all(edge.valid for edge in basic_graph.edges)
|
||||
# all edges should have a matched_type attribute
|
||||
assert all(hasattr(edge, "matched_type") for edge in basic_graph.edges)
|
||||
# The matched_type attribute should be in the source_types attr
|
||||
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
|
||||
# Get the root node
|
||||
root = get_root_vertex(basic_graph)
|
||||
# Root node is a TimeTravelGuideChain
|
||||
# which requires an llm and memory
|
||||
assert root is not None
|
||||
assert isinstance(root.params, dict)
|
||||
assert "llm" in root.params
|
||||
assert "memory" in root.params
|
||||
|
||||
|
||||
# def test_wrapper_node_build(openapi_graph):
|
||||
# wrapper_node = get_node_by_type(openapi_graph, WrapperVertex)
|
||||
# assert wrapper_node is not None
|
||||
# built_object = wrapper_node.build()
|
||||
# assert built_object is not None
|
||||
|
||||
|
||||
def test_find_last_node(grouped_chat_json_flow):
|
||||
grouped_chat_data = json.loads(grouped_chat_json_flow).get("data")
|
||||
nodes, edges = grouped_chat_data["nodes"], grouped_chat_data["edges"]
|
||||
last_node = find_last_node(nodes, edges)
|
||||
assert last_node is not None # Replace with the actual expected value
|
||||
assert last_node["id"] == "LLMChain-pimAb" # Replace with the actual expected value
|
||||
|
||||
|
||||
def test_ungroup_node(grouped_chat_json_flow):
|
||||
grouped_chat_data = json.loads(grouped_chat_json_flow).get("data")
|
||||
group_node = grouped_chat_data["nodes"][2] # Assuming the first node is a group node
|
||||
base_flow = copy.deepcopy(grouped_chat_data)
|
||||
ungroup_node(group_node["data"], base_flow)
|
||||
# after ungroup_node is called, the base_flow and grouped_chat_data should be different
|
||||
assert base_flow != grouped_chat_data
|
||||
# assert node 2 is not a group node anymore
|
||||
assert base_flow["nodes"][2]["data"]["node"].get("flow") is None
|
||||
# assert the edges are updated
|
||||
assert len(base_flow["edges"]) > len(grouped_chat_data["edges"])
|
||||
assert base_flow["edges"][0]["source"] == "ConversationBufferMemory-kUMif"
|
||||
assert base_flow["edges"][0]["target"] == "LLMChain-2P369"
|
||||
assert base_flow["edges"][1]["source"] == "PromptTemplate-Wjk4g"
|
||||
assert base_flow["edges"][1]["target"] == "LLMChain-2P369"
|
||||
assert base_flow["edges"][2]["source"] == "ChatOpenAI-rUJ1b"
|
||||
assert base_flow["edges"][2]["target"] == "LLMChain-2P369"
|
||||
|
||||
|
||||
def test_process_flow(grouped_chat_json_flow):
|
||||
grouped_chat_data = json.loads(grouped_chat_json_flow).get("data")
|
||||
|
||||
processed_flow = process_flow(grouped_chat_data)
|
||||
assert processed_flow is not None
|
||||
assert isinstance(processed_flow, dict)
|
||||
assert "nodes" in processed_flow
|
||||
assert "edges" in processed_flow
|
||||
|
||||
|
||||
def test_process_flow_one_group(one_grouped_chat_json_flow):
|
||||
grouped_chat_data = json.loads(one_grouped_chat_json_flow).get("data")
|
||||
# There should be only one node
|
||||
assert len(grouped_chat_data["nodes"]) == 1
|
||||
# Get the node, it should be a group node
|
||||
group_node = grouped_chat_data["nodes"][0]
|
||||
node_data = group_node["data"]["node"]
|
||||
assert node_data.get("flow") is not None
|
||||
template_data = node_data["template"]
|
||||
assert any("openai_api_key" in key for key in template_data.keys())
|
||||
# Get the openai_api_key dict
|
||||
openai_api_key = next(
|
||||
(template_data[key] for key in template_data.keys() if "openai_api_key" in key),
|
||||
None,
|
||||
)
|
||||
assert openai_api_key is not None
|
||||
assert openai_api_key["value"] == "test"
|
||||
|
||||
processed_flow = process_flow(grouped_chat_data)
|
||||
assert processed_flow is not None
|
||||
assert isinstance(processed_flow, dict)
|
||||
assert "nodes" in processed_flow
|
||||
assert "edges" in processed_flow
|
||||
|
||||
# Now get the node that has ChatOpenAI in its id
|
||||
chat_openai_node = next((node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None)
|
||||
assert chat_openai_node is not None
|
||||
assert chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] == "test"
|
||||
|
||||
|
||||
def test_process_flow_vector_store_grouped(vector_store_grouped_json_flow):
|
||||
grouped_chat_data = json.loads(vector_store_grouped_json_flow).get("data")
|
||||
nodes = grouped_chat_data["nodes"]
|
||||
assert len(nodes) == 4
|
||||
# There are two group nodes in this flow
|
||||
# One of them is inside the other totalling 7 nodes
|
||||
# 4 nodes grouped, one of these turns into 1 normal node and 1 group node
|
||||
# This group node has 2 nodes inside it
|
||||
|
||||
processed_flow = process_flow(grouped_chat_data)
|
||||
assert processed_flow is not None
|
||||
processed_nodes = processed_flow["nodes"]
|
||||
assert len(processed_nodes) == 7
|
||||
assert isinstance(processed_flow, dict)
|
||||
assert "nodes" in processed_flow
|
||||
assert "edges" in processed_flow
|
||||
edges = processed_flow["edges"]
|
||||
# Expected keywords in source and target fields
|
||||
expected_keywords = [
|
||||
{"source": "VectorStoreInfo", "target": "VectorStoreAgent"},
|
||||
{"source": "ChatOpenAI", "target": "VectorStoreAgent"},
|
||||
{"source": "OpenAIEmbeddings", "target": "Chroma"},
|
||||
{"source": "Chroma", "target": "VectorStoreInfo"},
|
||||
{"source": "WebBaseLoader", "target": "RecursiveCharacterTextSplitter"},
|
||||
{"source": "RecursiveCharacterTextSplitter", "target": "Chroma"},
|
||||
]
|
||||
|
||||
for idx, expected_keyword in enumerate(expected_keywords):
|
||||
for key, value in expected_keyword.items():
|
||||
assert (
|
||||
value in edges[idx][key].split("-")[0]
|
||||
), f"Edge {idx}, key {key} expected to contain {value} but got {edges[idx][key]}"
|
||||
|
||||
|
||||
def test_update_template(sample_template, sample_nodes):
|
||||
# Making a deep copy to keep original sample_nodes unchanged
|
||||
nodes_copy = copy.deepcopy(sample_nodes)
|
||||
update_template(sample_template, nodes_copy)
|
||||
|
||||
# Now, validate the updates.
|
||||
node1_updated = next((n for n in nodes_copy if n["id"] == "node1"), None)
|
||||
node2_updated = next((n for n in nodes_copy if n["id"] == "node2"), None)
|
||||
node3_updated = next((n for n in nodes_copy if n["id"] == "node3"), None)
|
||||
|
||||
assert node1_updated["data"]["node"]["template"]["some_field"]["show"] is True
|
||||
assert node1_updated["data"]["node"]["template"]["some_field"]["advanced"] is False
|
||||
assert node1_updated["data"]["node"]["template"]["some_field"]["display_name"] == "Name1"
|
||||
|
||||
assert node2_updated["data"]["node"]["template"]["other_field"]["show"] is False
|
||||
assert node2_updated["data"]["node"]["template"]["other_field"]["advanced"] is True
|
||||
assert node2_updated["data"]["node"]["template"]["other_field"]["display_name"] == "DisplayName2"
|
||||
|
||||
# Ensure node3 remains unchanged
|
||||
assert node3_updated == sample_nodes[2]
|
||||
|
||||
|
||||
# Test `update_target_handle`
|
||||
def test_update_target_handle_proxy():
|
||||
new_edge = {
|
||||
"data": {
|
||||
"targetHandle": {
|
||||
"type": "some_type",
|
||||
"proxy": {"id": "some_id", "field": ""},
|
||||
}
|
||||
}
|
||||
}
|
||||
g_nodes = [{"id": "some_id", "data": {"node": {"flow": None}}}]
|
||||
group_node_id = "group_id"
|
||||
updated_edge = update_target_handle(new_edge, g_nodes, group_node_id)
|
||||
assert updated_edge["data"]["targetHandle"] == new_edge["data"]["targetHandle"]
|
||||
|
||||
|
||||
# Test `set_new_target_handle`
|
||||
def test_set_new_target_handle():
|
||||
proxy_id = "proxy_id"
|
||||
new_edge = {"target": None, "data": {"targetHandle": {}}}
|
||||
target_handle = {"type": "type_1", "proxy": {"field": "field_1"}}
|
||||
node = {
|
||||
"data": {
|
||||
"node": {
|
||||
"flow": True,
|
||||
"template": {"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}},
|
||||
}
|
||||
}
|
||||
}
|
||||
set_new_target_handle(proxy_id, new_edge, target_handle, node)
|
||||
assert new_edge["target"] == "proxy_id"
|
||||
assert new_edge["data"]["targetHandle"]["fieldName"] == "field_1"
|
||||
assert new_edge["data"]["targetHandle"]["proxy"] == {
|
||||
"field": "new_field",
|
||||
"id": "new_id",
|
||||
}
|
||||
|
||||
|
||||
# Test `update_source_handle`
|
||||
def test_update_source_handle():
|
||||
new_edge = {"source": None, "data": {"sourceHandle": {"id": None}}}
|
||||
flow_data = {
|
||||
"nodes": [{"id": "some_node"}, {"id": "last_node"}],
|
||||
"edges": [{"source": "some_node"}],
|
||||
}
|
||||
updated_edge = update_source_handle(new_edge, flow_data["nodes"], flow_data["edges"])
|
||||
assert updated_edge["source"] == "last_node"
|
||||
assert updated_edge["data"]["sourceHandle"]["id"] == "last_node"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pickle_graph():
|
||||
starter_projects = load_starter_projects()
|
||||
data = starter_projects[0][1]["data"]
|
||||
graph = Graph.from_payload(data)
|
||||
assert isinstance(graph, Graph)
|
||||
pickled = pickle.dumps(graph)
|
||||
assert pickled is not None
|
||||
unpickled = pickle.loads(pickled)
|
||||
assert unpickled is not None
|
||||
81
tests/unit/test_helper_components.py
Normal file
81
tests/unit/test_helper_components.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
from langchain_core.documents import Document
|
||||
|
||||
from langflow.components import helpers
|
||||
from langflow.custom.utils import build_custom_component_template
|
||||
from langflow.schema import Record
|
||||
|
||||
|
||||
def test_update_record_component():
|
||||
# Arrange
|
||||
update_record_component = helpers.UpdateRecordComponent()
|
||||
|
||||
# Act
|
||||
new_data = {"new_key": "new_value"}
|
||||
existing_record = Record(data={"existing_key": "existing_value"})
|
||||
result = update_record_component.build(existing_record, new_data)
|
||||
assert result.data == {"existing_key": "existing_value", "new_key": "new_value"}
|
||||
assert result.existing_key == "existing_value"
|
||||
assert result.new_key == "new_value"
|
||||
|
||||
|
||||
def test_document_to_record_component():
|
||||
# Arrange
|
||||
document_to_record_component = helpers.DocumentToRecordComponent()
|
||||
|
||||
# Act
|
||||
# Replace with your actual test data
|
||||
document = Document(page_content="key: value", metadata={"url": "https://example.com"})
|
||||
result = document_to_record_component.build(document)
|
||||
|
||||
# Assert
|
||||
# Replace with your actual expected result
|
||||
assert result == [Record(data={"text": "key: value", "url": "https://example.com"})]
|
||||
|
||||
|
||||
def test_uuid_generator_component():
|
||||
# Arrange
|
||||
uuid_generator_component = helpers.UUIDGeneratorComponent()
|
||||
uuid_generator_component.code = open(helpers.IDGenerator.__file__, "r").read()
|
||||
|
||||
frontend_node, _ = build_custom_component_template(uuid_generator_component)
|
||||
|
||||
# Act
|
||||
build_config = frontend_node.get("template")
|
||||
field_name = "unique_id"
|
||||
build_config = uuid_generator_component.update_build_config(build_config, None, field_name)
|
||||
unique_id = build_config["unique_id"]["value"]
|
||||
result = uuid_generator_component.build(unique_id)
|
||||
|
||||
# Assert
|
||||
# UUID should be a string of length 36
|
||||
assert isinstance(result, str)
|
||||
assert len(result) == 36
|
||||
|
||||
|
||||
def test_records_as_text_component():
|
||||
# Arrange
|
||||
records_as_text_component = helpers.RecordsToTextComponent()
|
||||
|
||||
# Act
|
||||
# Replace with your actual test data
|
||||
records = [Record(data={"key": "value", "bacon": "eggs"})]
|
||||
template = "Data:{data} -- Bacon:{bacon}"
|
||||
result = records_as_text_component.build(records, template=template)
|
||||
|
||||
# Assert
|
||||
# Replace with your actual expected result
|
||||
assert result == "Data:{'key': 'value', 'bacon': 'eggs'} -- Bacon:eggs"
|
||||
|
||||
|
||||
def test_text_to_record_component():
|
||||
# Arrange
|
||||
text_to_record_component = helpers.CreateRecordComponent()
|
||||
|
||||
# Act
|
||||
# Replace with your actual test data
|
||||
dict_with_text = {"field_1": {"key": "value"}}
|
||||
result = text_to_record_component.build(number_of_fields=1, **dict_with_text)
|
||||
|
||||
# Assert
|
||||
# Replace with your actual expected result
|
||||
assert result == Record(data={"key": "value"})
|
||||
93
tests/unit/test_initial_setup.py
Normal file
93
tests/unit/test_initial_setup.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.initial_setup.setup import (
|
||||
STARTER_FOLDER_NAME,
|
||||
create_or_update_starter_projects,
|
||||
get_project_data,
|
||||
load_starter_projects,
|
||||
)
|
||||
from langflow.services.database.models.folder.model import Folder
|
||||
from langflow.services.deps import session_scope
|
||||
|
||||
|
||||
def test_load_starter_projects():
|
||||
projects = load_starter_projects()
|
||||
assert isinstance(projects, list)
|
||||
assert all(isinstance(project[1], dict) for project in projects)
|
||||
assert all(isinstance(project[0], Path) for project in projects)
|
||||
|
||||
|
||||
def test_get_project_data():
|
||||
projects = load_starter_projects()
|
||||
for _, project in projects:
|
||||
(
|
||||
project_name,
|
||||
project_description,
|
||||
project_is_component,
|
||||
updated_at_datetime,
|
||||
project_data,
|
||||
project_icon,
|
||||
project_icon_bg_color,
|
||||
) = get_project_data(project)
|
||||
assert isinstance(project_name, str)
|
||||
assert isinstance(project_description, str)
|
||||
assert isinstance(project_is_component, bool)
|
||||
assert isinstance(updated_at_datetime, datetime)
|
||||
assert isinstance(project_data, dict)
|
||||
assert isinstance(project_icon, str) or project_icon is None
|
||||
assert isinstance(project_icon_bg_color, str) or project_icon_bg_color is None
|
||||
|
||||
|
||||
def test_create_or_update_starter_projects(client):
|
||||
with session_scope() as session:
|
||||
# Run the function to create or update projects
|
||||
create_or_update_starter_projects()
|
||||
|
||||
# Get the number of projects returned by load_starter_projects
|
||||
num_projects = len(load_starter_projects())
|
||||
|
||||
# Get the number of projects in the database
|
||||
folder = session.exec(select(Folder).where(Folder.name == STARTER_FOLDER_NAME)).first()
|
||||
num_db_projects = len(folder.flows)
|
||||
|
||||
# Check that the number of projects in the database is the same as the number of projects returned by load_starter_projects
|
||||
assert num_db_projects == num_projects
|
||||
|
||||
|
||||
# Some starter projects require integration
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_starter_projects_can_run_successfully(client):
|
||||
# with session_scope() as session:
|
||||
# # Run the function to create or update projects
|
||||
# create_or_update_starter_projects()
|
||||
|
||||
# # Get the number of projects returned by load_starter_projects
|
||||
# num_projects = len(load_starter_projects())
|
||||
|
||||
# # Get the number of projects in the database
|
||||
# num_db_projects = session.exec(select(func.count(Flow.id)).where(Flow.folder == STARTER_FOLDER_NAME)).one()
|
||||
|
||||
# # Check that the number of projects in the database is the same as the number of projects returned by load_starter_projects
|
||||
# assert num_db_projects == num_projects
|
||||
|
||||
# # Get all the starter projects
|
||||
# projects = session.exec(select(Flow).where(Flow.folder == STARTER_FOLDER_NAME)).all()
|
||||
# graphs: list[tuple[str, Graph]] = []
|
||||
# for project in projects:
|
||||
# # Add tweaks to make file_path work
|
||||
# tweaks = {"path": __file__}
|
||||
# graph_data = process_tweaks(project.data, tweaks)
|
||||
# graph_object = Graph.from_payload(graph_data, flow_id=project.id)
|
||||
# graphs.append((project.name, graph_object))
|
||||
# assert len(graphs) == len(projects)
|
||||
# for name, graph in graphs:
|
||||
# outputs = await graph.arun(
|
||||
# inputs={},
|
||||
# outputs=[],
|
||||
# session_id="test",
|
||||
# )
|
||||
# assert all(isinstance(output, RunOutputs) for output in outputs), f"Project {name} error: {outputs}"
|
||||
# delete_messages(session_id="test")
|
||||
42
tests/unit/test_loading.py
Normal file
42
tests/unit/test_loading.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
import pytest
|
||||
from langflow.graph import Graph
|
||||
from langflow.graph.schema import RunOutputs
|
||||
from langflow.initial_setup.setup import load_starter_projects
|
||||
from langflow.load import load_flow_from_json, run_flow_from_json
|
||||
|
||||
|
||||
@pytest.mark.noclient
|
||||
def test_load_flow_from_json():
|
||||
"""Test loading a flow from a json file"""
|
||||
loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH)
|
||||
assert loaded is not None
|
||||
assert isinstance(loaded, Graph)
|
||||
|
||||
|
||||
@pytest.mark.noclient
|
||||
def test_load_flow_from_json_with_tweaks():
|
||||
"""Test loading a flow from a json file and applying tweaks"""
|
||||
tweaks = {"dndnode_82": {"model_name": "gpt-3.5-turbo-16k-0613"}}
|
||||
loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH, tweaks=tweaks)
|
||||
assert loaded is not None
|
||||
assert isinstance(loaded, Graph)
|
||||
|
||||
|
||||
@pytest.mark.noclient
|
||||
def test_load_flow_from_json_object():
|
||||
"""Test loading a flow from a json file and applying tweaks"""
|
||||
_, projects = zip(*load_starter_projects())
|
||||
project = projects[0]
|
||||
loaded = load_flow_from_json(project)
|
||||
assert loaded is not None
|
||||
assert isinstance(loaded, Graph)
|
||||
|
||||
|
||||
@pytest.mark.noclient
|
||||
def test_run_flow_from_json_object():
|
||||
"""Test loading a flow from a json file and applying tweaks"""
|
||||
_, projects = zip(*load_starter_projects())
|
||||
project = [project for project in projects if "Basic Prompting" in project["name"]][0]
|
||||
results = run_flow_from_json(project, input_value="test", fallback_to_env_vars=True)
|
||||
assert results is not None
|
||||
assert all(isinstance(result, RunOutputs) for result in results)
|
||||
45
tests/unit/test_login.py
Normal file
45
tests/unit/test_login.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import pytest
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.user import User
|
||||
from langflow.services.deps import session_scope
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user():
|
||||
return User(
|
||||
username="testuser",
|
||||
password=get_password_hash("testpassword"), # Assuming password needs to be hashed
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
|
||||
def test_login_successful(client, test_user):
|
||||
# Adding the test user to the database
|
||||
try:
|
||||
with session_scope() as session:
|
||||
session.add(test_user)
|
||||
session.commit()
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
response = client.post("api/v1/login", data={"username": "testuser", "password": "testpassword"})
|
||||
assert response.status_code == 200
|
||||
assert "access_token" in response.json()
|
||||
|
||||
|
||||
def test_login_unsuccessful_wrong_username(client):
|
||||
response = client.post("api/v1/login", data={"username": "wrongusername", "password": "testpassword"})
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Incorrect username or password"
|
||||
|
||||
|
||||
def test_login_unsuccessful_wrong_password(client, test_user, session):
|
||||
# Adding the test user to the database
|
||||
session.add(test_user)
|
||||
session.commit()
|
||||
|
||||
response = client.post("api/v1/login", data={"username": "testuser", "password": "wrongpassword"})
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Incorrect username or password"
|
||||
303
tests/unit/test_process.py
Normal file
303
tests/unit/test_process.py
Normal file
|
|
@ -0,0 +1,303 @@
|
|||
import pytest
|
||||
from langflow.processing.process import process_tweaks
|
||||
from langflow.services.deps import get_session_service
|
||||
|
||||
|
||||
def test_no_tweaks():
|
||||
graph_data = {
|
||||
"data": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 1},
|
||||
"param2": {"value": 2},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 3},
|
||||
"param2": {"value": 4},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
tweaks = {}
|
||||
result = process_tweaks(graph_data, tweaks)
|
||||
assert result == graph_data
|
||||
|
||||
|
||||
def test_single_tweak():
|
||||
graph_data = {
|
||||
"data": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 1, "type": "int"},
|
||||
"param2": {"value": 2, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 3, "type": "int"},
|
||||
"param2": {"value": 4, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
tweaks = {"node1": {"param1": 5}}
|
||||
expected_result = {
|
||||
"data": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 5, "type": "int"},
|
||||
"param2": {"value": 2, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 3, "type": "int"},
|
||||
"param2": {"value": 4, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
result = process_tweaks(graph_data, tweaks)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_multiple_tweaks():
|
||||
graph_data = {
|
||||
"data": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 1, "type": "int"},
|
||||
"param2": {"value": 2, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 3, "type": "int"},
|
||||
"param2": {"value": 4, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
tweaks = {
|
||||
"node1": {"param1": 5, "param2": 6},
|
||||
"node2": {"param1": 7},
|
||||
}
|
||||
expected_result = {
|
||||
"data": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 5, "type": "int"},
|
||||
"param2": {"value": 6, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 7, "type": "int"},
|
||||
"param2": {"value": 4, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
result = process_tweaks(graph_data, tweaks)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
# Test twekas that just pass the param and value but no node id.
|
||||
# This is a new feature that was added to the process_tweaks function
|
||||
def test_tweak_no_node_id():
|
||||
graph_data = {
|
||||
"data": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 1, "type": "int"},
|
||||
"param2": {"value": 2, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 3, "type": "int"},
|
||||
"param2": {"value": 4, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
tweaks = {"param1": 5}
|
||||
expected_result = {
|
||||
"data": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 5, "type": "int"},
|
||||
"param2": {"value": 2, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 5, "type": "int"},
|
||||
"param2": {"value": 4, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
result = process_tweaks(graph_data, tweaks)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_tweak_not_in_template():
|
||||
graph_data = {
|
||||
"data": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 1, "type": "int"},
|
||||
"param2": {"value": 2, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"param1": {"value": 3, "type": "int"},
|
||||
"param2": {"value": 4, "type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
tweaks = {"node1": {"param3": 5}}
|
||||
result = process_tweaks(graph_data, tweaks)
|
||||
assert result == graph_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_langchain_object_with_cached_session(client, basic_graph_data):
|
||||
# Provide a non-existent session_id
|
||||
session_service = get_session_service()
|
||||
session_id1 = "non-existent-session-id"
|
||||
graph1, artifacts1 = await session_service.load_session(session_id1, basic_graph_data)
|
||||
# Use the new session_id to get the langchain_object again
|
||||
graph2, artifacts2 = await session_service.load_session(session_id1, basic_graph_data)
|
||||
|
||||
assert graph1 == graph2
|
||||
assert artifacts1 == artifacts2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_langchain_object_with_no_cached_session(client, basic_graph_data):
|
||||
# Provide a non-existent session_id
|
||||
session_service = get_session_service()
|
||||
session_id1 = "non-existent-session-id"
|
||||
session_id = session_service.build_key(session_id1, basic_graph_data)
|
||||
graph1, artifacts1 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id")
|
||||
# Clear the cache
|
||||
await session_service.clear_session(session_id)
|
||||
# Use the new session_id to get the graph again
|
||||
graph2, artifacts2 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id")
|
||||
|
||||
# Since the cache was cleared, objects should be different
|
||||
assert id(graph1) != id(graph2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_langchain_object_without_session_id(client, basic_graph_data):
|
||||
# Provide a non-existent session_id
|
||||
session_service = get_session_service()
|
||||
session_id1 = None
|
||||
graph1, artifacts1 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id")
|
||||
# Use the new session_id to get the langchain_object again
|
||||
graph2, artifacts2 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id")
|
||||
|
||||
assert graph1 == graph2
|
||||
139
tests/unit/test_record.py
Normal file
139
tests/unit/test_record.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
from langchain_core.documents import Document
|
||||
|
||||
from langflow.schema import Record
|
||||
|
||||
|
||||
def test_record_initialization():
|
||||
record = Record(text_key="msg", data={"msg": "Hello, World!", "extra": "value"})
|
||||
assert record.msg == "Hello, World!"
|
||||
assert record.extra == "value"
|
||||
|
||||
|
||||
def test_validate_data_with_extra_keys():
|
||||
record = Record(dummy_key="dummy", data={"key": "value"})
|
||||
assert record.data["dummy_key"] == "dummy"
|
||||
assert "dummy_key" in record.data
|
||||
assert record.key == "value"
|
||||
|
||||
|
||||
def test_conversion_to_document():
|
||||
record = Record(data={"text": "Sample text", "meta": "data"})
|
||||
document = record.to_lc_document()
|
||||
assert document.page_content == "Sample text"
|
||||
assert document.metadata == {"meta": "data"}
|
||||
|
||||
|
||||
def test_conversion_from_document():
|
||||
document = Document(page_content="Doc content", metadata={"meta": "info"})
|
||||
record = Record.from_document(document)
|
||||
assert record.text == "Doc content"
|
||||
assert record.meta == "info"
|
||||
|
||||
|
||||
def test_add_method_for_strings():
|
||||
record1 = Record(data={"text": "Hello"})
|
||||
record2 = Record(data={"text": " World"})
|
||||
combined = record1 + record2
|
||||
assert combined.text == "Hello World"
|
||||
|
||||
|
||||
def test_add_method_for_integers():
|
||||
record1 = Record(data={"number": 5})
|
||||
record2 = Record(data={"number": 10})
|
||||
combined = record1 + record2
|
||||
assert combined.number == 15
|
||||
|
||||
|
||||
def test_add_method_with_non_overlapping_keys():
|
||||
record1 = Record(data={"text": "Hello"})
|
||||
record2 = Record(data={"number": 10})
|
||||
combined = record1 + record2
|
||||
assert combined.text == "Hello"
|
||||
assert combined.number == 10
|
||||
|
||||
|
||||
def test_custom_attribute_get_set_del():
|
||||
record = Record()
|
||||
record.custom_attr = "custom_value"
|
||||
assert record.custom_attr == "custom_value"
|
||||
del record.custom_attr
|
||||
assert record.custom_attr == record.default_value
|
||||
|
||||
|
||||
def test_deep_copy():
|
||||
import copy
|
||||
|
||||
record1 = Record(data={"text": "Hello", "number": 10})
|
||||
record2 = copy.deepcopy(record1)
|
||||
assert record2.text == "Hello"
|
||||
assert record2.number == 10
|
||||
record2.text = "World"
|
||||
assert record1.text == "Hello" # Ensure original is unchanged
|
||||
|
||||
|
||||
def test_custom_attribute_setting_and_getting():
|
||||
record = Record()
|
||||
record.dynamic_attribute = "Dynamic Value"
|
||||
assert record.dynamic_attribute == "Dynamic Value"
|
||||
|
||||
|
||||
def test_str_and_dir_methods():
|
||||
record = Record(text_key="text", data={"text": "Test Text", "key": "value"})
|
||||
assert "Test Text" in str(record)
|
||||
assert "key" in dir(record)
|
||||
assert "data" in dir(record)
|
||||
|
||||
|
||||
def test_dir_includes_data_keys():
|
||||
record = Record(data={"text": "Hello", "new_attr": "value"})
|
||||
dir_output = dir(record)
|
||||
|
||||
# Check for standard attributes
|
||||
assert "data" in dir_output
|
||||
assert "text_key" in dir_output
|
||||
assert "__add__" in dir_output # Checking for a method
|
||||
|
||||
# Check for dynamic attributes from data
|
||||
assert "text" in dir_output
|
||||
assert "new_attr" in dir_output
|
||||
|
||||
# Optionally, verify that dynamically added attributes are listed
|
||||
record.dynamic_attr = "dynamic"
|
||||
assert "dynamic_attr" in dir_output or "dynamic_attr" in dir(record) # To account for the change
|
||||
|
||||
|
||||
def test_dir_reflects_attribute_deletion():
|
||||
record = Record(data={"removable": "I can be removed"})
|
||||
assert "removable" in dir(record)
|
||||
|
||||
# Delete the attribute and check again
|
||||
del record.removable
|
||||
assert "removable" not in dir(record)
|
||||
|
||||
|
||||
def test_get_text_with_text_key():
|
||||
data = {"text": "Hello, World!"}
|
||||
schema = Record(data=data, text_key="text", default_value="default")
|
||||
result = schema.get_text()
|
||||
assert result == "Hello, World!"
|
||||
|
||||
|
||||
def test_get_text_without_text_key():
|
||||
data = {"other_key": "Hello, World!"}
|
||||
schema = Record(data=data, text_key="text", default_value="default")
|
||||
result = schema.get_text()
|
||||
assert result == "default"
|
||||
|
||||
|
||||
def test_get_text_with_empty_data():
|
||||
data = {}
|
||||
schema = Record(data=data, text_key="text", default_value="default")
|
||||
result = schema.get_text()
|
||||
assert result == "default"
|
||||
|
||||
|
||||
def test_get_text_with_none_data():
|
||||
data = None
|
||||
schema = Record(data=data, text_key="text", default_value="default")
|
||||
result = schema.get_text()
|
||||
assert result == "default"
|
||||
132
tests/unit/test_setup_superuser.py
Normal file
132
tests/unit/test_setup_superuser.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langflow.services.settings.constants import (
|
||||
DEFAULT_SUPERUSER,
|
||||
DEFAULT_SUPERUSER_PASSWORD,
|
||||
)
|
||||
from langflow.services.utils import teardown_superuser
|
||||
|
||||
# @patch("langflow.services.deps.get_session")
|
||||
# @patch("langflow.services.utils.create_super_user")
|
||||
# @patch("langflow.services.deps.get_settings_service")
|
||||
# # @patch("langflow.services.utils.verify_password")
|
||||
# def test_setup_superuser(
|
||||
# mock_get_session, mock_create_super_user, mock_get_settings_service
|
||||
# ):
|
||||
# # Test when AUTO_LOGIN is True
|
||||
# calls = []
|
||||
# mock_settings_service = Mock()
|
||||
# mock_settings_service.auth_settings.AUTO_LOGIN = True
|
||||
# mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
|
||||
# mock_settings_service.auth_settings.SUPERUSER_PASSWORD = DEFAULT_SUPERUSER_PASSWORD
|
||||
# mock_get_settings_service.return_value = mock_settings_service
|
||||
# mock_session = Mock()
|
||||
# mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
# mock_session
|
||||
# )
|
||||
# # return value of get_session is a generator
|
||||
# mock_get_session.return_value = iter([mock_session, mock_session, mock_session])
|
||||
# setup_superuser(mock_settings_service, mock_session)
|
||||
# mock_session.query.assert_called_once_with(User)
|
||||
# # Set return value of filter to be None
|
||||
# mock_session.query.return_value.filter.return_value.first.return_value = None
|
||||
# actual_expr = mock_session.query.return_value.filter.call_args[0][0]
|
||||
# expected_expr = User.username == DEFAULT_SUPERUSER
|
||||
|
||||
# assert str(actual_expr) == str(expected_expr)
|
||||
# create_call = call(
|
||||
# db=mock_session, username=DEFAULT_SUPERUSER, password=DEFAULT_SUPERUSER_PASSWORD
|
||||
# )
|
||||
# calls.append(create_call)
|
||||
# # mock_create_super_user.assert_has_calls(calls)
|
||||
# assert 1 == mock_create_super_user.call_count
|
||||
|
||||
# def reset_mock_credentials():
|
||||
# mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
|
||||
# mock_settings_service.auth_settings.SUPERUSER_PASSWORD = (
|
||||
# DEFAULT_SUPERUSER_PASSWORD
|
||||
# )
|
||||
|
||||
# ADMIN_USER_NAME = "admin_user"
|
||||
# # Test when username and password are default
|
||||
# mock_settings_service.auth_settings = Mock()
|
||||
# mock_settings_service.auth_settings.AUTO_LOGIN = False
|
||||
# mock_settings_service.auth_settings.SUPERUSER = ADMIN_USER_NAME
|
||||
# mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password"
|
||||
# mock_settings_service.auth_settings.reset_credentials = Mock(
|
||||
# side_effect=reset_mock_credentials
|
||||
# )
|
||||
|
||||
# mock_get_settings_service.return_value = mock_settings_service
|
||||
|
||||
# setup_superuser(mock_settings_service, mock_session)
|
||||
# mock_session.query.assert_called_with(User)
|
||||
# actual_expr = mock_session.query.return_value.filter.call_args[0][0]
|
||||
# expected_expr = User.username == ADMIN_USER_NAME
|
||||
|
||||
# assert str(actual_expr) == str(expected_expr)
|
||||
# create_call = call(db=mock_session, username=ADMIN_USER_NAME, password="password")
|
||||
# calls.append(create_call)
|
||||
# # mock_create_super_user.assert_has_calls(calls)
|
||||
# assert 2 == mock_create_super_user.call_count
|
||||
# # Test that superuser credentials are reset
|
||||
# mock_settings_service.auth_settings.reset_credentials.assert_called_once()
|
||||
# assert mock_settings_service.auth_settings.SUPERUSER != ADMIN_USER_NAME
|
||||
# assert mock_settings_service.auth_settings.SUPERUSER_PASSWORD != "password"
|
||||
|
||||
# # Test when superuser already exists
|
||||
# mock_settings_service.auth_settings.AUTO_LOGIN = False
|
||||
# mock_settings_service.auth_settings.SUPERUSER = ADMIN_USER_NAME
|
||||
# mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password"
|
||||
# mock_user = Mock()
|
||||
# mock_user.is_superuser = True
|
||||
# mock_session.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
# setup_superuser(mock_settings_service, mock_session)
|
||||
# mock_session.query.assert_called_with(User)
|
||||
# actual_expr = mock_session.query.return_value.filter.call_args[0][0]
|
||||
# expected_expr = User.username == ADMIN_USER_NAME
|
||||
|
||||
# assert str(actual_expr) == str(expected_expr)
|
||||
|
||||
|
||||
@patch("langflow.services.deps.get_settings_service")
|
||||
@patch("langflow.services.deps.get_session")
|
||||
def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service):
|
||||
mock_settings_service = MagicMock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = True
|
||||
mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
|
||||
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = DEFAULT_SUPERUSER_PASSWORD
|
||||
mock_get_settings_service.return_value = mock_settings_service
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_user.is_superuser = True
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_get_session.return_value = iter([mock_session])
|
||||
|
||||
teardown_superuser(mock_settings_service, mock_session)
|
||||
|
||||
mock_session.query.assert_not_called()
|
||||
|
||||
|
||||
@patch("langflow.services.deps.get_settings_service")
|
||||
@patch("langflow.services.deps.get_session")
|
||||
def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service):
|
||||
ADMIN_USER_NAME = "admin_user"
|
||||
mock_settings_service = MagicMock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = False
|
||||
mock_settings_service.auth_settings.SUPERUSER = ADMIN_USER_NAME
|
||||
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password"
|
||||
mock_get_settings_service.return_value = mock_settings_service
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_user.is_superuser = False
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_get_session.return_value = [mock_session]
|
||||
|
||||
teardown_superuser(mock_settings_service, mock_session)
|
||||
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.delete.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
113
tests/unit/test_template.py
Normal file
113
tests/unit/test_template.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
import importlib
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from langflow.interface.utils import build_template_from_class
|
||||
from langflow.utils.util import build_template_from_function, get_base_classes, get_default_factory
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Dummy classes for testing purposes
|
||||
class Parent(BaseModel):
|
||||
"""Parent Class"""
|
||||
|
||||
parent_field: str
|
||||
|
||||
|
||||
class Child(Parent):
|
||||
"""Child Class"""
|
||||
|
||||
child_field: int
|
||||
|
||||
|
||||
class ExampleClass1(BaseModel):
|
||||
"""Example class 1."""
|
||||
|
||||
def __init__(self, data: Optional[List[int]] = None):
|
||||
self.data = data or [1, 2, 3]
|
||||
|
||||
|
||||
class ExampleClass2(BaseModel):
|
||||
"""Example class 2."""
|
||||
|
||||
def __init__(self, data: Optional[Dict[str, int]] = None):
|
||||
self.data = data or {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
def example_loader_1() -> ExampleClass1:
|
||||
"""Example loader function 1."""
|
||||
return ExampleClass1()
|
||||
|
||||
|
||||
def example_loader_2() -> ExampleClass2:
|
||||
"""Example loader function 2."""
|
||||
return ExampleClass2()
|
||||
|
||||
|
||||
def test_build_template_from_function():
|
||||
type_to_loader_dict = {
|
||||
"example1": example_loader_1,
|
||||
"example2": example_loader_2,
|
||||
}
|
||||
|
||||
# Test with valid name
|
||||
result = build_template_from_function("ExampleClass1", type_to_loader_dict)
|
||||
|
||||
assert result is not None
|
||||
assert "template" in result
|
||||
assert "description" in result
|
||||
assert "base_classes" in result
|
||||
|
||||
# Test with add_function=True
|
||||
result_with_function = build_template_from_function("ExampleClass1", type_to_loader_dict, add_function=True)
|
||||
assert result_with_function is not None
|
||||
assert "Callable" in result_with_function["base_classes"]
|
||||
|
||||
# Test with invalid name
|
||||
with pytest.raises(ValueError, match=r".* not found"):
|
||||
build_template_from_function("NonExistent", type_to_loader_dict)
|
||||
|
||||
|
||||
# Test build_template_from_class
|
||||
def test_build_template_from_class():
|
||||
type_to_cls_dict: Dict[str, type] = {"parent": Parent, "child": Child}
|
||||
|
||||
# Test valid input
|
||||
result = build_template_from_class("Child", type_to_cls_dict)
|
||||
assert result is not None
|
||||
assert "template" in result
|
||||
assert "description" in result
|
||||
assert "base_classes" in result
|
||||
assert "Child" in result["base_classes"]
|
||||
assert "Parent" in result["base_classes"]
|
||||
assert result["description"] == "Child Class"
|
||||
|
||||
# Test invalid input
|
||||
with pytest.raises(ValueError, match="InvalidClass not found."):
|
||||
build_template_from_class("InvalidClass", type_to_cls_dict)
|
||||
|
||||
|
||||
# Test get_base_classes
|
||||
def test_get_base_classes():
|
||||
base_classes_parent = get_base_classes(Parent)
|
||||
base_classes_child = get_base_classes(Child)
|
||||
|
||||
assert "Parent" in base_classes_parent
|
||||
assert "Child" in base_classes_child
|
||||
assert "Parent" in base_classes_child
|
||||
|
||||
|
||||
# Test get_default_factory
|
||||
def test_get_default_factory():
|
||||
module_name = "langflow.utils.util"
|
||||
function_repr = "<function dummy_function>"
|
||||
|
||||
def dummy_function():
|
||||
return "default_value"
|
||||
|
||||
# Add dummy_function to your_module
|
||||
setattr(importlib.import_module(module_name), "dummy_function", dummy_function)
|
||||
|
||||
default_value = get_default_factory(module_name, function_repr)
|
||||
|
||||
assert default_value == "default_value"
|
||||
107
tests/unit/test_validate_code.py
Normal file
107
tests/unit/test_validate_code.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from langflow.utils.validate import (
|
||||
create_function,
|
||||
execute_function,
|
||||
extract_function_name,
|
||||
validate_code,
|
||||
)
|
||||
from requests.exceptions import MissingSchema
|
||||
|
||||
|
||||
def test_validate_code():
|
||||
# Test case with a valid import and function
|
||||
code1 = """
|
||||
import math
|
||||
|
||||
def square(x):
|
||||
return x ** 2
|
||||
"""
|
||||
errors1 = validate_code(code1)
|
||||
assert errors1 == {"imports": {"errors": []}, "function": {"errors": []}}
|
||||
|
||||
# Test case with an invalid import and valid function
|
||||
code2 = """
|
||||
import non_existent_module
|
||||
|
||||
def square(x):
|
||||
return x ** 2
|
||||
"""
|
||||
errors2 = validate_code(code2)
|
||||
assert errors2 == {
|
||||
"imports": {"errors": ["No module named 'non_existent_module'"]},
|
||||
"function": {"errors": []},
|
||||
}
|
||||
|
||||
# Test case with a valid import and invalid function syntax
|
||||
code3 = """
|
||||
import math
|
||||
|
||||
def square(x)
|
||||
return x ** 2
|
||||
"""
|
||||
errors3 = validate_code(code3)
|
||||
assert errors3 == {
|
||||
"imports": {"errors": []},
|
||||
"function": {"errors": ["expected ':' (<unknown>, line 4)"]},
|
||||
}
|
||||
|
||||
|
||||
def test_execute_function_success():
|
||||
code = """
|
||||
import math
|
||||
|
||||
def my_function(x):
|
||||
return math.sin(x) + 1
|
||||
"""
|
||||
result = execute_function(code, "my_function", 0.5)
|
||||
assert result == 1.479425538604203
|
||||
|
||||
|
||||
def test_execute_function_missing_module():
|
||||
code = """
|
||||
import some_missing_module
|
||||
|
||||
def my_function(x):
|
||||
return some_missing_module.some_function(x)
|
||||
"""
|
||||
with pytest.raises(ModuleNotFoundError):
|
||||
execute_function(code, "my_function", 0.5)
|
||||
|
||||
|
||||
def test_execute_function_missing_function():
|
||||
code = """
|
||||
import math
|
||||
|
||||
def my_function(x):
|
||||
return math.some_missing_function(x)
|
||||
"""
|
||||
with pytest.raises(AttributeError):
|
||||
execute_function(code, "my_function", 0.5)
|
||||
|
||||
|
||||
def test_execute_function_missing_schema():
|
||||
code = """
|
||||
import requests
|
||||
|
||||
def my_function(x):
|
||||
return requests.get(x).text
|
||||
"""
|
||||
with mock.patch("requests.get", side_effect=MissingSchema):
|
||||
with pytest.raises(MissingSchema):
|
||||
execute_function(code, "my_function", "invalid_url")
|
||||
|
||||
|
||||
def test_create_function():
|
||||
code = """
|
||||
import math
|
||||
|
||||
def my_function(x):
|
||||
return math.sin(x) + 1
|
||||
"""
|
||||
|
||||
function_name = extract_function_name(code)
|
||||
function = create_function(code, function_name)
|
||||
result = function(0.5)
|
||||
assert result == 1.479425538604203
|
||||
15
tests/unit/text_experimental_components.py
Normal file
15
tests/unit/text_experimental_components.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from langflow.components import experimental
|
||||
|
||||
|
||||
def test_python_function_component():
|
||||
# Arrange
|
||||
python_function_component = experimental.PythonFunctionComponent()
|
||||
|
||||
# Act
|
||||
# function must be a string representation
|
||||
function = "def function():\n return 'Hello, World!'"
|
||||
# result is the callable function
|
||||
result = python_function_component.build(function)
|
||||
|
||||
# Assert
|
||||
assert result() == "Hello, World!"
|
||||
Loading…
Add table
Add a link
Reference in a new issue