Fix import statements and formatting issues
This commit is contained in:
parent
47397153f4
commit
a6a69a0f3a
24 changed files with 924 additions and 527 deletions
|
|
@ -1,48 +1,75 @@
|
|||
from contextlib import contextmanager
|
||||
import json
|
||||
|
||||
# we need to import tmpdir
|
||||
import tempfile
|
||||
from contextlib import contextmanager, suppress
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, AsyncGenerator
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlmodel.pool import StaticPool
|
||||
from typer.testing import CliRunner
|
||||
from typing import AsyncGenerator, TYPE_CHECKING
|
||||
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.flow.model import Flow, FlowCreate
|
||||
from langflow.services.database.models.user.model import User, UserCreate
|
||||
from langflow.services.database.models.flow.flow import Flow, FlowCreate
|
||||
from langflow.services.database.models.user.user import User, UserCreate
|
||||
import orjson
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from langflow.services.getters import get_db_service
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
from sqlmodel import SQLModel, Session, create_engine
|
||||
from sqlmodel.pool import StaticPool
|
||||
from typer.testing import CliRunner
|
||||
|
||||
# we need to import tmpdir
|
||||
import tempfile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.service import DatabaseService
|
||||
from langflow.services.database.manager import DatabaseService
|
||||
|
||||
|
||||
def pytest_configure():
|
||||
pytest.BASIC_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "basic_example.json"
|
||||
pytest.COMPLEX_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "complex_example.json"
|
||||
pytest.OPENAPI_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "Openapi.json"
|
||||
pytest.GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "grouped_chat.json"
|
||||
pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "one_group_chat.json"
|
||||
pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json"
|
||||
pytest.BASIC_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "basic_example.json"
|
||||
)
|
||||
pytest.COMPLEX_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "complex_example.json"
|
||||
)
|
||||
pytest.OPENAPI_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "Openapi.json"
|
||||
)
|
||||
pytest.GROUPED_CHAT_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "grouped_chat.json"
|
||||
)
|
||||
pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "one_group_chat.json"
|
||||
)
|
||||
pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json"
|
||||
)
|
||||
|
||||
pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY = (
|
||||
Path(__file__).parent.absolute() / "data" / "BasicChatwithPromptandHistory.json"
|
||||
Path(__file__).parent.absolute() / "data" / "BasicChatWithPromptAndHistory.json"
|
||||
)
|
||||
pytest.CHAT_INPUT = Path(__file__).parent.absolute() / "data" / "ChatInputTest.json"
|
||||
pytest.TWO_OUTPUTS = (
|
||||
Path(__file__).parent.absolute() / "data" / "TwoOutputsTest.json"
|
||||
)
|
||||
pytest.VECTOR_STORE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "Vector_store.json"
|
||||
)
|
||||
pytest.VECTOR_STORE_PATH = Path(__file__).parent.absolute() / "data" / "Vector_store.json"
|
||||
pytest.CODE_WITH_SYNTAX_ERROR = """
|
||||
def get_text():
|
||||
retun "Hello World"
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def check_openai_api_key_in_environment_variables():
|
||||
import os
|
||||
|
||||
assert (
|
||||
os.environ.get("OPENAI_API_KEY") is not None
|
||||
), "OPENAI_API_KEY is not set in environment variables"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def async_client() -> AsyncGenerator:
|
||||
from langflow.main import create_app
|
||||
|
|
@ -54,7 +81,9 @@ async def async_client() -> AsyncGenerator:
|
|||
|
||||
@pytest.fixture(name="session")
|
||||
def session_fixture():
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
engine = create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
|
@ -90,7 +119,9 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
|
|||
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
|
||||
# monkeypatch langflow.services.task.manager.USE_CELERY to True
|
||||
# monkeypatch.setattr(manager, "USE_CELERY", True)
|
||||
monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config))
|
||||
monkeypatch.setattr(
|
||||
celery_app, "celery_app", celery_app.make_celery("langflow", Config)
|
||||
)
|
||||
|
||||
# def get_session_override():
|
||||
# return session
|
||||
|
|
@ -225,7 +256,7 @@ def test_user(client):
|
|||
username="testuser",
|
||||
password="testpassword",
|
||||
)
|
||||
response = client.post("/api/v1/users", json=user_data.model_dump())
|
||||
response = client.post("/api/v1/users", json=user_data.dict())
|
||||
assert response.status_code == 201
|
||||
return response.json()
|
||||
|
||||
|
|
@ -241,7 +272,11 @@ def active_user(client):
|
|||
is_superuser=False,
|
||||
)
|
||||
# check if user exists
|
||||
if active_user := session.query(User).filter(User.username == user.username).first():
|
||||
if (
|
||||
active_user := session.query(User)
|
||||
.filter(User.username == user.username)
|
||||
.first()
|
||||
):
|
||||
return active_user
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
|
@ -261,16 +296,13 @@ def logged_in_headers(client, active_user):
|
|||
|
||||
@pytest.fixture
|
||||
def flow(client, json_flow: str, active_user):
|
||||
from langflow.services.database.models.flow.model import FlowCreate
|
||||
from langflow.services.database.models.flow.flow import FlowCreate
|
||||
|
||||
loaded_json = json.loads(json_flow)
|
||||
flow_data = FlowCreate(
|
||||
name="test_flow",
|
||||
data=loaded_json.get("data"),
|
||||
user_id=active_user.id,
|
||||
description="description",
|
||||
name="test_flow", data=loaded_json.get("data"), user_id=active_user.id
|
||||
)
|
||||
flow = Flow.model_validate(flow_data.model_dump())
|
||||
flow = Flow(**flow_data.dict())
|
||||
with session_getter(get_db_service()) as session:
|
||||
session.add(flow)
|
||||
session.commit()
|
||||
|
|
@ -280,11 +312,31 @@ def flow(client, json_flow: str, active_user):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def added_flow(client, json_flow_with_prompt_and_history, logged_in_headers):
|
||||
def json_flow_with_prompt_and_history():
|
||||
with open(pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_chat_input():
|
||||
with open(pytest.CHAT_INPUT, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_two_outputs():
|
||||
with open(pytest.TWO_OUTPUTS, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def added_flow_with_prompt_and_history(
|
||||
client, json_flow_with_prompt_and_history, logged_in_headers
|
||||
):
|
||||
flow = orjson.loads(json_flow_with_prompt_and_history)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Basic Chat", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
|
@ -292,28 +344,37 @@ def added_flow(client, json_flow_with_prompt_and_history, logged_in_headers):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def added_vector_store(client, json_vector_store, logged_in_headers):
|
||||
vector_store = orjson.loads(json_vector_store)
|
||||
data = vector_store["data"]
|
||||
vector_store = FlowCreate(name="Vector Store", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=vector_store.model_dump(), headers=logged_in_headers)
|
||||
def added_flow_chat_input(client, json_chat_input, logged_in_headers):
|
||||
flow = orjson.loads(json_chat_input)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Chat Input", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == vector_store.name
|
||||
assert response.json()["data"] == vector_store.data
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_component_code():
|
||||
path = Path(__file__).parent.absolute() / "data" / "component.py"
|
||||
# load the content as a string
|
||||
with open(path, "r") as f:
|
||||
return f.read()
|
||||
def added_flow_two_outputs(client, json_two_outputs, logged_in_headers):
|
||||
flow = orjson.loads(json_two_outputs)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Two Outputs", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_component_with_templatefield_code():
|
||||
path = Path(__file__).parent.absolute() / "data" / "component_with_templatefield.py"
|
||||
# load the content as a string
|
||||
with open(path, "r") as f:
|
||||
return f.read()
|
||||
def added_vector_store(client, json_vector_store, logged_in_headers):
|
||||
vector_store = orjson.loads(json_vector_store)
|
||||
data = vector_store["data"]
|
||||
vector_store = FlowCreate(name="Vector Store", description="description", data=data)
|
||||
response = client.post(
|
||||
"api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == vector_store.name
|
||||
assert response.json()["data"] == vector_store.data
|
||||
return response.json()
|
||||
|
|
|
|||
1
tests/data/ChatInputTest.json
Normal file
1
tests/data/ChatInputTest.json
Normal file
File diff suppressed because one or more lines are too long
1
tests/data/TwoOutputsTest.json
Normal file
1
tests/data/TwoOutputsTest.json
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -12,7 +12,7 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"ZeroShotAgent",
|
||||
"BaseSingleActionAgent",
|
||||
"Agent",
|
||||
"Callable",
|
||||
"function",
|
||||
}
|
||||
template = zero_shot_agent["template"]
|
||||
|
||||
|
|
@ -28,7 +28,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": True,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
# Additional assertions for other template variables
|
||||
|
|
@ -44,7 +43,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
|
|
@ -58,7 +56,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["output_parser"] == {
|
||||
"required": False,
|
||||
|
|
@ -72,7 +69,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["input_variables"] == {
|
||||
"required": False,
|
||||
|
|
@ -86,7 +82,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": True,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["prefix"] == {
|
||||
"required": False,
|
||||
|
|
@ -101,7 +96,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["suffix"] == {
|
||||
"required": False,
|
||||
|
|
@ -116,7 +110,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -142,9 +135,6 @@ def test_json_agent(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"file_path": "",
|
||||
"fileTypes": [],
|
||||
"value": "",
|
||||
}
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
|
|
@ -159,9 +149,6 @@ def test_json_agent(client: TestClient, logged_in_headers):
|
|||
"advanced": False,
|
||||
"display_name": "LLM",
|
||||
"info": "",
|
||||
"file_path": "",
|
||||
"fileTypes": [],
|
||||
"value": "",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -182,12 +169,87 @@ def test_csv_agent(client: TestClient, logged_in_headers):
|
|||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "",
|
||||
"fileTypes": [".csv"],
|
||||
"suffixes": [".csv"],
|
||||
"fileTypes": ["csv"],
|
||||
"password": False,
|
||||
"name": "path",
|
||||
"type": "file",
|
||||
"list": False,
|
||||
"file_path": "",
|
||||
"file_path": None,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "llm",
|
||||
"type": "BaseLanguageModel",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"display_name": "LLM",
|
||||
"info": "",
|
||||
}
|
||||
|
||||
|
||||
def test_initialize_agent(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
|
||||
initialize_agent = agents["AgentInitializer"]
|
||||
assert initialize_agent["base_classes"] == ["AgentExecutor", "function"]
|
||||
template = initialize_agent["template"]
|
||||
|
||||
assert template["agent"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "zero-shot-react-description",
|
||||
"password": False,
|
||||
"options": [
|
||||
"zero-shot-react-description",
|
||||
"react-docstore",
|
||||
"self-ask-with-search",
|
||||
"conversational-react-description",
|
||||
"openai-functions",
|
||||
"openai-multi-functions",
|
||||
],
|
||||
"name": "agent",
|
||||
"type": "str",
|
||||
"list": True,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["memory"] == {
|
||||
"required": False,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "memory",
|
||||
"type": "BaseChatMemory",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["tools"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "tools",
|
||||
"type": "Tool",
|
||||
"list": True,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
|
|
@ -204,7 +266,4 @@ def test_csv_agent(client: TestClient, logged_in_headers):
|
|||
"advanced": False,
|
||||
"display_name": "LLM",
|
||||
"info": "",
|
||||
"file_path": "",
|
||||
"fileTypes": [],
|
||||
"value": "",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ from langflow.services.database.models.api_key import ApiKeyCreate
|
|||
def api_key(client, logged_in_headers, active_user):
|
||||
api_key = ApiKeyCreate(name="test-api-key")
|
||||
|
||||
response = client.post("api/v1/api_key", data=api_key.json(), headers=logged_in_headers)
|
||||
response = client.post(
|
||||
"api/v1/api_key", data=api_key.json(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
return response.json()
|
||||
|
||||
|
|
@ -26,7 +28,9 @@ def test_get_api_keys(client, logged_in_headers, api_key):
|
|||
|
||||
def test_create_api_key(client, logged_in_headers):
|
||||
api_key_name = "test-api-key"
|
||||
response = client.post("api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers)
|
||||
response = client.post(
|
||||
"api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "name" in data and data["name"] == api_key_name
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
import json
|
||||
from langflow.graph import Graph
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.graph import Graph
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# def test_chains_settings(client: TestClient, logged_in_headers):
|
||||
# response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
# assert response.status_code == 200
|
||||
|
|
@ -8,53 +9,21 @@ from fastapi.testclient import TestClient
|
|||
# assert set(chains.keys()) == set(settings.chains)
|
||||
|
||||
|
||||
def test_llm_checker_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
chain = chains["LLMCheckerChain"]
|
||||
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {
|
||||
"Callable",
|
||||
"LLMCheckerChain",
|
||||
"Chain",
|
||||
}
|
||||
|
||||
template = chain["template"]
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "llm",
|
||||
"type": "BaseLanguageModel",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["_type"] == "LLMCheckerChain"
|
||||
|
||||
# Test the description object
|
||||
assert chain["description"] == ""
|
||||
|
||||
|
||||
def test_llm_math_chain(client: TestClient, logged_in_headers):
|
||||
# Test the ConversationChain object
|
||||
def test_conversation_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
||||
chain = chains["LLMMathChain"]
|
||||
chain = chains["ConversationChain"]
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {
|
||||
"Callable",
|
||||
"LLMMathChain",
|
||||
"ConversationChain",
|
||||
"LLMChain",
|
||||
"Chain",
|
||||
"function",
|
||||
"Text",
|
||||
}
|
||||
|
||||
template = chain["template"]
|
||||
|
|
@ -70,7 +39,195 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["verbose"] == {
|
||||
"required": False,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "verbose",
|
||||
"type": "bool",
|
||||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
}
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "llm",
|
||||
"type": "BaseLanguageModel",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["input_key"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "input",
|
||||
"password": False,
|
||||
"name": "input_key",
|
||||
"type": "str",
|
||||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
}
|
||||
assert template["output_key"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "response",
|
||||
"password": False,
|
||||
"name": "output_key",
|
||||
"type": "str",
|
||||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
}
|
||||
assert template["_type"] == "ConversationChain"
|
||||
|
||||
# Test the description object
|
||||
assert (
|
||||
chain["description"]
|
||||
== "Chain to have a conversation and load context from memory."
|
||||
)
|
||||
|
||||
|
||||
def test_llm_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
chain = chains["LLMChain"]
|
||||
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {"function", "LLMChain", "Chain", "Text"}
|
||||
|
||||
template = chain["template"]
|
||||
assert template["memory"] == {
|
||||
"required": False,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "memory",
|
||||
"type": "BaseMemory",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["verbose"] == {
|
||||
"required": False,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
"value": False,
|
||||
"password": False,
|
||||
"name": "verbose",
|
||||
"type": "bool",
|
||||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
}
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "llm",
|
||||
"type": "BaseLanguageModel",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["output_key"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "text",
|
||||
"password": False,
|
||||
"name": "output_key",
|
||||
"type": "str",
|
||||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
}
|
||||
|
||||
|
||||
def test_llm_checker_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
chain = chains["LLMCheckerChain"]
|
||||
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {
|
||||
"function",
|
||||
"LLMCheckerChain",
|
||||
"Chain",
|
||||
"Text",
|
||||
}
|
||||
|
||||
template = chain["template"]
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "llm",
|
||||
"type": "BaseLanguageModel",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["_type"] == "LLMCheckerChain"
|
||||
|
||||
# Test the description object
|
||||
assert chain["description"] == ""
|
||||
|
||||
|
||||
def test_llm_math_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
||||
chain = chains["LLMMathChain"]
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {"function", "LLMMathChain", "Chain", "Text"}
|
||||
|
||||
template = chain["template"]
|
||||
assert template["memory"] == {
|
||||
"required": False,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "memory",
|
||||
"type": "BaseMemory",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["verbose"] == {
|
||||
"required": False,
|
||||
|
|
@ -85,7 +242,6 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
|
|
@ -99,7 +255,6 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["input_key"] == {
|
||||
"required": True,
|
||||
|
|
@ -114,7 +269,6 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["output_key"] == {
|
||||
"required": True,
|
||||
|
|
@ -129,12 +283,14 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["_type"] == "LLMMathChain"
|
||||
|
||||
# Test the description object
|
||||
assert chain["description"] == "Chain that interprets a prompt and executes python code to do math."
|
||||
assert (
|
||||
chain["description"]
|
||||
== "Chain that interprets a prompt and executes python code to do math."
|
||||
)
|
||||
|
||||
|
||||
def test_series_character_chain(client: TestClient, logged_in_headers):
|
||||
|
|
@ -147,7 +303,7 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
|
|||
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {
|
||||
"Callable",
|
||||
"function",
|
||||
"LLMChain",
|
||||
"BaseCustomChain",
|
||||
"Chain",
|
||||
|
|
@ -169,9 +325,6 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
"file_path": "",
|
||||
"value": "",
|
||||
}
|
||||
assert template["character"] == {
|
||||
"required": True,
|
||||
|
|
@ -185,9 +338,6 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
"file_path": "",
|
||||
"value": "",
|
||||
}
|
||||
assert template["series"] == {
|
||||
"required": True,
|
||||
|
|
@ -201,9 +351,6 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
"file_path": "",
|
||||
"value": "",
|
||||
}
|
||||
assert template["_type"] == "SeriesCharacterChain"
|
||||
|
||||
|
|
@ -247,12 +394,12 @@ def test_mid_journey_prompt_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"file_path": "",
|
||||
"fileTypes": [],
|
||||
"value": "",
|
||||
}
|
||||
# Test the description object
|
||||
assert chain["description"] == "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."
|
||||
assert (
|
||||
chain["description"]
|
||||
== "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."
|
||||
)
|
||||
|
||||
|
||||
def test_time_travel_guide_chain(client: TestClient, logged_in_headers):
|
||||
|
|
@ -288,9 +435,6 @@ def test_time_travel_guide_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"file_path": "",
|
||||
"fileTypes": [],
|
||||
"value": "",
|
||||
}
|
||||
assert template["memory"] == {
|
||||
"required": False,
|
||||
|
|
@ -304,9 +448,6 @@ def test_time_travel_guide_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"file_path": "",
|
||||
"fileTypes": [],
|
||||
"value": "",
|
||||
}
|
||||
|
||||
assert chain["description"] == "Time travel guide chain."
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
from pathlib import Path
|
||||
from tempfile import tempdir
|
||||
|
||||
from langflow.__main__ import app
|
||||
import pytest
|
||||
|
||||
from langflow.__main__ import app
|
||||
from langflow.services import deps
|
||||
from langflow.services import getters
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
@ -27,12 +26,11 @@ def test_components_path(runner, client, default_settings):
|
|||
["run", "--components-path", str(temp_dir), *default_settings],
|
||||
)
|
||||
assert result.exit_code == 0, result.stdout
|
||||
settings_service = deps.get_settings_service()
|
||||
settings_service = getters.get_settings_service()
|
||||
assert str(temp_dir) in settings_service.settings.COMPONENTS_PATH
|
||||
|
||||
|
||||
def test_superuser(runner, client, session):
|
||||
result = runner.invoke(app, ["superuser"], input="admin\nadmin\n")
|
||||
assert result.exit_code == 0, result.stdout
|
||||
assert "Superuser creation failed." not in result.output, result.output
|
||||
assert "Superuser created successfully." in result.output, result.output
|
||||
assert "Superuser created successfully." in result.stdout
|
||||
|
|
|
|||
|
|
@ -1,13 +1,19 @@
|
|||
import ast
|
||||
import pytest
|
||||
import types
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from langflow.interface.custom.base import CustomComponent
|
||||
from langflow.interface.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
|
||||
from langflow.interface.custom.custom_component.component import Component, ComponentCodeNullError
|
||||
from langflow.interface.custom.utils import build_custom_component_template
|
||||
|
||||
from fastapi import HTTPException
|
||||
from langflow.services.database.models.flow import Flow, FlowCreate
|
||||
from langflow.interface.custom.base import CustomComponent
|
||||
from langflow.interface.custom.component import (
|
||||
Component,
|
||||
ComponentCodeNullError,
|
||||
ComponentFunctionEntrypointNameNullError,
|
||||
)
|
||||
from langflow.interface.custom.code_parser import CodeParser, CodeSyntaxError
|
||||
|
||||
|
||||
code_default = """
|
||||
from langflow import Prompt
|
||||
|
|
@ -47,7 +53,7 @@ def test_code_parser_get_tree():
|
|||
Test the __get_tree method of the CodeParser class.
|
||||
"""
|
||||
parser = CodeParser(code_default)
|
||||
tree = parser.get_tree()
|
||||
tree = parser._CodeParser__get_tree()
|
||||
assert isinstance(tree, ast.AST)
|
||||
|
||||
|
||||
|
|
@ -60,23 +66,23 @@ def test_code_parser_syntax_error():
|
|||
|
||||
parser = CodeParser(code_syntax_error)
|
||||
with pytest.raises(CodeSyntaxError):
|
||||
parser.get_tree()
|
||||
parser._CodeParser__get_tree()
|
||||
|
||||
|
||||
def test_component_init():
|
||||
"""
|
||||
Test the initialization of the Component class.
|
||||
"""
|
||||
component = Component(code=code_default, _function_entrypoint_name="build")
|
||||
component = Component(code=code_default, function_entrypoint_name="build")
|
||||
assert component.code == code_default
|
||||
assert component._function_entrypoint_name == "build"
|
||||
assert component.function_entrypoint_name == "build"
|
||||
|
||||
|
||||
def test_component_get_code_tree():
|
||||
"""
|
||||
Test the get_code_tree method of the Component class.
|
||||
"""
|
||||
component = Component(code=code_default, _function_entrypoint_name="build")
|
||||
component = Component(code=code_default, function_entrypoint_name="build")
|
||||
tree = component.get_code_tree(component.code)
|
||||
assert "imports" in tree
|
||||
|
||||
|
|
@ -86,20 +92,19 @@ def test_component_code_null_error():
|
|||
Test the get_function method raises the
|
||||
ComponentCodeNullError when the code is empty.
|
||||
"""
|
||||
component = Component(code="", _function_entrypoint_name="")
|
||||
component = Component(code="", function_entrypoint_name="")
|
||||
with pytest.raises(ComponentCodeNullError):
|
||||
component.get_function()
|
||||
|
||||
|
||||
# TODO: Validate if we should remove this
|
||||
# def test_component_function_entrypoint_name_null_error():
|
||||
# """
|
||||
# Test the get_function method raises the ComponentFunctionEntrypointNameNullError
|
||||
# when the function_entrypoint_name is empty.
|
||||
# """
|
||||
# component = Component(code=code_default, _function_entrypoint_name="")
|
||||
# with pytest.raises(ComponentFunctionEntrypointNameNullError):
|
||||
# component.get_function()
|
||||
def test_component_function_entrypoint_name_null_error():
|
||||
"""
|
||||
Test the get_function method raises the ComponentFunctionEntrypointNameNullError
|
||||
when the function_entrypoint_name is empty.
|
||||
"""
|
||||
component = Component(code=code_default, function_entrypoint_name="")
|
||||
with pytest.raises(ComponentFunctionEntrypointNameNullError):
|
||||
component.get_function()
|
||||
|
||||
|
||||
def test_custom_component_init():
|
||||
|
|
@ -108,7 +113,9 @@ def test_custom_component_init():
|
|||
"""
|
||||
function_entrypoint_name = "build"
|
||||
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name=function_entrypoint_name)
|
||||
custom_component = CustomComponent(
|
||||
code=code_default, function_entrypoint_name=function_entrypoint_name
|
||||
)
|
||||
assert custom_component.code == code_default
|
||||
assert custom_component.function_entrypoint_name == function_entrypoint_name
|
||||
|
||||
|
|
@ -117,8 +124,10 @@ def test_custom_component_build_template_config():
|
|||
"""
|
||||
Test the build_template_config property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
|
||||
config = custom_component.template_config
|
||||
custom_component = CustomComponent(
|
||||
code=code_default, function_entrypoint_name="build"
|
||||
)
|
||||
config = custom_component.build_template_config
|
||||
assert isinstance(config, dict)
|
||||
|
||||
|
||||
|
|
@ -126,7 +135,9 @@ def test_custom_component_get_function():
|
|||
"""
|
||||
Test the get_function property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
|
||||
custom_component = CustomComponent(
|
||||
code="def build(): pass", function_entrypoint_name="build"
|
||||
)
|
||||
my_function = custom_component.get_function
|
||||
assert isinstance(my_function, types.FunctionType)
|
||||
|
||||
|
|
@ -137,7 +148,7 @@ def test_code_parser_parse_imports_import():
|
|||
class with an import statement.
|
||||
"""
|
||||
parser = CodeParser(code_default)
|
||||
tree = parser.get_tree()
|
||||
tree = parser._CodeParser__get_tree()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
parser.parse_imports(node)
|
||||
|
|
@ -150,7 +161,7 @@ def test_code_parser_parse_imports_importfrom():
|
|||
class with an import from statement.
|
||||
"""
|
||||
parser = CodeParser("from os import path")
|
||||
tree = parser.get_tree()
|
||||
tree = parser._CodeParser__get_tree()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
parser.parse_imports(node)
|
||||
|
|
@ -162,7 +173,7 @@ def test_code_parser_parse_functions():
|
|||
Test the parse_functions method of the CodeParser class.
|
||||
"""
|
||||
parser = CodeParser("def test(): pass")
|
||||
tree = parser.get_tree()
|
||||
tree = parser._CodeParser__get_tree()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
parser.parse_functions(node)
|
||||
|
|
@ -175,7 +186,7 @@ def test_code_parser_parse_classes():
|
|||
Test the parse_classes method of the CodeParser class.
|
||||
"""
|
||||
parser = CodeParser("class Test: pass")
|
||||
tree = parser.get_tree()
|
||||
tree = parser._CodeParser__get_tree()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
parser.parse_classes(node)
|
||||
|
|
@ -188,7 +199,7 @@ def test_code_parser_parse_global_vars():
|
|||
Test the parse_global_vars method of the CodeParser class.
|
||||
"""
|
||||
parser = CodeParser("x = 1")
|
||||
tree = parser.get_tree()
|
||||
tree = parser._CodeParser__get_tree()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Assign):
|
||||
parser.parse_global_vars(node)
|
||||
|
|
@ -201,7 +212,7 @@ def test_component_get_function_valid():
|
|||
Test the get_function method of the Component
|
||||
class with valid code and function_entrypoint_name.
|
||||
"""
|
||||
component = Component(code="def build(): pass", _function_entrypoint_name="build")
|
||||
component = Component(code="def build(): pass", function_entrypoint_name="build")
|
||||
my_function = component.get_function()
|
||||
assert callable(my_function)
|
||||
|
||||
|
|
@ -211,7 +222,9 @@ def test_custom_component_get_function_entrypoint_args():
|
|||
Test the get_function_entrypoint_args
|
||||
property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
|
||||
custom_component = CustomComponent(
|
||||
code=code_default, function_entrypoint_name="build"
|
||||
)
|
||||
args = custom_component.get_function_entrypoint_args
|
||||
assert len(args) == 4
|
||||
assert args[0]["name"] == "self"
|
||||
|
|
@ -224,18 +237,20 @@ def test_custom_component_get_function_entrypoint_return_type():
|
|||
Test the get_function_entrypoint_return_type
|
||||
property of the CustomComponent class.
|
||||
"""
|
||||
from langchain.schema import Document
|
||||
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
|
||||
custom_component = CustomComponent(
|
||||
code=code_default, function_entrypoint_name="build"
|
||||
)
|
||||
return_type = custom_component.get_function_entrypoint_return_type
|
||||
assert return_type == [Document]
|
||||
assert return_type == ["Document"]
|
||||
|
||||
|
||||
def test_custom_component_get_main_class_name():
|
||||
"""
|
||||
Test the get_main_class_name property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
|
||||
custom_component = CustomComponent(
|
||||
code=code_default, function_entrypoint_name="build"
|
||||
)
|
||||
class_name = custom_component.get_main_class_name
|
||||
assert class_name == "YourComponent"
|
||||
|
||||
|
|
@ -245,7 +260,9 @@ def test_custom_component_get_function_valid():
|
|||
Test the get_function property of the CustomComponent
|
||||
class with valid code and function_entrypoint_name.
|
||||
"""
|
||||
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
|
||||
custom_component = CustomComponent(
|
||||
code="def build(): pass", function_entrypoint_name="build"
|
||||
)
|
||||
my_function = custom_component.get_function
|
||||
assert callable(my_function)
|
||||
|
||||
|
|
@ -280,7 +297,9 @@ def test_code_parser_parse_callable_details_no_args():
|
|||
parser = CodeParser("")
|
||||
node = ast.FunctionDef(
|
||||
name="test",
|
||||
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
|
||||
args=ast.arguments(
|
||||
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
|
||||
),
|
||||
body=[],
|
||||
decorator_list=[],
|
||||
returns=None,
|
||||
|
|
@ -309,7 +328,7 @@ def test_code_parser_parse_ann_assign():
|
|||
stmt = ast.AnnAssign(
|
||||
target=ast.Name(id="x", ctx=ast.Store()),
|
||||
annotation=ast.Name(id="int", ctx=ast.Load()),
|
||||
value=ast.Constant(n=1),
|
||||
value=ast.Num(n=1),
|
||||
simple=1,
|
||||
)
|
||||
result = parser.parse_ann_assign(stmt)
|
||||
|
|
@ -326,7 +345,9 @@ def test_code_parser_parse_function_def_not_init():
|
|||
parser = CodeParser("")
|
||||
stmt = ast.FunctionDef(
|
||||
name="test",
|
||||
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
|
||||
args=ast.arguments(
|
||||
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
|
||||
),
|
||||
body=[],
|
||||
decorator_list=[],
|
||||
returns=None,
|
||||
|
|
@ -344,7 +365,9 @@ def test_code_parser_parse_function_def_init():
|
|||
parser = CodeParser("")
|
||||
stmt = ast.FunctionDef(
|
||||
name="__init__",
|
||||
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
|
||||
args=ast.arguments(
|
||||
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
|
||||
),
|
||||
body=[],
|
||||
decorator_list=[],
|
||||
returns=None,
|
||||
|
|
@ -359,17 +382,29 @@ def test_component_get_code_tree_syntax_error():
|
|||
Test the get_code_tree method of the Component class
|
||||
raises the CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
component = Component(code="import os as", _function_entrypoint_name="build")
|
||||
component = Component(code="import os as", function_entrypoint_name="build")
|
||||
with pytest.raises(CodeSyntaxError):
|
||||
component.get_code_tree(component.code)
|
||||
|
||||
|
||||
def test_custom_component_class_template_validation_no_code():
|
||||
"""
|
||||
Test the _class_template_validation method of the CustomComponent class
|
||||
raises the HTTPException when the code is None.
|
||||
"""
|
||||
custom_component = CustomComponent(code=None, function_entrypoint_name="build")
|
||||
with pytest.raises(HTTPException):
|
||||
custom_component._class_template_validation(custom_component.code)
|
||||
|
||||
|
||||
def test_custom_component_get_code_tree_syntax_error():
|
||||
"""
|
||||
Test the get_code_tree method of the CustomComponent class
|
||||
raises the CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
custom_component = CustomComponent(code="import os as", function_entrypoint_name="build")
|
||||
custom_component = CustomComponent(
|
||||
code="import os as", function_entrypoint_name="build"
|
||||
)
|
||||
with pytest.raises(CodeSyntaxError):
|
||||
custom_component.get_code_tree(custom_component.code)
|
||||
|
||||
|
|
@ -423,7 +458,9 @@ def test_custom_component_build_not_implemented():
|
|||
Test the build method of the CustomComponent
|
||||
class raises the NotImplementedError.
|
||||
"""
|
||||
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
|
||||
custom_component = CustomComponent(
|
||||
code="def build(): pass", function_entrypoint_name="build"
|
||||
)
|
||||
with pytest.raises(NotImplementedError):
|
||||
custom_component.build()
|
||||
|
||||
|
|
@ -431,7 +468,7 @@ def test_custom_component_build_not_implemented():
|
|||
def test_build_config_no_code():
|
||||
component = CustomComponent(code=None)
|
||||
|
||||
assert component.get_function_entrypoint_args == []
|
||||
assert component.get_function_entrypoint_args == ""
|
||||
assert component.get_function_entrypoint_return_type == []
|
||||
|
||||
|
||||
|
|
@ -457,7 +494,9 @@ def test_flow(db):
|
|||
}
|
||||
|
||||
# Create flow
|
||||
flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data)
|
||||
flow = FlowCreate(
|
||||
id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data
|
||||
)
|
||||
|
||||
# Add to database
|
||||
db.add(flow)
|
||||
|
|
@ -518,36 +557,3 @@ def test_build_config_field_value_keys(component):
|
|||
config = component.build_config()
|
||||
field_values = config["fields"].values()
|
||||
assert all("type" in value for value in field_values)
|
||||
|
||||
|
||||
def test_create_and_validate_component_valid_code(test_component_code):
|
||||
component = CustomComponent(code=test_component_code)
|
||||
assert isinstance(component, CustomComponent)
|
||||
|
||||
|
||||
def test_build_langchain_template_custom_component_valid_code(test_component_code):
|
||||
component = CustomComponent(code=test_component_code)
|
||||
frontend_node = build_custom_component_template(component)
|
||||
assert isinstance(frontend_node, dict)
|
||||
template = frontend_node["template"]
|
||||
assert isinstance(template, dict)
|
||||
assert "param" in template
|
||||
param_options = template["param"]["options"]
|
||||
# Now run it again with an update field
|
||||
frontend_node = build_custom_component_template(component, update_field="param")
|
||||
new_param_options = frontend_node["template"]["param"]["options"]
|
||||
assert param_options != new_param_options
|
||||
|
||||
|
||||
def test_build_langchain_template_custom_component_templatefield(test_component_with_templatefield_code):
|
||||
component = CustomComponent(code=test_component_with_templatefield_code)
|
||||
frontend_node = build_custom_component_template(component)
|
||||
assert isinstance(frontend_node, dict)
|
||||
template = frontend_node["template"]
|
||||
assert isinstance(template, dict)
|
||||
assert "param" in template
|
||||
param_options = template["param"]["options"]
|
||||
# Now run it again with an update field
|
||||
frontend_node = build_custom_component_template(component, update_field="param")
|
||||
new_param_options = frontend_node["template"]["param"]["options"]
|
||||
assert param_options != new_param_options
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ def test_python_function_tool():
|
|||
with pytest.raises(SyntaxError):
|
||||
code = pytest.CODE_WITH_SYNTAX_ERROR
|
||||
func = get_function(code)
|
||||
func = PythonFunctionTool(name="Test", description="Testing", code=code, func=func)
|
||||
func = PythonFunctionTool(
|
||||
name="Test", description="Testing", code=code, func=func
|
||||
)
|
||||
|
||||
|
||||
def test_python_function():
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
from uuid import UUID, uuid4
|
||||
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service
|
||||
import orjson
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from langflow.api.v1.schemas import FlowListCreate
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
|
||||
from uuid import UUID, uuid4
|
||||
from sqlmodel import Session
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from langflow.api.v1.schemas import FlowListCreate
|
||||
from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def json_style():
|
||||
|
|
@ -25,17 +27,21 @@ def json_style():
|
|||
)
|
||||
|
||||
|
||||
def test_create_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
def test_create_flow(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
# flow is optional so we can create a flow without a flow
|
||||
flow = FlowCreate(name="Test Flow", description="description")
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(exclude_unset=True), headers=logged_in_headers)
|
||||
flow = FlowCreate(name="Test Flow")
|
||||
response = client.post(
|
||||
"api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
|
@ -45,13 +51,13 @@ def test_read_flows(client: TestClient, json_flow: str, active_user, logged_in_h
|
|||
flow_data = orjson.loads(json_flow)
|
||||
data = flow_data["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
|
@ -65,7 +71,7 @@ def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_he
|
|||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
flow_id = response.json()["id"] # flow_id should be a UUID but is a string
|
||||
# turn it into a UUID
|
||||
flow_id = UUID(flow_id)
|
||||
|
|
@ -76,12 +82,14 @@ def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_he
|
|||
assert response.json()["data"] == flow.data
|
||||
|
||||
|
||||
def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
def test_update_flow(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
|
||||
flow_id = response.json()["id"]
|
||||
updated_flow = FlowUpdate(
|
||||
|
|
@ -89,7 +97,9 @@ def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_
|
|||
description="updated description",
|
||||
data=data,
|
||||
)
|
||||
response = client.patch(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
|
||||
response = client.patch(
|
||||
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == updated_flow.name
|
||||
|
|
@ -97,18 +107,22 @@ def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_
|
|||
# assert response.json()["data"] == updated_flow.data
|
||||
|
||||
|
||||
def test_delete_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
def test_delete_flow(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
flow_id = response.json()["id"]
|
||||
response = client.delete(f"api/v1/flows/{flow_id}", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "Flow deleted successfully"
|
||||
|
||||
|
||||
def test_create_flows(client: TestClient, session: Session, json_flow: str, logged_in_headers):
|
||||
def test_create_flows(
|
||||
client: TestClient, session: Session, json_flow: str, logged_in_headers
|
||||
):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
# Create test data
|
||||
|
|
@ -119,7 +133,9 @@ def test_create_flows(client: TestClient, session: Session, json_flow: str, logg
|
|||
]
|
||||
)
|
||||
# Make request to endpoint
|
||||
response = client.post("api/v1/flows/batch/", json=flow_list.model_dump(), headers=logged_in_headers)
|
||||
response = client.post(
|
||||
"api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers
|
||||
)
|
||||
# Check response status code
|
||||
assert response.status_code == 201
|
||||
# Check response data
|
||||
|
|
@ -133,7 +149,9 @@ def test_create_flows(client: TestClient, session: Session, json_flow: str, logg
|
|||
assert response_data[1]["data"] == data
|
||||
|
||||
|
||||
def test_upload_file(client: TestClient, session: Session, json_flow: str, logged_in_headers):
|
||||
def test_upload_file(
|
||||
client: TestClient, session: Session, json_flow: str, logged_in_headers
|
||||
):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
# Create test data
|
||||
|
|
@ -143,7 +161,7 @@ def test_upload_file(client: TestClient, session: Session, json_flow: str, logge
|
|||
FlowCreate(name="Flow 2", description="description", data=data),
|
||||
]
|
||||
)
|
||||
file_contents = orjson_dumps(flow_list.model_dump())
|
||||
file_contents = orjson_dumps(flow_list.dict())
|
||||
response = client.post(
|
||||
"api/v1/flows/upload/",
|
||||
files={"file": ("examples.json", file_contents, "application/json")},
|
||||
|
|
@ -182,7 +200,7 @@ def test_download_file(
|
|||
with session_getter(db_manager) as session:
|
||||
for flow in flow_list.flows:
|
||||
flow.user_id = active_user.id
|
||||
db_flow = Flow.model_validate(flow, from_attributes=True)
|
||||
db_flow = Flow.from_orm(flow)
|
||||
session.add(db_flow)
|
||||
session.commit()
|
||||
# Make request to endpoint
|
||||
|
|
@ -200,7 +218,9 @@ def test_download_file(
|
|||
assert response_data[1]["data"] == data
|
||||
|
||||
|
||||
def test_create_flow_with_invalid_data(client: TestClient, active_user, logged_in_headers):
|
||||
def test_create_flow_with_invalid_data(
|
||||
client: TestClient, active_user, logged_in_headers
|
||||
):
|
||||
flow = {"name": "a" * 256, "data": "Invalid flow data"}
|
||||
response = client.post("api/v1/flows/", json=flow, headers=logged_in_headers)
|
||||
assert response.status_code == 422
|
||||
|
|
@ -212,19 +232,29 @@ def test_get_nonexistent_flow(client: TestClient, active_user, logged_in_headers
|
|||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_update_flow_idempotency(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
def test_update_flow_idempotency(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
flow_data = orjson.loads(json_flow)
|
||||
data = flow_data["data"]
|
||||
flow_data = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow_data.model_dump(), headers=logged_in_headers)
|
||||
response = client.post(
|
||||
"api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
flow_id = response.json()["id"]
|
||||
updated_flow = FlowCreate(name="Updated Flow", description="description", data=data)
|
||||
response1 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
|
||||
response2 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
|
||||
response1 = client.put(
|
||||
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
|
||||
)
|
||||
response2 = client.put(
|
||||
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response1.json() == response2.json()
|
||||
|
||||
|
||||
def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
def test_update_nonexistent_flow(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
flow_data = orjson.loads(json_flow)
|
||||
data = flow_data["data"]
|
||||
uuid = uuid4()
|
||||
|
|
@ -233,7 +263,9 @@ def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user
|
|||
description="description",
|
||||
data=data,
|
||||
)
|
||||
response = client.patch(f"api/v1/flows/{uuid}", json=updated_flow.model_dump(), headers=logged_in_headers)
|
||||
response = client.patch(
|
||||
f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
import time
|
||||
import uuid
|
||||
from collections import namedtuple
|
||||
|
||||
import uuid
|
||||
from langflow.processing.process import Result
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.api_key.api_key import ApiKey
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from langflow.interface.tools.constants import CUSTOM_TOOLS
|
||||
from langflow.processing.process import Result
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.api_key.model import ApiKey
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from langflow.template.frontend_node.chains import TimeTravelGuideChainNode
|
||||
|
||||
import time
|
||||
|
||||
|
||||
def run_post(client, flow_id, headers, post_data):
|
||||
response = client.post(
|
||||
|
|
@ -24,13 +25,16 @@ def run_post(client, flow_id, headers, post_data):
|
|||
|
||||
|
||||
# Helper function to poll task status
|
||||
def poll_task_status(client, headers, href, max_attempts=20, sleep_time=2):
|
||||
def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1):
|
||||
for _ in range(max_attempts):
|
||||
task_status_response = client.get(
|
||||
href,
|
||||
headers=headers,
|
||||
)
|
||||
if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS":
|
||||
if (
|
||||
task_status_response.status_code == 200
|
||||
and task_status_response.json()["status"] == "SUCCESS"
|
||||
):
|
||||
return task_status_response.json()
|
||||
time.sleep(sleep_time)
|
||||
return None # Return None if task did not complete in time
|
||||
|
|
@ -72,14 +76,11 @@ PROMPT_REQUEST = {
|
|||
"text-ada-001",
|
||||
],
|
||||
"ChatOpenAI": [
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
],
|
||||
"Anthropic": [
|
||||
"claude-v1",
|
||||
|
|
@ -126,7 +127,11 @@ def created_api_key(active_user):
|
|||
)
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first():
|
||||
if (
|
||||
existing_api_key := session.query(ApiKey)
|
||||
.filter(ApiKey.api_key == api_key.api_key)
|
||||
.first()
|
||||
):
|
||||
return existing_api_key
|
||||
session.add(api_key)
|
||||
session.commit()
|
||||
|
|
@ -185,7 +190,9 @@ def test_process_flow_invalid_id(client, monkeypatch, created_api_key):
|
|||
}
|
||||
|
||||
invalid_id = uuid.uuid4()
|
||||
response = client.post(f"api/v1/process/{invalid_id}", headers=headers, json=post_data)
|
||||
response = client.post(
|
||||
f"api/v1/process/{invalid_id}", headers=headers, json=post_data
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert f"Flow {invalid_id} not found" in response.json()["detail"]
|
||||
|
|
@ -226,7 +233,9 @@ def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_k
|
|||
|
||||
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
|
||||
monkeypatch.setattr(endpoints, "process_graph_cached_task", mock_process_graph_cached_task)
|
||||
monkeypatch.setattr(
|
||||
endpoints, "process_graph_cached_task", mock_process_graph_cached_task
|
||||
)
|
||||
|
||||
api_key = created_api_key.api_key
|
||||
headers = {"x-api-key": api_key}
|
||||
|
|
@ -410,6 +419,105 @@ def test_various_prompts(client, prompt, expected_input_variables):
|
|||
assert response.json()["input_variables"] == expected_input_variables
|
||||
|
||||
|
||||
def test_get_vertices_flow_not_found(client, logged_in_headers):
|
||||
response = client.get(
|
||||
"/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers
|
||||
)
|
||||
assert (
|
||||
response.status_code == 500
|
||||
) # Or whatever status code you've set for invalid ID
|
||||
|
||||
|
||||
def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_headers):
|
||||
flow_id = added_flow_with_prompt_and_history["id"]
|
||||
response = client.get(
|
||||
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "ids" in response.json()
|
||||
# The response should contain the list in this order
|
||||
# ['ConversationBufferMemory-Lu2Nb', 'PromptTemplate-5Q0W8', 'ChatOpenAI-vy7fV', 'LLMChain-UjBh1']
|
||||
# The important part is before the - (ConversationBufferMemory, PromptTemplate, ChatOpenAI, LLMChain)
|
||||
ids = [id.split("-")[0] for id in response.json()["ids"]]
|
||||
assert ids == [
|
||||
"ConversationBufferMemory",
|
||||
"PromptTemplate",
|
||||
"ChatOpenAI",
|
||||
"LLMChain",
|
||||
]
|
||||
|
||||
|
||||
def test_build_vertex_invalid_flow_id(client, logged_in_headers):
|
||||
response = client.post(
|
||||
"/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_build_vertex_invalid_vertex_id(
|
||||
client, added_flow_with_prompt_and_history, logged_in_headers
|
||||
):
|
||||
flow_id = added_flow_with_prompt_and_history["id"]
|
||||
response = client.post(
|
||||
f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_build_all_vertices_in_sequence_with_chat_input(
|
||||
client, added_flow_chat_input, logged_in_headers
|
||||
):
|
||||
flow_id = added_flow_chat_input["id"]
|
||||
|
||||
# First, get all the vertices in the correct sequence
|
||||
response = client.get(
|
||||
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "ids" in response.json()
|
||||
vertex_ids = response.json()["ids"]
|
||||
|
||||
# Now, iterate through each vertex and build it
|
||||
for vertex_id in vertex_ids:
|
||||
response = client.post(
|
||||
f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers
|
||||
)
|
||||
json_response = response.json()
|
||||
assert (
|
||||
response.status_code == 200
|
||||
), f"Failed at vertex {vertex_id}: {json_response}"
|
||||
assert "valid" in json_response
|
||||
assert json_response["valid"], json_response["params"]
|
||||
|
||||
|
||||
def test_build_all_vertices_in_sequence_with_two_outputs(
|
||||
client, added_flow_two_outputs, logged_in_headers
|
||||
):
|
||||
"""This tests the case where a node has two outputs, one of which is Text and the other (in this case) is
|
||||
a LLMChain. We need to make sure the correct output is passed in both cases."""
|
||||
flow_id = added_flow_two_outputs["id"]
|
||||
|
||||
# First, get all the vertices in the correct sequence
|
||||
response = client.get(
|
||||
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "ids" in response.json()
|
||||
vertex_ids = response.json()["ids"]
|
||||
|
||||
# Now, iterate through each vertex and build it
|
||||
for vertex_id in vertex_ids:
|
||||
response = client.post(
|
||||
f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers
|
||||
)
|
||||
json_response = response.json()
|
||||
assert (
|
||||
response.status_code == 200
|
||||
), f"Failed at vertex {vertex_id}: {json_response}"
|
||||
assert "valid" in json_response
|
||||
assert json_response["valid"], json_response["params"]
|
||||
|
||||
|
||||
def test_basic_chat_in_process(client, added_flow, created_api_key):
|
||||
# Run the /api/v1/process/{flow_id} endpoint
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
|
|
@ -498,7 +606,9 @@ def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_a
|
|||
|
||||
|
||||
@pytest.mark.async_test
|
||||
def test_vector_store_in_process(distributed_client, added_vector_store, created_api_key):
|
||||
def test_vector_store_in_process(
|
||||
distributed_client, added_vector_store, created_api_key
|
||||
):
|
||||
# Run the /api/v1/process/{flow_id} endpoint
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
post_data = {"inputs": {"input": "What is Langflow?"}}
|
||||
|
|
@ -549,7 +659,9 @@ def test_async_task_processing(distributed_client, added_flow, created_api_key):
|
|||
|
||||
# Test function without loop
|
||||
@pytest.mark.async_test
|
||||
def test_async_task_processing_vector_store(client, added_vector_store, created_api_key):
|
||||
def test_async_task_processing_vector_store(
|
||||
client, added_vector_store, created_api_key
|
||||
):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
post_data = {"inputs": {"input": "How do I upload examples?"}}
|
||||
|
||||
|
|
@ -578,4 +690,6 @@ def test_async_task_processing_vector_store(client, added_vector_store, created_
|
|||
# Validate that the task completed successfully and the result is as expected
|
||||
assert "result" in task_status_json, task_status_json
|
||||
assert "output" in task_status_json["result"], task_status_json["result"]
|
||||
assert "Langflow" in task_status_json["result"]["output"], task_status_json["result"]
|
||||
assert "Langflow" in task_status_json["result"]["output"], task_status_json[
|
||||
"result"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -31,14 +31,17 @@ def test_template_field_defaults(sample_template_field: TemplateField):
|
|||
assert sample_template_field.is_list is False
|
||||
assert sample_template_field.show is True
|
||||
assert sample_template_field.multiline is False
|
||||
assert sample_template_field.value == ""
|
||||
assert sample_template_field.value is None
|
||||
assert sample_template_field.suffixes == []
|
||||
assert sample_template_field.file_types == []
|
||||
assert sample_template_field.file_path == ""
|
||||
assert sample_template_field.file_path is None
|
||||
assert sample_template_field.password is False
|
||||
assert sample_template_field.name == "test_field"
|
||||
|
||||
|
||||
def test_template_to_dict(sample_template: Template, sample_template_field: TemplateField):
|
||||
def test_template_to_dict(
|
||||
sample_template: Template, sample_template_field: TemplateField
|
||||
):
|
||||
template_dict = sample_template.to_dict()
|
||||
assert template_dict["_type"] == "test_template"
|
||||
assert len(template_dict) == 2 # _type and test_field
|
||||
|
|
|
|||
|
|
@ -1,31 +1,32 @@
|
|||
import copy
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from typing import Type, Union
|
||||
|
||||
import pytest
|
||||
from langflow.graph.edge.base import Edge
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langchain.agents import AgentExecutor
|
||||
import pytest
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
|
||||
from langflow.graph import Graph
|
||||
from langflow.graph.edge.base import Edge
|
||||
from langflow.graph.vertex.types import (
|
||||
FileToolVertex,
|
||||
LLMVertex,
|
||||
ToolkitVertex,
|
||||
)
|
||||
from langflow.processing.process import get_result_and_thought
|
||||
from langflow.utils.payload import get_root_node
|
||||
from langflow.graph.graph.utils import (
|
||||
find_last_node,
|
||||
process_flow,
|
||||
set_new_target_handle,
|
||||
ungroup_node,
|
||||
process_flow,
|
||||
update_source_handle,
|
||||
update_target_handle,
|
||||
update_template,
|
||||
)
|
||||
from langflow.graph.utils import UnbuiltObject
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.vertex.types import FileToolVertex, LLMVertex, ToolkitVertex
|
||||
from langflow.processing.process import get_result_and_thought
|
||||
from langflow.utils.payload import get_root_vertex
|
||||
|
||||
# Test cases for the graph module
|
||||
|
||||
|
|
@ -46,7 +47,13 @@ def sample_nodes():
|
|||
return [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {"node": {"template": {"some_field": {"show": True, "advanced": False, "name": "Name1"}}}},
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"some_field": {"show": True, "advanced": False, "name": "Name1"}
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
|
|
@ -64,7 +71,11 @@ def sample_nodes():
|
|||
},
|
||||
{
|
||||
"id": "node3",
|
||||
"data": {"node": {"template": {"unrelated_field": {"show": True, "advanced": True}}}},
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {"unrelated_field": {"show": True, "advanced": True}}
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
|
@ -82,8 +93,8 @@ def test_graph_structure(basic_graph):
|
|||
assert isinstance(node, Vertex)
|
||||
for edge in basic_graph.edges:
|
||||
assert isinstance(edge, Edge)
|
||||
assert edge.source_id in basic_graph.vertex_map.keys()
|
||||
assert edge.target_id in basic_graph.vertex_map.keys()
|
||||
assert edge.source in basic_graph.vertices
|
||||
assert edge.target in basic_graph.vertices
|
||||
|
||||
|
||||
def test_circular_dependencies(basic_graph):
|
||||
|
|
@ -91,7 +102,7 @@ def test_circular_dependencies(basic_graph):
|
|||
|
||||
def check_circular(node, visited):
|
||||
visited.add(node)
|
||||
neighbors = basic_graph.get_vertices_with_target(node)
|
||||
neighbors = basic_graph.get_nodes_with_target(node)
|
||||
for neighbor in neighbors:
|
||||
if neighbor in visited:
|
||||
return True
|
||||
|
|
@ -124,13 +135,13 @@ def test_invalid_node_types():
|
|||
Graph(graph_data["nodes"], graph_data["edges"])
|
||||
|
||||
|
||||
def test_get_vertices_with_target(basic_graph):
|
||||
def test_get_nodes_with_target(basic_graph):
|
||||
"""Test getting connected nodes"""
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_vertex(basic_graph)
|
||||
root = get_root_node(basic_graph)
|
||||
assert root is not None
|
||||
connected_nodes = basic_graph.get_vertices_with_target(root.id)
|
||||
connected_nodes = basic_graph.get_nodes_with_target(root)
|
||||
assert connected_nodes is not None
|
||||
|
||||
|
||||
|
|
@ -139,66 +150,23 @@ def test_get_node_neighbors_basic(basic_graph):
|
|||
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_vertex(basic_graph)
|
||||
root = get_root_node(basic_graph)
|
||||
assert root is not None
|
||||
neighbors = basic_graph.get_vertex_neighbors(root)
|
||||
neighbors = basic_graph.get_node_neighbors(root)
|
||||
assert neighbors is not None
|
||||
assert isinstance(neighbors, dict)
|
||||
# Root Node is an Agent, it requires an LLMChain and tools
|
||||
# We need to check if there is a Chain in the one of the neighbors'
|
||||
# data attribute in the type key
|
||||
assert any("ConversationBufferMemory" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
|
||||
assert any(
|
||||
"ConversationBufferMemory" in neighbor.data["type"]
|
||||
for neighbor, val in neighbors.items()
|
||||
if val
|
||||
)
|
||||
|
||||
assert any("OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
|
||||
|
||||
|
||||
# def test_get_node_neighbors_complex(complex_graph):
|
||||
# """Test getting node neighbors"""
|
||||
# assert isinstance(complex_graph, Graph)
|
||||
# # Get root node
|
||||
# root = get_root_node(complex_graph)
|
||||
# assert root is not None
|
||||
# neighbors = complex_graph.get_nodes_with_target(root)
|
||||
# assert neighbors is not None
|
||||
# # Neighbors should be a list of nodes
|
||||
# assert isinstance(neighbors, list)
|
||||
# # Root Node is an Agent, it requires an LLMChain and tools
|
||||
# # We need to check if there is a Chain in the one of the neighbors'
|
||||
# assert any("Chain" in neighbor.data["type"] for neighbor in neighbors)
|
||||
# # assert Tool is in the neighbors
|
||||
# assert any("Tool" in neighbor.data["type"] for neighbor in neighbors)
|
||||
# # Now on to the Chain's neighbors
|
||||
# chain = next(neighbor for neighbor in neighbors if "Chain" in neighbor.data["type"])
|
||||
# chain_neighbors = complex_graph.get_nodes_with_target(chain)
|
||||
# assert chain_neighbors is not None
|
||||
# # Check if there is a LLM in the chain's neighbors
|
||||
# assert any("OpenAI" in neighbor.data["type"] for neighbor in chain_neighbors)
|
||||
# # Chain should have a Prompt as a neighbor
|
||||
# assert any("Prompt" in neighbor.data["type"] for neighbor in chain_neighbors)
|
||||
# # Now on to the Tool's neighbors
|
||||
# tool = next(neighbor for neighbor in neighbors if "Tool" in neighbor.data["type"])
|
||||
# tool_neighbors = complex_graph.get_nodes_with_target(tool)
|
||||
# assert tool_neighbors is not None
|
||||
# # Check if there is an Agent in the tool's neighbors
|
||||
# assert any("Agent" in neighbor.data["type"] for neighbor in tool_neighbors)
|
||||
# # This Agent has a Tool that has a PythonFunction as func
|
||||
# agent = next(
|
||||
# neighbor for neighbor in tool_neighbors if "Agent" in neighbor.data["type"]
|
||||
# )
|
||||
# agent_neighbors = complex_graph.get_nodes_with_target(agent)
|
||||
# assert agent_neighbors is not None
|
||||
# # Check if there is a Tool in the agent's neighbors
|
||||
# assert any("Tool" in neighbor.data["type"] for neighbor in agent_neighbors)
|
||||
# # This Tool has a PythonFunction as func
|
||||
# tool = next(
|
||||
# neighbor for neighbor in agent_neighbors if "Tool" in neighbor.data["type"]
|
||||
# )
|
||||
# tool_neighbors = complex_graph.get_nodes_with_target(tool)
|
||||
# assert tool_neighbors is not None
|
||||
# # Check if there is a PythonFunction in the tool's neighbors
|
||||
# assert any(
|
||||
# "PythonFunctionTool" in neighbor.data["type"] for neighbor in tool_neighbors
|
||||
# )
|
||||
assert any(
|
||||
"OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val
|
||||
)
|
||||
|
||||
|
||||
def test_get_node(basic_graph):
|
||||
|
|
@ -212,7 +180,7 @@ def test_get_node(basic_graph):
|
|||
def test_build_nodes(basic_graph):
|
||||
"""Test building nodes"""
|
||||
|
||||
assert len(basic_graph.vertices) == len(basic_graph._vertices)
|
||||
assert len(basic_graph.vertices) == len(basic_graph._nodes)
|
||||
for node in basic_graph.vertices:
|
||||
assert isinstance(node, Vertex)
|
||||
|
||||
|
|
@ -222,21 +190,20 @@ def test_build_edges(basic_graph):
|
|||
assert len(basic_graph.edges) == len(basic_graph._edges)
|
||||
for edge in basic_graph.edges:
|
||||
assert isinstance(edge, Edge)
|
||||
|
||||
assert isinstance(edge.source_id, str)
|
||||
assert isinstance(edge.target_id, str)
|
||||
assert isinstance(edge.source, Vertex)
|
||||
assert isinstance(edge.target, Vertex)
|
||||
|
||||
|
||||
def test_get_root_vertex(client, basic_graph, complex_graph):
|
||||
def test_get_root_node(client, basic_graph, complex_graph):
|
||||
"""Test getting root node"""
|
||||
assert isinstance(basic_graph, Graph)
|
||||
root = get_root_vertex(basic_graph)
|
||||
root = get_root_node(basic_graph)
|
||||
assert root is not None
|
||||
assert isinstance(root, Vertex)
|
||||
assert root.data["type"] == "TimeTravelGuideChain"
|
||||
# For complex example, the root node is a ZeroShotAgent too
|
||||
assert isinstance(complex_graph, Graph)
|
||||
root = get_root_vertex(complex_graph)
|
||||
root = get_root_node(complex_graph)
|
||||
assert root is not None
|
||||
assert isinstance(root, Vertex)
|
||||
assert root.data["type"] == "ZeroShotAgent"
|
||||
|
|
@ -272,7 +239,7 @@ def test_build_params(basic_graph):
|
|||
# The matched_type attribute should be in the source_types attr
|
||||
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
|
||||
# Get the root node
|
||||
root = get_root_vertex(basic_graph)
|
||||
root = get_root_node(basic_graph)
|
||||
# Root node is a TimeTravelGuideChain
|
||||
# which requires an llm and memory
|
||||
assert root is not None
|
||||
|
|
@ -281,32 +248,29 @@ def test_build_params(basic_graph):
|
|||
assert "memory" in root.params
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build(basic_graph):
|
||||
def test_build(basic_graph):
|
||||
"""Test Node's build method"""
|
||||
await assert_agent_was_built(basic_graph)
|
||||
assert_agent_was_built(basic_graph)
|
||||
|
||||
|
||||
async def assert_agent_was_built(graph):
|
||||
def assert_agent_was_built(graph):
|
||||
"""Assert that the agent was built"""
|
||||
assert isinstance(graph, Graph)
|
||||
# Now we test the build method
|
||||
# Build the Agent
|
||||
result = await graph.build()
|
||||
result = graph.build()
|
||||
# The agent should be a AgentExecutor
|
||||
assert isinstance(result, Chain)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_node_build(basic_graph):
|
||||
def test_llm_node_build(basic_graph):
|
||||
llm_node = get_node_by_type(basic_graph, LLMVertex)
|
||||
assert llm_node is not None
|
||||
built_object = await llm_node.build()
|
||||
assert built_object is not UnbuiltObject()
|
||||
built_object = llm_node.build()
|
||||
assert built_object is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toolkit_node_build(client, openapi_graph):
|
||||
def test_toolkit_node_build(client, openapi_graph):
|
||||
# Write a file to the disk
|
||||
file_path = "api-with-examples.yaml"
|
||||
with open(file_path, "w") as f:
|
||||
|
|
@ -314,31 +278,36 @@ async def test_toolkit_node_build(client, openapi_graph):
|
|||
|
||||
toolkit_node = get_node_by_type(openapi_graph, ToolkitVertex)
|
||||
assert toolkit_node is not None
|
||||
built_object = await toolkit_node.build()
|
||||
assert built_object is not UnbuiltObject
|
||||
built_object = toolkit_node.build()
|
||||
assert built_object is not None
|
||||
# Remove the file
|
||||
os.remove(file_path)
|
||||
assert not Path(file_path).exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_tool_node_build(client, openapi_graph):
|
||||
def test_file_tool_node_build(client, openapi_graph):
|
||||
file_path = "api-with-examples.yaml"
|
||||
with open(file_path, "w") as f:
|
||||
f.write("openapi: 3.0.0")
|
||||
|
||||
assert Path(file_path).exists()
|
||||
file_tool_node = get_node_by_type(openapi_graph, FileToolVertex)
|
||||
assert file_tool_node is not UnbuiltObject and file_tool_node is not None
|
||||
built_object = await file_tool_node.build()
|
||||
assert built_object is not UnbuiltObject
|
||||
assert file_tool_node is not None
|
||||
built_object = file_tool_node.build()
|
||||
assert built_object is not None
|
||||
# Remove the file
|
||||
os.remove(file_path)
|
||||
assert not Path(file_path).exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_result_and_thought(basic_graph):
|
||||
# def test_wrapper_node_build(openapi_graph):
|
||||
# wrapper_node = get_node_by_type(openapi_graph, WrapperVertex)
|
||||
# assert wrapper_node is not None
|
||||
# built_object = wrapper_node.build()
|
||||
# assert built_object is not None
|
||||
|
||||
|
||||
def test_get_result_and_thought(basic_graph):
|
||||
"""Test the get_result_and_thought method"""
|
||||
responses = [
|
||||
"Final Answer: I am a response",
|
||||
|
|
@ -350,7 +319,7 @@ async def test_get_result_and_thought(basic_graph):
|
|||
assert llm_node is not None
|
||||
llm_node._built_object = FakeListLLM(responses=responses)
|
||||
llm_node._built = True
|
||||
langchain_object = await basic_graph.build()
|
||||
langchain_object = basic_graph.build()
|
||||
# assert all nodes are built
|
||||
assert all(node._built for node in basic_graph.vertices)
|
||||
# now build again and check if FakeListLLM was used
|
||||
|
|
@ -370,7 +339,9 @@ def test_find_last_node(grouped_chat_json_flow):
|
|||
|
||||
def test_ungroup_node(grouped_chat_json_flow):
|
||||
grouped_chat_data = json.loads(grouped_chat_json_flow).get("data")
|
||||
group_node = grouped_chat_data["nodes"][2] # Assuming the first node is a group node
|
||||
group_node = grouped_chat_data["nodes"][
|
||||
2
|
||||
] # Assuming the first node is a group node
|
||||
base_flow = copy.deepcopy(grouped_chat_data)
|
||||
ungroup_node(group_node["data"], base_flow)
|
||||
# after ungroup_node is called, the base_flow and grouped_chat_data should be different
|
||||
|
|
@ -422,9 +393,14 @@ def test_process_flow_one_group(one_grouped_chat_json_flow):
|
|||
assert "edges" in processed_flow
|
||||
|
||||
# Now get the node that has ChatOpenAI in its id
|
||||
chat_openai_node = next((node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None)
|
||||
chat_openai_node = next(
|
||||
(node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None
|
||||
)
|
||||
assert chat_openai_node is not None
|
||||
assert chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] == "test"
|
||||
assert (
|
||||
chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"]
|
||||
== "test"
|
||||
)
|
||||
|
||||
|
||||
def test_process_flow_vector_store_grouped(vector_store_grouped_json_flow):
|
||||
|
|
@ -471,15 +447,19 @@ def test_update_template(sample_template, sample_nodes):
|
|||
node2_updated = next((n for n in nodes_copy if n["id"] == "node2"), None)
|
||||
node3_updated = next((n for n in nodes_copy if n["id"] == "node3"), None)
|
||||
|
||||
assert node1_updated is not None
|
||||
assert node1_updated["data"]["node"]["template"]["some_field"]["show"] is True
|
||||
assert node1_updated["data"]["node"]["template"]["some_field"]["advanced"] is False
|
||||
assert node1_updated["data"]["node"]["template"]["some_field"]["display_name"] == "Name1"
|
||||
assert (
|
||||
node1_updated["data"]["node"]["template"]["some_field"]["display_name"]
|
||||
== "Name1"
|
||||
)
|
||||
|
||||
assert node2_updated is not None
|
||||
assert node2_updated["data"]["node"]["template"]["other_field"]["show"] is False
|
||||
assert node2_updated["data"]["node"]["template"]["other_field"]["advanced"] is True
|
||||
assert node2_updated["data"]["node"]["template"]["other_field"]["display_name"] == "DisplayName2"
|
||||
assert (
|
||||
node2_updated["data"]["node"]["template"]["other_field"]["display_name"]
|
||||
== "DisplayName2"
|
||||
)
|
||||
|
||||
# Ensure node3 remains unchanged
|
||||
assert node3_updated == sample_nodes[2]
|
||||
|
|
@ -510,7 +490,9 @@ def test_set_new_target_handle():
|
|||
"data": {
|
||||
"node": {
|
||||
"flow": True,
|
||||
"template": {"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}},
|
||||
"template": {
|
||||
"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -530,34 +512,34 @@ def test_update_source_handle():
|
|||
"nodes": [{"id": "some_node"}, {"id": "last_node"}],
|
||||
"edges": [{"source": "some_node"}],
|
||||
}
|
||||
updated_edge = update_source_handle(new_edge, flow_data["nodes"], flow_data["edges"])
|
||||
updated_edge = update_source_handle(
|
||||
new_edge, flow_data["nodes"], flow_data["edges"]
|
||||
)
|
||||
assert updated_edge["source"] == "last_node"
|
||||
assert updated_edge["data"]["sourceHandle"]["id"] == "last_node"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pickle_graph(json_vector_store):
|
||||
def test_pickle_graph(json_vector_store):
|
||||
loaded_json = json.loads(json_vector_store)
|
||||
graph = Graph.from_payload(loaded_json)
|
||||
assert isinstance(graph, Graph)
|
||||
first_result = await graph.build()
|
||||
first_result = graph.build()
|
||||
assert isinstance(first_result, AgentExecutor)
|
||||
pickled = pickle.dumps(graph)
|
||||
assert pickled is not UnbuiltObject
|
||||
assert pickled is not None
|
||||
unpickled = pickle.loads(pickled)
|
||||
assert unpickled is not UnbuiltObject
|
||||
result = await unpickled.build()
|
||||
assert unpickled is not None
|
||||
result = unpickled.build()
|
||||
assert isinstance(result, AgentExecutor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pickle_each_vertex(json_vector_store):
|
||||
def test_pickle_each_vertex(json_vector_store):
|
||||
loaded_json = json.loads(json_vector_store)
|
||||
graph = Graph.from_payload(loaded_json)
|
||||
assert isinstance(graph, Graph)
|
||||
for vertex in graph.vertices:
|
||||
await vertex.build()
|
||||
for vertex in graph.nodes:
|
||||
vertex.build()
|
||||
pickled = pickle.dumps(vertex)
|
||||
assert pickled is not UnbuiltObject
|
||||
assert pickled is not None
|
||||
unpickled = pickle.loads(pickled)
|
||||
assert unpickled is not UnbuiltObject
|
||||
assert unpickled is not None
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["verbose"] == {
|
||||
"required": False,
|
||||
|
|
@ -36,7 +35,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["client"] == {
|
||||
"required": False,
|
||||
|
|
@ -50,7 +48,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["model_name"] == {
|
||||
"required": False,
|
||||
|
|
@ -72,7 +69,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": True,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
# Add more assertions for other properties here
|
||||
assert template["temperature"] == {
|
||||
|
|
@ -88,8 +84,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["max_tokens"] == {
|
||||
"required": False,
|
||||
|
|
@ -104,7 +98,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["top_p"] == {
|
||||
"required": False,
|
||||
|
|
@ -119,8 +112,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["frequency_penalty"] == {
|
||||
"required": False,
|
||||
|
|
@ -135,8 +126,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["presence_penalty"] == {
|
||||
"required": False,
|
||||
|
|
@ -151,8 +140,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["n"] == {
|
||||
"required": False,
|
||||
|
|
@ -167,7 +154,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["best_of"] == {
|
||||
"required": False,
|
||||
|
|
@ -182,7 +168,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["model_kwargs"] == {
|
||||
"required": False,
|
||||
|
|
@ -196,7 +181,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["openai_api_key"] == {
|
||||
"required": False,
|
||||
|
|
@ -212,7 +196,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["batch_size"] == {
|
||||
"required": False,
|
||||
|
|
@ -227,7 +210,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["request_timeout"] == {
|
||||
"required": False,
|
||||
|
|
@ -241,8 +223,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["logit_bias"] == {
|
||||
"required": False,
|
||||
|
|
@ -256,7 +236,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["max_retries"] == {
|
||||
"required": False,
|
||||
|
|
@ -264,14 +243,13 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"placeholder": "",
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
"value": 2,
|
||||
"value": 6,
|
||||
"password": False,
|
||||
"name": "max_retries",
|
||||
"type": "int",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["streaming"] == {
|
||||
"required": False,
|
||||
|
|
@ -286,7 +264,6 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -312,7 +289,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["client"] == {
|
||||
"required": False,
|
||||
|
|
@ -326,7 +302,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["model_name"] == {
|
||||
"required": False,
|
||||
|
|
@ -334,11 +309,10 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "gpt-4-1106-preview",
|
||||
"value": "gpt-3.5-turbo-0613",
|
||||
"password": False,
|
||||
"options": [
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-3.5-turbo",
|
||||
|
|
@ -349,7 +323,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": True,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["temperature"] == {
|
||||
"required": False,
|
||||
|
|
@ -364,8 +337,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["model_kwargs"] == {
|
||||
"required": False,
|
||||
|
|
@ -379,7 +350,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["openai_api_key"] == {
|
||||
"required": False,
|
||||
|
|
@ -395,7 +365,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["request_timeout"] == {
|
||||
"required": False,
|
||||
|
|
@ -409,8 +378,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["max_retries"] == {
|
||||
"required": False,
|
||||
|
|
@ -418,14 +385,13 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"placeholder": "",
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
"value": 2,
|
||||
"value": 6,
|
||||
"password": False,
|
||||
"name": "max_retries",
|
||||
"type": "int",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["streaming"] == {
|
||||
"required": False,
|
||||
|
|
@ -440,7 +406,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["n"] == {
|
||||
"required": False,
|
||||
|
|
@ -455,7 +420,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
assert template["max_tokens"] == {
|
||||
|
|
@ -470,7 +434,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["_type"] == "ChatOpenAI"
|
||||
assert (
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from langchain.chains.base import Chain
|
||||
from langflow.processing.process import load_flow_from_json
|
||||
from langflow.graph import Graph
|
||||
from langflow.processing.load import load_flow_from_json
|
||||
from langflow.utils.payload import get_root_vertex
|
||||
from langflow.utils.payload import get_root_node
|
||||
|
||||
|
||||
def test_load_flow_from_json():
|
||||
|
|
@ -16,22 +15,21 @@ def test_load_flow_from_json():
|
|||
|
||||
def test_load_flow_from_json_with_tweaks():
|
||||
"""Test loading a flow from a json file and applying tweaks"""
|
||||
tweaks = {"dndnode_82": {"model_name": "test model"}}
|
||||
tweaks = {"dndnode_82": {"model_name": "gpt-3.5-turbo-16k-0613"}}
|
||||
loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH, tweaks=tweaks)
|
||||
assert loaded is not None
|
||||
assert isinstance(loaded, Chain)
|
||||
assert loaded.llm.model_name == "test model"
|
||||
assert loaded.llm.model_name == "gpt-3.5-turbo-16k-0613"
|
||||
|
||||
|
||||
def test_get_root_vertex():
|
||||
def test_get_root_node():
|
||||
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
root = get_root_vertex(graph)
|
||||
root = get_root_node(graph)
|
||||
assert root is not None
|
||||
assert hasattr(root, "id")
|
||||
assert hasattr(root, "data")
|
||||
assert hasattr(root, "data")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from langflow.services.getters import get_db_service
|
||||
import pytest
|
||||
from langflow.services.database.models.user import User
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
|
|
@ -9,7 +9,9 @@ from langflow.services.auth.utils import get_password_hash
|
|||
def test_user():
|
||||
return User(
|
||||
username="testuser",
|
||||
password=get_password_hash("testpassword"), # Assuming password needs to be hashed
|
||||
password=get_password_hash(
|
||||
"testpassword"
|
||||
), # Assuming password needs to be hashed
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
|
@ -21,13 +23,17 @@ def test_login_successful(client, test_user):
|
|||
session.add(test_user)
|
||||
session.commit()
|
||||
|
||||
response = client.post("api/v1/login", data={"username": "testuser", "password": "testpassword"})
|
||||
response = client.post(
|
||||
"api/v1/login", data={"username": "testuser", "password": "testpassword"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "access_token" in response.json()
|
||||
|
||||
|
||||
def test_login_unsuccessful_wrong_username(client):
|
||||
response = client.post("api/v1/login", data={"username": "wrongusername", "password": "testpassword"})
|
||||
response = client.post(
|
||||
"api/v1/login", data={"username": "wrongusername", "password": "testpassword"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Incorrect username or password"
|
||||
|
||||
|
|
@ -37,6 +43,8 @@ def test_login_unsuccessful_wrong_password(client, test_user, session):
|
|||
session.add(test_user)
|
||||
session.commit()
|
||||
|
||||
response = client.post("api/v1/login", data={"username": "testuser", "password": "wrongpassword"})
|
||||
response = client.post(
|
||||
"api/v1/login", data={"username": "testuser", "password": "wrongpassword"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Incorrect username or password"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import pytest
|
||||
from langflow.processing.process import process_tweaks
|
||||
from langflow.services.deps import get_session_service
|
||||
from langflow.services.getters import get_session_service
|
||||
|
||||
|
||||
def test_no_tweaks():
|
||||
|
|
@ -198,42 +197,39 @@ def test_tweak_not_in_template():
|
|||
assert result == graph_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_langchain_object_with_cached_session(client, basic_graph_data):
|
||||
def test_load_langchain_object_with_cached_session(client, basic_graph_data):
|
||||
# Provide a non-existent session_id
|
||||
session_service = get_session_service()
|
||||
session_id1 = "non-existent-session-id"
|
||||
graph1, artifacts1 = await session_service.load_session(session_id1, basic_graph_data)
|
||||
graph1, artifacts1 = session_service.load_session(session_id1, basic_graph_data)
|
||||
# Use the new session_id to get the langchain_object again
|
||||
graph2, artifacts2 = await session_service.load_session(session_id1, basic_graph_data)
|
||||
graph2, artifacts2 = session_service.load_session(session_id1, basic_graph_data)
|
||||
|
||||
assert graph1 == graph2
|
||||
assert artifacts1 == artifacts2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_langchain_object_with_no_cached_session(client, basic_graph_data):
|
||||
def test_load_langchain_object_with_no_cached_session(client, basic_graph_data):
|
||||
# Provide a non-existent session_id
|
||||
session_service = get_session_service()
|
||||
session_id1 = "non-existent-session-id"
|
||||
session_id = session_service.build_key(session_id1, basic_graph_data)
|
||||
graph1, artifacts1 = await session_service.load_session(session_id, basic_graph_data)
|
||||
graph1, artifacts1 = session_service.load_session(session_id, basic_graph_data)
|
||||
# Clear the cache
|
||||
session_service.clear_session(session_id)
|
||||
# Use the new session_id to get the langchain_object again
|
||||
graph2, artifacts2 = await session_service.load_session(session_id, basic_graph_data)
|
||||
graph2, artifacts2 = session_service.load_session(session_id, basic_graph_data)
|
||||
|
||||
assert id(graph1) != id(graph2)
|
||||
# Since the cache was cleared, objects should be different
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_langchain_object_without_session_id(client, basic_graph_data):
|
||||
def test_load_langchain_object_without_session_id(client, basic_graph_data):
|
||||
# Provide a non-existent session_id
|
||||
session_service = get_session_service()
|
||||
session_id1 = None
|
||||
graph1, artifacts1 = await session_service.load_session(session_id1, basic_graph_data)
|
||||
graph1, artifacts1 = session_service.load_session(session_id1, basic_graph_data)
|
||||
# Use the new session_id to get the langchain_object again
|
||||
graph2, artifacts2 = await session_service.load_session(session_id1, basic_graph_data)
|
||||
graph2, artifacts2 = session_service.load_session(session_id1, basic_graph_data)
|
||||
|
||||
assert graph1 == graph2
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
|
||||
def test_prompts_settings(client: TestClient, logged_in_headers):
|
||||
|
|
@ -32,7 +31,6 @@ def test_prompt_template(client: TestClient, logged_in_headers):
|
|||
"list": True,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
assert template["output_parser"] == {
|
||||
|
|
@ -47,7 +45,6 @@ def test_prompt_template(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
assert template["partial_variables"] == {
|
||||
|
|
@ -62,7 +59,6 @@ def test_prompt_template(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
assert template["template"] == {
|
||||
|
|
@ -77,7 +73,6 @@ def test_prompt_template(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
assert template["template_format"] == {
|
||||
|
|
@ -93,7 +88,6 @@ def test_prompt_template(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
assert template["validate_template"] == {
|
||||
|
|
@ -102,12 +96,11 @@ def test_prompt_template(client: TestClient, logged_in_headers):
|
|||
"placeholder": "",
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
"value": False,
|
||||
"value": True,
|
||||
"password": False,
|
||||
"name": "validate_template",
|
||||
"type": "bool",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,17 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch, MagicMock
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.settings.constants import (
|
||||
DEFAULT_SUPERUSER,
|
||||
DEFAULT_SUPERUSER_PASSWORD,
|
||||
)
|
||||
from langflow.services.utils import (
|
||||
teardown_superuser,
|
||||
)
|
||||
|
||||
from langflow.services.settings.constants import DEFAULT_SUPERUSER, DEFAULT_SUPERUSER_PASSWORD
|
||||
from langflow.services.utils import teardown_superuser
|
||||
|
||||
# @patch("langflow.services.deps.get_session")
|
||||
# @patch("langflow.services.getters.get_session")
|
||||
# @patch("langflow.services.utils.create_super_user")
|
||||
# @patch("langflow.services.deps.get_settings_service")
|
||||
# @patch("langflow.services.getters.get_settings_service")
|
||||
# # @patch("langflow.services.utils.verify_password")
|
||||
# def test_setup_superuser(
|
||||
# mock_get_session, mock_create_super_user, mock_get_settings_service
|
||||
|
|
@ -86,9 +92,11 @@ from langflow.services.utils import teardown_superuser
|
|||
# assert str(actual_expr) == str(expected_expr)
|
||||
|
||||
|
||||
@patch("langflow.services.deps.get_settings_service")
|
||||
@patch("langflow.services.deps.get_session")
|
||||
def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service):
|
||||
@patch("langflow.services.getters.get_settings_service")
|
||||
@patch("langflow.services.getters.get_session")
|
||||
def test_teardown_superuser_default_superuser(
|
||||
mock_get_session, mock_get_settings_service
|
||||
):
|
||||
mock_settings_service = MagicMock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = True
|
||||
mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
|
||||
|
|
@ -103,12 +111,20 @@ def test_teardown_superuser_default_superuser(mock_get_session, mock_get_setting
|
|||
|
||||
teardown_superuser(mock_settings_service, mock_session)
|
||||
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.query.assert_called_once_with(User)
|
||||
actual_expr = mock_session.query.return_value.filter.call_args[0][0]
|
||||
expected_expr = User.username == DEFAULT_SUPERUSER
|
||||
|
||||
assert str(actual_expr) == str(expected_expr)
|
||||
mock_session.delete.assert_called_once_with(mock_user)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@patch("langflow.services.deps.get_settings_service")
|
||||
@patch("langflow.services.deps.get_session")
|
||||
def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service):
|
||||
@patch("langflow.services.getters.get_settings_service")
|
||||
@patch("langflow.services.getters.get_session")
|
||||
def test_teardown_superuser_no_default_superuser(
|
||||
mock_get_session, mock_get_settings_service
|
||||
):
|
||||
ADMIN_USER_NAME = "admin_user"
|
||||
mock_settings_service = MagicMock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = False
|
||||
|
|
@ -119,11 +135,11 @@ def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_sett
|
|||
mock_session = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_user.is_superuser = False
|
||||
mock_session.exec.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_get_session.return_value = [mock_session]
|
||||
|
||||
teardown_superuser(mock_settings_service, mock_session)
|
||||
|
||||
mock_session.exec.assert_called_once()
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.delete.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ import importlib
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.utils.constants import CHAT_OPENAI_MODELS, OPENAI_MODELS
|
||||
from langflow.utils.util import (
|
||||
build_template_from_class,
|
||||
|
|
@ -12,6 +10,7 @@ from langflow.utils.util import (
|
|||
get_base_classes,
|
||||
get_default_factory,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Dummy classes for testing purposes
|
||||
|
|
@ -66,9 +65,11 @@ def test_build_template_from_function():
|
|||
assert "base_classes" in result
|
||||
|
||||
# Test with add_function=True
|
||||
result_with_function = build_template_from_function("ExampleClass1", type_to_loader_dict, add_function=True)
|
||||
result_with_function = build_template_from_function(
|
||||
"ExampleClass1", type_to_loader_dict, add_function=True
|
||||
)
|
||||
assert result_with_function is not None
|
||||
assert "Callable" in result_with_function["base_classes"]
|
||||
assert "function" in result_with_function["base_classes"]
|
||||
|
||||
# Test with invalid name
|
||||
with pytest.raises(ValueError, match=r".* not found"):
|
||||
|
|
@ -236,7 +237,7 @@ def test_format_dict():
|
|||
"password": False,
|
||||
"multiline": False,
|
||||
"options": CHAT_OPENAI_MODELS,
|
||||
"value": "gpt-4-1106-preview",
|
||||
"value": "gpt-4-1106-preview"",
|
||||
},
|
||||
}
|
||||
assert format_dict(input_dict, "OpenAI") == expected_output_openai
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.services.auth.utils import create_super_user, get_password_hash
|
||||
from langflow.services.database.models.user import UserUpdate
|
||||
from langflow.services.database.models.user.model import User
|
||||
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from langflow.services.getters import get_db_service, get_settings_service
|
||||
import pytest
|
||||
from langflow.services.database.models.user import UserUpdate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -86,11 +85,15 @@ def test_deactivated_user_cannot_access(client, deactivated_user, logged_in_head
|
|||
assert response.json()["detail"] == "The user doesn't have enough privileges"
|
||||
|
||||
|
||||
def test_data_consistency_after_update(client, active_user, logged_in_headers, super_user_headers):
|
||||
def test_data_consistency_after_update(
|
||||
client, active_user, logged_in_headers, super_user_headers
|
||||
):
|
||||
user_id = active_user.id
|
||||
update_data = UserUpdate(is_active=False)
|
||||
|
||||
response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=super_user_headers)
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=super_user_headers
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
# Fetch the updated user from the database
|
||||
|
|
@ -117,7 +120,7 @@ def test_inactive_user(client):
|
|||
username="inactiveuser",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=False,
|
||||
last_login_at=datetime.now(),
|
||||
last_login_at="2023-01-01T00:00:00", # Set to a valid datetime string
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
|
@ -164,13 +167,17 @@ def test_patch_user(client, active_user, logged_in_headers):
|
|||
username="newname",
|
||||
)
|
||||
|
||||
response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=logged_in_headers)
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
update_data = UserUpdate(
|
||||
profile_image="new_image",
|
||||
)
|
||||
|
||||
response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=logged_in_headers)
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
|
||||
|
|
@ -182,7 +189,7 @@ def test_patch_reset_password(client, active_user, logged_in_headers):
|
|||
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}/reset-password",
|
||||
json=update_data.model_dump(),
|
||||
json=update_data.dict(),
|
||||
headers=logged_in_headers,
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
|
@ -198,13 +205,19 @@ def test_patch_user_wrong_id(client, active_user, logged_in_headers):
|
|||
username="newname",
|
||||
)
|
||||
|
||||
response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=logged_in_headers)
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 422, response.json()
|
||||
json_response = response.json()
|
||||
detail = json_response["detail"]
|
||||
assert detail[0]["type"] == "uuid_parsing"
|
||||
assert detail[0]["loc"] == ["path", "user_id"]
|
||||
assert detail[0]["input"] == "wrong_id"
|
||||
assert response.json() == {
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["path", "user_id"],
|
||||
"msg": "value is not a valid uuid",
|
||||
"type": "type_error.uuid",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_delete_user(client, test_user, super_user_headers):
|
||||
|
|
@ -218,11 +231,15 @@ def test_delete_user_wrong_id(client, test_user, super_user_headers):
|
|||
user_id = "wrong_id"
|
||||
response = client.delete(f"/api/v1/users/{user_id}", headers=super_user_headers)
|
||||
assert response.status_code == 422
|
||||
json_response = response.json()
|
||||
detail = json_response["detail"]
|
||||
assert detail[0]["type"] == "uuid_parsing"
|
||||
assert detail[0]["loc"] == ["path", "user_id"]
|
||||
assert detail[0]["input"] == "wrong_id"
|
||||
assert response.json() == {
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["path", "user_id"],
|
||||
"msg": "value is not a valid uuid",
|
||||
"type": "type_error.uuid",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_normal_user_cant_delete_user(client, test_user, logged_in_headers):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from fastapi.testclient import TestClient
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
|
||||
# check that all agents are in settings.agents
|
||||
|
|
|
|||
|
|
@ -31,7 +31,9 @@ def test_websocket_endpoint(client: TestClient, active_user, logged_in_headers):
|
|||
# Assuming your websocket_endpoint uses chat_service which caches data from stream_build
|
||||
access_token = logged_in_headers["Authorization"].split(" ")[1]
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
with client.websocket_connect(f"api/v1/chat/non_existing_client_id?token={access_token}") as websocket:
|
||||
with client.websocket_connect(
|
||||
f"api/v1/chat/non_existing_client_id?token={access_token}"
|
||||
) as websocket:
|
||||
websocket.send_json({"type": "test"})
|
||||
data = websocket.receive_json()
|
||||
assert "Please, build the flow before sending messages" in data["message"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue