test: add astra integration test (#2189)

* add first astra integ test framework

* use fixtures

* remove old tests from merge

* Add correct sender type

* chore: Update unit test command in GitHub workflow

---------

Co-authored-by: ogabrielluiz <gabriel@langflow.org>
This commit is contained in:
Jordan Frazier 2024-06-15 19:50:38 -07:00 committed by GitHub
commit ca660cf8df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 211 additions and 12 deletions

0
tests/unit/__init__.py Normal file
View file

428
tests/unit/conftest.py Normal file
View file

@ -0,0 +1,428 @@
import json
import os.path
import shutil
# we need to import tmpdir
import tempfile
from contextlib import contextmanager, suppress
from pathlib import Path
from typing import TYPE_CHECKING, AsyncGenerator
import orjson
import pytest
from dotenv import load_dotenv
from fastapi.testclient import TestClient
from httpx import AsyncClient
from langflow.graph.graph.base import Graph
from langflow.initial_setup.setup import STARTER_FOLDER_NAME
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.models.flow.model import Flow, FlowCreate
from langflow.services.database.models.folder.model import Folder
from langflow.services.database.models.user.model import User, UserCreate
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from sqlmodel import Session, SQLModel, create_engine, select
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService
load_dotenv()
def pytest_configure(config):
config.addinivalue_line("markers", "noclient: don't create a client for this test")
data_path = Path(__file__).parent.absolute() / "data"
pytest.BASIC_EXAMPLE_PATH = data_path / "basic_example.json"
pytest.COMPLEX_EXAMPLE_PATH = data_path / "complex_example.json"
pytest.OPENAPI_EXAMPLE_PATH = data_path / "Openapi.json"
pytest.GROUPED_CHAT_EXAMPLE_PATH = data_path / "grouped_chat.json"
pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = data_path / "one_group_chat.json"
pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = data_path / "vector_store_grouped.json"
pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY = data_path / "BasicChatwithPromptandHistory.json"
pytest.CHAT_INPUT = data_path / "ChatInputTest.json"
pytest.TWO_OUTPUTS = data_path / "TwoOutputsTest.json"
pytest.VECTOR_STORE_PATH = data_path / "Vector_store.json"
pytest.CODE_WITH_SYNTAX_ERROR = """
def get_text():
retun "Hello World"
"""
# validate that all the paths are correct and the files exist
for path in [
pytest.BASIC_EXAMPLE_PATH,
pytest.COMPLEX_EXAMPLE_PATH,
pytest.OPENAPI_EXAMPLE_PATH,
pytest.GROUPED_CHAT_EXAMPLE_PATH,
pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH,
pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH,
pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY,
pytest.CHAT_INPUT,
pytest.TWO_OUTPUTS,
pytest.VECTOR_STORE_PATH,
]:
assert path.exists(), f"File {path} does not exist. Available files: {list(data_path.iterdir())}"
@pytest.fixture(autouse=True)
def check_openai_api_key_in_environment_variables():
import os
assert os.environ.get("OPENAI_API_KEY") is not None, "OPENAI_API_KEY is not set in environment variables"
@pytest.fixture()
async def async_client() -> AsyncGenerator:
from langflow.main import create_app
app = create_app()
async with AsyncClient(app=app, base_url="http://testserver") as client:
yield client
@pytest.fixture(name="session")
def session_fixture():
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
class Config:
broker_url = "redis://localhost:6379/0"
result_backend = "redis://localhost:6379/0"
@pytest.fixture(name="load_flows_dir")
def load_flows_dir():
tempdir = tempfile.TemporaryDirectory()
yield tempdir.name
@pytest.fixture(name="distributed_env")
def setup_env(monkeypatch):
monkeypatch.setenv("LANGFLOW_CACHE_TYPE", "redis")
monkeypatch.setenv("LANGFLOW_REDIS_HOST", "result_backend")
monkeypatch.setenv("LANGFLOW_REDIS_PORT", "6379")
monkeypatch.setenv("LANGFLOW_REDIS_DB", "0")
monkeypatch.setenv("LANGFLOW_REDIS_EXPIRE", "3600")
monkeypatch.setenv("LANGFLOW_REDIS_PASSWORD", "")
monkeypatch.setenv("FLOWER_UNAUTHENTICATED_API", "True")
monkeypatch.setenv("BROKER_URL", "redis://result_backend:6379/0")
monkeypatch.setenv("RESULT_BACKEND", "redis://result_backend:6379/0")
monkeypatch.setenv("C_FORCE_ROOT", "true")
@pytest.fixture(name="distributed_client")
def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
# Here we load the .env from ../deploy/.env
from langflow.core import celery_app
db_dir = tempfile.mkdtemp()
db_path = Path(db_dir) / "test.db"
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
# monkeypatch langflow.services.task.manager.USE_CELERY to True
# monkeypatch.setattr(manager, "USE_CELERY", True)
monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config))
# def get_session_override():
# return session
from langflow.main import create_app
app = create_app()
# app.dependency_overrides[get_session] = get_session_override
with TestClient(app) as client:
yield client
app.dependency_overrides.clear()
monkeypatch.undo()
def get_graph(_type="basic"):
"""Get a graph from a json file"""
if _type == "basic":
path = pytest.BASIC_EXAMPLE_PATH
elif _type == "complex":
path = pytest.COMPLEX_EXAMPLE_PATH
elif _type == "openapi":
path = pytest.OPENAPI_EXAMPLE_PATH
with open(path, "r") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
edges = data_graph["edges"]
return Graph(nodes, edges)
@pytest.fixture
def basic_graph_data():
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
return json.load(f)
@pytest.fixture
def basic_graph():
return get_graph()
@pytest.fixture
def complex_graph():
return get_graph("complex")
@pytest.fixture
def openapi_graph():
return get_graph("openapi")
@pytest.fixture
def json_flow():
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
return f.read()
@pytest.fixture
def grouped_chat_json_flow():
with open(pytest.GROUPED_CHAT_EXAMPLE_PATH, "r") as f:
return f.read()
@pytest.fixture
def one_grouped_chat_json_flow():
with open(pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH, "r") as f:
return f.read()
@pytest.fixture
def vector_store_grouped_json_flow():
with open(pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH, "r") as f:
return f.read()
@pytest.fixture
def json_flow_with_prompt_and_history():
with open(pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY, "r") as f:
return f.read()
@pytest.fixture
def json_vector_store():
with open(pytest.VECTOR_STORE_PATH, "r") as f:
return f.read()
@pytest.fixture(name="client", autouse=True)
def client_fixture(session: Session, monkeypatch, request, load_flows_dir):
# Set the database url to a test database
if "noclient" in request.keywords:
yield
else:
db_dir = tempfile.mkdtemp()
db_path = Path(db_dir) / "test.db"
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
if "load_flows" in request.keywords:
shutil.copyfile(
pytest.BASIC_EXAMPLE_PATH, os.path.join(load_flows_dir, "c54f9130-f2fa-4a3e-b22a-3856d946351b.json")
)
monkeypatch.setenv("LANGFLOW_LOAD_FLOWS_PATH", load_flows_dir)
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "true")
from langflow.main import create_app
app = create_app()
# app.dependency_overrides[get_session] = get_session_override
with TestClient(app) as client:
yield client
# app.dependency_overrides.clear()
monkeypatch.undo()
# clear the temp db
with suppress(FileNotFoundError):
db_path.unlink()
# create a fixture for session_getter above
@pytest.fixture(name="session_getter")
def session_getter_fixture(client):
@contextmanager
def blank_session_getter(db_service: "DatabaseService"):
with Session(db_service.engine) as session:
yield session
yield blank_session_getter
@pytest.fixture
def runner():
return CliRunner()
@pytest.fixture
def test_user(client):
user_data = UserCreate(
username="testuser",
password="testpassword",
)
response = client.post("/api/v1/users", json=user_data.dict())
assert response.status_code == 201
return response.json()
@pytest.fixture(scope="function")
def active_user(client):
db_manager = get_db_service()
with session_getter(db_manager) as session:
user = User(
username="activeuser",
password=get_password_hash("testpassword"),
is_active=True,
is_superuser=False,
)
# check if user exists
if active_user := session.exec(select(User).where(User.username == user.username)).first():
return active_user
session.add(user)
session.commit()
session.refresh(user)
return user
@pytest.fixture
def logged_in_headers(client, active_user):
login_data = {"username": active_user.username, "password": "testpassword"}
response = client.post("/api/v1/login", data=login_data)
assert response.status_code == 200
tokens = response.json()
a_token = tokens["access_token"]
return {"Authorization": f"Bearer {a_token}"}
@pytest.fixture
def flow(client, json_flow: str, active_user):
from langflow.services.database.models.flow.model import FlowCreate
loaded_json = json.loads(json_flow)
flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id)
flow = Flow.model_validate(flow_data)
with session_getter(get_db_service()) as session:
session.add(flow)
session.commit()
session.refresh(flow)
return flow
@pytest.fixture
def json_chat_input():
with open(pytest.CHAT_INPUT, "r") as f:
return f.read()
@pytest.fixture
def json_two_outputs():
with open(pytest.TWO_OUTPUTS, "r") as f:
return f.read()
@pytest.fixture
def added_flow_with_prompt_and_history(client, json_flow_with_prompt_and_history, logged_in_headers):
flow = orjson.loads(json_flow_with_prompt_and_history)
data = flow["data"]
flow = FlowCreate(name="Basic Chat", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
return response.json()
@pytest.fixture
def added_flow_chat_input(client, json_chat_input, logged_in_headers):
flow = orjson.loads(json_chat_input)
data = flow["data"]
flow = FlowCreate(name="Chat Input", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
return response.json()
@pytest.fixture
def added_flow_two_outputs(client, json_two_outputs, logged_in_headers):
flow = orjson.loads(json_two_outputs)
data = flow["data"]
flow = FlowCreate(name="Two Outputs", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
return response.json()
@pytest.fixture
def added_vector_store(client, json_vector_store, logged_in_headers):
vector_store = orjson.loads(json_vector_store)
data = vector_store["data"]
vector_store = FlowCreate(name="Vector Store", description="description", data=data)
response = client.post("api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == vector_store.name
assert response.json()["data"] == vector_store.data
return response.json()
@pytest.fixture
def created_api_key(active_user):
hashed = get_password_hash("random_key")
api_key = ApiKey(
name="test_api_key",
user_id=active_user.id,
api_key="random_key",
hashed_api_key=hashed,
)
db_manager = get_db_service()
with session_getter(db_manager) as session:
if existing_api_key := session.exec(select(ApiKey).where(ApiKey.api_key == api_key.api_key)).first():
return existing_api_key
session.add(api_key)
session.commit()
session.refresh(api_key)
return api_key
@pytest.fixture(name="starter_project")
def get_starter_project(active_user):
# once the client is created, we can get the starter project
with session_getter(get_db_service()) as session:
flow = session.exec(
select(Flow)
.where(Flow.folder.has(Folder.name == STARTER_FOLDER_NAME))
.where(Flow.name == "Basic Prompting (Hello, World)")
).first()
if not flow:
raise ValueError("No starter project found")
new_flow_create = FlowCreate(
name=flow.name,
description=flow.description,
data=flow.data,
user_id=active_user.id,
)
new_flow = Flow.model_validate(new_flow_create, from_attributes=True)
session.add(new_flow)
session.commit()
session.refresh(new_flow)
new_flow_dict = new_flow.model_dump()
return new_flow_dict

View file

@ -0,0 +1,45 @@
import pytest
from langflow.services.database.models.api_key import ApiKeyCreate
@pytest.fixture
def api_key(client, logged_in_headers, active_user):
api_key = ApiKeyCreate(name="test-api-key")
response = client.post("api/v1/api_key", data=api_key.model_dump_json(), headers=logged_in_headers)
assert response.status_code == 200, response.text
return response.json()
def test_get_api_keys(client, logged_in_headers, api_key):
response = client.get("api/v1/api_key", headers=logged_in_headers)
assert response.status_code == 200, response.text
data = response.json()
assert "total_count" in data
assert "user_id" in data
assert "api_keys" in data
assert any("test-api-key" in api_key["name"] for api_key in data["api_keys"])
# assert all api keys in data["api_keys"] are masked
assert all("**" in api_key["api_key"] for api_key in data["api_keys"])
def test_create_api_key(client, logged_in_headers):
api_key_name = "test-api-key"
response = client.post("api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers)
assert response.status_code == 200
data = response.json()
assert "name" in data and data["name"] == api_key_name
assert "api_key" in data
# When creating the API key is returned which is
# the only time the API key is unmasked
assert "**" not in data["api_key"]
def test_delete_api_key(client, logged_in_headers, active_user, api_key):
# Assuming a function to create a test API key, returning the key ID
api_key_id = api_key["id"]
response = client.delete(f"api/v1/api_key/{api_key_id}", headers=logged_in_headers)
assert response.status_code == 200
data = response.json()
assert data["detail"] == "API Key deleted"
# Optionally, add a follow-up check to ensure that the key is actually removed from the database

47
tests/unit/test_cache.py Normal file
View file

@ -0,0 +1,47 @@
import json
import pytest
from langflow.graph import Graph
def get_graph(_type="basic"):
"""Get a graph from a json file"""
if _type == "basic":
path = pytest.BASIC_EXAMPLE_PATH
elif _type == "complex":
path = pytest.COMPLEX_EXAMPLE_PATH
elif _type == "openapi":
path = pytest.OPENAPI_EXAMPLE_PATH
with open(path, "r") as f:
flow_graph = json.load(f)
return flow_graph["data"]
@pytest.fixture
def basic_data_graph():
return get_graph()
@pytest.fixture
def complex_data_graph():
return get_graph("complex")
@pytest.fixture
def openapi_data_graph():
return get_graph("openapi")
def langchain_objects_are_equal(obj1, obj2):
return str(obj1) == str(obj2)
# Test build_graph
@pytest.mark.asyncio
async def test_build_graph(client, basic_data_graph):
graph = Graph.from_payload(basic_data_graph)
assert graph is not None
assert len(graph.vertices) == len(basic_data_graph["nodes"])
assert len(graph.edges) == len(basic_data_graph["edges"])

37
tests/unit/test_cli.py Normal file
View file

@ -0,0 +1,37 @@
from pathlib import Path
from tempfile import tempdir
import pytest
from langflow.__main__ import app
from langflow.services import deps
@pytest.fixture(scope="module")
def default_settings():
return [
"--backend-only",
"--no-open-browser",
]
def test_components_path(runner, client, default_settings):
# Create a foldr in the tmp directory
temp_dir = Path(tempdir)
# create a "components" folder
temp_dir = temp_dir / "components"
temp_dir.mkdir(exist_ok=True)
result = runner.invoke(
app,
["run", "--components-path", str(temp_dir), *default_settings],
)
assert result.exit_code == 0, result.stdout
settings_service = deps.get_settings_service()
assert str(temp_dir) in settings_service.settings.components_path
def test_superuser(runner, client, session):
result = runner.invoke(app, ["superuser"], input="admin\nadmin\n")
assert result.exit_code == 0, result.stdout
assert "Superuser created successfully." in result.stdout

View file

@ -0,0 +1,519 @@
import ast
import types
from uuid import uuid4
import pytest
from langchain_core.documents import Document
from langflow.custom import CustomComponent
from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
from langflow.custom.custom_component.component import Component, ComponentCodeNullError
from langflow.services.database.models.flow import Flow, FlowCreate
code_default = """
from langflow.field_typing import Prompt
from langflow.custom import CustomComponent
from langflow.field_typing import BaseLanguageModel
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_core.documents import Document
import requests
class YourComponent(CustomComponent):
display_name: str = "Your Component"
description: str = "Your description"
field_config = { "url": { "multiline": True, "required": True } }
def build(self, url: str, llm: BaseLanguageModel, template: Prompt) -> Document:
response = requests.get(url)
prompt = PromptTemplate.from_template(template)
chain = LLMChain(llm=llm, prompt=prompt)
result = chain.run(response.text[:300])
return Document(page_content=str(result))
"""
def test_code_parser_init():
"""
Test the initialization of the CodeParser class.
"""
parser = CodeParser(code_default)
assert parser.code == code_default
def test_code_parser_get_tree():
"""
Test the __get_tree method of the CodeParser class.
"""
parser = CodeParser(code_default)
tree = parser.get_tree()
assert isinstance(tree, ast.AST)
def test_code_parser_syntax_error():
"""
Test the __get_tree method raises the
CodeSyntaxError when given incorrect syntax.
"""
code_syntax_error = "zzz import os"
parser = CodeParser(code_syntax_error)
with pytest.raises(CodeSyntaxError):
parser.get_tree()
def test_component_init():
"""
Test the initialization of the Component class.
"""
component = Component(code=code_default, function_entrypoint_name="build")
assert component.code == code_default
assert component.function_entrypoint_name == "build"
def test_component_get_code_tree():
"""
Test the get_code_tree method of the Component class.
"""
component = Component(code=code_default, function_entrypoint_name="build")
tree = component.get_code_tree(component.code)
assert "imports" in tree
def test_component_code_null_error():
"""
Test the get_function method raises the
ComponentCodeNullError when the code is empty.
"""
component = Component(code="", function_entrypoint_name="")
with pytest.raises(ComponentCodeNullError):
component.get_function()
def test_custom_component_init():
"""
Test the initialization of the CustomComponent class.
"""
function_entrypoint_name = "build"
custom_component = CustomComponent(code=code_default, function_entrypoint_name=function_entrypoint_name)
assert custom_component.code == code_default
assert custom_component.function_entrypoint_name == function_entrypoint_name
def test_custom_component_build_template_config():
"""
Test the build_template_config property of the CustomComponent class.
"""
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
config = custom_component.build_template_config()
assert isinstance(config, dict)
def test_custom_component_get_function():
"""
Test the get_function property of the CustomComponent class.
"""
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
my_function = custom_component.get_function()
assert isinstance(my_function, types.FunctionType)
def test_code_parser_parse_imports_import():
"""
Test the parse_imports method of the CodeParser
class with an import statement.
"""
parser = CodeParser(code_default)
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
parser.parse_imports(node)
assert "requests" in parser.data["imports"]
def test_code_parser_parse_imports_importfrom():
"""
Test the parse_imports method of the CodeParser
class with an import from statement.
"""
parser = CodeParser("from os import path")
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom):
parser.parse_imports(node)
assert ("os", "path") in parser.data["imports"]
def test_code_parser_parse_functions():
"""
Test the parse_functions method of the CodeParser class.
"""
parser = CodeParser("def test(): pass")
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
parser.parse_functions(node)
assert len(parser.data["functions"]) == 1
assert parser.data["functions"][0]["name"] == "test"
def test_code_parser_parse_classes():
"""
Test the parse_classes method of the CodeParser class.
"""
parser = CodeParser("class Test: pass")
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
parser.parse_classes(node)
assert len(parser.data["classes"]) == 1
assert parser.data["classes"][0]["name"] == "Test"
def test_code_parser_parse_global_vars():
"""
Test the parse_global_vars method of the CodeParser class.
"""
parser = CodeParser("x = 1")
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
parser.parse_global_vars(node)
assert len(parser.data["global_vars"]) == 1
assert parser.data["global_vars"][0]["targets"] == ["x"]
def test_component_get_function_valid():
"""
Test the get_function method of the Component
class with valid code and function_entrypoint_name.
"""
component = Component(code="def build(): pass", function_entrypoint_name="build")
my_function = component.get_function()
assert callable(my_function)
def test_custom_component_get_function_entrypoint_args():
"""
Test the get_function_entrypoint_args
property of the CustomComponent class.
"""
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
args = custom_component.get_function_entrypoint_args
assert len(args) == 4
assert args[0]["name"] == "self"
assert args[1]["name"] == "url"
assert args[2]["name"] == "llm"
def test_custom_component_get_function_entrypoint_return_type():
"""
Test the get_function_entrypoint_return_type
property of the CustomComponent class.
"""
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
return_type = custom_component.get_function_entrypoint_return_type
assert return_type == [Document]
def test_custom_component_get_main_class_name():
"""
Test the get_main_class_name property of the CustomComponent class.
"""
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
class_name = custom_component.get_main_class_name
assert class_name == "YourComponent"
def test_custom_component_get_function_valid():
"""
Test the get_function property of the CustomComponent
class with valid code and function_entrypoint_name.
"""
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
my_function = custom_component.get_function
assert callable(my_function)
def test_code_parser_parse_arg_no_annotation():
"""
Test the parse_arg method of the CodeParser class without an annotation.
"""
parser = CodeParser("")
arg = ast.arg(arg="x", annotation=None)
result = parser.parse_arg(arg, None)
assert result["name"] == "x"
assert "type" not in result
def test_code_parser_parse_arg_with_annotation():
"""
Test the parse_arg method of the CodeParser class with an annotation.
"""
parser = CodeParser("")
arg = ast.arg(arg="x", annotation=ast.Name(id="int", ctx=ast.Load()))
result = parser.parse_arg(arg, None)
assert result["name"] == "x"
assert result["type"] == "int"
def test_code_parser_parse_callable_details_no_args():
"""
Test the parse_callable_details method of the
CodeParser class with a function with no arguments.
"""
parser = CodeParser("")
node = ast.FunctionDef(
name="test",
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[],
decorator_list=[],
returns=None,
)
result = parser.parse_callable_details(node)
assert result["name"] == "test"
assert len(result["args"]) == 0
def test_code_parser_parse_assign():
"""
Test the parse_assign method of the CodeParser class.
"""
parser = CodeParser("")
stmt = ast.Assign(targets=[ast.Name(id="x", ctx=ast.Store())], value=ast.Num(n=1))
result = parser.parse_assign(stmt)
assert result["name"] == "x"
assert result["value"] == "1"
def test_code_parser_parse_ann_assign():
"""
Test the parse_ann_assign method of the CodeParser class.
"""
parser = CodeParser("")
stmt = ast.AnnAssign(
target=ast.Name(id="x", ctx=ast.Store()),
annotation=ast.Name(id="int", ctx=ast.Load()),
value=ast.Num(n=1),
simple=1,
)
result = parser.parse_ann_assign(stmt)
assert result["name"] == "x"
assert result["value"] == "1"
assert result["annotation"] == "int"
def test_code_parser_parse_function_def_not_init():
"""
Test the parse_function_def method of the
CodeParser class with a function that is not __init__.
"""
parser = CodeParser("")
stmt = ast.FunctionDef(
name="test",
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[],
decorator_list=[],
returns=None,
)
result, is_init = parser.parse_function_def(stmt)
assert result["name"] == "test"
assert not is_init
def test_code_parser_parse_function_def_init():
"""
Test the parse_function_def method of the
CodeParser class with an __init__ function.
"""
parser = CodeParser("")
stmt = ast.FunctionDef(
name="__init__",
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[],
decorator_list=[],
returns=None,
)
result, is_init = parser.parse_function_def(stmt)
assert result["name"] == "__init__"
assert is_init
def test_component_get_code_tree_syntax_error():
"""
Test the get_code_tree method of the Component class
raises the CodeSyntaxError when given incorrect syntax.
"""
component = Component(code="import os as", function_entrypoint_name="build")
with pytest.raises(CodeSyntaxError):
component.get_code_tree(component.code)
def test_custom_component_class_template_validation_no_code():
"""
Test the _class_template_validation method of the CustomComponent class
raises the HTTPException when the code is None.
"""
custom_component = CustomComponent(code=None, function_entrypoint_name="build")
with pytest.raises(TypeError):
custom_component.get_function()
def test_custom_component_get_code_tree_syntax_error():
"""
Test the get_code_tree method of the CustomComponent class
raises the CodeSyntaxError when given incorrect syntax.
"""
custom_component = CustomComponent(code="import os as", function_entrypoint_name="build")
with pytest.raises(CodeSyntaxError):
custom_component.get_code_tree(custom_component.code)
def test_custom_component_get_function_entrypoint_args_no_args():
"""
Test the get_function_entrypoint_args property of
the CustomComponent class with a build method with no arguments.
"""
my_code = """
class MyMainClass(CustomComponent):
def build():
pass"""
custom_component = CustomComponent(code=my_code, function_entrypoint_name="build")
args = custom_component.get_function_entrypoint_args
assert len(args) == 0
def test_custom_component_get_function_entrypoint_return_type_no_return_type():
"""
Test the get_function_entrypoint_return_type property of the
CustomComponent class with a build method with no return type.
"""
my_code = """
class MyClass(CustomComponent):
def build():
pass"""
custom_component = CustomComponent(code=my_code, function_entrypoint_name="build")
return_type = custom_component.get_function_entrypoint_return_type
assert return_type == []
def test_custom_component_get_main_class_name_no_main_class():
"""
Test the get_main_class_name property of the
CustomComponent class when there is no main class.
"""
my_code = """
def build():
pass"""
custom_component = CustomComponent(code=my_code, function_entrypoint_name="build")
class_name = custom_component.get_main_class_name
assert class_name == ""
def test_custom_component_build_not_implemented():
"""
Test the build method of the CustomComponent
class raises the NotImplementedError.
"""
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
with pytest.raises(NotImplementedError):
custom_component.build()
def test_build_config_no_code():
component = CustomComponent(code=None)
assert component.get_function_entrypoint_args == []
assert component.get_function_entrypoint_return_type == []
@pytest.fixture
def component(client, active_user):
return CustomComponent(
user_id=active_user.id,
field_config={
"fields": {
"llm": {"type": "str"},
"url": {"type": "str"},
"year": {"type": "int"},
}
},
)
@pytest.fixture(scope="session")
def test_flow(db):
flow_data = {
"nodes": [{"id": "1"}, {"id": "2"}],
"edges": [{"source": "1", "target": "2"}],
}
# Create flow
flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data)
# Add to database
db.add(flow)
db.commit()
yield flow
# Clean up
db.delete(flow)
db.commit()
@pytest.fixture(scope="session")
def db(app):
# Setup database for tests
yield app.db
# Teardown
app.db.drop_all()
def test_list_flows_return_type(component):
flows = component.list_flows()
assert isinstance(flows, list)
def test_list_flows_flow_objects(component):
flows = component.list_flows()
assert all(isinstance(flow, Flow) for flow in flows)
def test_build_config_return_type(component):
config = component.build_config()
assert isinstance(config, dict)
def test_build_config_has_fields(component):
config = component.build_config()
assert "fields" in config
def test_build_config_fields_dict(component):
config = component.build_config()
assert isinstance(config["fields"], dict)
def test_build_config_field_keys(component):
config = component.build_config()
assert all(isinstance(key, str) for key in config["fields"])
def test_build_config_field_values_dict(component):
config = component.build_config()
assert all(isinstance(value, dict) for value in config["fields"].values())
def test_build_config_field_value_keys(component):
config = component.build_config()
field_values = config["fields"].values()
assert all("type" in value for value in field_values)

View file

@ -0,0 +1,186 @@
import os
from pathlib import Path
from unittest.mock import Mock, patch
import httpx
import pytest
import respx
from dictdiffer import diff
from httpx import Response
from langflow.components import data
@pytest.fixture
def api_request():
# This fixture provides an instance of APIRequest for each test case
return data.APIRequest()
@pytest.mark.asyncio
@respx.mock
async def test_successful_get_request(api_request):
# Mocking a successful GET request
url = "https://example.com/api/test"
method = "GET"
mock_response = {"success": True}
respx.get(url).mock(return_value=Response(200, json=mock_response))
# Making the request
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url)
# Assertions
assert result.data["status_code"] == 200
assert result.data["result"] == mock_response
def test_parse_curl(api_request):
# Arrange
field_value = (
"curl -X GET https://example.com/api/test -H 'Content-Type: application/json' -d '{\"key\": \"value\"}'"
)
build_config = {
"method": {"value": ""},
"urls": {"value": []},
"headers": {},
"body": {},
}
# Act
new_build_config = api_request.parse_curl(field_value, build_config.copy())
# Assert
assert new_build_config["method"]["value"] == "GET"
assert new_build_config["urls"]["value"] == ["https://example.com/api/test"]
assert new_build_config["headers"]["value"] == {"Content-Type": "application/json"}
assert new_build_config["body"]["value"] == {"key": "value"}
@pytest.mark.asyncio
@respx.mock
async def test_failed_request(api_request):
# Mocking a failed GET request
url = "https://example.com/api/test"
method = "GET"
respx.get(url).mock(return_value=Response(404))
# Making the request
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url)
# Assertions
assert result.data["status_code"] == 404
@pytest.mark.asyncio
@respx.mock
async def test_timeout(api_request):
# Mocking a timeout
url = "https://example.com/api/timeout"
method = "GET"
respx.get(url).mock(side_effect=httpx.TimeoutException(message="Timeout", request=None))
# Making the request
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url, timeout=1)
# Assertions
assert result.data["status_code"] == 408
assert result.data["error"] == "Request timed out"
@pytest.mark.asyncio
@respx.mock
async def test_build_with_multiple_urls(api_request):
# This test depends on having a working internet connection and accessible URLs
# It's better to mock these requests using respx or a similar library
# Setup for multiple URLs
method = "GET"
urls = ["https://example.com/api/one", "https://example.com/api/two"]
# You would mock these requests similarly to the single request tests
for url in urls:
respx.get(url).mock(return_value=Response(200, json={"success": True}))
# Do I have to mock the async client?
#
# Execute the build method
results = await api_request.build(method=method, urls=urls)
# Assertions
assert len(results) == len(urls)
@patch("langflow.components.data.Directory.parallel_load_records")
@patch("langflow.components.data.Directory.retrieve_file_paths")
@patch("langflow.components.data.DirectoryComponent.resolve_path")
def test_directory_component_build_with_multithreading(
mock_resolve_path, mock_retrieve_file_paths, mock_parallel_load_records
):
# Arrange
directory_component = data.DirectoryComponent()
path = os.path.dirname(os.path.abspath(__file__))
depth = 1
max_concurrency = 2
load_hidden = False
recursive = True
silent_errors = False
use_multithreading = True
mock_resolve_path.return_value = path
mock_retrieve_file_paths.return_value = [
os.path.join(path, file) for file in os.listdir(path) if file.endswith(".py")
]
mock_parallel_load_records.return_value = [Mock()]
# Act
directory_component.build(
path,
depth,
max_concurrency,
load_hidden,
recursive,
silent_errors,
use_multithreading,
)
# Assert
mock_resolve_path.assert_called_once_with(path)
mock_retrieve_file_paths.assert_called_once_with(path, load_hidden, recursive, depth)
mock_parallel_load_records.assert_called_once_with(
mock_retrieve_file_paths.return_value, silent_errors, max_concurrency
)
def test_directory_without_mocks():
directory_component = data.DirectoryComponent()
from langflow.initial_setup import setup
from langflow.initial_setup.setup import load_starter_projects
_, projects = zip(*load_starter_projects())
# the setup module has a folder where the projects are stored
# the contents of that folder are in the projects variable
# the directory component can be used to load the projects
# and we can validate if the contents are the same as the projects variable
setup_path = Path(setup.__file__).parent / "starter_projects"
results = directory_component.build(str(setup_path), use_multithreading=False)
assert len(results) == len(projects)
# each result is a Record that contains the content attribute
# each are dict that are exactly the same as one of the projects
for i, result in enumerate(results):
assert result.text in projects, list(diff(result.text, projects[i]))
# in ../docs/docs/components there are many mdx files
# check if the directory component can load them
# just check if the number of results is the same as the number of files
docs_path = Path(__file__).parent.parent / "docs" / "docs" / "components"
results = directory_component.build(str(docs_path), use_multithreading=False)
docs_files = list(docs_path.glob("*.mdx"))
assert len(results) == len(docs_files)
def test_url_component():
url_component = data.URLComponent()
# the url component can be used to load the contents of a website
records = url_component.build(["https://langflow.org"])
assert all(record.data for record in records)
assert all(record.text for record in records)
assert all(record.source for record in records)

277
tests/unit/test_database.py Normal file
View file

@ -0,0 +1,277 @@
from uuid import UUID, uuid4
import orjson
import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session
from langflow.api.v1.schemas import FlowListCreate
from langflow.initial_setup.setup import load_starter_projects
from langflow.services.database.models.base import orjson_dumps
from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
@pytest.fixture(scope="module")
def json_style():
# class FlowStyleBase(SQLModel):
# color: str = Field(index=True)
# emoji: str = Field(index=False)
# flow_id: UUID = Field(default=None, foreign_key="flow.id")
return orjson_dumps(
{
"color": "red",
"emoji": "👍",
}
)
def test_create_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
flow = FlowCreate(name=str(uuid4()), description="description", data=data)
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
# flow is optional so we can create a flow without a flow
flow = FlowCreate(name="Test Flow")
response = client.post("api/v1/flows/", json=flow.model_dump(exclude_unset=True), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
def test_read_flows(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow_data = orjson.loads(json_flow)
data = flow_data["data"]
flow = FlowCreate(name=str(uuid4()), description="description", data=data)
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
flow = FlowCreate(name=str(uuid4()), description="description", data=data)
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
response = client.get("api/v1/flows/", headers=logged_in_headers)
assert response.status_code == 200
assert len(response.json()) > 0
def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
flow = FlowCreate(name="Test Flow", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
flow_id = response.json()["id"] # flow_id should be a UUID but is a string
# turn it into a UUID
flow_id = UUID(flow_id)
response = client.get(f"api/v1/flows/{flow_id}", headers=logged_in_headers)
assert response.status_code == 200
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
flow = FlowCreate(name="Test Flow", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
flow_id = response.json()["id"]
updated_flow = FlowUpdate(
name="Updated Flow",
description="updated description",
data=data,
)
response = client.patch(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
assert response.status_code == 200
assert response.json()["name"] == updated_flow.name
assert response.json()["description"] == updated_flow.description
# assert response.json()["data"] == updated_flow.data
def test_delete_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
flow = FlowCreate(name="Test Flow", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
flow_id = response.json()["id"]
response = client.delete(f"api/v1/flows/{flow_id}", headers=logged_in_headers)
assert response.status_code == 200
assert response.json()["message"] == "Flow deleted successfully"
def test_delete_flows(client: TestClient, json_flow: str, active_user, logged_in_headers):
# Create ten flows
number_of_flows = 10
flows = [FlowCreate(name=f"Flow {i}", description="description", data={}) for i in range(number_of_flows)]
flow_ids = []
for flow in flows:
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
assert response.status_code == 201
flow_ids.append(response.json()["id"])
response = client.request("DELETE", "api/v1/flows/", headers=logged_in_headers, json=flow_ids)
assert response.status_code == 200, response.content
assert response.json().get("deleted") == number_of_flows
def test_create_flows(client: TestClient, session: Session, json_flow: str, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
# Create test data
flow_list = FlowListCreate(
flows=[
FlowCreate(name="Flow 1", description="description", data=data),
FlowCreate(name="Flow 2", description="description", data=data),
]
)
# Make request to endpoint
response = client.post("api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers)
# Check response status code
assert response.status_code == 201
# Check response data
response_data = response.json()
assert len(response_data) == 2
assert response_data[0]["name"] == "Flow 1"
assert response_data[0]["description"] == "description"
assert response_data[0]["data"] == data
assert response_data[1]["name"] == "Flow 2"
assert response_data[1]["description"] == "description"
assert response_data[1]["data"] == data
def test_upload_file(client: TestClient, session: Session, json_flow: str, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
# Create test data
flow_list = FlowListCreate(
flows=[
FlowCreate(name="Flow 1", description="description", data=data),
FlowCreate(name="Flow 2", description="description", data=data),
]
)
file_contents = orjson_dumps(flow_list.dict())
response = client.post(
"api/v1/flows/upload/",
files={"file": ("examples.json", file_contents, "application/json")},
headers=logged_in_headers,
)
# Check response status code
assert response.status_code == 201
# Check response data
response_data = response.json()
assert len(response_data) == 2
assert response_data[0]["name"] == "Flow 1"
assert response_data[0]["description"] == "description"
assert response_data[0]["data"] == data
assert response_data[1]["name"] == "Flow 2"
assert response_data[1]["description"] == "description"
assert response_data[1]["data"] == data
def test_download_file(
client: TestClient,
session: Session,
json_flow,
active_user,
logged_in_headers,
):
flow = orjson.loads(json_flow)
data = flow["data"]
# Create test data
flow_list = FlowListCreate(
flows=[
FlowCreate(name="Flow 1", description="description", data=data),
FlowCreate(name="Flow 2", description="description", data=data),
]
)
db_manager = get_db_service()
with session_getter(db_manager) as session:
for flow in flow_list.flows:
flow.user_id = active_user.id
db_flow = Flow.model_validate(flow, from_attributes=True)
session.add(db_flow)
session.commit()
# Make request to endpoint
response = client.get("api/v1/flows/download/", headers=logged_in_headers)
# Check response status code
assert response.status_code == 200, response.json()
# Check response data
response_data = response.json()["flows"]
starter_projects = load_starter_projects()
number_of_projects = len(starter_projects) + len(flow_list.flows)
assert len(response_data) == number_of_projects, response_data
assert response_data[0]["name"] == "Flow 1"
assert response_data[0]["description"] == "description"
assert response_data[0]["data"] == data
assert response_data[1]["name"] == "Flow 2"
assert response_data[1]["description"] == "description"
assert response_data[1]["data"] == data
def test_create_flow_with_invalid_data(client: TestClient, active_user, logged_in_headers):
flow = {"name": "a" * 256, "data": "Invalid flow data"}
response = client.post("api/v1/flows/", json=flow, headers=logged_in_headers)
assert response.status_code == 422
def test_get_nonexistent_flow(client: TestClient, active_user, logged_in_headers):
uuid = uuid4()
response = client.get(f"api/v1/flows/{uuid}", headers=logged_in_headers)
assert response.status_code == 404
def test_update_flow_idempotency(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow_data = orjson.loads(json_flow)
data = flow_data["data"]
flow_data = FlowCreate(name="Test Flow", description="description", data=data)
response = client.post("api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers)
flow_id = response.json()["id"]
updated_flow = FlowCreate(name="Updated Flow", description="description", data=data)
response1 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
response2 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
assert response1.json() == response2.json()
def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow_data = orjson.loads(json_flow)
data = flow_data["data"]
uuid = uuid4()
updated_flow = FlowCreate(
name="Updated Flow",
description="description",
data=data,
)
response = client.patch(f"api/v1/flows/{uuid}", json=updated_flow.model_dump(), headers=logged_in_headers)
assert response.status_code == 404, response.text
def test_delete_nonexistent_flow(client: TestClient, active_user, logged_in_headers):
uuid = uuid4()
response = client.delete(f"api/v1/flows/{uuid}", headers=logged_in_headers)
assert response.status_code == 404
def test_read_only_starter_projects(client: TestClient, active_user, logged_in_headers):
response = client.get("api/v1/flows/", headers=logged_in_headers)
starter_projects = load_starter_projects()
assert response.status_code == 200
assert len(response.json()) == len(starter_projects)
@pytest.mark.load_flows
def test_load_flows(client: TestClient, load_flows_dir):
client.get("/api/v1/auto_login")
response = client.get("api/v1/flows/c54f9130-f2fa-4a3e-b22a-3856d946351b")
assert response.status_code == 200
assert response.json()["name"] == "BasicExample"

104
tests/unit/test_files.py Normal file
View file

@ -0,0 +1,104 @@
from unittest.mock import MagicMock
import pytest
from langflow.services.deps import get_storage_service
from langflow.services.storage.service import StorageService
@pytest.fixture
def mock_storage_service():
# Create a mock instance of StorageService
service = MagicMock(spec=StorageService)
# Setup mock behaviors for the service methods as needed
service.save_file.return_value = None
service.get_file.return_value = b"file content" # Binary content for files
service.list_files.return_value = ["file1.txt", "file2.jpg"]
service.delete_file.return_value = None
return service
def test_upload_file(client, mock_storage_service, created_api_key, flow):
headers = {"x-api-key": created_api_key.api_key}
# Replace the actual storage service with the mock
client.app.dependency_overrides[get_storage_service] = lambda: mock_storage_service
response = client.post(
f"api/v1/files/upload/{flow.id}",
files={"file": ("test.txt", b"test content")},
headers=headers,
)
assert response.status_code == 201
assert response.json() == {
"flowId": str(flow.id),
"file_path": f"{flow.id}/test.txt",
}
def test_download_file(client, mock_storage_service, created_api_key, flow):
headers = {"x-api-key": created_api_key.api_key}
client.app.dependency_overrides[get_storage_service] = lambda: mock_storage_service
response = client.get(f"api/v1/files/download/{flow.id}/test.txt", headers=headers)
assert response.status_code == 200
assert response.content == b"file content"
def test_list_files(client, mock_storage_service, created_api_key, flow):
headers = {"x-api-key": created_api_key.api_key}
client.app.dependency_overrides[get_storage_service] = lambda: mock_storage_service
response = client.get(f"api/v1/files/list/{flow.id}", headers=headers)
assert response.status_code == 200
assert response.json() == {"files": ["file1.txt", "file2.jpg"]}
def test_delete_file(client, mock_storage_service, created_api_key, flow):
headers = {"x-api-key": created_api_key.api_key}
client.app.dependency_overrides[get_storage_service] = lambda: mock_storage_service
response = client.delete(f"api/v1/files/delete/{flow.id}/test.txt", headers=headers)
assert response.status_code == 200
assert response.json() == {"message": "File test.txt deleted successfully"}
def test_file_operations(client, created_api_key, flow):
headers = {"x-api-key": created_api_key.api_key}
flow_id = flow.id
file_name = "test.txt"
file_content = b"Hello, world!"
# Step 1: Upload the file
response = client.post(
f"api/v1/files/upload/{flow_id}",
files={"file": (file_name, file_content)},
headers=headers,
)
assert response.status_code == 201
assert response.json() == {
"flowId": str(flow_id),
"file_path": f"{flow_id}/{file_name}",
}
# Step 2: List files in the folder
response = client.get(f"api/v1/files/list/{flow_id}", headers=headers)
assert response.status_code == 200
assert file_name in response.json()["files"]
# Step 3: Download the file and verify its content
response = client.get(f"api/v1/files/download/{flow_id}/{file_name}", headers=headers)
assert response.status_code == 200
assert response.content == file_content
# the headers are application/octet-stream
assert response.headers["content-type"] == "application/octet-stream"
# mime_type is inside media_type
# Step 4: Delete the file
response = client.delete(f"api/v1/files/delete/{flow_id}/{file_name}", headers=headers)
assert response.status_code == 200
assert response.json() == {"message": f"File {file_name} deleted successfully"}
# Verify that the file is indeed deleted
response = client.get(f"api/v1/files/list/{flow_id}", headers=headers)
assert file_name not in response.json()["files"]

View file

@ -0,0 +1,56 @@
import pytest
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
from langflow.template.template.base import Template
@pytest.fixture
def sample_template_field() -> TemplateField:
return TemplateField(name="test_field", field_type="str")
@pytest.fixture
def sample_template(sample_template_field: TemplateField) -> Template:
return Template(type_name="test_template", fields=[sample_template_field])
@pytest.fixture
def sample_frontend_node(sample_template: Template) -> FrontendNode:
return FrontendNode(
template=sample_template,
description="test description",
base_classes=["base_class1", "base_class2"],
name="test_frontend_node",
)
def test_template_field_defaults(sample_template_field: TemplateField):
assert sample_template_field.field_type == "str"
assert sample_template_field.required is False
assert sample_template_field.placeholder == ""
assert sample_template_field.is_list is False
assert sample_template_field.show is True
assert sample_template_field.multiline is False
assert sample_template_field.value == ""
assert sample_template_field.file_types == []
assert sample_template_field.file_path == ""
assert sample_template_field.password is False
assert sample_template_field.name == "test_field"
def test_template_to_dict(sample_template: Template, sample_template_field: TemplateField):
template_dict = sample_template.to_dict()
assert template_dict["_type"] == "test_template"
assert len(template_dict) == 2 # _type and test_field
assert "test_field" in template_dict
assert "type" in template_dict["test_field"]
assert "required" in template_dict["test_field"]
def test_frontend_node_to_dict(sample_frontend_node: FrontendNode):
node_dict = sample_frontend_node.to_dict()
assert len(node_dict) == 1
assert "test_frontend_node" in node_dict
assert "description" in node_dict["test_frontend_node"]
assert "template" in node_dict["test_frontend_node"]
assert "base_classes" in node_dict["test_frontend_node"]

418
tests/unit/test_graph.py Normal file
View file

@ -0,0 +1,418 @@
import copy
import json
import pickle
from typing import Type, Union
import pytest
from langflow.graph import Graph
from langflow.graph.edge.base import Edge
from langflow.graph.graph.utils import (
find_last_node,
process_flow,
set_new_target_handle,
ungroup_node,
update_source_handle,
update_target_handle,
update_template,
)
from langflow.graph.vertex.base import Vertex
from langflow.initial_setup.setup import load_starter_projects
from langflow.utils.payload import get_root_vertex
# Test cases for the graph module
# now we have three types of graph:
# BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH
@pytest.fixture
def sample_template():
return {
"field1": {"proxy": {"field": "some_field", "id": "node1"}},
"field2": {"proxy": {"field": "other_field", "id": "node2"}},
}
@pytest.fixture
def sample_nodes():
return [
{
"id": "node1",
"data": {"node": {"template": {"some_field": {"show": True, "advanced": False, "name": "Name1"}}}},
},
{
"id": "node2",
"data": {
"node": {
"template": {
"other_field": {
"show": False,
"advanced": True,
"display_name": "DisplayName2",
}
}
}
},
},
{
"id": "node3",
"data": {"node": {"template": {"unrelated_field": {"show": True, "advanced": True}}}},
},
]
def get_node_by_type(graph, node_type: Type[Vertex]) -> Union[Vertex, None]:
"""Get a node by type"""
return next((node for node in graph.vertices if isinstance(node, node_type)), None)
def test_graph_structure(basic_graph):
assert isinstance(basic_graph, Graph)
assert len(basic_graph.vertices) > 0
assert len(basic_graph.edges) > 0
for node in basic_graph.vertices:
assert isinstance(node, Vertex)
for edge in basic_graph.edges:
assert isinstance(edge, Edge)
source_vertex = basic_graph.get_vertex(edge.source_id)
target_vertex = basic_graph.get_vertex(edge.target_id)
assert source_vertex in basic_graph.vertices
assert target_vertex in basic_graph.vertices
def test_circular_dependencies(basic_graph):
assert isinstance(basic_graph, Graph)
def check_circular(node, visited):
visited.add(node)
neighbors = basic_graph.get_vertices_with_target(node)
for neighbor in neighbors:
if neighbor in visited:
return True
if check_circular(neighbor, visited.copy()):
return True
return False
for node in basic_graph.vertices:
assert not check_circular(node, set())
def test_invalid_node_types():
graph_data = {
"nodes": [
{
"id": "1",
"data": {
"node": {
"base_classes": ["BaseClass"],
"template": {
"_type": "InvalidNodeType",
},
},
},
},
],
"edges": [],
}
with pytest.raises(Exception):
Graph(graph_data["nodes"], graph_data["edges"])
def test_get_vertices_with_target(basic_graph):
"""Test getting connected nodes"""
assert isinstance(basic_graph, Graph)
# Get root node
root = get_root_vertex(basic_graph)
assert root is not None
connected_nodes = basic_graph.get_vertices_with_target(root.id)
assert connected_nodes is not None
def test_get_node_neighbors_basic(basic_graph):
"""Test getting node neighbors"""
assert isinstance(basic_graph, Graph)
# Get root node
root = get_root_vertex(basic_graph)
assert root is not None
neighbors = basic_graph.get_vertex_neighbors(root)
assert neighbors is not None
assert isinstance(neighbors, dict)
# Root Node is an Agent, it requires an LLMChain and tools
# We need to check if there is a Chain in the one of the neighbors'
# data attribute in the type key
assert any("ConversationBufferMemory" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
assert any("OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
def test_get_node(basic_graph):
"""Test getting a single node"""
node_id = basic_graph.vertices[0].id
node = basic_graph.get_vertex(node_id)
assert isinstance(node, Vertex)
assert node.id == node_id
def test_build_nodes(basic_graph):
"""Test building nodes"""
assert len(basic_graph.vertices) == len(basic_graph._vertices)
for node in basic_graph.vertices:
assert isinstance(node, Vertex)
def test_build_edges(basic_graph):
"""Test building edges"""
assert len(basic_graph.edges) == len(basic_graph._edges)
for edge in basic_graph.edges:
assert isinstance(edge, Edge)
assert isinstance(edge.source_id, str)
assert isinstance(edge.target_id, str)
def test_get_root_vertex(client, basic_graph, complex_graph):
"""Test getting root node"""
assert isinstance(basic_graph, Graph)
root = get_root_vertex(basic_graph)
assert root is not None
assert isinstance(root, Vertex)
assert root.data["type"] == "TimeTravelGuideChain"
# For complex example, the root node is a ZeroShotAgent too
assert isinstance(complex_graph, Graph)
root = get_root_vertex(complex_graph)
assert root is not None
assert isinstance(root, Vertex)
assert root.data["type"] == "ZeroShotAgent"
def test_validate_edges(basic_graph):
"""Test validating edges"""
assert isinstance(basic_graph, Graph)
# all edges should be valid
assert all(edge.valid for edge in basic_graph.edges)
def test_matched_type(basic_graph):
"""Test matched type attribute in Edge"""
assert isinstance(basic_graph, Graph)
# all edges should be valid
assert all(edge.valid for edge in basic_graph.edges)
# all edges should have a matched_type attribute
assert all(hasattr(edge, "matched_type") for edge in basic_graph.edges)
# The matched_type attribute should be in the source_types attr
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
def test_build_params(basic_graph):
"""Test building params"""
assert isinstance(basic_graph, Graph)
# all edges should be valid
assert all(edge.valid for edge in basic_graph.edges)
# all edges should have a matched_type attribute
assert all(hasattr(edge, "matched_type") for edge in basic_graph.edges)
# The matched_type attribute should be in the source_types attr
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
# Get the root node
root = get_root_vertex(basic_graph)
# Root node is a TimeTravelGuideChain
# which requires an llm and memory
assert root is not None
assert isinstance(root.params, dict)
assert "llm" in root.params
assert "memory" in root.params
# def test_wrapper_node_build(openapi_graph):
# wrapper_node = get_node_by_type(openapi_graph, WrapperVertex)
# assert wrapper_node is not None
# built_object = wrapper_node.build()
# assert built_object is not None
def test_find_last_node(grouped_chat_json_flow):
grouped_chat_data = json.loads(grouped_chat_json_flow).get("data")
nodes, edges = grouped_chat_data["nodes"], grouped_chat_data["edges"]
last_node = find_last_node(nodes, edges)
assert last_node is not None # Replace with the actual expected value
assert last_node["id"] == "LLMChain-pimAb" # Replace with the actual expected value
def test_ungroup_node(grouped_chat_json_flow):
grouped_chat_data = json.loads(grouped_chat_json_flow).get("data")
group_node = grouped_chat_data["nodes"][2] # Assuming the first node is a group node
base_flow = copy.deepcopy(grouped_chat_data)
ungroup_node(group_node["data"], base_flow)
# after ungroup_node is called, the base_flow and grouped_chat_data should be different
assert base_flow != grouped_chat_data
# assert node 2 is not a group node anymore
assert base_flow["nodes"][2]["data"]["node"].get("flow") is None
# assert the edges are updated
assert len(base_flow["edges"]) > len(grouped_chat_data["edges"])
assert base_flow["edges"][0]["source"] == "ConversationBufferMemory-kUMif"
assert base_flow["edges"][0]["target"] == "LLMChain-2P369"
assert base_flow["edges"][1]["source"] == "PromptTemplate-Wjk4g"
assert base_flow["edges"][1]["target"] == "LLMChain-2P369"
assert base_flow["edges"][2]["source"] == "ChatOpenAI-rUJ1b"
assert base_flow["edges"][2]["target"] == "LLMChain-2P369"
def test_process_flow(grouped_chat_json_flow):
grouped_chat_data = json.loads(grouped_chat_json_flow).get("data")
processed_flow = process_flow(grouped_chat_data)
assert processed_flow is not None
assert isinstance(processed_flow, dict)
assert "nodes" in processed_flow
assert "edges" in processed_flow
def test_process_flow_one_group(one_grouped_chat_json_flow):
grouped_chat_data = json.loads(one_grouped_chat_json_flow).get("data")
# There should be only one node
assert len(grouped_chat_data["nodes"]) == 1
# Get the node, it should be a group node
group_node = grouped_chat_data["nodes"][0]
node_data = group_node["data"]["node"]
assert node_data.get("flow") is not None
template_data = node_data["template"]
assert any("openai_api_key" in key for key in template_data.keys())
# Get the openai_api_key dict
openai_api_key = next(
(template_data[key] for key in template_data.keys() if "openai_api_key" in key),
None,
)
assert openai_api_key is not None
assert openai_api_key["value"] == "test"
processed_flow = process_flow(grouped_chat_data)
assert processed_flow is not None
assert isinstance(processed_flow, dict)
assert "nodes" in processed_flow
assert "edges" in processed_flow
# Now get the node that has ChatOpenAI in its id
chat_openai_node = next((node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None)
assert chat_openai_node is not None
assert chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] == "test"
def test_process_flow_vector_store_grouped(vector_store_grouped_json_flow):
grouped_chat_data = json.loads(vector_store_grouped_json_flow).get("data")
nodes = grouped_chat_data["nodes"]
assert len(nodes) == 4
# There are two group nodes in this flow
# One of them is inside the other totalling 7 nodes
# 4 nodes grouped, one of these turns into 1 normal node and 1 group node
# This group node has 2 nodes inside it
processed_flow = process_flow(grouped_chat_data)
assert processed_flow is not None
processed_nodes = processed_flow["nodes"]
assert len(processed_nodes) == 7
assert isinstance(processed_flow, dict)
assert "nodes" in processed_flow
assert "edges" in processed_flow
edges = processed_flow["edges"]
# Expected keywords in source and target fields
expected_keywords = [
{"source": "VectorStoreInfo", "target": "VectorStoreAgent"},
{"source": "ChatOpenAI", "target": "VectorStoreAgent"},
{"source": "OpenAIEmbeddings", "target": "Chroma"},
{"source": "Chroma", "target": "VectorStoreInfo"},
{"source": "WebBaseLoader", "target": "RecursiveCharacterTextSplitter"},
{"source": "RecursiveCharacterTextSplitter", "target": "Chroma"},
]
for idx, expected_keyword in enumerate(expected_keywords):
for key, value in expected_keyword.items():
assert (
value in edges[idx][key].split("-")[0]
), f"Edge {idx}, key {key} expected to contain {value} but got {edges[idx][key]}"
def test_update_template(sample_template, sample_nodes):
# Making a deep copy to keep original sample_nodes unchanged
nodes_copy = copy.deepcopy(sample_nodes)
update_template(sample_template, nodes_copy)
# Now, validate the updates.
node1_updated = next((n for n in nodes_copy if n["id"] == "node1"), None)
node2_updated = next((n for n in nodes_copy if n["id"] == "node2"), None)
node3_updated = next((n for n in nodes_copy if n["id"] == "node3"), None)
assert node1_updated["data"]["node"]["template"]["some_field"]["show"] is True
assert node1_updated["data"]["node"]["template"]["some_field"]["advanced"] is False
assert node1_updated["data"]["node"]["template"]["some_field"]["display_name"] == "Name1"
assert node2_updated["data"]["node"]["template"]["other_field"]["show"] is False
assert node2_updated["data"]["node"]["template"]["other_field"]["advanced"] is True
assert node2_updated["data"]["node"]["template"]["other_field"]["display_name"] == "DisplayName2"
# Ensure node3 remains unchanged
assert node3_updated == sample_nodes[2]
# Test `update_target_handle`
def test_update_target_handle_proxy():
new_edge = {
"data": {
"targetHandle": {
"type": "some_type",
"proxy": {"id": "some_id", "field": ""},
}
}
}
g_nodes = [{"id": "some_id", "data": {"node": {"flow": None}}}]
group_node_id = "group_id"
updated_edge = update_target_handle(new_edge, g_nodes, group_node_id)
assert updated_edge["data"]["targetHandle"] == new_edge["data"]["targetHandle"]
# Test `set_new_target_handle`
def test_set_new_target_handle():
proxy_id = "proxy_id"
new_edge = {"target": None, "data": {"targetHandle": {}}}
target_handle = {"type": "type_1", "proxy": {"field": "field_1"}}
node = {
"data": {
"node": {
"flow": True,
"template": {"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}},
}
}
}
set_new_target_handle(proxy_id, new_edge, target_handle, node)
assert new_edge["target"] == "proxy_id"
assert new_edge["data"]["targetHandle"]["fieldName"] == "field_1"
assert new_edge["data"]["targetHandle"]["proxy"] == {
"field": "new_field",
"id": "new_id",
}
# Test `update_source_handle`
def test_update_source_handle():
new_edge = {"source": None, "data": {"sourceHandle": {"id": None}}}
flow_data = {
"nodes": [{"id": "some_node"}, {"id": "last_node"}],
"edges": [{"source": "some_node"}],
}
updated_edge = update_source_handle(new_edge, flow_data["nodes"], flow_data["edges"])
assert updated_edge["source"] == "last_node"
assert updated_edge["data"]["sourceHandle"]["id"] == "last_node"
@pytest.mark.asyncio
async def test_pickle_graph():
starter_projects = load_starter_projects()
data = starter_projects[0][1]["data"]
graph = Graph.from_payload(data)
assert isinstance(graph, Graph)
pickled = pickle.dumps(graph)
assert pickled is not None
unpickled = pickle.loads(pickled)
assert unpickled is not None

View file

@ -0,0 +1,81 @@
from langchain_core.documents import Document
from langflow.components import helpers
from langflow.custom.utils import build_custom_component_template
from langflow.schema import Record
def test_update_record_component():
# Arrange
update_record_component = helpers.UpdateRecordComponent()
# Act
new_data = {"new_key": "new_value"}
existing_record = Record(data={"existing_key": "existing_value"})
result = update_record_component.build(existing_record, new_data)
assert result.data == {"existing_key": "existing_value", "new_key": "new_value"}
assert result.existing_key == "existing_value"
assert result.new_key == "new_value"
def test_document_to_record_component():
# Arrange
document_to_record_component = helpers.DocumentToRecordComponent()
# Act
# Replace with your actual test data
document = Document(page_content="key: value", metadata={"url": "https://example.com"})
result = document_to_record_component.build(document)
# Assert
# Replace with your actual expected result
assert result == [Record(data={"text": "key: value", "url": "https://example.com"})]
def test_uuid_generator_component():
# Arrange
uuid_generator_component = helpers.UUIDGeneratorComponent()
uuid_generator_component.code = open(helpers.IDGenerator.__file__, "r").read()
frontend_node, _ = build_custom_component_template(uuid_generator_component)
# Act
build_config = frontend_node.get("template")
field_name = "unique_id"
build_config = uuid_generator_component.update_build_config(build_config, None, field_name)
unique_id = build_config["unique_id"]["value"]
result = uuid_generator_component.build(unique_id)
# Assert
# UUID should be a string of length 36
assert isinstance(result, str)
assert len(result) == 36
def test_records_as_text_component():
# Arrange
records_as_text_component = helpers.RecordsToTextComponent()
# Act
# Replace with your actual test data
records = [Record(data={"key": "value", "bacon": "eggs"})]
template = "Data:{data} -- Bacon:{bacon}"
result = records_as_text_component.build(records, template=template)
# Assert
# Replace with your actual expected result
assert result == "Data:{'key': 'value', 'bacon': 'eggs'} -- Bacon:eggs"
def test_text_to_record_component():
# Arrange
text_to_record_component = helpers.CreateRecordComponent()
# Act
# Replace with your actual test data
dict_with_text = {"field_1": {"key": "value"}}
result = text_to_record_component.build(number_of_fields=1, **dict_with_text)
# Assert
# Replace with your actual expected result
assert result == Record(data={"key": "value"})

View file

@ -0,0 +1,93 @@
from datetime import datetime
from pathlib import Path
from sqlmodel import select
from langflow.initial_setup.setup import (
STARTER_FOLDER_NAME,
create_or_update_starter_projects,
get_project_data,
load_starter_projects,
)
from langflow.services.database.models.folder.model import Folder
from langflow.services.deps import session_scope
def test_load_starter_projects():
projects = load_starter_projects()
assert isinstance(projects, list)
assert all(isinstance(project[1], dict) for project in projects)
assert all(isinstance(project[0], Path) for project in projects)
def test_get_project_data():
projects = load_starter_projects()
for _, project in projects:
(
project_name,
project_description,
project_is_component,
updated_at_datetime,
project_data,
project_icon,
project_icon_bg_color,
) = get_project_data(project)
assert isinstance(project_name, str)
assert isinstance(project_description, str)
assert isinstance(project_is_component, bool)
assert isinstance(updated_at_datetime, datetime)
assert isinstance(project_data, dict)
assert isinstance(project_icon, str) or project_icon is None
assert isinstance(project_icon_bg_color, str) or project_icon_bg_color is None
def test_create_or_update_starter_projects(client):
with session_scope() as session:
# Run the function to create or update projects
create_or_update_starter_projects()
# Get the number of projects returned by load_starter_projects
num_projects = len(load_starter_projects())
# Get the number of projects in the database
folder = session.exec(select(Folder).where(Folder.name == STARTER_FOLDER_NAME)).first()
num_db_projects = len(folder.flows)
# Check that the number of projects in the database is the same as the number of projects returned by load_starter_projects
assert num_db_projects == num_projects
# Some starter projects require integration
# @pytest.mark.asyncio
# async def test_starter_projects_can_run_successfully(client):
# with session_scope() as session:
# # Run the function to create or update projects
# create_or_update_starter_projects()
# # Get the number of projects returned by load_starter_projects
# num_projects = len(load_starter_projects())
# # Get the number of projects in the database
# num_db_projects = session.exec(select(func.count(Flow.id)).where(Flow.folder == STARTER_FOLDER_NAME)).one()
# # Check that the number of projects in the database is the same as the number of projects returned by load_starter_projects
# assert num_db_projects == num_projects
# # Get all the starter projects
# projects = session.exec(select(Flow).where(Flow.folder == STARTER_FOLDER_NAME)).all()
# graphs: list[tuple[str, Graph]] = []
# for project in projects:
# # Add tweaks to make file_path work
# tweaks = {"path": __file__}
# graph_data = process_tweaks(project.data, tweaks)
# graph_object = Graph.from_payload(graph_data, flow_id=project.id)
# graphs.append((project.name, graph_object))
# assert len(graphs) == len(projects)
# for name, graph in graphs:
# outputs = await graph.arun(
# inputs={},
# outputs=[],
# session_id="test",
# )
# assert all(isinstance(output, RunOutputs) for output in outputs), f"Project {name} error: {outputs}"
# delete_messages(session_id="test")

View file

@ -0,0 +1,42 @@
import pytest
from langflow.graph import Graph
from langflow.graph.schema import RunOutputs
from langflow.initial_setup.setup import load_starter_projects
from langflow.load import load_flow_from_json, run_flow_from_json
@pytest.mark.noclient
def test_load_flow_from_json():
"""Test loading a flow from a json file"""
loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH)
assert loaded is not None
assert isinstance(loaded, Graph)
@pytest.mark.noclient
def test_load_flow_from_json_with_tweaks():
"""Test loading a flow from a json file and applying tweaks"""
tweaks = {"dndnode_82": {"model_name": "gpt-3.5-turbo-16k-0613"}}
loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH, tweaks=tweaks)
assert loaded is not None
assert isinstance(loaded, Graph)
@pytest.mark.noclient
def test_load_flow_from_json_object():
"""Test loading a flow from a json file and applying tweaks"""
_, projects = zip(*load_starter_projects())
project = projects[0]
loaded = load_flow_from_json(project)
assert loaded is not None
assert isinstance(loaded, Graph)
@pytest.mark.noclient
def test_run_flow_from_json_object():
"""Test loading a flow from a json file and applying tweaks"""
_, projects = zip(*load_starter_projects())
project = [project for project in projects if "Basic Prompting" in project["name"]][0]
results = run_flow_from_json(project, input_value="test", fallback_to_env_vars=True)
assert results is not None
assert all(isinstance(result, RunOutputs) for result in results)

45
tests/unit/test_login.py Normal file
View file

@ -0,0 +1,45 @@
import pytest
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.user import User
from langflow.services.deps import session_scope
from sqlalchemy.exc import IntegrityError
@pytest.fixture
def test_user():
return User(
username="testuser",
password=get_password_hash("testpassword"), # Assuming password needs to be hashed
is_active=True,
is_superuser=False,
)
def test_login_successful(client, test_user):
# Adding the test user to the database
try:
with session_scope() as session:
session.add(test_user)
session.commit()
except IntegrityError:
pass
response = client.post("api/v1/login", data={"username": "testuser", "password": "testpassword"})
assert response.status_code == 200
assert "access_token" in response.json()
def test_login_unsuccessful_wrong_username(client):
response = client.post("api/v1/login", data={"username": "wrongusername", "password": "testpassword"})
assert response.status_code == 401
assert response.json()["detail"] == "Incorrect username or password"
def test_login_unsuccessful_wrong_password(client, test_user, session):
# Adding the test user to the database
session.add(test_user)
session.commit()
response = client.post("api/v1/login", data={"username": "testuser", "password": "wrongpassword"})
assert response.status_code == 401
assert response.json()["detail"] == "Incorrect username or password"

303
tests/unit/test_process.py Normal file
View file

@ -0,0 +1,303 @@
import pytest
from langflow.processing.process import process_tweaks
from langflow.services.deps import get_session_service
def test_no_tweaks():
graph_data = {
"data": {
"nodes": [
{
"id": "node1",
"data": {
"node": {
"template": {
"param1": {"value": 1},
"param2": {"value": 2},
}
}
},
},
{
"id": "node2",
"data": {
"node": {
"template": {
"param1": {"value": 3},
"param2": {"value": 4},
}
}
},
},
]
}
}
tweaks = {}
result = process_tweaks(graph_data, tweaks)
assert result == graph_data
def test_single_tweak():
graph_data = {
"data": {
"nodes": [
{
"id": "node1",
"data": {
"node": {
"template": {
"param1": {"value": 1, "type": "int"},
"param2": {"value": 2, "type": "int"},
}
}
},
},
{
"id": "node2",
"data": {
"node": {
"template": {
"param1": {"value": 3, "type": "int"},
"param2": {"value": 4, "type": "int"},
}
}
},
},
]
}
}
tweaks = {"node1": {"param1": 5}}
expected_result = {
"data": {
"nodes": [
{
"id": "node1",
"data": {
"node": {
"template": {
"param1": {"value": 5, "type": "int"},
"param2": {"value": 2, "type": "int"},
}
}
},
},
{
"id": "node2",
"data": {
"node": {
"template": {
"param1": {"value": 3, "type": "int"},
"param2": {"value": 4, "type": "int"},
}
}
},
},
]
}
}
result = process_tweaks(graph_data, tweaks)
assert result == expected_result
def test_multiple_tweaks():
graph_data = {
"data": {
"nodes": [
{
"id": "node1",
"data": {
"node": {
"template": {
"param1": {"value": 1, "type": "int"},
"param2": {"value": 2, "type": "int"},
}
}
},
},
{
"id": "node2",
"data": {
"node": {
"template": {
"param1": {"value": 3, "type": "int"},
"param2": {"value": 4, "type": "int"},
}
}
},
},
]
}
}
tweaks = {
"node1": {"param1": 5, "param2": 6},
"node2": {"param1": 7},
}
expected_result = {
"data": {
"nodes": [
{
"id": "node1",
"data": {
"node": {
"template": {
"param1": {"value": 5, "type": "int"},
"param2": {"value": 6, "type": "int"},
}
}
},
},
{
"id": "node2",
"data": {
"node": {
"template": {
"param1": {"value": 7, "type": "int"},
"param2": {"value": 4, "type": "int"},
}
}
},
},
]
}
}
result = process_tweaks(graph_data, tweaks)
assert result == expected_result
# Test twekas that just pass the param and value but no node id.
# This is a new feature that was added to the process_tweaks function
def test_tweak_no_node_id():
graph_data = {
"data": {
"nodes": [
{
"id": "node1",
"data": {
"node": {
"template": {
"param1": {"value": 1, "type": "int"},
"param2": {"value": 2, "type": "int"},
}
}
},
},
{
"id": "node2",
"data": {
"node": {
"template": {
"param1": {"value": 3, "type": "int"},
"param2": {"value": 4, "type": "int"},
}
}
},
},
]
}
}
tweaks = {"param1": 5}
expected_result = {
"data": {
"nodes": [
{
"id": "node1",
"data": {
"node": {
"template": {
"param1": {"value": 5, "type": "int"},
"param2": {"value": 2, "type": "int"},
}
}
},
},
{
"id": "node2",
"data": {
"node": {
"template": {
"param1": {"value": 5, "type": "int"},
"param2": {"value": 4, "type": "int"},
}
}
},
},
]
}
}
result = process_tweaks(graph_data, tweaks)
assert result == expected_result
def test_tweak_not_in_template():
graph_data = {
"data": {
"nodes": [
{
"id": "node1",
"data": {
"node": {
"template": {
"param1": {"value": 1, "type": "int"},
"param2": {"value": 2, "type": "int"},
}
}
},
},
{
"id": "node2",
"data": {
"node": {
"template": {
"param1": {"value": 3, "type": "int"},
"param2": {"value": 4, "type": "int"},
}
}
},
},
]
}
}
tweaks = {"node1": {"param3": 5}}
result = process_tweaks(graph_data, tweaks)
assert result == graph_data
@pytest.mark.asyncio
async def test_load_langchain_object_with_cached_session(client, basic_graph_data):
# Provide a non-existent session_id
session_service = get_session_service()
session_id1 = "non-existent-session-id"
graph1, artifacts1 = await session_service.load_session(session_id1, basic_graph_data)
# Use the new session_id to get the langchain_object again
graph2, artifacts2 = await session_service.load_session(session_id1, basic_graph_data)
assert graph1 == graph2
assert artifacts1 == artifacts2
@pytest.mark.asyncio
async def test_load_langchain_object_with_no_cached_session(client, basic_graph_data):
# Provide a non-existent session_id
session_service = get_session_service()
session_id1 = "non-existent-session-id"
session_id = session_service.build_key(session_id1, basic_graph_data)
graph1, artifacts1 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id")
# Clear the cache
await session_service.clear_session(session_id)
# Use the new session_id to get the graph again
graph2, artifacts2 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id")
# Since the cache was cleared, objects should be different
assert id(graph1) != id(graph2)
@pytest.mark.asyncio
async def test_load_langchain_object_without_session_id(client, basic_graph_data):
# Provide a non-existent session_id
session_service = get_session_service()
session_id1 = None
graph1, artifacts1 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id")
# Use the new session_id to get the langchain_object again
graph2, artifacts2 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id")
assert graph1 == graph2

139
tests/unit/test_record.py Normal file
View file

@ -0,0 +1,139 @@
from langchain_core.documents import Document
from langflow.schema import Record
def test_record_initialization():
record = Record(text_key="msg", data={"msg": "Hello, World!", "extra": "value"})
assert record.msg == "Hello, World!"
assert record.extra == "value"
def test_validate_data_with_extra_keys():
record = Record(dummy_key="dummy", data={"key": "value"})
assert record.data["dummy_key"] == "dummy"
assert "dummy_key" in record.data
assert record.key == "value"
def test_conversion_to_document():
record = Record(data={"text": "Sample text", "meta": "data"})
document = record.to_lc_document()
assert document.page_content == "Sample text"
assert document.metadata == {"meta": "data"}
def test_conversion_from_document():
document = Document(page_content="Doc content", metadata={"meta": "info"})
record = Record.from_document(document)
assert record.text == "Doc content"
assert record.meta == "info"
def test_add_method_for_strings():
record1 = Record(data={"text": "Hello"})
record2 = Record(data={"text": " World"})
combined = record1 + record2
assert combined.text == "Hello World"
def test_add_method_for_integers():
record1 = Record(data={"number": 5})
record2 = Record(data={"number": 10})
combined = record1 + record2
assert combined.number == 15
def test_add_method_with_non_overlapping_keys():
record1 = Record(data={"text": "Hello"})
record2 = Record(data={"number": 10})
combined = record1 + record2
assert combined.text == "Hello"
assert combined.number == 10
def test_custom_attribute_get_set_del():
record = Record()
record.custom_attr = "custom_value"
assert record.custom_attr == "custom_value"
del record.custom_attr
assert record.custom_attr == record.default_value
def test_deep_copy():
import copy
record1 = Record(data={"text": "Hello", "number": 10})
record2 = copy.deepcopy(record1)
assert record2.text == "Hello"
assert record2.number == 10
record2.text = "World"
assert record1.text == "Hello" # Ensure original is unchanged
def test_custom_attribute_setting_and_getting():
record = Record()
record.dynamic_attribute = "Dynamic Value"
assert record.dynamic_attribute == "Dynamic Value"
def test_str_and_dir_methods():
record = Record(text_key="text", data={"text": "Test Text", "key": "value"})
assert "Test Text" in str(record)
assert "key" in dir(record)
assert "data" in dir(record)
def test_dir_includes_data_keys():
record = Record(data={"text": "Hello", "new_attr": "value"})
dir_output = dir(record)
# Check for standard attributes
assert "data" in dir_output
assert "text_key" in dir_output
assert "__add__" in dir_output # Checking for a method
# Check for dynamic attributes from data
assert "text" in dir_output
assert "new_attr" in dir_output
# Optionally, verify that dynamically added attributes are listed
record.dynamic_attr = "dynamic"
assert "dynamic_attr" in dir_output or "dynamic_attr" in dir(record) # To account for the change
def test_dir_reflects_attribute_deletion():
record = Record(data={"removable": "I can be removed"})
assert "removable" in dir(record)
# Delete the attribute and check again
del record.removable
assert "removable" not in dir(record)
def test_get_text_with_text_key():
data = {"text": "Hello, World!"}
schema = Record(data=data, text_key="text", default_value="default")
result = schema.get_text()
assert result == "Hello, World!"
def test_get_text_without_text_key():
data = {"other_key": "Hello, World!"}
schema = Record(data=data, text_key="text", default_value="default")
result = schema.get_text()
assert result == "default"
def test_get_text_with_empty_data():
data = {}
schema = Record(data=data, text_key="text", default_value="default")
result = schema.get_text()
assert result == "default"
def test_get_text_with_none_data():
data = None
schema = Record(data=data, text_key="text", default_value="default")
result = schema.get_text()
assert result == "default"

View file

@ -0,0 +1,132 @@
from unittest.mock import MagicMock, patch
from langflow.services.settings.constants import (
DEFAULT_SUPERUSER,
DEFAULT_SUPERUSER_PASSWORD,
)
from langflow.services.utils import teardown_superuser
# @patch("langflow.services.deps.get_session")
# @patch("langflow.services.utils.create_super_user")
# @patch("langflow.services.deps.get_settings_service")
# # @patch("langflow.services.utils.verify_password")
# def test_setup_superuser(
# mock_get_session, mock_create_super_user, mock_get_settings_service
# ):
# # Test when AUTO_LOGIN is True
# calls = []
# mock_settings_service = Mock()
# mock_settings_service.auth_settings.AUTO_LOGIN = True
# mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
# mock_settings_service.auth_settings.SUPERUSER_PASSWORD = DEFAULT_SUPERUSER_PASSWORD
# mock_get_settings_service.return_value = mock_settings_service
# mock_session = Mock()
# mock_session.query.return_value.filter.return_value.first.return_value = (
# mock_session
# )
# # return value of get_session is a generator
# mock_get_session.return_value = iter([mock_session, mock_session, mock_session])
# setup_superuser(mock_settings_service, mock_session)
# mock_session.query.assert_called_once_with(User)
# # Set return value of filter to be None
# mock_session.query.return_value.filter.return_value.first.return_value = None
# actual_expr = mock_session.query.return_value.filter.call_args[0][0]
# expected_expr = User.username == DEFAULT_SUPERUSER
# assert str(actual_expr) == str(expected_expr)
# create_call = call(
# db=mock_session, username=DEFAULT_SUPERUSER, password=DEFAULT_SUPERUSER_PASSWORD
# )
# calls.append(create_call)
# # mock_create_super_user.assert_has_calls(calls)
# assert 1 == mock_create_super_user.call_count
# def reset_mock_credentials():
# mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
# mock_settings_service.auth_settings.SUPERUSER_PASSWORD = (
# DEFAULT_SUPERUSER_PASSWORD
# )
# ADMIN_USER_NAME = "admin_user"
# # Test when username and password are default
# mock_settings_service.auth_settings = Mock()
# mock_settings_service.auth_settings.AUTO_LOGIN = False
# mock_settings_service.auth_settings.SUPERUSER = ADMIN_USER_NAME
# mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password"
# mock_settings_service.auth_settings.reset_credentials = Mock(
# side_effect=reset_mock_credentials
# )
# mock_get_settings_service.return_value = mock_settings_service
# setup_superuser(mock_settings_service, mock_session)
# mock_session.query.assert_called_with(User)
# actual_expr = mock_session.query.return_value.filter.call_args[0][0]
# expected_expr = User.username == ADMIN_USER_NAME
# assert str(actual_expr) == str(expected_expr)
# create_call = call(db=mock_session, username=ADMIN_USER_NAME, password="password")
# calls.append(create_call)
# # mock_create_super_user.assert_has_calls(calls)
# assert 2 == mock_create_super_user.call_count
# # Test that superuser credentials are reset
# mock_settings_service.auth_settings.reset_credentials.assert_called_once()
# assert mock_settings_service.auth_settings.SUPERUSER != ADMIN_USER_NAME
# assert mock_settings_service.auth_settings.SUPERUSER_PASSWORD != "password"
# # Test when superuser already exists
# mock_settings_service.auth_settings.AUTO_LOGIN = False
# mock_settings_service.auth_settings.SUPERUSER = ADMIN_USER_NAME
# mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password"
# mock_user = Mock()
# mock_user.is_superuser = True
# mock_session.query.return_value.filter.return_value.first.return_value = mock_user
# setup_superuser(mock_settings_service, mock_session)
# mock_session.query.assert_called_with(User)
# actual_expr = mock_session.query.return_value.filter.call_args[0][0]
# expected_expr = User.username == ADMIN_USER_NAME
# assert str(actual_expr) == str(expected_expr)
@patch("langflow.services.deps.get_settings_service")
@patch("langflow.services.deps.get_session")
def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service):
mock_settings_service = MagicMock()
mock_settings_service.auth_settings.AUTO_LOGIN = True
mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = DEFAULT_SUPERUSER_PASSWORD
mock_get_settings_service.return_value = mock_settings_service
mock_session = MagicMock()
mock_user = MagicMock()
mock_user.is_superuser = True
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
mock_get_session.return_value = iter([mock_session])
teardown_superuser(mock_settings_service, mock_session)
mock_session.query.assert_not_called()
@patch("langflow.services.deps.get_settings_service")
@patch("langflow.services.deps.get_session")
def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service):
ADMIN_USER_NAME = "admin_user"
mock_settings_service = MagicMock()
mock_settings_service.auth_settings.AUTO_LOGIN = False
mock_settings_service.auth_settings.SUPERUSER = ADMIN_USER_NAME
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password"
mock_get_settings_service.return_value = mock_settings_service
mock_session = MagicMock()
mock_user = MagicMock()
mock_user.is_superuser = False
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
mock_get_session.return_value = [mock_session]
teardown_superuser(mock_settings_service, mock_session)
mock_session.query.assert_not_called()
mock_session.delete.assert_not_called()
mock_session.commit.assert_not_called()

113
tests/unit/test_template.py Normal file
View file

@ -0,0 +1,113 @@
import importlib
from typing import Dict, List, Optional
import pytest
from langflow.interface.utils import build_template_from_class
from langflow.utils.util import build_template_from_function, get_base_classes, get_default_factory
from pydantic import BaseModel
# Dummy classes for testing purposes
class Parent(BaseModel):
"""Parent Class"""
parent_field: str
class Child(Parent):
"""Child Class"""
child_field: int
class ExampleClass1(BaseModel):
"""Example class 1."""
def __init__(self, data: Optional[List[int]] = None):
self.data = data or [1, 2, 3]
class ExampleClass2(BaseModel):
"""Example class 2."""
def __init__(self, data: Optional[Dict[str, int]] = None):
self.data = data or {"a": 1, "b": 2, "c": 3}
def example_loader_1() -> ExampleClass1:
"""Example loader function 1."""
return ExampleClass1()
def example_loader_2() -> ExampleClass2:
"""Example loader function 2."""
return ExampleClass2()
def test_build_template_from_function():
type_to_loader_dict = {
"example1": example_loader_1,
"example2": example_loader_2,
}
# Test with valid name
result = build_template_from_function("ExampleClass1", type_to_loader_dict)
assert result is not None
assert "template" in result
assert "description" in result
assert "base_classes" in result
# Test with add_function=True
result_with_function = build_template_from_function("ExampleClass1", type_to_loader_dict, add_function=True)
assert result_with_function is not None
assert "Callable" in result_with_function["base_classes"]
# Test with invalid name
with pytest.raises(ValueError, match=r".* not found"):
build_template_from_function("NonExistent", type_to_loader_dict)
# Test build_template_from_class
def test_build_template_from_class():
type_to_cls_dict: Dict[str, type] = {"parent": Parent, "child": Child}
# Test valid input
result = build_template_from_class("Child", type_to_cls_dict)
assert result is not None
assert "template" in result
assert "description" in result
assert "base_classes" in result
assert "Child" in result["base_classes"]
assert "Parent" in result["base_classes"]
assert result["description"] == "Child Class"
# Test invalid input
with pytest.raises(ValueError, match="InvalidClass not found."):
build_template_from_class("InvalidClass", type_to_cls_dict)
# Test get_base_classes
def test_get_base_classes():
base_classes_parent = get_base_classes(Parent)
base_classes_child = get_base_classes(Child)
assert "Parent" in base_classes_parent
assert "Child" in base_classes_child
assert "Parent" in base_classes_child
# Test get_default_factory
def test_get_default_factory():
module_name = "langflow.utils.util"
function_repr = "<function dummy_function>"
def dummy_function():
return "default_value"
# Add dummy_function to your_module
setattr(importlib.import_module(module_name), "dummy_function", dummy_function)
default_value = get_default_factory(module_name, function_repr)
assert default_value == "default_value"

View file

@ -0,0 +1,107 @@
from unittest import mock
import pytest
from langflow.utils.validate import (
create_function,
execute_function,
extract_function_name,
validate_code,
)
from requests.exceptions import MissingSchema
def test_validate_code():
# Test case with a valid import and function
code1 = """
import math
def square(x):
return x ** 2
"""
errors1 = validate_code(code1)
assert errors1 == {"imports": {"errors": []}, "function": {"errors": []}}
# Test case with an invalid import and valid function
code2 = """
import non_existent_module
def square(x):
return x ** 2
"""
errors2 = validate_code(code2)
assert errors2 == {
"imports": {"errors": ["No module named 'non_existent_module'"]},
"function": {"errors": []},
}
# Test case with a valid import and invalid function syntax
code3 = """
import math
def square(x)
return x ** 2
"""
errors3 = validate_code(code3)
assert errors3 == {
"imports": {"errors": []},
"function": {"errors": ["expected ':' (<unknown>, line 4)"]},
}
def test_execute_function_success():
code = """
import math
def my_function(x):
return math.sin(x) + 1
"""
result = execute_function(code, "my_function", 0.5)
assert result == 1.479425538604203
def test_execute_function_missing_module():
code = """
import some_missing_module
def my_function(x):
return some_missing_module.some_function(x)
"""
with pytest.raises(ModuleNotFoundError):
execute_function(code, "my_function", 0.5)
def test_execute_function_missing_function():
code = """
import math
def my_function(x):
return math.some_missing_function(x)
"""
with pytest.raises(AttributeError):
execute_function(code, "my_function", 0.5)
def test_execute_function_missing_schema():
code = """
import requests
def my_function(x):
return requests.get(x).text
"""
with mock.patch("requests.get", side_effect=MissingSchema):
with pytest.raises(MissingSchema):
execute_function(code, "my_function", "invalid_url")
def test_create_function():
code = """
import math
def my_function(x):
return math.sin(x) + 1
"""
function_name = extract_function_name(code)
function = create_function(code, function_name)
result = function(0.5)
assert result == 1.479425538604203

View file

@ -0,0 +1,15 @@
from langflow.components import experimental
def test_python_function_component():
# Arrange
python_function_component = experimental.PythonFunctionComponent()
# Act
# function must be a string representation
function = "def function():\n return 'Hello, World!'"
# result is the callable function
result = python_function_component.build(function)
# Assert
assert result() == "Hello, World!"