Fix import statements and formatting issues

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-01-25 13:03:42 -03:00
commit a6a69a0f3a
24 changed files with 924 additions and 527 deletions

View file

@ -1,48 +1,75 @@
from contextlib import contextmanager
import json
# we need to import tmpdir
import tempfile
from contextlib import contextmanager, suppress
from contextlib import suppress
from pathlib import Path
from typing import TYPE_CHECKING, AsyncGenerator
import orjson
import pytest
from fastapi.testclient import TestClient
from httpx import AsyncClient
from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
from typing import AsyncGenerator, TYPE_CHECKING
from langflow.graph.graph.base import Graph
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.flow.model import Flow, FlowCreate
from langflow.services.database.models.user.model import User, UserCreate
from langflow.services.database.models.flow.flow import Flow, FlowCreate
from langflow.services.database.models.user.user import User, UserCreate
import orjson
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from langflow.services.getters import get_db_service
import pytest
from fastapi.testclient import TestClient
from httpx import AsyncClient
from sqlmodel import SQLModel, Session, create_engine
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
# we need to import tmpdir
import tempfile
if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService
from langflow.services.database.manager import DatabaseService
def pytest_configure():
pytest.BASIC_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "basic_example.json"
pytest.COMPLEX_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "complex_example.json"
pytest.OPENAPI_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "Openapi.json"
pytest.GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "grouped_chat.json"
pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "one_group_chat.json"
pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json"
pytest.BASIC_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "basic_example.json"
)
pytest.COMPLEX_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "complex_example.json"
)
pytest.OPENAPI_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "Openapi.json"
)
pytest.GROUPED_CHAT_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "grouped_chat.json"
)
pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "one_group_chat.json"
)
pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json"
)
pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY = (
Path(__file__).parent.absolute() / "data" / "BasicChatwithPromptandHistory.json"
Path(__file__).parent.absolute() / "data" / "BasicChatWithPromptAndHistory.json"
)
pytest.CHAT_INPUT = Path(__file__).parent.absolute() / "data" / "ChatInputTest.json"
pytest.TWO_OUTPUTS = (
Path(__file__).parent.absolute() / "data" / "TwoOutputsTest.json"
)
pytest.VECTOR_STORE_PATH = (
Path(__file__).parent.absolute() / "data" / "Vector_store.json"
)
pytest.VECTOR_STORE_PATH = Path(__file__).parent.absolute() / "data" / "Vector_store.json"
pytest.CODE_WITH_SYNTAX_ERROR = """
def get_text():
retun "Hello World"
"""
@pytest.fixture(autouse=True)
def check_openai_api_key_in_environment_variables():
import os
assert (
os.environ.get("OPENAI_API_KEY") is not None
), "OPENAI_API_KEY is not set in environment variables"
@pytest.fixture()
async def async_client() -> AsyncGenerator:
from langflow.main import create_app
@ -54,7 +81,9 @@ async def async_client() -> AsyncGenerator:
@pytest.fixture(name="session")
def session_fixture():
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
engine = create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
@ -90,7 +119,9 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
# monkeypatch langflow.services.task.manager.USE_CELERY to True
# monkeypatch.setattr(manager, "USE_CELERY", True)
monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config))
monkeypatch.setattr(
celery_app, "celery_app", celery_app.make_celery("langflow", Config)
)
# def get_session_override():
# return session
@ -225,7 +256,7 @@ def test_user(client):
username="testuser",
password="testpassword",
)
response = client.post("/api/v1/users", json=user_data.model_dump())
response = client.post("/api/v1/users", json=user_data.dict())
assert response.status_code == 201
return response.json()
@ -241,7 +272,11 @@ def active_user(client):
is_superuser=False,
)
# check if user exists
if active_user := session.query(User).filter(User.username == user.username).first():
if (
active_user := session.query(User)
.filter(User.username == user.username)
.first()
):
return active_user
session.add(user)
session.commit()
@ -261,16 +296,13 @@ def logged_in_headers(client, active_user):
@pytest.fixture
def flow(client, json_flow: str, active_user):
from langflow.services.database.models.flow.model import FlowCreate
from langflow.services.database.models.flow.flow import FlowCreate
loaded_json = json.loads(json_flow)
flow_data = FlowCreate(
name="test_flow",
data=loaded_json.get("data"),
user_id=active_user.id,
description="description",
name="test_flow", data=loaded_json.get("data"), user_id=active_user.id
)
flow = Flow.model_validate(flow_data.model_dump())
flow = Flow(**flow_data.dict())
with session_getter(get_db_service()) as session:
session.add(flow)
session.commit()
@ -280,11 +312,31 @@ def flow(client, json_flow: str, active_user):
@pytest.fixture
def added_flow(client, json_flow_with_prompt_and_history, logged_in_headers):
def json_flow_with_prompt_and_history():
with open(pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY, "r") as f:
return f.read()
@pytest.fixture
def json_chat_input():
with open(pytest.CHAT_INPUT, "r") as f:
return f.read()
@pytest.fixture
def json_two_outputs():
with open(pytest.TWO_OUTPUTS, "r") as f:
return f.read()
@pytest.fixture
def added_flow_with_prompt_and_history(
client, json_flow_with_prompt_and_history, logged_in_headers
):
flow = orjson.loads(json_flow_with_prompt_and_history)
data = flow["data"]
flow = FlowCreate(name="Basic Chat", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
@ -292,28 +344,37 @@ def added_flow(client, json_flow_with_prompt_and_history, logged_in_headers):
@pytest.fixture
def added_vector_store(client, json_vector_store, logged_in_headers):
vector_store = orjson.loads(json_vector_store)
data = vector_store["data"]
vector_store = FlowCreate(name="Vector Store", description="description", data=data)
response = client.post("api/v1/flows/", json=vector_store.model_dump(), headers=logged_in_headers)
def added_flow_chat_input(client, json_chat_input, logged_in_headers):
flow = orjson.loads(json_chat_input)
data = flow["data"]
flow = FlowCreate(name="Chat Input", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == vector_store.name
assert response.json()["data"] == vector_store.data
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
return response.json()
@pytest.fixture
def test_component_code():
path = Path(__file__).parent.absolute() / "data" / "component.py"
# load the content as a string
with open(path, "r") as f:
return f.read()
def added_flow_two_outputs(client, json_two_outputs, logged_in_headers):
flow = orjson.loads(json_two_outputs)
data = flow["data"]
flow = FlowCreate(name="Two Outputs", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
@pytest.fixture
def test_component_with_templatefield_code():
path = Path(__file__).parent.absolute() / "data" / "component_with_templatefield.py"
# load the content as a string
with open(path, "r") as f:
return f.read()
def added_vector_store(client, json_vector_store, logged_in_headers):
vector_store = orjson.loads(json_vector_store)
data = vector_store["data"]
vector_store = FlowCreate(name="Vector Store", description="description", data=data)
response = client.post(
"api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers
)
assert response.status_code == 201
assert response.json()["name"] == vector_store.name
assert response.json()["data"] == vector_store.data
return response.json()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -12,7 +12,7 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
"ZeroShotAgent",
"BaseSingleActionAgent",
"Agent",
"Callable",
"function",
}
template = zero_shot_agent["template"]
@ -28,7 +28,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
"list": True,
"advanced": False,
"info": "",
"fileTypes": [],
}
# Additional assertions for other template variables
@ -44,7 +43,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["llm"] == {
"required": True,
@ -58,7 +56,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["output_parser"] == {
"required": False,
@ -72,7 +69,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["input_variables"] == {
"required": False,
@ -86,7 +82,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
"list": True,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["prefix"] == {
"required": False,
@ -101,7 +96,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["suffix"] == {
"required": False,
@ -116,7 +110,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
@ -142,9 +135,6 @@ def test_json_agent(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"file_path": "",
"fileTypes": [],
"value": "",
}
assert template["llm"] == {
"required": True,
@ -159,9 +149,6 @@ def test_json_agent(client: TestClient, logged_in_headers):
"advanced": False,
"display_name": "LLM",
"info": "",
"file_path": "",
"fileTypes": [],
"value": "",
}
@ -182,12 +169,87 @@ def test_csv_agent(client: TestClient, logged_in_headers):
"show": True,
"multiline": False,
"value": "",
"fileTypes": [".csv"],
"suffixes": [".csv"],
"fileTypes": ["csv"],
"password": False,
"name": "path",
"type": "file",
"list": False,
"file_path": "",
"file_path": None,
"advanced": False,
"info": "",
}
assert template["llm"] == {
"required": True,
"dynamic": False,
"placeholder": "",
"show": True,
"multiline": False,
"password": False,
"name": "llm",
"type": "BaseLanguageModel",
"list": False,
"advanced": False,
"display_name": "LLM",
"info": "",
}
def test_initialize_agent(client: TestClient, logged_in_headers):
response = client.get("api/v1/all", headers=logged_in_headers)
assert response.status_code == 200
json_response = response.json()
agents = json_response["agents"]
initialize_agent = agents["AgentInitializer"]
assert initialize_agent["base_classes"] == ["AgentExecutor", "function"]
template = initialize_agent["template"]
assert template["agent"] == {
"required": True,
"dynamic": False,
"placeholder": "",
"show": True,
"multiline": False,
"value": "zero-shot-react-description",
"password": False,
"options": [
"zero-shot-react-description",
"react-docstore",
"self-ask-with-search",
"conversational-react-description",
"openai-functions",
"openai-multi-functions",
],
"name": "agent",
"type": "str",
"list": True,
"advanced": False,
"info": "",
}
assert template["memory"] == {
"required": False,
"dynamic": False,
"placeholder": "",
"show": True,
"multiline": False,
"password": False,
"name": "memory",
"type": "BaseChatMemory",
"list": False,
"advanced": False,
"info": "",
}
assert template["tools"] == {
"required": True,
"dynamic": False,
"placeholder": "",
"show": True,
"multiline": False,
"password": False,
"name": "tools",
"type": "Tool",
"list": True,
"advanced": False,
"info": "",
}
@ -204,7 +266,4 @@ def test_csv_agent(client: TestClient, logged_in_headers):
"advanced": False,
"display_name": "LLM",
"info": "",
"file_path": "",
"fileTypes": [],
"value": "",
}

View file

@ -6,7 +6,9 @@ from langflow.services.database.models.api_key import ApiKeyCreate
def api_key(client, logged_in_headers, active_user):
api_key = ApiKeyCreate(name="test-api-key")
response = client.post("api/v1/api_key", data=api_key.json(), headers=logged_in_headers)
response = client.post(
"api/v1/api_key", data=api_key.json(), headers=logged_in_headers
)
assert response.status_code == 200, response.text
return response.json()
@ -26,7 +28,9 @@ def test_get_api_keys(client, logged_in_headers, api_key):
def test_create_api_key(client, logged_in_headers):
api_key_name = "test-api-key"
response = client.post("api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers)
response = client.post(
"api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers
)
assert response.status_code == 200
data = response.json()
assert "name" in data and data["name"] == api_key_name

View file

@ -1,9 +1,8 @@
import json
from langflow.graph import Graph
import pytest
from langflow.graph import Graph
def get_graph(_type="basic"):
"""Get a graph from a json file"""

View file

@ -1,5 +1,6 @@
from fastapi.testclient import TestClient
# def test_chains_settings(client: TestClient, logged_in_headers):
# response = client.get("api/v1/all", headers=logged_in_headers)
# assert response.status_code == 200
@ -8,53 +9,21 @@ from fastapi.testclient import TestClient
# assert set(chains.keys()) == set(settings.chains)
def test_llm_checker_chain(client: TestClient, logged_in_headers):
response = client.get("api/v1/all", headers=logged_in_headers)
assert response.status_code == 200
json_response = response.json()
chains = json_response["chains"]
chain = chains["LLMCheckerChain"]
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
assert set(chain["base_classes"]) == {
"Callable",
"LLMCheckerChain",
"Chain",
}
template = chain["template"]
assert template["llm"] == {
"required": True,
"dynamic": False,
"placeholder": "",
"show": True,
"multiline": False,
"password": False,
"name": "llm",
"type": "BaseLanguageModel",
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["_type"] == "LLMCheckerChain"
# Test the description object
assert chain["description"] == ""
def test_llm_math_chain(client: TestClient, logged_in_headers):
# Test the ConversationChain object
def test_conversation_chain(client: TestClient, logged_in_headers):
response = client.get("api/v1/all", headers=logged_in_headers)
assert response.status_code == 200
json_response = response.json()
chains = json_response["chains"]
chain = chains["LLMMathChain"]
chain = chains["ConversationChain"]
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
assert set(chain["base_classes"]) == {
"Callable",
"LLMMathChain",
"ConversationChain",
"LLMChain",
"Chain",
"function",
"Text",
}
template = chain["template"]
@ -70,7 +39,195 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["verbose"] == {
"required": False,
"dynamic": False,
"placeholder": "",
"show": False,
"multiline": False,
"password": False,
"name": "verbose",
"type": "bool",
"list": False,
"advanced": True,
"info": "",
}
assert template["llm"] == {
"required": True,
"dynamic": False,
"placeholder": "",
"show": True,
"multiline": False,
"password": False,
"name": "llm",
"type": "BaseLanguageModel",
"list": False,
"advanced": False,
"info": "",
}
assert template["input_key"] == {
"required": True,
"dynamic": False,
"placeholder": "",
"show": True,
"multiline": False,
"value": "input",
"password": False,
"name": "input_key",
"type": "str",
"list": False,
"advanced": True,
"info": "",
}
assert template["output_key"] == {
"required": True,
"dynamic": False,
"placeholder": "",
"show": True,
"multiline": False,
"value": "response",
"password": False,
"name": "output_key",
"type": "str",
"list": False,
"advanced": True,
"info": "",
}
assert template["_type"] == "ConversationChain"
# Test the description object
assert (
chain["description"]
== "Chain to have a conversation and load context from memory."
)
def test_llm_chain(client: TestClient, logged_in_headers):
response = client.get("api/v1/all", headers=logged_in_headers)
assert response.status_code == 200
json_response = response.json()
chains = json_response["chains"]
chain = chains["LLMChain"]
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
assert set(chain["base_classes"]) == {"function", "LLMChain", "Chain", "Text"}
template = chain["template"]
assert template["memory"] == {
"required": False,
"dynamic": False,
"placeholder": "",
"show": True,
"multiline": False,
"password": False,
"name": "memory",
"type": "BaseMemory",
"list": False,
"advanced": False,
"info": "",
}
assert template["verbose"] == {
"required": False,
"dynamic": False,
"placeholder": "",
"show": False,
"multiline": False,
"value": False,
"password": False,
"name": "verbose",
"type": "bool",
"list": False,
"advanced": True,
"info": "",
}
assert template["llm"] == {
"required": True,
"dynamic": False,
"placeholder": "",
"show": True,
"multiline": False,
"password": False,
"name": "llm",
"type": "BaseLanguageModel",
"list": False,
"advanced": False,
"info": "",
}
assert template["output_key"] == {
"required": True,
"dynamic": False,
"placeholder": "",
"show": True,
"multiline": False,
"value": "text",
"password": False,
"name": "output_key",
"type": "str",
"list": False,
"advanced": True,
"info": "",
}
def test_llm_checker_chain(client: TestClient, logged_in_headers):
response = client.get("api/v1/all", headers=logged_in_headers)
assert response.status_code == 200
json_response = response.json()
chains = json_response["chains"]
chain = chains["LLMCheckerChain"]
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
assert set(chain["base_classes"]) == {
"function",
"LLMCheckerChain",
"Chain",
"Text",
}
template = chain["template"]
assert template["llm"] == {
"required": True,
"dynamic": False,
"placeholder": "",
"show": True,
"multiline": False,
"password": False,
"name": "llm",
"type": "BaseLanguageModel",
"list": False,
"advanced": False,
"info": "",
}
assert template["_type"] == "LLMCheckerChain"
# Test the description object
assert chain["description"] == ""
def test_llm_math_chain(client: TestClient, logged_in_headers):
response = client.get("api/v1/all", headers=logged_in_headers)
assert response.status_code == 200
json_response = response.json()
chains = json_response["chains"]
chain = chains["LLMMathChain"]
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
assert set(chain["base_classes"]) == {"function", "LLMMathChain", "Chain", "Text"}
template = chain["template"]
assert template["memory"] == {
"required": False,
"dynamic": False,
"placeholder": "",
"show": True,
"multiline": False,
"password": False,
"name": "memory",
"type": "BaseMemory",
"list": False,
"advanced": False,
"info": "",
}
assert template["verbose"] == {
"required": False,
@ -85,7 +242,6 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
"list": False,
"advanced": True,
"info": "",
"fileTypes": [],
}
assert template["llm"] == {
"required": True,
@ -99,7 +255,6 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["input_key"] == {
"required": True,
@ -114,7 +269,6 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
"list": False,
"advanced": True,
"info": "",
"fileTypes": [],
}
assert template["output_key"] == {
"required": True,
@ -129,12 +283,14 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
"list": False,
"advanced": True,
"info": "",
"fileTypes": [],
}
assert template["_type"] == "LLMMathChain"
# Test the description object
assert chain["description"] == "Chain that interprets a prompt and executes python code to do math."
assert (
chain["description"]
== "Chain that interprets a prompt and executes python code to do math."
)
def test_series_character_chain(client: TestClient, logged_in_headers):
@ -147,7 +303,7 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
assert set(chain["base_classes"]) == {
"Callable",
"function",
"LLMChain",
"BaseCustomChain",
"Chain",
@ -169,9 +325,6 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
"file_path": "",
"value": "",
}
assert template["character"] == {
"required": True,
@ -185,9 +338,6 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
"file_path": "",
"value": "",
}
assert template["series"] == {
"required": True,
@ -201,9 +351,6 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
"file_path": "",
"value": "",
}
assert template["_type"] == "SeriesCharacterChain"
@ -247,12 +394,12 @@ def test_mid_journey_prompt_chain(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"file_path": "",
"fileTypes": [],
"value": "",
}
# Test the description object
assert chain["description"] == "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."
assert (
chain["description"]
== "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."
)
def test_time_travel_guide_chain(client: TestClient, logged_in_headers):
@ -288,9 +435,6 @@ def test_time_travel_guide_chain(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"file_path": "",
"fileTypes": [],
"value": "",
}
assert template["memory"] == {
"required": False,
@ -304,9 +448,6 @@ def test_time_travel_guide_chain(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"file_path": "",
"fileTypes": [],
"value": "",
}
assert chain["description"] == "Time travel guide chain."

View file

@ -1,10 +1,9 @@
from pathlib import Path
from tempfile import tempdir
from langflow.__main__ import app
import pytest
from langflow.__main__ import app
from langflow.services import deps
from langflow.services import getters
@pytest.fixture(scope="module")
@ -27,12 +26,11 @@ def test_components_path(runner, client, default_settings):
["run", "--components-path", str(temp_dir), *default_settings],
)
assert result.exit_code == 0, result.stdout
settings_service = deps.get_settings_service()
settings_service = getters.get_settings_service()
assert str(temp_dir) in settings_service.settings.COMPONENTS_PATH
def test_superuser(runner, client, session):
result = runner.invoke(app, ["superuser"], input="admin\nadmin\n")
assert result.exit_code == 0, result.stdout
assert "Superuser creation failed." not in result.output, result.output
assert "Superuser created successfully." in result.output, result.output
assert "Superuser created successfully." in result.stdout

View file

@ -1,13 +1,19 @@
import ast
import pytest
import types
from uuid import uuid4
import pytest
from langflow.interface.custom.base import CustomComponent
from langflow.interface.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
from langflow.interface.custom.custom_component.component import Component, ComponentCodeNullError
from langflow.interface.custom.utils import build_custom_component_template
from fastapi import HTTPException
from langflow.services.database.models.flow import Flow, FlowCreate
from langflow.interface.custom.base import CustomComponent
from langflow.interface.custom.component import (
Component,
ComponentCodeNullError,
ComponentFunctionEntrypointNameNullError,
)
from langflow.interface.custom.code_parser import CodeParser, CodeSyntaxError
code_default = """
from langflow import Prompt
@ -47,7 +53,7 @@ def test_code_parser_get_tree():
Test the __get_tree method of the CodeParser class.
"""
parser = CodeParser(code_default)
tree = parser.get_tree()
tree = parser._CodeParser__get_tree()
assert isinstance(tree, ast.AST)
@ -60,23 +66,23 @@ def test_code_parser_syntax_error():
parser = CodeParser(code_syntax_error)
with pytest.raises(CodeSyntaxError):
parser.get_tree()
parser._CodeParser__get_tree()
def test_component_init():
"""
Test the initialization of the Component class.
"""
component = Component(code=code_default, _function_entrypoint_name="build")
component = Component(code=code_default, function_entrypoint_name="build")
assert component.code == code_default
assert component._function_entrypoint_name == "build"
assert component.function_entrypoint_name == "build"
def test_component_get_code_tree():
"""
Test the get_code_tree method of the Component class.
"""
component = Component(code=code_default, _function_entrypoint_name="build")
component = Component(code=code_default, function_entrypoint_name="build")
tree = component.get_code_tree(component.code)
assert "imports" in tree
@ -86,20 +92,19 @@ def test_component_code_null_error():
Test the get_function method raises the
ComponentCodeNullError when the code is empty.
"""
component = Component(code="", _function_entrypoint_name="")
component = Component(code="", function_entrypoint_name="")
with pytest.raises(ComponentCodeNullError):
component.get_function()
# TODO: Validate if we should remove this
# def test_component_function_entrypoint_name_null_error():
# """
# Test the get_function method raises the ComponentFunctionEntrypointNameNullError
# when the function_entrypoint_name is empty.
# """
# component = Component(code=code_default, _function_entrypoint_name="")
# with pytest.raises(ComponentFunctionEntrypointNameNullError):
# component.get_function()
def test_component_function_entrypoint_name_null_error():
"""
Test the get_function method raises the ComponentFunctionEntrypointNameNullError
when the function_entrypoint_name is empty.
"""
component = Component(code=code_default, function_entrypoint_name="")
with pytest.raises(ComponentFunctionEntrypointNameNullError):
component.get_function()
def test_custom_component_init():
@ -108,7 +113,9 @@ def test_custom_component_init():
"""
function_entrypoint_name = "build"
custom_component = CustomComponent(code=code_default, function_entrypoint_name=function_entrypoint_name)
custom_component = CustomComponent(
code=code_default, function_entrypoint_name=function_entrypoint_name
)
assert custom_component.code == code_default
assert custom_component.function_entrypoint_name == function_entrypoint_name
@ -117,8 +124,10 @@ def test_custom_component_build_template_config():
"""
Test the build_template_config property of the CustomComponent class.
"""
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
config = custom_component.template_config
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
config = custom_component.build_template_config
assert isinstance(config, dict)
@ -126,7 +135,9 @@ def test_custom_component_get_function():
"""
Test the get_function property of the CustomComponent class.
"""
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
)
my_function = custom_component.get_function
assert isinstance(my_function, types.FunctionType)
@ -137,7 +148,7 @@ def test_code_parser_parse_imports_import():
class with an import statement.
"""
parser = CodeParser(code_default)
tree = parser.get_tree()
tree = parser._CodeParser__get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
parser.parse_imports(node)
@ -150,7 +161,7 @@ def test_code_parser_parse_imports_importfrom():
class with an import from statement.
"""
parser = CodeParser("from os import path")
tree = parser.get_tree()
tree = parser._CodeParser__get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom):
parser.parse_imports(node)
@ -162,7 +173,7 @@ def test_code_parser_parse_functions():
Test the parse_functions method of the CodeParser class.
"""
parser = CodeParser("def test(): pass")
tree = parser.get_tree()
tree = parser._CodeParser__get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
parser.parse_functions(node)
@ -175,7 +186,7 @@ def test_code_parser_parse_classes():
Test the parse_classes method of the CodeParser class.
"""
parser = CodeParser("class Test: pass")
tree = parser.get_tree()
tree = parser._CodeParser__get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
parser.parse_classes(node)
@ -188,7 +199,7 @@ def test_code_parser_parse_global_vars():
Test the parse_global_vars method of the CodeParser class.
"""
parser = CodeParser("x = 1")
tree = parser.get_tree()
tree = parser._CodeParser__get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
parser.parse_global_vars(node)
@ -201,7 +212,7 @@ def test_component_get_function_valid():
Test the get_function method of the Component
class with valid code and function_entrypoint_name.
"""
component = Component(code="def build(): pass", _function_entrypoint_name="build")
component = Component(code="def build(): pass", function_entrypoint_name="build")
my_function = component.get_function()
assert callable(my_function)
@ -211,7 +222,9 @@ def test_custom_component_get_function_entrypoint_args():
Test the get_function_entrypoint_args
property of the CustomComponent class.
"""
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
args = custom_component.get_function_entrypoint_args
assert len(args) == 4
assert args[0]["name"] == "self"
@ -224,18 +237,20 @@ def test_custom_component_get_function_entrypoint_return_type():
Test the get_function_entrypoint_return_type
property of the CustomComponent class.
"""
from langchain.schema import Document
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
return_type = custom_component.get_function_entrypoint_return_type
assert return_type == [Document]
assert return_type == ["Document"]
def test_custom_component_get_main_class_name():
"""
Test the get_main_class_name property of the CustomComponent class.
"""
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
class_name = custom_component.get_main_class_name
assert class_name == "YourComponent"
@ -245,7 +260,9 @@ def test_custom_component_get_function_valid():
Test the get_function property of the CustomComponent
class with valid code and function_entrypoint_name.
"""
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
)
my_function = custom_component.get_function
assert callable(my_function)
@ -280,7 +297,9 @@ def test_code_parser_parse_callable_details_no_args():
parser = CodeParser("")
node = ast.FunctionDef(
name="test",
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
args=ast.arguments(
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
),
body=[],
decorator_list=[],
returns=None,
@ -309,7 +328,7 @@ def test_code_parser_parse_ann_assign():
stmt = ast.AnnAssign(
target=ast.Name(id="x", ctx=ast.Store()),
annotation=ast.Name(id="int", ctx=ast.Load()),
value=ast.Constant(n=1),
value=ast.Num(n=1),
simple=1,
)
result = parser.parse_ann_assign(stmt)
@ -326,7 +345,9 @@ def test_code_parser_parse_function_def_not_init():
parser = CodeParser("")
stmt = ast.FunctionDef(
name="test",
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
args=ast.arguments(
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
),
body=[],
decorator_list=[],
returns=None,
@ -344,7 +365,9 @@ def test_code_parser_parse_function_def_init():
parser = CodeParser("")
stmt = ast.FunctionDef(
name="__init__",
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
args=ast.arguments(
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
),
body=[],
decorator_list=[],
returns=None,
@ -359,17 +382,29 @@ def test_component_get_code_tree_syntax_error():
Test the get_code_tree method of the Component class
raises the CodeSyntaxError when given incorrect syntax.
"""
component = Component(code="import os as", _function_entrypoint_name="build")
component = Component(code="import os as", function_entrypoint_name="build")
with pytest.raises(CodeSyntaxError):
component.get_code_tree(component.code)
def test_custom_component_class_template_validation_no_code():
"""
Test the _class_template_validation method of the CustomComponent class
raises the HTTPException when the code is None.
"""
custom_component = CustomComponent(code=None, function_entrypoint_name="build")
with pytest.raises(HTTPException):
custom_component._class_template_validation(custom_component.code)
def test_custom_component_get_code_tree_syntax_error():
"""
Test the get_code_tree method of the CustomComponent class
raises the CodeSyntaxError when given incorrect syntax.
"""
custom_component = CustomComponent(code="import os as", function_entrypoint_name="build")
custom_component = CustomComponent(
code="import os as", function_entrypoint_name="build"
)
with pytest.raises(CodeSyntaxError):
custom_component.get_code_tree(custom_component.code)
@ -423,7 +458,9 @@ def test_custom_component_build_not_implemented():
Test the build method of the CustomComponent
class raises the NotImplementedError.
"""
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
)
with pytest.raises(NotImplementedError):
custom_component.build()
@ -431,7 +468,7 @@ def test_custom_component_build_not_implemented():
def test_build_config_no_code():
component = CustomComponent(code=None)
assert component.get_function_entrypoint_args == []
assert component.get_function_entrypoint_args == ""
assert component.get_function_entrypoint_return_type == []
@ -457,7 +494,9 @@ def test_flow(db):
}
# Create flow
flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data)
flow = FlowCreate(
id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data
)
# Add to database
db.add(flow)
@ -518,36 +557,3 @@ def test_build_config_field_value_keys(component):
config = component.build_config()
field_values = config["fields"].values()
assert all("type" in value for value in field_values)
def test_create_and_validate_component_valid_code(test_component_code):
component = CustomComponent(code=test_component_code)
assert isinstance(component, CustomComponent)
def test_build_langchain_template_custom_component_valid_code(test_component_code):
component = CustomComponent(code=test_component_code)
frontend_node = build_custom_component_template(component)
assert isinstance(frontend_node, dict)
template = frontend_node["template"]
assert isinstance(template, dict)
assert "param" in template
param_options = template["param"]["options"]
# Now run it again with an update field
frontend_node = build_custom_component_template(component, update_field="param")
new_param_options = frontend_node["template"]["param"]["options"]
assert param_options != new_param_options
def test_build_langchain_template_custom_component_templatefield(test_component_with_templatefield_code):
component = CustomComponent(code=test_component_with_templatefield_code)
frontend_node = build_custom_component_template(component)
assert isinstance(frontend_node, dict)
template = frontend_node["template"]
assert isinstance(template, dict)
assert "param" in template
param_options = template["param"]["options"]
# Now run it again with an update field
frontend_node = build_custom_component_template(component, update_field="param")
new_param_options = frontend_node["template"]["param"]["options"]
assert param_options != new_param_options

View file

@ -18,7 +18,9 @@ def test_python_function_tool():
with pytest.raises(SyntaxError):
code = pytest.CODE_WITH_SYNTAX_ERROR
func = get_function(code)
func = PythonFunctionTool(name="Test", description="Testing", code=code, func=func)
func = PythonFunctionTool(
name="Test", description="Testing", code=code, func=func
)
def test_python_function():

View file

@ -1,15 +1,17 @@
from uuid import UUID, uuid4
from langflow.services.database.models.base import orjson_dumps
from langflow.services.database.utils import session_getter
from langflow.services.getters import get_db_service
import orjson
import pytest
from fastapi.testclient import TestClient
from langflow.api.v1.schemas import FlowListCreate
from langflow.services.database.models.base import orjson_dumps
from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from uuid import UUID, uuid4
from sqlmodel import Session
from fastapi.testclient import TestClient
from langflow.api.v1.schemas import FlowListCreate
from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
@pytest.fixture(scope="module")
def json_style():
@ -25,17 +27,21 @@ def json_style():
)
def test_create_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
def test_create_flow(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
flow = orjson.loads(json_flow)
data = flow["data"]
flow = FlowCreate(name="Test Flow", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
# flow is optional so we can create a flow without a flow
flow = FlowCreate(name="Test Flow", description="description")
response = client.post("api/v1/flows/", json=flow.model_dump(exclude_unset=True), headers=logged_in_headers)
flow = FlowCreate(name="Test Flow")
response = client.post(
"api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers
)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
@ -45,13 +51,13 @@ def test_read_flows(client: TestClient, json_flow: str, active_user, logged_in_h
flow_data = orjson.loads(json_flow)
data = flow_data["data"]
flow = FlowCreate(name="Test Flow", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
flow = FlowCreate(name="Test Flow", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
@ -65,7 +71,7 @@ def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_he
flow = orjson.loads(json_flow)
data = flow["data"]
flow = FlowCreate(name="Test Flow", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
flow_id = response.json()["id"] # flow_id should be a UUID but is a string
# turn it into a UUID
flow_id = UUID(flow_id)
@ -76,12 +82,14 @@ def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_he
assert response.json()["data"] == flow.data
def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
def test_update_flow(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
flow = orjson.loads(json_flow)
data = flow["data"]
flow = FlowCreate(name="Test Flow", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
flow_id = response.json()["id"]
updated_flow = FlowUpdate(
@ -89,7 +97,9 @@ def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_
description="updated description",
data=data,
)
response = client.patch(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
response = client.patch(
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
)
assert response.status_code == 200
assert response.json()["name"] == updated_flow.name
@ -97,18 +107,22 @@ def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_
# assert response.json()["data"] == updated_flow.data
def test_delete_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
def test_delete_flow(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
flow = orjson.loads(json_flow)
data = flow["data"]
flow = FlowCreate(name="Test Flow", description="description", data=data)
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
flow_id = response.json()["id"]
response = client.delete(f"api/v1/flows/{flow_id}", headers=logged_in_headers)
assert response.status_code == 200
assert response.json()["message"] == "Flow deleted successfully"
def test_create_flows(client: TestClient, session: Session, json_flow: str, logged_in_headers):
def test_create_flows(
client: TestClient, session: Session, json_flow: str, logged_in_headers
):
flow = orjson.loads(json_flow)
data = flow["data"]
# Create test data
@ -119,7 +133,9 @@ def test_create_flows(client: TestClient, session: Session, json_flow: str, logg
]
)
# Make request to endpoint
response = client.post("api/v1/flows/batch/", json=flow_list.model_dump(), headers=logged_in_headers)
response = client.post(
"api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers
)
# Check response status code
assert response.status_code == 201
# Check response data
@ -133,7 +149,9 @@ def test_create_flows(client: TestClient, session: Session, json_flow: str, logg
assert response_data[1]["data"] == data
def test_upload_file(client: TestClient, session: Session, json_flow: str, logged_in_headers):
def test_upload_file(
client: TestClient, session: Session, json_flow: str, logged_in_headers
):
flow = orjson.loads(json_flow)
data = flow["data"]
# Create test data
@ -143,7 +161,7 @@ def test_upload_file(client: TestClient, session: Session, json_flow: str, logge
FlowCreate(name="Flow 2", description="description", data=data),
]
)
file_contents = orjson_dumps(flow_list.model_dump())
file_contents = orjson_dumps(flow_list.dict())
response = client.post(
"api/v1/flows/upload/",
files={"file": ("examples.json", file_contents, "application/json")},
@ -182,7 +200,7 @@ def test_download_file(
with session_getter(db_manager) as session:
for flow in flow_list.flows:
flow.user_id = active_user.id
db_flow = Flow.model_validate(flow, from_attributes=True)
db_flow = Flow.from_orm(flow)
session.add(db_flow)
session.commit()
# Make request to endpoint
@ -200,7 +218,9 @@ def test_download_file(
assert response_data[1]["data"] == data
def test_create_flow_with_invalid_data(client: TestClient, active_user, logged_in_headers):
def test_create_flow_with_invalid_data(
client: TestClient, active_user, logged_in_headers
):
flow = {"name": "a" * 256, "data": "Invalid flow data"}
response = client.post("api/v1/flows/", json=flow, headers=logged_in_headers)
assert response.status_code == 422
@ -212,19 +232,29 @@ def test_get_nonexistent_flow(client: TestClient, active_user, logged_in_headers
assert response.status_code == 404
def test_update_flow_idempotency(client: TestClient, json_flow: str, active_user, logged_in_headers):
def test_update_flow_idempotency(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
flow_data = orjson.loads(json_flow)
data = flow_data["data"]
flow_data = FlowCreate(name="Test Flow", description="description", data=data)
response = client.post("api/v1/flows/", json=flow_data.model_dump(), headers=logged_in_headers)
response = client.post(
"api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers
)
flow_id = response.json()["id"]
updated_flow = FlowCreate(name="Updated Flow", description="description", data=data)
response1 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
response2 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
response1 = client.put(
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
)
response2 = client.put(
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
)
assert response1.json() == response2.json()
def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
def test_update_nonexistent_flow(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
flow_data = orjson.loads(json_flow)
data = flow_data["data"]
uuid = uuid4()
@ -233,7 +263,9 @@ def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user
description="description",
data=data,
)
response = client.patch(f"api/v1/flows/{uuid}", json=updated_flow.model_dump(), headers=logged_in_headers)
response = client.patch(
f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers
)
assert response.status_code == 404

View file

@ -1,17 +1,18 @@
import time
import uuid
from collections import namedtuple
import uuid
from langflow.processing.process import Result
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.api_key.api_key import ApiKey
from langflow.services.getters import get_settings_service
from langflow.services.database.utils import session_getter
from langflow.services.getters import get_db_service
import pytest
from fastapi.testclient import TestClient
from langflow.interface.tools.constants import CUSTOM_TOOLS
from langflow.processing.process import Result
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service, get_settings_service
from langflow.template.frontend_node.chains import TimeTravelGuideChainNode
import time
def run_post(client, flow_id, headers, post_data):
response = client.post(
@ -24,13 +25,16 @@ def run_post(client, flow_id, headers, post_data):
# Helper function to poll task status
def poll_task_status(client, headers, href, max_attempts=20, sleep_time=2):
def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1):
for _ in range(max_attempts):
task_status_response = client.get(
href,
headers=headers,
)
if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS":
if (
task_status_response.status_code == 200
and task_status_response.json()["status"] == "SUCCESS"
):
return task_status_response.json()
time.sleep(sleep_time)
return None # Return None if task did not complete in time
@ -72,14 +76,11 @@ PROMPT_REQUEST = {
"text-ada-001",
],
"ChatOpenAI": [
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-16k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-32k",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
],
"Anthropic": [
"claude-v1",
@ -126,7 +127,11 @@ def created_api_key(active_user):
)
db_manager = get_db_service()
with session_getter(db_manager) as session:
if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first():
if (
existing_api_key := session.query(ApiKey)
.filter(ApiKey.api_key == api_key.api_key)
.first()
):
return existing_api_key
session.add(api_key)
session.commit()
@ -185,7 +190,9 @@ def test_process_flow_invalid_id(client, monkeypatch, created_api_key):
}
invalid_id = uuid.uuid4()
response = client.post(f"api/v1/process/{invalid_id}", headers=headers, json=post_data)
response = client.post(
f"api/v1/process/{invalid_id}", headers=headers, json=post_data
)
assert response.status_code == 404
assert f"Flow {invalid_id} not found" in response.json()["detail"]
@ -226,7 +233,9 @@ def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_k
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
monkeypatch.setattr(endpoints, "process_graph_cached_task", mock_process_graph_cached_task)
monkeypatch.setattr(
endpoints, "process_graph_cached_task", mock_process_graph_cached_task
)
api_key = created_api_key.api_key
headers = {"x-api-key": api_key}
@ -410,6 +419,105 @@ def test_various_prompts(client, prompt, expected_input_variables):
assert response.json()["input_variables"] == expected_input_variables
def test_get_vertices_flow_not_found(client, logged_in_headers):
response = client.get(
"/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers
)
assert (
response.status_code == 500
) # Or whatever status code you've set for invalid ID
def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_headers):
flow_id = added_flow_with_prompt_and_history["id"]
response = client.get(
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
)
assert response.status_code == 200
assert "ids" in response.json()
# The response should contain the list in this order
# ['ConversationBufferMemory-Lu2Nb', 'PromptTemplate-5Q0W8', 'ChatOpenAI-vy7fV', 'LLMChain-UjBh1']
# The important part is before the - (ConversationBufferMemory, PromptTemplate, ChatOpenAI, LLMChain)
ids = [id.split("-")[0] for id in response.json()["ids"]]
assert ids == [
"ConversationBufferMemory",
"PromptTemplate",
"ChatOpenAI",
"LLMChain",
]
def test_build_vertex_invalid_flow_id(client, logged_in_headers):
response = client.post(
"/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers
)
assert response.status_code == 500
def test_build_vertex_invalid_vertex_id(
client, added_flow_with_prompt_and_history, logged_in_headers
):
flow_id = added_flow_with_prompt_and_history["id"]
response = client.post(
f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers
)
assert response.status_code == 500
def test_build_all_vertices_in_sequence_with_chat_input(
client, added_flow_chat_input, logged_in_headers
):
flow_id = added_flow_chat_input["id"]
# First, get all the vertices in the correct sequence
response = client.get(
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
)
assert response.status_code == 200
assert "ids" in response.json()
vertex_ids = response.json()["ids"]
# Now, iterate through each vertex and build it
for vertex_id in vertex_ids:
response = client.post(
f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers
)
json_response = response.json()
assert (
response.status_code == 200
), f"Failed at vertex {vertex_id}: {json_response}"
assert "valid" in json_response
assert json_response["valid"], json_response["params"]
def test_build_all_vertices_in_sequence_with_two_outputs(
client, added_flow_two_outputs, logged_in_headers
):
"""This tests the case where a node has two outputs, one of which is Text and the other (in this case) is
a LLMChain. We need to make sure the correct output is passed in both cases."""
flow_id = added_flow_two_outputs["id"]
# First, get all the vertices in the correct sequence
response = client.get(
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
)
assert response.status_code == 200
assert "ids" in response.json()
vertex_ids = response.json()["ids"]
# Now, iterate through each vertex and build it
for vertex_id in vertex_ids:
response = client.post(
f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers
)
json_response = response.json()
assert (
response.status_code == 200
), f"Failed at vertex {vertex_id}: {json_response}"
assert "valid" in json_response
assert json_response["valid"], json_response["params"]
def test_basic_chat_in_process(client, added_flow, created_api_key):
# Run the /api/v1/process/{flow_id} endpoint
headers = {"x-api-key": created_api_key.api_key}
@ -498,7 +606,9 @@ def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_a
@pytest.mark.async_test
def test_vector_store_in_process(distributed_client, added_vector_store, created_api_key):
def test_vector_store_in_process(
distributed_client, added_vector_store, created_api_key
):
# Run the /api/v1/process/{flow_id} endpoint
headers = {"x-api-key": created_api_key.api_key}
post_data = {"inputs": {"input": "What is Langflow?"}}
@ -549,7 +659,9 @@ def test_async_task_processing(distributed_client, added_flow, created_api_key):
# Test function without loop
@pytest.mark.async_test
def test_async_task_processing_vector_store(client, added_vector_store, created_api_key):
def test_async_task_processing_vector_store(
client, added_vector_store, created_api_key
):
headers = {"x-api-key": created_api_key.api_key}
post_data = {"inputs": {"input": "How do I upload examples?"}}
@ -578,4 +690,6 @@ def test_async_task_processing_vector_store(client, added_vector_store, created_
# Validate that the task completed successfully and the result is as expected
assert "result" in task_status_json, task_status_json
assert "output" in task_status_json["result"], task_status_json["result"]
assert "Langflow" in task_status_json["result"]["output"], task_status_json["result"]
assert "Langflow" in task_status_json["result"]["output"], task_status_json[
"result"
]

View file

@ -31,14 +31,17 @@ def test_template_field_defaults(sample_template_field: TemplateField):
assert sample_template_field.is_list is False
assert sample_template_field.show is True
assert sample_template_field.multiline is False
assert sample_template_field.value == ""
assert sample_template_field.value is None
assert sample_template_field.suffixes == []
assert sample_template_field.file_types == []
assert sample_template_field.file_path == ""
assert sample_template_field.file_path is None
assert sample_template_field.password is False
assert sample_template_field.name == "test_field"
def test_template_to_dict(sample_template: Template, sample_template_field: TemplateField):
def test_template_to_dict(
sample_template: Template, sample_template_field: TemplateField
):
template_dict = sample_template.to_dict()
assert template_dict["_type"] == "test_template"
assert len(template_dict) == 2 # _type and test_field

View file

@ -1,31 +1,32 @@
import copy
import json
import os
import pickle
from pathlib import Path
import pickle
from typing import Type, Union
import pytest
from langflow.graph.edge.base import Edge
from langflow.graph.vertex.base import Vertex
from langchain.agents import AgentExecutor
import pytest
from langchain.chains.base import Chain
from langchain.llms.fake import FakeListLLM
from langflow.graph import Graph
from langflow.graph.edge.base import Edge
from langflow.graph.vertex.types import (
FileToolVertex,
LLMVertex,
ToolkitVertex,
)
from langflow.processing.process import get_result_and_thought
from langflow.utils.payload import get_root_node
from langflow.graph.graph.utils import (
find_last_node,
process_flow,
set_new_target_handle,
ungroup_node,
process_flow,
update_source_handle,
update_target_handle,
update_template,
)
from langflow.graph.utils import UnbuiltObject
from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.types import FileToolVertex, LLMVertex, ToolkitVertex
from langflow.processing.process import get_result_and_thought
from langflow.utils.payload import get_root_vertex
# Test cases for the graph module
@ -46,7 +47,13 @@ def sample_nodes():
return [
{
"id": "node1",
"data": {"node": {"template": {"some_field": {"show": True, "advanced": False, "name": "Name1"}}}},
"data": {
"node": {
"template": {
"some_field": {"show": True, "advanced": False, "name": "Name1"}
}
}
},
},
{
"id": "node2",
@ -64,7 +71,11 @@ def sample_nodes():
},
{
"id": "node3",
"data": {"node": {"template": {"unrelated_field": {"show": True, "advanced": True}}}},
"data": {
"node": {
"template": {"unrelated_field": {"show": True, "advanced": True}}
}
},
},
]
@ -82,8 +93,8 @@ def test_graph_structure(basic_graph):
assert isinstance(node, Vertex)
for edge in basic_graph.edges:
assert isinstance(edge, Edge)
assert edge.source_id in basic_graph.vertex_map.keys()
assert edge.target_id in basic_graph.vertex_map.keys()
assert edge.source in basic_graph.vertices
assert edge.target in basic_graph.vertices
def test_circular_dependencies(basic_graph):
@ -91,7 +102,7 @@ def test_circular_dependencies(basic_graph):
def check_circular(node, visited):
visited.add(node)
neighbors = basic_graph.get_vertices_with_target(node)
neighbors = basic_graph.get_nodes_with_target(node)
for neighbor in neighbors:
if neighbor in visited:
return True
@ -124,13 +135,13 @@ def test_invalid_node_types():
Graph(graph_data["nodes"], graph_data["edges"])
def test_get_vertices_with_target(basic_graph):
def test_get_nodes_with_target(basic_graph):
"""Test getting connected nodes"""
assert isinstance(basic_graph, Graph)
# Get root node
root = get_root_vertex(basic_graph)
root = get_root_node(basic_graph)
assert root is not None
connected_nodes = basic_graph.get_vertices_with_target(root.id)
connected_nodes = basic_graph.get_nodes_with_target(root)
assert connected_nodes is not None
@ -139,66 +150,23 @@ def test_get_node_neighbors_basic(basic_graph):
assert isinstance(basic_graph, Graph)
# Get root node
root = get_root_vertex(basic_graph)
root = get_root_node(basic_graph)
assert root is not None
neighbors = basic_graph.get_vertex_neighbors(root)
neighbors = basic_graph.get_node_neighbors(root)
assert neighbors is not None
assert isinstance(neighbors, dict)
# Root Node is an Agent, it requires an LLMChain and tools
# We need to check if there is a Chain in the one of the neighbors'
# data attribute in the type key
assert any("ConversationBufferMemory" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
assert any(
"ConversationBufferMemory" in neighbor.data["type"]
for neighbor, val in neighbors.items()
if val
)
assert any("OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
# def test_get_node_neighbors_complex(complex_graph):
# """Test getting node neighbors"""
# assert isinstance(complex_graph, Graph)
# # Get root node
# root = get_root_node(complex_graph)
# assert root is not None
# neighbors = complex_graph.get_nodes_with_target(root)
# assert neighbors is not None
# # Neighbors should be a list of nodes
# assert isinstance(neighbors, list)
# # Root Node is an Agent, it requires an LLMChain and tools
# # We need to check if there is a Chain in the one of the neighbors'
# assert any("Chain" in neighbor.data["type"] for neighbor in neighbors)
# # assert Tool is in the neighbors
# assert any("Tool" in neighbor.data["type"] for neighbor in neighbors)
# # Now on to the Chain's neighbors
# chain = next(neighbor for neighbor in neighbors if "Chain" in neighbor.data["type"])
# chain_neighbors = complex_graph.get_nodes_with_target(chain)
# assert chain_neighbors is not None
# # Check if there is a LLM in the chain's neighbors
# assert any("OpenAI" in neighbor.data["type"] for neighbor in chain_neighbors)
# # Chain should have a Prompt as a neighbor
# assert any("Prompt" in neighbor.data["type"] for neighbor in chain_neighbors)
# # Now on to the Tool's neighbors
# tool = next(neighbor for neighbor in neighbors if "Tool" in neighbor.data["type"])
# tool_neighbors = complex_graph.get_nodes_with_target(tool)
# assert tool_neighbors is not None
# # Check if there is an Agent in the tool's neighbors
# assert any("Agent" in neighbor.data["type"] for neighbor in tool_neighbors)
# # This Agent has a Tool that has a PythonFunction as func
# agent = next(
# neighbor for neighbor in tool_neighbors if "Agent" in neighbor.data["type"]
# )
# agent_neighbors = complex_graph.get_nodes_with_target(agent)
# assert agent_neighbors is not None
# # Check if there is a Tool in the agent's neighbors
# assert any("Tool" in neighbor.data["type"] for neighbor in agent_neighbors)
# # This Tool has a PythonFunction as func
# tool = next(
# neighbor for neighbor in agent_neighbors if "Tool" in neighbor.data["type"]
# )
# tool_neighbors = complex_graph.get_nodes_with_target(tool)
# assert tool_neighbors is not None
# # Check if there is a PythonFunction in the tool's neighbors
# assert any(
# "PythonFunctionTool" in neighbor.data["type"] for neighbor in tool_neighbors
# )
assert any(
"OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val
)
def test_get_node(basic_graph):
@ -212,7 +180,7 @@ def test_get_node(basic_graph):
def test_build_nodes(basic_graph):
"""Test building nodes"""
assert len(basic_graph.vertices) == len(basic_graph._vertices)
assert len(basic_graph.vertices) == len(basic_graph._nodes)
for node in basic_graph.vertices:
assert isinstance(node, Vertex)
@ -222,21 +190,20 @@ def test_build_edges(basic_graph):
assert len(basic_graph.edges) == len(basic_graph._edges)
for edge in basic_graph.edges:
assert isinstance(edge, Edge)
assert isinstance(edge.source_id, str)
assert isinstance(edge.target_id, str)
assert isinstance(edge.source, Vertex)
assert isinstance(edge.target, Vertex)
def test_get_root_vertex(client, basic_graph, complex_graph):
def test_get_root_node(client, basic_graph, complex_graph):
"""Test getting root node"""
assert isinstance(basic_graph, Graph)
root = get_root_vertex(basic_graph)
root = get_root_node(basic_graph)
assert root is not None
assert isinstance(root, Vertex)
assert root.data["type"] == "TimeTravelGuideChain"
# For complex example, the root node is a ZeroShotAgent too
assert isinstance(complex_graph, Graph)
root = get_root_vertex(complex_graph)
root = get_root_node(complex_graph)
assert root is not None
assert isinstance(root, Vertex)
assert root.data["type"] == "ZeroShotAgent"
@ -272,7 +239,7 @@ def test_build_params(basic_graph):
# The matched_type attribute should be in the source_types attr
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
# Get the root node
root = get_root_vertex(basic_graph)
root = get_root_node(basic_graph)
# Root node is a TimeTravelGuideChain
# which requires an llm and memory
assert root is not None
@ -281,32 +248,29 @@ def test_build_params(basic_graph):
assert "memory" in root.params
@pytest.mark.asyncio
async def test_build(basic_graph):
def test_build(basic_graph):
"""Test Node's build method"""
await assert_agent_was_built(basic_graph)
assert_agent_was_built(basic_graph)
async def assert_agent_was_built(graph):
def assert_agent_was_built(graph):
"""Assert that the agent was built"""
assert isinstance(graph, Graph)
# Now we test the build method
# Build the Agent
result = await graph.build()
result = graph.build()
# The agent should be a AgentExecutor
assert isinstance(result, Chain)
@pytest.mark.asyncio
async def test_llm_node_build(basic_graph):
def test_llm_node_build(basic_graph):
llm_node = get_node_by_type(basic_graph, LLMVertex)
assert llm_node is not None
built_object = await llm_node.build()
assert built_object is not UnbuiltObject()
built_object = llm_node.build()
assert built_object is not None
@pytest.mark.asyncio
async def test_toolkit_node_build(client, openapi_graph):
def test_toolkit_node_build(client, openapi_graph):
# Write a file to the disk
file_path = "api-with-examples.yaml"
with open(file_path, "w") as f:
@ -314,31 +278,36 @@ async def test_toolkit_node_build(client, openapi_graph):
toolkit_node = get_node_by_type(openapi_graph, ToolkitVertex)
assert toolkit_node is not None
built_object = await toolkit_node.build()
assert built_object is not UnbuiltObject
built_object = toolkit_node.build()
assert built_object is not None
# Remove the file
os.remove(file_path)
assert not Path(file_path).exists()
@pytest.mark.asyncio
async def test_file_tool_node_build(client, openapi_graph):
def test_file_tool_node_build(client, openapi_graph):
file_path = "api-with-examples.yaml"
with open(file_path, "w") as f:
f.write("openapi: 3.0.0")
assert Path(file_path).exists()
file_tool_node = get_node_by_type(openapi_graph, FileToolVertex)
assert file_tool_node is not UnbuiltObject and file_tool_node is not None
built_object = await file_tool_node.build()
assert built_object is not UnbuiltObject
assert file_tool_node is not None
built_object = file_tool_node.build()
assert built_object is not None
# Remove the file
os.remove(file_path)
assert not Path(file_path).exists()
@pytest.mark.asyncio
async def test_get_result_and_thought(basic_graph):
# def test_wrapper_node_build(openapi_graph):
# wrapper_node = get_node_by_type(openapi_graph, WrapperVertex)
# assert wrapper_node is not None
# built_object = wrapper_node.build()
# assert built_object is not None
def test_get_result_and_thought(basic_graph):
"""Test the get_result_and_thought method"""
responses = [
"Final Answer: I am a response",
@ -350,7 +319,7 @@ async def test_get_result_and_thought(basic_graph):
assert llm_node is not None
llm_node._built_object = FakeListLLM(responses=responses)
llm_node._built = True
langchain_object = await basic_graph.build()
langchain_object = basic_graph.build()
# assert all nodes are built
assert all(node._built for node in basic_graph.vertices)
# now build again and check if FakeListLLM was used
@ -370,7 +339,9 @@ def test_find_last_node(grouped_chat_json_flow):
def test_ungroup_node(grouped_chat_json_flow):
grouped_chat_data = json.loads(grouped_chat_json_flow).get("data")
group_node = grouped_chat_data["nodes"][2] # Assuming the first node is a group node
group_node = grouped_chat_data["nodes"][
2
] # Assuming the first node is a group node
base_flow = copy.deepcopy(grouped_chat_data)
ungroup_node(group_node["data"], base_flow)
# after ungroup_node is called, the base_flow and grouped_chat_data should be different
@ -422,9 +393,14 @@ def test_process_flow_one_group(one_grouped_chat_json_flow):
assert "edges" in processed_flow
# Now get the node that has ChatOpenAI in its id
chat_openai_node = next((node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None)
chat_openai_node = next(
(node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None
)
assert chat_openai_node is not None
assert chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] == "test"
assert (
chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"]
== "test"
)
def test_process_flow_vector_store_grouped(vector_store_grouped_json_flow):
@ -471,15 +447,19 @@ def test_update_template(sample_template, sample_nodes):
node2_updated = next((n for n in nodes_copy if n["id"] == "node2"), None)
node3_updated = next((n for n in nodes_copy if n["id"] == "node3"), None)
assert node1_updated is not None
assert node1_updated["data"]["node"]["template"]["some_field"]["show"] is True
assert node1_updated["data"]["node"]["template"]["some_field"]["advanced"] is False
assert node1_updated["data"]["node"]["template"]["some_field"]["display_name"] == "Name1"
assert (
node1_updated["data"]["node"]["template"]["some_field"]["display_name"]
== "Name1"
)
assert node2_updated is not None
assert node2_updated["data"]["node"]["template"]["other_field"]["show"] is False
assert node2_updated["data"]["node"]["template"]["other_field"]["advanced"] is True
assert node2_updated["data"]["node"]["template"]["other_field"]["display_name"] == "DisplayName2"
assert (
node2_updated["data"]["node"]["template"]["other_field"]["display_name"]
== "DisplayName2"
)
# Ensure node3 remains unchanged
assert node3_updated == sample_nodes[2]
@ -510,7 +490,9 @@ def test_set_new_target_handle():
"data": {
"node": {
"flow": True,
"template": {"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}},
"template": {
"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}
},
}
}
}
@ -530,34 +512,34 @@ def test_update_source_handle():
"nodes": [{"id": "some_node"}, {"id": "last_node"}],
"edges": [{"source": "some_node"}],
}
updated_edge = update_source_handle(new_edge, flow_data["nodes"], flow_data["edges"])
updated_edge = update_source_handle(
new_edge, flow_data["nodes"], flow_data["edges"]
)
assert updated_edge["source"] == "last_node"
assert updated_edge["data"]["sourceHandle"]["id"] == "last_node"
@pytest.mark.asyncio
async def test_pickle_graph(json_vector_store):
def test_pickle_graph(json_vector_store):
loaded_json = json.loads(json_vector_store)
graph = Graph.from_payload(loaded_json)
assert isinstance(graph, Graph)
first_result = await graph.build()
first_result = graph.build()
assert isinstance(first_result, AgentExecutor)
pickled = pickle.dumps(graph)
assert pickled is not UnbuiltObject
assert pickled is not None
unpickled = pickle.loads(pickled)
assert unpickled is not UnbuiltObject
result = await unpickled.build()
assert unpickled is not None
result = unpickled.build()
assert isinstance(result, AgentExecutor)
@pytest.mark.asyncio
async def test_pickle_each_vertex(json_vector_store):
def test_pickle_each_vertex(json_vector_store):
loaded_json = json.loads(json_vector_store)
graph = Graph.from_payload(loaded_json)
assert isinstance(graph, Graph)
for vertex in graph.vertices:
await vertex.build()
for vertex in graph.nodes:
vertex.build()
pickled = pickle.dumps(vertex)
assert pickled is not UnbuiltObject
assert pickled is not None
unpickled = pickle.loads(pickled)
assert unpickled is not UnbuiltObject
assert unpickled is not None

View file

@ -22,7 +22,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["verbose"] == {
"required": False,
@ -36,7 +35,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["client"] == {
"required": False,
@ -50,7 +48,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["model_name"] == {
"required": False,
@ -72,7 +69,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": True,
"advanced": False,
"info": "",
"fileTypes": [],
}
# Add more assertions for other properties here
assert template["temperature"] == {
@ -88,8 +84,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["max_tokens"] == {
"required": False,
@ -104,7 +98,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["top_p"] == {
"required": False,
@ -119,8 +112,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["frequency_penalty"] == {
"required": False,
@ -135,8 +126,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["presence_penalty"] == {
"required": False,
@ -151,8 +140,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["n"] == {
"required": False,
@ -167,7 +154,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["best_of"] == {
"required": False,
@ -182,7 +168,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["model_kwargs"] == {
"required": False,
@ -196,7 +181,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": True,
"info": "",
"fileTypes": [],
}
assert template["openai_api_key"] == {
"required": False,
@ -212,7 +196,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["batch_size"] == {
"required": False,
@ -227,7 +210,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["request_timeout"] == {
"required": False,
@ -241,8 +223,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["logit_bias"] == {
"required": False,
@ -256,7 +236,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["max_retries"] == {
"required": False,
@ -264,14 +243,13 @@ def test_openai(client: TestClient, logged_in_headers):
"placeholder": "",
"show": False,
"multiline": False,
"value": 2,
"value": 6,
"password": False,
"name": "max_retries",
"type": "int",
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["streaming"] == {
"required": False,
@ -286,7 +264,6 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
@ -312,7 +289,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["client"] == {
"required": False,
@ -326,7 +302,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["model_name"] == {
"required": False,
@ -334,11 +309,10 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"placeholder": "",
"show": True,
"multiline": False,
"value": "gpt-4-1106-preview",
"value": "gpt-3.5-turbo-0613",
"password": False,
"options": [
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
"gpt-4-32k",
"gpt-3.5-turbo",
@ -349,7 +323,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"list": True,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["temperature"] == {
"required": False,
@ -364,8 +337,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["model_kwargs"] == {
"required": False,
@ -379,7 +350,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"list": False,
"advanced": True,
"info": "",
"fileTypes": [],
}
assert template["openai_api_key"] == {
"required": False,
@ -395,7 +365,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["request_timeout"] == {
"required": False,
@ -409,8 +378,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["max_retries"] == {
"required": False,
@ -418,14 +385,13 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"placeholder": "",
"show": False,
"multiline": False,
"value": 2,
"value": 6,
"password": False,
"name": "max_retries",
"type": "int",
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["streaming"] == {
"required": False,
@ -440,7 +406,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["n"] == {
"required": False,
@ -455,7 +420,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["max_tokens"] == {
@ -470,7 +434,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["_type"] == "ChatOpenAI"
assert (

View file

@ -1,10 +1,9 @@
import json
import pytest
from langchain.chains.base import Chain
from langflow.processing.process import load_flow_from_json
from langflow.graph import Graph
from langflow.processing.load import load_flow_from_json
from langflow.utils.payload import get_root_vertex
from langflow.utils.payload import get_root_node
def test_load_flow_from_json():
@ -16,22 +15,21 @@ def test_load_flow_from_json():
def test_load_flow_from_json_with_tweaks():
"""Test loading a flow from a json file and applying tweaks"""
tweaks = {"dndnode_82": {"model_name": "test model"}}
tweaks = {"dndnode_82": {"model_name": "gpt-3.5-turbo-16k-0613"}}
loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH, tweaks=tweaks)
assert loaded is not None
assert isinstance(loaded, Chain)
assert loaded.llm.model_name == "test model"
assert loaded.llm.model_name == "gpt-3.5-turbo-16k-0613"
def test_get_root_vertex():
def test_get_root_node():
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
edges = data_graph["edges"]
graph = Graph(nodes, edges)
root = get_root_vertex(graph)
root = get_root_node(graph)
assert root is not None
assert hasattr(root, "id")
assert hasattr(root, "data")
assert hasattr(root, "data")

View file

@ -1,5 +1,5 @@
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from langflow.services.getters import get_db_service
import pytest
from langflow.services.database.models.user import User
from langflow.services.auth.utils import get_password_hash
@ -9,7 +9,9 @@ from langflow.services.auth.utils import get_password_hash
def test_user():
return User(
username="testuser",
password=get_password_hash("testpassword"), # Assuming password needs to be hashed
password=get_password_hash(
"testpassword"
), # Assuming password needs to be hashed
is_active=True,
is_superuser=False,
)
@ -21,13 +23,17 @@ def test_login_successful(client, test_user):
session.add(test_user)
session.commit()
response = client.post("api/v1/login", data={"username": "testuser", "password": "testpassword"})
response = client.post(
"api/v1/login", data={"username": "testuser", "password": "testpassword"}
)
assert response.status_code == 200
assert "access_token" in response.json()
def test_login_unsuccessful_wrong_username(client):
response = client.post("api/v1/login", data={"username": "wrongusername", "password": "testpassword"})
response = client.post(
"api/v1/login", data={"username": "wrongusername", "password": "testpassword"}
)
assert response.status_code == 401
assert response.json()["detail"] == "Incorrect username or password"
@ -37,6 +43,8 @@ def test_login_unsuccessful_wrong_password(client, test_user, session):
session.add(test_user)
session.commit()
response = client.post("api/v1/login", data={"username": "testuser", "password": "wrongpassword"})
response = client.post(
"api/v1/login", data={"username": "testuser", "password": "wrongpassword"}
)
assert response.status_code == 401
assert response.json()["detail"] == "Incorrect username or password"

View file

@ -1,6 +1,5 @@
import pytest
from langflow.processing.process import process_tweaks
from langflow.services.deps import get_session_service
from langflow.services.getters import get_session_service
def test_no_tweaks():
@ -198,42 +197,39 @@ def test_tweak_not_in_template():
assert result == graph_data
@pytest.mark.asyncio
async def test_load_langchain_object_with_cached_session(client, basic_graph_data):
def test_load_langchain_object_with_cached_session(client, basic_graph_data):
# Provide a non-existent session_id
session_service = get_session_service()
session_id1 = "non-existent-session-id"
graph1, artifacts1 = await session_service.load_session(session_id1, basic_graph_data)
graph1, artifacts1 = session_service.load_session(session_id1, basic_graph_data)
# Use the new session_id to get the langchain_object again
graph2, artifacts2 = await session_service.load_session(session_id1, basic_graph_data)
graph2, artifacts2 = session_service.load_session(session_id1, basic_graph_data)
assert graph1 == graph2
assert artifacts1 == artifacts2
@pytest.mark.asyncio
async def test_load_langchain_object_with_no_cached_session(client, basic_graph_data):
def test_load_langchain_object_with_no_cached_session(client, basic_graph_data):
# Provide a non-existent session_id
session_service = get_session_service()
session_id1 = "non-existent-session-id"
session_id = session_service.build_key(session_id1, basic_graph_data)
graph1, artifacts1 = await session_service.load_session(session_id, basic_graph_data)
graph1, artifacts1 = session_service.load_session(session_id, basic_graph_data)
# Clear the cache
session_service.clear_session(session_id)
# Use the new session_id to get the langchain_object again
graph2, artifacts2 = await session_service.load_session(session_id, basic_graph_data)
graph2, artifacts2 = session_service.load_session(session_id, basic_graph_data)
assert id(graph1) != id(graph2)
# Since the cache was cleared, objects should be different
@pytest.mark.asyncio
async def test_load_langchain_object_without_session_id(client, basic_graph_data):
def test_load_langchain_object_without_session_id(client, basic_graph_data):
# Provide a non-existent session_id
session_service = get_session_service()
session_id1 = None
graph1, artifacts1 = await session_service.load_session(session_id1, basic_graph_data)
graph1, artifacts1 = session_service.load_session(session_id1, basic_graph_data)
# Use the new session_id to get the langchain_object again
graph2, artifacts2 = await session_service.load_session(session_id1, basic_graph_data)
graph2, artifacts2 = session_service.load_session(session_id1, basic_graph_data)
assert graph1 == graph2

View file

@ -1,6 +1,5 @@
from fastapi.testclient import TestClient
from langflow.services.deps import get_settings_service
from langflow.services.getters import get_settings_service
def test_prompts_settings(client: TestClient, logged_in_headers):
@ -32,7 +31,6 @@ def test_prompt_template(client: TestClient, logged_in_headers):
"list": True,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["output_parser"] == {
@ -47,7 +45,6 @@ def test_prompt_template(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["partial_variables"] == {
@ -62,7 +59,6 @@ def test_prompt_template(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["template"] == {
@ -77,7 +73,6 @@ def test_prompt_template(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["template_format"] == {
@ -93,7 +88,6 @@ def test_prompt_template(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}
assert template["validate_template"] == {
@ -102,12 +96,11 @@ def test_prompt_template(client: TestClient, logged_in_headers):
"placeholder": "",
"show": False,
"multiline": False,
"value": False,
"value": True,
"password": False,
"name": "validate_template",
"type": "bool",
"list": False,
"advanced": False,
"info": "",
"fileTypes": [],
}

View file

@ -1,11 +1,17 @@
from unittest.mock import MagicMock, patch
from unittest.mock import patch, MagicMock
from langflow.services.database.models.user.user import User
from langflow.services.settings.constants import (
DEFAULT_SUPERUSER,
DEFAULT_SUPERUSER_PASSWORD,
)
from langflow.services.utils import (
teardown_superuser,
)
from langflow.services.settings.constants import DEFAULT_SUPERUSER, DEFAULT_SUPERUSER_PASSWORD
from langflow.services.utils import teardown_superuser
# @patch("langflow.services.deps.get_session")
# @patch("langflow.services.getters.get_session")
# @patch("langflow.services.utils.create_super_user")
# @patch("langflow.services.deps.get_settings_service")
# @patch("langflow.services.getters.get_settings_service")
# # @patch("langflow.services.utils.verify_password")
# def test_setup_superuser(
# mock_get_session, mock_create_super_user, mock_get_settings_service
@ -86,9 +92,11 @@ from langflow.services.utils import teardown_superuser
# assert str(actual_expr) == str(expected_expr)
@patch("langflow.services.deps.get_settings_service")
@patch("langflow.services.deps.get_session")
def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service):
@patch("langflow.services.getters.get_settings_service")
@patch("langflow.services.getters.get_session")
def test_teardown_superuser_default_superuser(
mock_get_session, mock_get_settings_service
):
mock_settings_service = MagicMock()
mock_settings_service.auth_settings.AUTO_LOGIN = True
mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
@ -103,12 +111,20 @@ def test_teardown_superuser_default_superuser(mock_get_session, mock_get_setting
teardown_superuser(mock_settings_service, mock_session)
mock_session.query.assert_not_called()
mock_session.query.assert_called_once_with(User)
actual_expr = mock_session.query.return_value.filter.call_args[0][0]
expected_expr = User.username == DEFAULT_SUPERUSER
assert str(actual_expr) == str(expected_expr)
mock_session.delete.assert_called_once_with(mock_user)
mock_session.commit.assert_called_once()
@patch("langflow.services.deps.get_settings_service")
@patch("langflow.services.deps.get_session")
def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service):
@patch("langflow.services.getters.get_settings_service")
@patch("langflow.services.getters.get_session")
def test_teardown_superuser_no_default_superuser(
mock_get_session, mock_get_settings_service
):
ADMIN_USER_NAME = "admin_user"
mock_settings_service = MagicMock()
mock_settings_service.auth_settings.AUTO_LOGIN = False
@ -119,11 +135,11 @@ def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_sett
mock_session = MagicMock()
mock_user = MagicMock()
mock_user.is_superuser = False
mock_session.exec.return_value.filter.return_value.first.return_value = mock_user
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
mock_get_session.return_value = [mock_session]
teardown_superuser(mock_settings_service, mock_session)
mock_session.exec.assert_called_once()
mock_session.query.assert_not_called()
mock_session.delete.assert_not_called()
mock_session.commit.assert_not_called()

View file

@ -2,8 +2,6 @@ import importlib
from typing import Dict, List, Optional
import pytest
from pydantic import BaseModel
from langflow.utils.constants import CHAT_OPENAI_MODELS, OPENAI_MODELS
from langflow.utils.util import (
build_template_from_class,
@ -12,6 +10,7 @@ from langflow.utils.util import (
get_base_classes,
get_default_factory,
)
from pydantic import BaseModel
# Dummy classes for testing purposes
@ -66,9 +65,11 @@ def test_build_template_from_function():
assert "base_classes" in result
# Test with add_function=True
result_with_function = build_template_from_function("ExampleClass1", type_to_loader_dict, add_function=True)
result_with_function = build_template_from_function(
"ExampleClass1", type_to_loader_dict, add_function=True
)
assert result_with_function is not None
assert "Callable" in result_with_function["base_classes"]
assert "function" in result_with_function["base_classes"]
# Test with invalid name
with pytest.raises(ValueError, match=r".* not found"):
@ -236,7 +237,7 @@ def test_format_dict():
"password": False,
"multiline": False,
"options": CHAT_OPENAI_MODELS,
"value": "gpt-4-1106-preview",
"value": "gpt-4-1106-preview"",
},
}
assert format_dict(input_dict, "OpenAI") == expected_output_openai

View file

@ -1,12 +1,11 @@
from datetime import datetime
import pytest
from langflow.services.auth.utils import create_super_user, get_password_hash
from langflow.services.database.models.user import UserUpdate
from langflow.services.database.models.user.model import User
from langflow.services.database.models.user.user import User
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service, get_settings_service
from langflow.services.getters import get_db_service, get_settings_service
import pytest
from langflow.services.database.models.user import UserUpdate
@pytest.fixture
@ -86,11 +85,15 @@ def test_deactivated_user_cannot_access(client, deactivated_user, logged_in_head
assert response.json()["detail"] == "The user doesn't have enough privileges"
def test_data_consistency_after_update(client, active_user, logged_in_headers, super_user_headers):
def test_data_consistency_after_update(
client, active_user, logged_in_headers, super_user_headers
):
user_id = active_user.id
update_data = UserUpdate(is_active=False)
response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=super_user_headers)
response = client.patch(
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=super_user_headers
)
assert response.status_code == 200, response.json()
# Fetch the updated user from the database
@ -117,7 +120,7 @@ def test_inactive_user(client):
username="inactiveuser",
password=get_password_hash("testpassword"),
is_active=False,
last_login_at=datetime.now(),
last_login_at="2023-01-01T00:00:00", # Set to a valid datetime string
)
session.add(user)
session.commit()
@ -164,13 +167,17 @@ def test_patch_user(client, active_user, logged_in_headers):
username="newname",
)
response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=logged_in_headers)
response = client.patch(
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers
)
assert response.status_code == 200, response.json()
update_data = UserUpdate(
profile_image="new_image",
)
response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=logged_in_headers)
response = client.patch(
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers
)
assert response.status_code == 200, response.json()
@ -182,7 +189,7 @@ def test_patch_reset_password(client, active_user, logged_in_headers):
response = client.patch(
f"/api/v1/users/{user_id}/reset-password",
json=update_data.model_dump(),
json=update_data.dict(),
headers=logged_in_headers,
)
assert response.status_code == 200, response.json()
@ -198,13 +205,19 @@ def test_patch_user_wrong_id(client, active_user, logged_in_headers):
username="newname",
)
response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=logged_in_headers)
response = client.patch(
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers
)
assert response.status_code == 422, response.json()
json_response = response.json()
detail = json_response["detail"]
assert detail[0]["type"] == "uuid_parsing"
assert detail[0]["loc"] == ["path", "user_id"]
assert detail[0]["input"] == "wrong_id"
assert response.json() == {
"detail": [
{
"loc": ["path", "user_id"],
"msg": "value is not a valid uuid",
"type": "type_error.uuid",
}
]
}
def test_delete_user(client, test_user, super_user_headers):
@ -218,11 +231,15 @@ def test_delete_user_wrong_id(client, test_user, super_user_headers):
user_id = "wrong_id"
response = client.delete(f"/api/v1/users/{user_id}", headers=super_user_headers)
assert response.status_code == 422
json_response = response.json()
detail = json_response["detail"]
assert detail[0]["type"] == "uuid_parsing"
assert detail[0]["loc"] == ["path", "user_id"]
assert detail[0]["input"] == "wrong_id"
assert response.json() == {
"detail": [
{
"loc": ["path", "user_id"],
"msg": "value is not a valid uuid",
"type": "type_error.uuid",
}
]
}
def test_normal_user_cant_delete_user(client, test_user, logged_in_headers):

View file

@ -1,5 +1,5 @@
from fastapi.testclient import TestClient
from langflow.services.deps import get_settings_service
from langflow.services.getters import get_settings_service
# check that all agents are in settings.agents

View file

@ -31,7 +31,9 @@ def test_websocket_endpoint(client: TestClient, active_user, logged_in_headers):
# Assuming your websocket_endpoint uses chat_service which caches data from stream_build
access_token = logged_in_headers["Authorization"].split(" ")[1]
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect(f"api/v1/chat/non_existing_client_id?token={access_token}") as websocket:
with client.websocket_connect(
f"api/v1/chat/non_existing_client_id?token={access_token}"
) as websocket:
websocket.send_json({"type": "test"})
data = websocket.receive_json()
assert "Please, build the flow before sending messages" in data["message"]