Merge remote-tracking branch 'origin/dev' into merge
This commit is contained in:
commit
e3a2abacae
231 changed files with 22158 additions and 6728 deletions
|
|
@ -1,46 +1,41 @@
|
|||
from contextlib import contextmanager
|
||||
import json
|
||||
from contextlib import suppress
|
||||
# we need to import tmpdir
|
||||
import tempfile
|
||||
from contextlib import contextmanager, suppress
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, AsyncGenerator
|
||||
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.flow.flow import Flow, FlowCreate
|
||||
from langflow.services.database.models.user.user import User, UserCreate
|
||||
import orjson
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
from sqlmodel import SQLModel, Session, create_engine
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlmodel.pool import StaticPool
|
||||
from typer.testing import CliRunner
|
||||
|
||||
# we need to import tmpdir
|
||||
import tempfile
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.flow.model import Flow, FlowCreate
|
||||
from langflow.services.database.models.user.model import User, UserCreate
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.manager import DatabaseService
|
||||
from langflow.services.database.service import DatabaseService
|
||||
|
||||
|
||||
def pytest_configure():
|
||||
pytest.BASIC_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "basic_example.json"
|
||||
)
|
||||
pytest.COMPLEX_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "complex_example.json"
|
||||
)
|
||||
pytest.OPENAPI_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "Openapi.json"
|
||||
)
|
||||
pytest.BASIC_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "basic_example.json"
|
||||
pytest.COMPLEX_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "complex_example.json"
|
||||
pytest.OPENAPI_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "Openapi.json"
|
||||
pytest.GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "grouped_chat.json"
|
||||
pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "one_group_chat.json"
|
||||
pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json"
|
||||
|
||||
pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY = (
|
||||
Path(__file__).parent.absolute() / "data" / "BasicChatwithPromptandHistory.json"
|
||||
)
|
||||
pytest.VECTOR_STORE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "Vector_store.json"
|
||||
)
|
||||
pytest.VECTOR_STORE_PATH = Path(__file__).parent.absolute() / "data" / "Vector_store.json"
|
||||
pytest.CODE_WITH_SYNTAX_ERROR = """
|
||||
def get_text():
|
||||
retun "Hello World"
|
||||
|
|
@ -58,9 +53,7 @@ async def async_client() -> AsyncGenerator:
|
|||
|
||||
@pytest.fixture(name="session")
|
||||
def session_fixture():
|
||||
engine = create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
)
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
|
@ -96,9 +89,7 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
|
|||
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
|
||||
# monkeypatch langflow.services.task.manager.USE_CELERY to True
|
||||
# monkeypatch.setattr(manager, "USE_CELERY", True)
|
||||
monkeypatch.setattr(
|
||||
celery_app, "celery_app", celery_app.make_celery("langflow", Config)
|
||||
)
|
||||
monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config))
|
||||
|
||||
# def get_session_override():
|
||||
# return session
|
||||
|
|
@ -215,7 +206,7 @@ def test_user(client):
|
|||
username="testuser",
|
||||
password="testpassword",
|
||||
)
|
||||
response = client.post("/api/v1/users", json=user_data.dict())
|
||||
response = client.post("/api/v1/users", json=user_data.model_dump())
|
||||
assert response.status_code == 201
|
||||
return response.json()
|
||||
|
||||
|
|
@ -231,11 +222,7 @@ def active_user(client):
|
|||
is_superuser=False,
|
||||
)
|
||||
# check if user exists
|
||||
if (
|
||||
active_user := session.query(User)
|
||||
.filter(User.username == user.username)
|
||||
.first()
|
||||
):
|
||||
if active_user := session.query(User).filter(User.username == user.username).first():
|
||||
return active_user
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
|
@ -255,13 +242,16 @@ def logged_in_headers(client, active_user):
|
|||
|
||||
@pytest.fixture
|
||||
def flow(client, json_flow: str, active_user):
|
||||
from langflow.services.database.models.flow.flow import FlowCreate
|
||||
from langflow.services.database.models.flow.model import FlowCreate
|
||||
|
||||
loaded_json = json.loads(json_flow)
|
||||
flow_data = FlowCreate(
|
||||
name="test_flow", data=loaded_json.get("data"), user_id=active_user.id
|
||||
name="test_flow",
|
||||
data=loaded_json.get("data"),
|
||||
user_id=active_user.id,
|
||||
description="description",
|
||||
)
|
||||
flow = Flow(**flow_data.dict())
|
||||
flow = Flow.model_validate(flow_data.model_dump())
|
||||
with session_getter(get_db_service()) as session:
|
||||
session.add(flow)
|
||||
session.commit()
|
||||
|
|
@ -275,7 +265,7 @@ def added_flow(client, json_flow_with_prompt_and_history, logged_in_headers):
|
|||
flow = orjson.loads(json_flow_with_prompt_and_history)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Basic Chat", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
|
@ -287,10 +277,24 @@ def added_vector_store(client, json_vector_store, logged_in_headers):
|
|||
vector_store = orjson.loads(json_vector_store)
|
||||
data = vector_store["data"]
|
||||
vector_store = FlowCreate(name="Vector Store", description="description", data=data)
|
||||
response = client.post(
|
||||
"api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers
|
||||
)
|
||||
response = client.post("api/v1/flows/", json=vector_store.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == vector_store.name
|
||||
assert response.json()["data"] == vector_store.data
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_component_code():
|
||||
path = Path(__file__).parent.absolute() / "data" / "component.py"
|
||||
# load the content as a string
|
||||
with open(path, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_component_with_templatefield_code():
|
||||
path = Path(__file__).parent.absolute() / "data" / "component_with_templatefield.py"
|
||||
# load the content as a string
|
||||
with open(path, "r") as f:
|
||||
return f.read()
|
||||
|
|
|
|||
16
tests/data/component.py
Normal file
16
tests/data/component.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
import random
|
||||
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
||||
class TestComponent(CustomComponent):
|
||||
def refresh_values(self):
|
||||
# This is a function that will be called every time the component is updated
|
||||
# and should return a list of random strings
|
||||
return [f"Random {random.randint(1, 100)}" for _ in range(5)]
|
||||
|
||||
def build_config(self):
|
||||
return {"param": {"display_name": "Param", "options": self.refresh_values}}
|
||||
|
||||
def build(self, param: int):
|
||||
return param
|
||||
17
tests/data/component_with_templatefield.py
Normal file
17
tests/data/component_with_templatefield.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
import random
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import TemplateField
|
||||
|
||||
|
||||
class TestComponent(CustomComponent):
|
||||
def refresh_values(self):
|
||||
# This is a function that will be called every time the component is updated
|
||||
# and should return a list of random strings
|
||||
return [f"Random {random.randint(1, 100)}" for _ in range(5)]
|
||||
|
||||
def build_config(self):
|
||||
return {"param": TemplateField(display_name="Param", options=self.refresh_values)}
|
||||
|
||||
def build(self, param: int):
|
||||
return param
|
||||
|
|
@ -66,9 +66,7 @@ class NameTest(FastHttpUser):
|
|||
result1, session_id = self.process(name, self.flow_id, payload1)
|
||||
|
||||
payload2 = {
|
||||
"inputs": {
|
||||
"text": "What is my name? Please, answer like this: Your name is <name>"
|
||||
},
|
||||
"inputs": {"text": "What is my name? Please, answer like this: Your name is <name>"},
|
||||
"session_id": session_id,
|
||||
"sync": False,
|
||||
}
|
||||
|
|
@ -88,9 +86,7 @@ class NameTest(FastHttpUser):
|
|||
logged_in_headers = {"Authorization": f"Bearer {a_token}"}
|
||||
print("Logged in")
|
||||
with open(
|
||||
Path(__file__).parent.parent
|
||||
/ "data"
|
||||
/ "BasicChatwithPromptandHistory.json",
|
||||
Path(__file__).parent.parent / "data" / "BasicChatwithPromptandHistory.json",
|
||||
"r",
|
||||
) as f:
|
||||
json_flow = f.read()
|
||||
|
|
@ -115,11 +111,7 @@ class NameTest(FastHttpUser):
|
|||
)
|
||||
print(response.json())
|
||||
user_id = next(
|
||||
(
|
||||
user["id"]
|
||||
for user in response.json()["users"]
|
||||
if user["username"] == "superuser"
|
||||
),
|
||||
(user["id"] for user in response.json()["users"] if user["username"] == "superuser"),
|
||||
None,
|
||||
)
|
||||
# Create api key
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"ZeroShotAgent",
|
||||
"BaseSingleActionAgent",
|
||||
"Agent",
|
||||
"function",
|
||||
"Callable",
|
||||
}
|
||||
template = zero_shot_agent["template"]
|
||||
|
||||
|
|
@ -28,6 +28,7 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": True,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
# Additional assertions for other template variables
|
||||
|
|
@ -43,6 +44,7 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
|
|
@ -56,6 +58,7 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["output_parser"] == {
|
||||
"required": False,
|
||||
|
|
@ -69,6 +72,7 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["input_variables"] == {
|
||||
"required": False,
|
||||
|
|
@ -82,6 +86,7 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": True,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["prefix"] == {
|
||||
"required": False,
|
||||
|
|
@ -96,6 +101,7 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["suffix"] == {
|
||||
"required": False,
|
||||
|
|
@ -110,6 +116,7 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -135,6 +142,9 @@ def test_json_agent(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"file_path": "",
|
||||
"fileTypes": [],
|
||||
"value": "",
|
||||
}
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
|
|
@ -149,6 +159,9 @@ def test_json_agent(client: TestClient, logged_in_headers):
|
|||
"advanced": False,
|
||||
"display_name": "LLM",
|
||||
"info": "",
|
||||
"file_path": "",
|
||||
"fileTypes": [],
|
||||
"value": "",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -169,87 +182,12 @@ def test_csv_agent(client: TestClient, logged_in_headers):
|
|||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "",
|
||||
"suffixes": [".csv"],
|
||||
"fileTypes": ["csv"],
|
||||
"fileTypes": [".csv"],
|
||||
"password": False,
|
||||
"name": "path",
|
||||
"type": "file",
|
||||
"list": False,
|
||||
"file_path": None,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "llm",
|
||||
"type": "BaseLanguageModel",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"display_name": "LLM",
|
||||
"info": "",
|
||||
}
|
||||
|
||||
|
||||
def test_initialize_agent(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
|
||||
initialize_agent = agents["AgentInitializer"]
|
||||
assert initialize_agent["base_classes"] == ["AgentExecutor", "function"]
|
||||
template = initialize_agent["template"]
|
||||
|
||||
assert template["agent"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "zero-shot-react-description",
|
||||
"password": False,
|
||||
"options": [
|
||||
"zero-shot-react-description",
|
||||
"react-docstore",
|
||||
"self-ask-with-search",
|
||||
"conversational-react-description",
|
||||
"openai-functions",
|
||||
"openai-multi-functions",
|
||||
],
|
||||
"name": "agent",
|
||||
"type": "str",
|
||||
"list": True,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["memory"] == {
|
||||
"required": False,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "memory",
|
||||
"type": "BaseChatMemory",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["tools"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "tools",
|
||||
"type": "Tool",
|
||||
"list": True,
|
||||
"file_path": "",
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
|
|
@ -266,4 +204,7 @@ def test_initialize_agent(client: TestClient, logged_in_headers):
|
|||
"advanced": False,
|
||||
"display_name": "LLM",
|
||||
"info": "",
|
||||
"file_path": "",
|
||||
"fileTypes": [],
|
||||
"value": "",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,9 +6,7 @@ from langflow.services.database.models.api_key import ApiKeyCreate
|
|||
def api_key(client, logged_in_headers, active_user):
|
||||
api_key = ApiKeyCreate(name="test-api-key")
|
||||
|
||||
response = client.post(
|
||||
"api/v1/api_key", data=api_key.json(), headers=logged_in_headers
|
||||
)
|
||||
response = client.post("api/v1/api_key", data=api_key.json(), headers=logged_in_headers)
|
||||
assert response.status_code == 200, response.text
|
||||
return response.json()
|
||||
|
||||
|
|
@ -28,9 +26,7 @@ def test_get_api_keys(client, logged_in_headers, api_key):
|
|||
|
||||
def test_create_api_key(client, logged_in_headers):
|
||||
api_key_name = "test-api-key"
|
||||
response = client.post(
|
||||
"api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers
|
||||
)
|
||||
response = client.post("api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "name" in data and data["name"] == api_key_name
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import json
|
||||
from langflow.graph import Graph
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.graph import Graph
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
|
|
@ -41,5 +42,5 @@ def langchain_objects_are_equal(obj1, obj2):
|
|||
def test_build_graph(client, basic_data_graph):
|
||||
graph = Graph.from_payload(basic_data_graph)
|
||||
assert graph is not None
|
||||
assert len(graph.nodes) == len(basic_data_graph["nodes"])
|
||||
assert len(graph.vertices) == len(basic_data_graph["nodes"])
|
||||
assert len(graph.edges) == len(basic_data_graph["edges"])
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# def test_chains_settings(client: TestClient, logged_in_headers):
|
||||
# response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
# assert response.status_code == 200
|
||||
|
|
@ -9,170 +8,6 @@ from fastapi.testclient import TestClient
|
|||
# assert set(chains.keys()) == set(settings.chains)
|
||||
|
||||
|
||||
# Test the ConversationChain object
|
||||
def test_conversation_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
||||
chain = chains["ConversationChain"]
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {
|
||||
"ConversationChain",
|
||||
"LLMChain",
|
||||
"Chain",
|
||||
"function",
|
||||
}
|
||||
|
||||
template = chain["template"]
|
||||
assert template["memory"] == {
|
||||
"required": False,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "memory",
|
||||
"type": "BaseMemory",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["verbose"] == {
|
||||
"required": False,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "verbose",
|
||||
"type": "bool",
|
||||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
}
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "llm",
|
||||
"type": "BaseLanguageModel",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["input_key"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "input",
|
||||
"password": False,
|
||||
"name": "input_key",
|
||||
"type": "str",
|
||||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
}
|
||||
assert template["output_key"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "response",
|
||||
"password": False,
|
||||
"name": "output_key",
|
||||
"type": "str",
|
||||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
}
|
||||
assert template["_type"] == "ConversationChain"
|
||||
|
||||
# Test the description object
|
||||
assert (
|
||||
chain["description"]
|
||||
== "Chain to have a conversation and load context from memory."
|
||||
)
|
||||
|
||||
|
||||
def test_llm_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
chain = chains["LLMChain"]
|
||||
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {
|
||||
"function",
|
||||
"LLMChain",
|
||||
"Chain",
|
||||
}
|
||||
|
||||
template = chain["template"]
|
||||
assert template["memory"] == {
|
||||
"required": False,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "memory",
|
||||
"type": "BaseMemory",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["verbose"] == {
|
||||
"required": False,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
"value": False,
|
||||
"password": False,
|
||||
"name": "verbose",
|
||||
"type": "bool",
|
||||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
}
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "llm",
|
||||
"type": "BaseLanguageModel",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
}
|
||||
assert template["output_key"] == {
|
||||
"required": True,
|
||||
"dynamic": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "text",
|
||||
"password": False,
|
||||
"name": "output_key",
|
||||
"type": "str",
|
||||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
}
|
||||
|
||||
|
||||
def test_llm_checker_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
|
|
@ -182,7 +17,7 @@ def test_llm_checker_chain(client: TestClient, logged_in_headers):
|
|||
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {
|
||||
"function",
|
||||
"Callable",
|
||||
"LLMCheckerChain",
|
||||
"Chain",
|
||||
}
|
||||
|
|
@ -200,6 +35,7 @@ def test_llm_checker_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["_type"] == "LLMCheckerChain"
|
||||
|
||||
|
|
@ -216,7 +52,7 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
|
|||
chain = chains["LLMMathChain"]
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {
|
||||
"function",
|
||||
"Callable",
|
||||
"LLMMathChain",
|
||||
"Chain",
|
||||
}
|
||||
|
|
@ -234,6 +70,7 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["verbose"] == {
|
||||
"required": False,
|
||||
|
|
@ -248,6 +85,7 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["llm"] == {
|
||||
"required": True,
|
||||
|
|
@ -261,6 +99,7 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["input_key"] == {
|
||||
"required": True,
|
||||
|
|
@ -275,6 +114,7 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["output_key"] == {
|
||||
"required": True,
|
||||
|
|
@ -289,14 +129,12 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["_type"] == "LLMMathChain"
|
||||
|
||||
# Test the description object
|
||||
assert (
|
||||
chain["description"]
|
||||
== "Chain that interprets a prompt and executes python code to do math."
|
||||
)
|
||||
assert chain["description"] == "Chain that interprets a prompt and executes python code to do math."
|
||||
|
||||
|
||||
def test_series_character_chain(client: TestClient, logged_in_headers):
|
||||
|
|
@ -309,7 +147,7 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
|
|||
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {
|
||||
"function",
|
||||
"Callable",
|
||||
"LLMChain",
|
||||
"BaseCustomChain",
|
||||
"Chain",
|
||||
|
|
@ -331,6 +169,9 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
"file_path": "",
|
||||
"value": "",
|
||||
}
|
||||
assert template["character"] == {
|
||||
"required": True,
|
||||
|
|
@ -344,6 +185,9 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
"file_path": "",
|
||||
"value": "",
|
||||
}
|
||||
assert template["series"] == {
|
||||
"required": True,
|
||||
|
|
@ -357,6 +201,9 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
"file_path": "",
|
||||
"value": "",
|
||||
}
|
||||
assert template["_type"] == "SeriesCharacterChain"
|
||||
|
||||
|
|
@ -400,12 +247,12 @@ def test_mid_journey_prompt_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"file_path": "",
|
||||
"fileTypes": [],
|
||||
"value": "",
|
||||
}
|
||||
# Test the description object
|
||||
assert (
|
||||
chain["description"]
|
||||
== "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."
|
||||
)
|
||||
assert chain["description"] == "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."
|
||||
|
||||
|
||||
def test_time_travel_guide_chain(client: TestClient, logged_in_headers):
|
||||
|
|
@ -441,6 +288,9 @@ def test_time_travel_guide_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"file_path": "",
|
||||
"fileTypes": [],
|
||||
"value": "",
|
||||
}
|
||||
assert template["memory"] == {
|
||||
"required": False,
|
||||
|
|
@ -454,6 +304,9 @@ def test_time_travel_guide_chain(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"file_path": "",
|
||||
"fileTypes": [],
|
||||
"value": "",
|
||||
}
|
||||
|
||||
assert chain["description"] == "Time travel guide chain."
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from pathlib import Path
|
||||
from tempfile import tempdir
|
||||
from langflow.__main__ import app
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.services import getters
|
||||
from langflow.__main__ import app
|
||||
from langflow.services import deps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
@ -26,11 +27,12 @@ def test_components_path(runner, client, default_settings):
|
|||
["run", "--components-path", str(temp_dir), *default_settings],
|
||||
)
|
||||
assert result.exit_code == 0, result.stdout
|
||||
settings_service = getters.get_settings_service()
|
||||
settings_service = deps.get_settings_service()
|
||||
assert str(temp_dir) in settings_service.settings.COMPONENTS_PATH
|
||||
|
||||
|
||||
def test_superuser(runner, client, session):
|
||||
result = runner.invoke(app, ["superuser"], input="admin\nadmin\n")
|
||||
assert result.exit_code == 0, result.stdout
|
||||
assert "Superuser created successfully." in result.stdout
|
||||
assert "Superuser creation failed." not in result.output, result.output
|
||||
assert "Superuser created successfully." in result.output, result.output
|
||||
|
|
|
|||
|
|
@ -1,19 +1,15 @@
|
|||
import ast
|
||||
import pytest
|
||||
import types
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from langflow.services.database.models.flow import Flow, FlowCreate
|
||||
from langflow.interface.custom.base import CustomComponent
|
||||
from langflow.interface.custom.component import (
|
||||
Component,
|
||||
ComponentCodeNullError,
|
||||
ComponentFunctionEntrypointNameNullError,
|
||||
)
|
||||
from langflow.interface.custom.code_parser import CodeParser, CodeSyntaxError
|
||||
|
||||
from langflow.interface.custom.base import CustomComponent
|
||||
from langflow.interface.custom.code_parser import CodeParser, CodeSyntaxError
|
||||
from langflow.interface.custom.component import Component, ComponentCodeNullError
|
||||
from langflow.interface.types import build_custom_component_template, create_and_validate_component
|
||||
from langflow.services.database.models.flow import Flow, FlowCreate
|
||||
|
||||
code_default = """
|
||||
from langflow import Prompt
|
||||
|
|
@ -73,16 +69,16 @@ def test_component_init():
|
|||
"""
|
||||
Test the initialization of the Component class.
|
||||
"""
|
||||
component = Component(code=code_default, function_entrypoint_name="build")
|
||||
component = Component(code=code_default, _function_entrypoint_name="build")
|
||||
assert component.code == code_default
|
||||
assert component.function_entrypoint_name == "build"
|
||||
assert component._function_entrypoint_name == "build"
|
||||
|
||||
|
||||
def test_component_get_code_tree():
|
||||
"""
|
||||
Test the get_code_tree method of the Component class.
|
||||
"""
|
||||
component = Component(code=code_default, function_entrypoint_name="build")
|
||||
component = Component(code=code_default, _function_entrypoint_name="build")
|
||||
tree = component.get_code_tree(component.code)
|
||||
assert "imports" in tree
|
||||
|
||||
|
|
@ -92,19 +88,20 @@ def test_component_code_null_error():
|
|||
Test the get_function method raises the
|
||||
ComponentCodeNullError when the code is empty.
|
||||
"""
|
||||
component = Component(code="", function_entrypoint_name="")
|
||||
component = Component(code="", _function_entrypoint_name="")
|
||||
with pytest.raises(ComponentCodeNullError):
|
||||
component.get_function()
|
||||
|
||||
|
||||
def test_component_function_entrypoint_name_null_error():
|
||||
"""
|
||||
Test the get_function method raises the ComponentFunctionEntrypointNameNullError
|
||||
when the function_entrypoint_name is empty.
|
||||
"""
|
||||
component = Component(code=code_default, function_entrypoint_name="")
|
||||
with pytest.raises(ComponentFunctionEntrypointNameNullError):
|
||||
component.get_function()
|
||||
# TODO: Validate if we should remove this
|
||||
# def test_component_function_entrypoint_name_null_error():
|
||||
# """
|
||||
# Test the get_function method raises the ComponentFunctionEntrypointNameNullError
|
||||
# when the function_entrypoint_name is empty.
|
||||
# """
|
||||
# component = Component(code=code_default, _function_entrypoint_name="")
|
||||
# with pytest.raises(ComponentFunctionEntrypointNameNullError):
|
||||
# component.get_function()
|
||||
|
||||
|
||||
def test_custom_component_init():
|
||||
|
|
@ -113,9 +110,7 @@ def test_custom_component_init():
|
|||
"""
|
||||
function_entrypoint_name = "build"
|
||||
|
||||
custom_component = CustomComponent(
|
||||
code=code_default, function_entrypoint_name=function_entrypoint_name
|
||||
)
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name=function_entrypoint_name)
|
||||
assert custom_component.code == code_default
|
||||
assert custom_component.function_entrypoint_name == function_entrypoint_name
|
||||
|
||||
|
|
@ -124,10 +119,8 @@ def test_custom_component_build_template_config():
|
|||
"""
|
||||
Test the build_template_config property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(
|
||||
code=code_default, function_entrypoint_name="build"
|
||||
)
|
||||
config = custom_component.build_template_config
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
|
||||
config = custom_component.template_config
|
||||
assert isinstance(config, dict)
|
||||
|
||||
|
||||
|
|
@ -135,9 +128,7 @@ def test_custom_component_get_function():
|
|||
"""
|
||||
Test the get_function property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(
|
||||
code="def build(): pass", function_entrypoint_name="build"
|
||||
)
|
||||
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
|
||||
my_function = custom_component.get_function
|
||||
assert isinstance(my_function, types.FunctionType)
|
||||
|
||||
|
|
@ -212,7 +203,7 @@ def test_component_get_function_valid():
|
|||
Test the get_function method of the Component
|
||||
class with valid code and function_entrypoint_name.
|
||||
"""
|
||||
component = Component(code="def build(): pass", function_entrypoint_name="build")
|
||||
component = Component(code="def build(): pass", _function_entrypoint_name="build")
|
||||
my_function = component.get_function()
|
||||
assert callable(my_function)
|
||||
|
||||
|
|
@ -222,9 +213,7 @@ def test_custom_component_get_function_entrypoint_args():
|
|||
Test the get_function_entrypoint_args
|
||||
property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(
|
||||
code=code_default, function_entrypoint_name="build"
|
||||
)
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
|
||||
args = custom_component.get_function_entrypoint_args
|
||||
assert len(args) == 4
|
||||
assert args[0]["name"] == "self"
|
||||
|
|
@ -237,20 +226,18 @@ def test_custom_component_get_function_entrypoint_return_type():
|
|||
Test the get_function_entrypoint_return_type
|
||||
property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(
|
||||
code=code_default, function_entrypoint_name="build"
|
||||
)
|
||||
from langchain.schema import Document
|
||||
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
|
||||
return_type = custom_component.get_function_entrypoint_return_type
|
||||
assert return_type == ["Document"]
|
||||
assert return_type == [Document]
|
||||
|
||||
|
||||
def test_custom_component_get_main_class_name():
|
||||
"""
|
||||
Test the get_main_class_name property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(
|
||||
code=code_default, function_entrypoint_name="build"
|
||||
)
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
|
||||
class_name = custom_component.get_main_class_name
|
||||
assert class_name == "YourComponent"
|
||||
|
||||
|
|
@ -260,9 +247,7 @@ def test_custom_component_get_function_valid():
|
|||
Test the get_function property of the CustomComponent
|
||||
class with valid code and function_entrypoint_name.
|
||||
"""
|
||||
custom_component = CustomComponent(
|
||||
code="def build(): pass", function_entrypoint_name="build"
|
||||
)
|
||||
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
|
||||
my_function = custom_component.get_function
|
||||
assert callable(my_function)
|
||||
|
||||
|
|
@ -297,9 +282,7 @@ def test_code_parser_parse_callable_details_no_args():
|
|||
parser = CodeParser("")
|
||||
node = ast.FunctionDef(
|
||||
name="test",
|
||||
args=ast.arguments(
|
||||
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
|
||||
),
|
||||
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
|
||||
body=[],
|
||||
decorator_list=[],
|
||||
returns=None,
|
||||
|
|
@ -345,9 +328,7 @@ def test_code_parser_parse_function_def_not_init():
|
|||
parser = CodeParser("")
|
||||
stmt = ast.FunctionDef(
|
||||
name="test",
|
||||
args=ast.arguments(
|
||||
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
|
||||
),
|
||||
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
|
||||
body=[],
|
||||
decorator_list=[],
|
||||
returns=None,
|
||||
|
|
@ -365,9 +346,7 @@ def test_code_parser_parse_function_def_init():
|
|||
parser = CodeParser("")
|
||||
stmt = ast.FunctionDef(
|
||||
name="__init__",
|
||||
args=ast.arguments(
|
||||
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
|
||||
),
|
||||
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
|
||||
body=[],
|
||||
decorator_list=[],
|
||||
returns=None,
|
||||
|
|
@ -382,7 +361,7 @@ def test_component_get_code_tree_syntax_error():
|
|||
Test the get_code_tree method of the Component class
|
||||
raises the CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
component = Component(code="import os as", function_entrypoint_name="build")
|
||||
component = Component(code="import os as", _function_entrypoint_name="build")
|
||||
with pytest.raises(CodeSyntaxError):
|
||||
component.get_code_tree(component.code)
|
||||
|
||||
|
|
@ -402,9 +381,7 @@ def test_custom_component_get_code_tree_syntax_error():
|
|||
Test the get_code_tree method of the CustomComponent class
|
||||
raises the CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
custom_component = CustomComponent(
|
||||
code="import os as", function_entrypoint_name="build"
|
||||
)
|
||||
custom_component = CustomComponent(code="import os as", function_entrypoint_name="build")
|
||||
with pytest.raises(CodeSyntaxError):
|
||||
custom_component.get_code_tree(custom_component.code)
|
||||
|
||||
|
|
@ -458,9 +435,7 @@ def test_custom_component_build_not_implemented():
|
|||
Test the build method of the CustomComponent
|
||||
class raises the NotImplementedError.
|
||||
"""
|
||||
custom_component = CustomComponent(
|
||||
code="def build(): pass", function_entrypoint_name="build"
|
||||
)
|
||||
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
|
||||
with pytest.raises(NotImplementedError):
|
||||
custom_component.build()
|
||||
|
||||
|
|
@ -468,7 +443,7 @@ def test_custom_component_build_not_implemented():
|
|||
def test_build_config_no_code():
|
||||
component = CustomComponent(code=None)
|
||||
|
||||
assert component.get_function_entrypoint_args == ""
|
||||
assert component.get_function_entrypoint_args == []
|
||||
assert component.get_function_entrypoint_return_type == []
|
||||
|
||||
|
||||
|
|
@ -494,9 +469,7 @@ def test_flow(db):
|
|||
}
|
||||
|
||||
# Create flow
|
||||
flow = FlowCreate(
|
||||
id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data
|
||||
)
|
||||
flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data)
|
||||
|
||||
# Add to database
|
||||
db.add(flow)
|
||||
|
|
@ -557,3 +530,36 @@ def test_build_config_field_value_keys(component):
|
|||
config = component.build_config()
|
||||
field_values = config["fields"].values()
|
||||
assert all("type" in value for value in field_values)
|
||||
|
||||
|
||||
def test_create_and_validate_component_valid_code(test_component_code):
|
||||
component = create_and_validate_component(test_component_code)
|
||||
assert isinstance(component, CustomComponent)
|
||||
|
||||
|
||||
def test_build_langchain_template_custom_component_valid_code(test_component_code):
|
||||
component = create_and_validate_component(test_component_code)
|
||||
frontend_node = build_custom_component_template(component)
|
||||
assert isinstance(frontend_node, dict)
|
||||
template = frontend_node["template"]
|
||||
assert isinstance(template, dict)
|
||||
assert "param" in template
|
||||
param_options = template["param"]["options"]
|
||||
# Now run it again with an update field
|
||||
frontend_node = build_custom_component_template(component, update_field="param")
|
||||
new_param_options = frontend_node["template"]["param"]["options"]
|
||||
assert param_options != new_param_options
|
||||
|
||||
|
||||
def test_build_langchain_template_custom_component_templatefield(test_component_with_templatefield_code):
|
||||
component = create_and_validate_component(test_component_with_templatefield_code)
|
||||
frontend_node = build_custom_component_template(component)
|
||||
assert isinstance(frontend_node, dict)
|
||||
template = frontend_node["template"]
|
||||
assert isinstance(template, dict)
|
||||
assert "param" in template
|
||||
param_options = template["param"]["options"]
|
||||
# Now run it again with an update field
|
||||
frontend_node = build_custom_component_template(component, update_field="param")
|
||||
new_param_options = frontend_node["template"]["param"]["options"]
|
||||
assert param_options != new_param_options
|
||||
|
|
|
|||
|
|
@ -18,9 +18,7 @@ def test_python_function_tool():
|
|||
with pytest.raises(SyntaxError):
|
||||
code = pytest.CODE_WITH_SYNTAX_ERROR
|
||||
func = get_function(code)
|
||||
func = PythonFunctionTool(
|
||||
name="Test", description="Testing", code=code, func=func
|
||||
)
|
||||
func = PythonFunctionTool(name="Test", description="Testing", code=code, func=func)
|
||||
|
||||
|
||||
def test_python_function():
|
||||
|
|
|
|||
|
|
@ -1,16 +1,14 @@
|
|||
from langflow.services.database.models.base import orjson_dumps
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from uuid import UUID, uuid4
|
||||
from sqlmodel import Session
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from langflow.api.v1.schemas import FlowListCreate
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
@ -27,21 +25,17 @@ def json_style():
|
|||
)
|
||||
|
||||
|
||||
def test_create_flow(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
def test_create_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
# flow is optional so we can create a flow without a flow
|
||||
flow = FlowCreate(name="Test Flow")
|
||||
response = client.post(
|
||||
"api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers
|
||||
)
|
||||
flow = FlowCreate(name="Test Flow", description="description")
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(exclude_unset=True), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
|
@ -51,13 +45,13 @@ def test_read_flows(client: TestClient, json_flow: str, active_user, logged_in_h
|
|||
flow_data = orjson.loads(json_flow)
|
||||
data = flow_data["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
|
@ -71,7 +65,7 @@ def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_he
|
|||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
flow_id = response.json()["id"] # flow_id should be a UUID but is a string
|
||||
# turn it into a UUID
|
||||
flow_id = UUID(flow_id)
|
||||
|
|
@ -82,14 +76,12 @@ def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_he
|
|||
assert response.json()["data"] == flow.data
|
||||
|
||||
|
||||
def test_update_flow(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
|
||||
flow_id = response.json()["id"]
|
||||
updated_flow = FlowUpdate(
|
||||
|
|
@ -97,9 +89,7 @@ def test_update_flow(
|
|||
description="updated description",
|
||||
data=data,
|
||||
)
|
||||
response = client.patch(
|
||||
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
|
||||
)
|
||||
response = client.patch(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == updated_flow.name
|
||||
|
|
@ -107,22 +97,18 @@ def test_update_flow(
|
|||
# assert response.json()["data"] == updated_flow.data
|
||||
|
||||
|
||||
def test_delete_flow(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
def test_delete_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
flow_id = response.json()["id"]
|
||||
response = client.delete(f"api/v1/flows/{flow_id}", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "Flow deleted successfully"
|
||||
|
||||
|
||||
def test_create_flows(
|
||||
client: TestClient, session: Session, json_flow: str, logged_in_headers
|
||||
):
|
||||
def test_create_flows(client: TestClient, session: Session, json_flow: str, logged_in_headers):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
# Create test data
|
||||
|
|
@ -133,9 +119,7 @@ def test_create_flows(
|
|||
]
|
||||
)
|
||||
# Make request to endpoint
|
||||
response = client.post(
|
||||
"api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers
|
||||
)
|
||||
response = client.post("api/v1/flows/batch/", json=flow_list.model_dump(), headers=logged_in_headers)
|
||||
# Check response status code
|
||||
assert response.status_code == 201
|
||||
# Check response data
|
||||
|
|
@ -149,9 +133,7 @@ def test_create_flows(
|
|||
assert response_data[1]["data"] == data
|
||||
|
||||
|
||||
def test_upload_file(
|
||||
client: TestClient, session: Session, json_flow: str, logged_in_headers
|
||||
):
|
||||
def test_upload_file(client: TestClient, session: Session, json_flow: str, logged_in_headers):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
# Create test data
|
||||
|
|
@ -161,7 +143,7 @@ def test_upload_file(
|
|||
FlowCreate(name="Flow 2", description="description", data=data),
|
||||
]
|
||||
)
|
||||
file_contents = orjson_dumps(flow_list.dict())
|
||||
file_contents = orjson_dumps(flow_list.model_dump())
|
||||
response = client.post(
|
||||
"api/v1/flows/upload/",
|
||||
files={"file": ("examples.json", file_contents, "application/json")},
|
||||
|
|
@ -200,7 +182,7 @@ def test_download_file(
|
|||
with session_getter(db_manager) as session:
|
||||
for flow in flow_list.flows:
|
||||
flow.user_id = active_user.id
|
||||
db_flow = Flow.from_orm(flow)
|
||||
db_flow = Flow.model_validate(flow, from_attributes=True)
|
||||
session.add(db_flow)
|
||||
session.commit()
|
||||
# Make request to endpoint
|
||||
|
|
@ -218,9 +200,7 @@ def test_download_file(
|
|||
assert response_data[1]["data"] == data
|
||||
|
||||
|
||||
def test_create_flow_with_invalid_data(
|
||||
client: TestClient, active_user, logged_in_headers
|
||||
):
|
||||
def test_create_flow_with_invalid_data(client: TestClient, active_user, logged_in_headers):
|
||||
flow = {"name": "a" * 256, "data": "Invalid flow data"}
|
||||
response = client.post("api/v1/flows/", json=flow, headers=logged_in_headers)
|
||||
assert response.status_code == 422
|
||||
|
|
@ -232,29 +212,19 @@ def test_get_nonexistent_flow(client: TestClient, active_user, logged_in_headers
|
|||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_update_flow_idempotency(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
def test_update_flow_idempotency(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow_data = orjson.loads(json_flow)
|
||||
data = flow_data["data"]
|
||||
flow_data = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post(
|
||||
"api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
response = client.post("api/v1/flows/", json=flow_data.model_dump(), headers=logged_in_headers)
|
||||
flow_id = response.json()["id"]
|
||||
updated_flow = FlowCreate(name="Updated Flow", description="description", data=data)
|
||||
response1 = client.put(
|
||||
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
|
||||
)
|
||||
response2 = client.put(
|
||||
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
|
||||
)
|
||||
response1 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
|
||||
response2 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers)
|
||||
assert response1.json() == response2.json()
|
||||
|
||||
|
||||
def test_update_nonexistent_flow(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow_data = orjson.loads(json_flow)
|
||||
data = flow_data["data"]
|
||||
uuid = uuid4()
|
||||
|
|
@ -263,9 +233,7 @@ def test_update_nonexistent_flow(
|
|||
description="description",
|
||||
data=data,
|
||||
)
|
||||
response = client.patch(
|
||||
f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers
|
||||
)
|
||||
response = client.patch(f"api/v1/flows/{uuid}", json=updated_flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,17 @@
|
|||
from collections import namedtuple
|
||||
import time
|
||||
import uuid
|
||||
from langflow.processing.process import Result
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.api_key.api_key import ApiKey
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service
|
||||
from collections import namedtuple
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from langflow.interface.tools.constants import CUSTOM_TOOLS
|
||||
from langflow.processing.process import Result
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.api_key.model import ApiKey
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from langflow.template.frontend_node.chains import TimeTravelGuideChainNode
|
||||
|
||||
import time
|
||||
|
||||
|
||||
def run_post(client, flow_id, headers, post_data):
|
||||
response = client.post(
|
||||
|
|
@ -25,16 +24,13 @@ def run_post(client, flow_id, headers, post_data):
|
|||
|
||||
|
||||
# Helper function to poll task status
|
||||
def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1):
|
||||
def poll_task_status(client, headers, href, max_attempts=20, sleep_time=2):
|
||||
for _ in range(max_attempts):
|
||||
task_status_response = client.get(
|
||||
href,
|
||||
headers=headers,
|
||||
)
|
||||
if (
|
||||
task_status_response.status_code == 200
|
||||
and task_status_response.json()["status"] == "SUCCESS"
|
||||
):
|
||||
if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS":
|
||||
return task_status_response.json()
|
||||
time.sleep(sleep_time)
|
||||
return None # Return None if task did not complete in time
|
||||
|
|
@ -130,11 +126,7 @@ def created_api_key(active_user):
|
|||
)
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
if (
|
||||
existing_api_key := session.query(ApiKey)
|
||||
.filter(ApiKey.api_key == api_key.api_key)
|
||||
.first()
|
||||
):
|
||||
if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first():
|
||||
return existing_api_key
|
||||
session.add(api_key)
|
||||
session.commit()
|
||||
|
|
@ -193,9 +185,7 @@ def test_process_flow_invalid_id(client, monkeypatch, created_api_key):
|
|||
}
|
||||
|
||||
invalid_id = uuid.uuid4()
|
||||
response = client.post(
|
||||
f"api/v1/process/{invalid_id}", headers=headers, json=post_data
|
||||
)
|
||||
response = client.post(f"api/v1/process/{invalid_id}", headers=headers, json=post_data)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert f"Flow {invalid_id} not found" in response.json()["detail"]
|
||||
|
|
@ -236,9 +226,7 @@ def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_k
|
|||
|
||||
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
|
||||
monkeypatch.setattr(
|
||||
endpoints, "process_graph_cached_task", mock_process_graph_cached_task
|
||||
)
|
||||
monkeypatch.setattr(endpoints, "process_graph_cached_task", mock_process_graph_cached_task)
|
||||
|
||||
api_key = created_api_key.api_key
|
||||
headers = {"x-api-key": api_key}
|
||||
|
|
@ -510,9 +498,7 @@ def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_a
|
|||
|
||||
|
||||
@pytest.mark.async_test
|
||||
def test_vector_store_in_process(
|
||||
distributed_client, added_vector_store, created_api_key
|
||||
):
|
||||
def test_vector_store_in_process(distributed_client, added_vector_store, created_api_key):
|
||||
# Run the /api/v1/process/{flow_id} endpoint
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
post_data = {"inputs": {"input": "What is Langflow?"}}
|
||||
|
|
@ -563,9 +549,7 @@ def test_async_task_processing(distributed_client, added_flow, created_api_key):
|
|||
|
||||
# Test function without loop
|
||||
@pytest.mark.async_test
|
||||
def test_async_task_processing_vector_store(
|
||||
client, added_vector_store, created_api_key
|
||||
):
|
||||
def test_async_task_processing_vector_store(client, added_vector_store, created_api_key):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
post_data = {"inputs": {"input": "How do I upload examples?"}}
|
||||
|
||||
|
|
@ -594,6 +578,4 @@ def test_async_task_processing_vector_store(
|
|||
# Validate that the task completed successfully and the result is as expected
|
||||
assert "result" in task_status_json, task_status_json
|
||||
assert "output" in task_status_json["result"], task_status_json["result"]
|
||||
assert "Langflow" in task_status_json["result"]["output"], task_status_json[
|
||||
"result"
|
||||
]
|
||||
assert "Langflow" in task_status_json["result"]["output"], task_status_json["result"]
|
||||
|
|
|
|||
|
|
@ -31,17 +31,14 @@ def test_template_field_defaults(sample_template_field: TemplateField):
|
|||
assert sample_template_field.is_list is False
|
||||
assert sample_template_field.show is True
|
||||
assert sample_template_field.multiline is False
|
||||
assert sample_template_field.value is None
|
||||
assert sample_template_field.suffixes == []
|
||||
assert sample_template_field.value == ""
|
||||
assert sample_template_field.file_types == []
|
||||
assert sample_template_field.file_path is None
|
||||
assert sample_template_field.file_path == ""
|
||||
assert sample_template_field.password is False
|
||||
assert sample_template_field.name == "test_field"
|
||||
|
||||
|
||||
def test_template_to_dict(
|
||||
sample_template: Template, sample_template_field: TemplateField
|
||||
):
|
||||
def test_template_to_dict(sample_template: Template, sample_template_field: TemplateField):
|
||||
template_dict = sample_template.to_dict()
|
||||
assert template_dict["_type"] == "test_template"
|
||||
assert len(template_dict) == 2 # _type and test_field
|
||||
|
|
|
|||
|
|
@ -1,22 +1,26 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Type, Union
|
||||
from langflow.graph.edge.base import Edge
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langchain.agents import AgentExecutor
|
||||
|
||||
import pytest
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
|
||||
from langflow.graph import Graph
|
||||
from langflow.graph.vertex.types import (
|
||||
FileToolVertex,
|
||||
LLMVertex,
|
||||
ToolkitVertex,
|
||||
)
|
||||
from langflow.graph.edge.base import Edge
|
||||
from langflow.graph.graph.utils import (find_last_node, process_flow,
|
||||
set_new_target_handle, ungroup_node,
|
||||
update_source_handle,
|
||||
update_target_handle, update_template)
|
||||
from langflow.graph.utils import UnbuiltObject
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.vertex.types import (FileToolVertex, LLMVertex,
|
||||
ToolkitVertex)
|
||||
from langflow.processing.process import get_result_and_thought
|
||||
from langflow.utils.payload import get_root_node
|
||||
from langflow.utils.payload import get_root_vertex
|
||||
|
||||
# Test cases for the graph module
|
||||
|
||||
|
|
@ -24,21 +28,57 @@ from langflow.utils.payload import get_root_node
|
|||
# BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_template():
|
||||
return {
|
||||
"field1": {"proxy": {"field": "some_field", "id": "node1"}},
|
||||
"field2": {"proxy": {"field": "other_field", "id": "node2"}},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_nodes():
|
||||
return [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {"node": {"template": {"some_field": {"show": True, "advanced": False, "name": "Name1"}}}},
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"data": {
|
||||
"node": {
|
||||
"template": {
|
||||
"other_field": {
|
||||
"show": False,
|
||||
"advanced": True,
|
||||
"display_name": "DisplayName2",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "node3",
|
||||
"data": {"node": {"template": {"unrelated_field": {"show": True, "advanced": True}}}},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_node_by_type(graph, node_type: Type[Vertex]) -> Union[Vertex, None]:
|
||||
"""Get a node by type"""
|
||||
return next((node for node in graph.nodes if isinstance(node, node_type)), None)
|
||||
return next((node for node in graph.vertices if isinstance(node, node_type)), None)
|
||||
|
||||
|
||||
def test_graph_structure(basic_graph):
|
||||
assert isinstance(basic_graph, Graph)
|
||||
assert len(basic_graph.nodes) > 0
|
||||
assert len(basic_graph.vertices) > 0
|
||||
assert len(basic_graph.edges) > 0
|
||||
for node in basic_graph.nodes:
|
||||
for node in basic_graph.vertices:
|
||||
assert isinstance(node, Vertex)
|
||||
for edge in basic_graph.edges:
|
||||
assert isinstance(edge, Edge)
|
||||
assert edge.source in basic_graph.nodes
|
||||
assert edge.target in basic_graph.nodes
|
||||
assert edge.source_id in basic_graph.vertex_map.keys()
|
||||
assert edge.target_id in basic_graph.vertex_map.keys()
|
||||
|
||||
|
||||
def test_circular_dependencies(basic_graph):
|
||||
|
|
@ -46,7 +86,7 @@ def test_circular_dependencies(basic_graph):
|
|||
|
||||
def check_circular(node, visited):
|
||||
visited.add(node)
|
||||
neighbors = basic_graph.get_nodes_with_target(node)
|
||||
neighbors = basic_graph.get_vertices_with_target(node)
|
||||
for neighbor in neighbors:
|
||||
if neighbor in visited:
|
||||
return True
|
||||
|
|
@ -54,7 +94,7 @@ def test_circular_dependencies(basic_graph):
|
|||
return True
|
||||
return False
|
||||
|
||||
for node in basic_graph.nodes:
|
||||
for node in basic_graph.vertices:
|
||||
assert not check_circular(node, set())
|
||||
|
||||
|
||||
|
|
@ -79,13 +119,13 @@ def test_invalid_node_types():
|
|||
Graph(graph_data["nodes"], graph_data["edges"])
|
||||
|
||||
|
||||
def test_get_nodes_with_target(basic_graph):
|
||||
def test_get_vertices_with_target(basic_graph):
|
||||
"""Test getting connected nodes"""
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_node(basic_graph)
|
||||
root = get_root_vertex(basic_graph)
|
||||
assert root is not None
|
||||
connected_nodes = basic_graph.get_nodes_with_target(root)
|
||||
connected_nodes = basic_graph.get_vertices_with_target(root.id)
|
||||
assert connected_nodes is not None
|
||||
|
||||
|
||||
|
|
@ -94,23 +134,17 @@ def test_get_node_neighbors_basic(basic_graph):
|
|||
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_node(basic_graph)
|
||||
root = get_root_vertex(basic_graph)
|
||||
assert root is not None
|
||||
neighbors = basic_graph.get_node_neighbors(root)
|
||||
neighbors = basic_graph.get_vertex_neighbors(root)
|
||||
assert neighbors is not None
|
||||
assert isinstance(neighbors, dict)
|
||||
# Root Node is an Agent, it requires an LLMChain and tools
|
||||
# We need to check if there is a Chain in the one of the neighbors'
|
||||
# data attribute in the type key
|
||||
assert any(
|
||||
"ConversationBufferMemory" in neighbor.data["type"]
|
||||
for neighbor, val in neighbors.items()
|
||||
if val
|
||||
)
|
||||
assert any("ConversationBufferMemory" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
|
||||
|
||||
assert any(
|
||||
"OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val
|
||||
)
|
||||
assert any("OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
|
||||
|
||||
|
||||
# def test_get_node_neighbors_complex(complex_graph):
|
||||
|
|
@ -164,8 +198,8 @@ def test_get_node_neighbors_basic(basic_graph):
|
|||
|
||||
def test_get_node(basic_graph):
|
||||
"""Test getting a single node"""
|
||||
node_id = basic_graph.nodes[0].id
|
||||
node = basic_graph.get_node(node_id)
|
||||
node_id = basic_graph.vertices[0].id
|
||||
node = basic_graph.get_vertex(node_id)
|
||||
assert isinstance(node, Vertex)
|
||||
assert node.id == node_id
|
||||
|
||||
|
|
@ -173,8 +207,8 @@ def test_get_node(basic_graph):
|
|||
def test_build_nodes(basic_graph):
|
||||
"""Test building nodes"""
|
||||
|
||||
assert len(basic_graph.nodes) == len(basic_graph._nodes)
|
||||
for node in basic_graph.nodes:
|
||||
assert len(basic_graph.vertices) == len(basic_graph._vertices)
|
||||
for node in basic_graph.vertices:
|
||||
assert isinstance(node, Vertex)
|
||||
|
||||
|
||||
|
|
@ -183,20 +217,21 @@ def test_build_edges(basic_graph):
|
|||
assert len(basic_graph.edges) == len(basic_graph._edges)
|
||||
for edge in basic_graph.edges:
|
||||
assert isinstance(edge, Edge)
|
||||
assert isinstance(edge.source, Vertex)
|
||||
assert isinstance(edge.target, Vertex)
|
||||
|
||||
assert isinstance(edge.source_id, str)
|
||||
assert isinstance(edge.target_id, str)
|
||||
|
||||
|
||||
def test_get_root_node(client, basic_graph, complex_graph):
|
||||
def test_get_root_vertex(client, basic_graph, complex_graph):
|
||||
"""Test getting root node"""
|
||||
assert isinstance(basic_graph, Graph)
|
||||
root = get_root_node(basic_graph)
|
||||
root = get_root_vertex(basic_graph)
|
||||
assert root is not None
|
||||
assert isinstance(root, Vertex)
|
||||
assert root.data["type"] == "TimeTravelGuideChain"
|
||||
# For complex example, the root node is a ZeroShotAgent too
|
||||
assert isinstance(complex_graph, Graph)
|
||||
root = get_root_node(complex_graph)
|
||||
root = get_root_vertex(complex_graph)
|
||||
assert root is not None
|
||||
assert isinstance(root, Vertex)
|
||||
assert root.data["type"] == "ZeroShotAgent"
|
||||
|
|
@ -232,7 +267,7 @@ def test_build_params(basic_graph):
|
|||
# The matched_type attribute should be in the source_types attr
|
||||
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
|
||||
# Get the root node
|
||||
root = get_root_node(basic_graph)
|
||||
root = get_root_vertex(basic_graph)
|
||||
# Root node is a TimeTravelGuideChain
|
||||
# which requires an llm and memory
|
||||
assert root is not None
|
||||
|
|
@ -241,29 +276,32 @@ def test_build_params(basic_graph):
|
|||
assert "memory" in root.params
|
||||
|
||||
|
||||
def test_build(basic_graph):
|
||||
@pytest.mark.asyncio
|
||||
async def test_build(basic_graph):
|
||||
"""Test Node's build method"""
|
||||
assert_agent_was_built(basic_graph)
|
||||
await assert_agent_was_built(basic_graph)
|
||||
|
||||
|
||||
def assert_agent_was_built(graph):
|
||||
async def assert_agent_was_built(graph):
|
||||
"""Assert that the agent was built"""
|
||||
assert isinstance(graph, Graph)
|
||||
# Now we test the build method
|
||||
# Build the Agent
|
||||
result = graph.build()
|
||||
result = await graph.build()
|
||||
# The agent should be a AgentExecutor
|
||||
assert isinstance(result, Chain)
|
||||
|
||||
|
||||
def test_llm_node_build(basic_graph):
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_node_build(basic_graph):
|
||||
llm_node = get_node_by_type(basic_graph, LLMVertex)
|
||||
assert llm_node is not None
|
||||
built_object = llm_node.build()
|
||||
assert built_object is not None
|
||||
built_object = await llm_node.build()
|
||||
assert built_object is not UnbuiltObject()
|
||||
|
||||
|
||||
def test_toolkit_node_build(client, openapi_graph):
|
||||
@pytest.mark.asyncio
|
||||
async def test_toolkit_node_build(client, openapi_graph):
|
||||
# Write a file to the disk
|
||||
file_path = "api-with-examples.yaml"
|
||||
with open(file_path, "w") as f:
|
||||
|
|
@ -271,36 +309,31 @@ def test_toolkit_node_build(client, openapi_graph):
|
|||
|
||||
toolkit_node = get_node_by_type(openapi_graph, ToolkitVertex)
|
||||
assert toolkit_node is not None
|
||||
built_object = toolkit_node.build()
|
||||
assert built_object is not None
|
||||
built_object = await toolkit_node.build()
|
||||
assert built_object is not UnbuiltObject
|
||||
# Remove the file
|
||||
os.remove(file_path)
|
||||
assert not Path(file_path).exists()
|
||||
|
||||
|
||||
def test_file_tool_node_build(client, openapi_graph):
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_tool_node_build(client, openapi_graph):
|
||||
file_path = "api-with-examples.yaml"
|
||||
with open(file_path, "w") as f:
|
||||
f.write("openapi: 3.0.0")
|
||||
|
||||
assert Path(file_path).exists()
|
||||
file_tool_node = get_node_by_type(openapi_graph, FileToolVertex)
|
||||
assert file_tool_node is not None
|
||||
built_object = file_tool_node.build()
|
||||
assert built_object is not None
|
||||
assert file_tool_node is not UnbuiltObject and file_tool_node is not None
|
||||
built_object = await file_tool_node.build()
|
||||
assert built_object is not UnbuiltObject
|
||||
# Remove the file
|
||||
os.remove(file_path)
|
||||
assert not Path(file_path).exists()
|
||||
|
||||
|
||||
# def test_wrapper_node_build(openapi_graph):
|
||||
# wrapper_node = get_node_by_type(openapi_graph, WrapperVertex)
|
||||
# assert wrapper_node is not None
|
||||
# built_object = wrapper_node.build()
|
||||
# assert built_object is not None
|
||||
|
||||
|
||||
def test_get_result_and_thought(basic_graph):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_result_and_thought(basic_graph):
|
||||
"""Test the get_result_and_thought method"""
|
||||
responses = [
|
||||
"Final Answer: I am a response",
|
||||
|
|
@ -312,9 +345,9 @@ def test_get_result_and_thought(basic_graph):
|
|||
assert llm_node is not None
|
||||
llm_node._built_object = FakeListLLM(responses=responses)
|
||||
llm_node._built = True
|
||||
langchain_object = basic_graph.build()
|
||||
langchain_object = await basic_graph.build()
|
||||
# assert all nodes are built
|
||||
assert all(node._built for node in basic_graph.nodes)
|
||||
assert all(node._built for node in basic_graph.vertices)
|
||||
# now build again and check if FakeListLLM was used
|
||||
|
||||
# Get the result and thought
|
||||
|
|
@ -322,27 +355,204 @@ def test_get_result_and_thought(basic_graph):
|
|||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
def test_pickle_graph(json_vector_store):
|
||||
def test_find_last_node(grouped_chat_json_flow):
|
||||
grouped_chat_data = json.loads(grouped_chat_json_flow).get("data")
|
||||
nodes, edges = grouped_chat_data["nodes"], grouped_chat_data["edges"]
|
||||
last_node = find_last_node(nodes, edges)
|
||||
assert last_node is not None # Replace with the actual expected value
|
||||
assert last_node["id"] == "LLMChain-pimAb" # Replace with the actual expected value
|
||||
|
||||
|
||||
def test_ungroup_node(grouped_chat_json_flow):
|
||||
grouped_chat_data = json.loads(grouped_chat_json_flow).get("data")
|
||||
group_node = grouped_chat_data["nodes"][2] # Assuming the first node is a group node
|
||||
base_flow = copy.deepcopy(grouped_chat_data)
|
||||
ungroup_node(group_node["data"], base_flow)
|
||||
# after ungroup_node is called, the base_flow and grouped_chat_data should be different
|
||||
assert base_flow != grouped_chat_data
|
||||
# assert node 2 is not a group node anymore
|
||||
assert base_flow["nodes"][2]["data"]["node"].get("flow") is None
|
||||
# assert the edges are updated
|
||||
assert len(base_flow["edges"]) > len(grouped_chat_data["edges"])
|
||||
assert base_flow["edges"][0]["source"] == "ConversationBufferMemory-kUMif"
|
||||
assert base_flow["edges"][0]["target"] == "LLMChain-2P369"
|
||||
assert base_flow["edges"][1]["source"] == "PromptTemplate-Wjk4g"
|
||||
assert base_flow["edges"][1]["target"] == "LLMChain-2P369"
|
||||
assert base_flow["edges"][2]["source"] == "ChatOpenAI-rUJ1b"
|
||||
assert base_flow["edges"][2]["target"] == "LLMChain-2P369"
|
||||
|
||||
|
||||
def test_process_flow(grouped_chat_json_flow):
|
||||
grouped_chat_data = json.loads(grouped_chat_json_flow).get("data")
|
||||
|
||||
processed_flow = process_flow(grouped_chat_data)
|
||||
assert processed_flow is not None
|
||||
assert isinstance(processed_flow, dict)
|
||||
assert "nodes" in processed_flow
|
||||
assert "edges" in processed_flow
|
||||
|
||||
|
||||
def test_process_flow_one_group(one_grouped_chat_json_flow):
|
||||
grouped_chat_data = json.loads(one_grouped_chat_json_flow).get("data")
|
||||
# There should be only one node
|
||||
assert len(grouped_chat_data["nodes"]) == 1
|
||||
# Get the node, it should be a group node
|
||||
group_node = grouped_chat_data["nodes"][0]
|
||||
node_data = group_node["data"]["node"]
|
||||
assert node_data.get("flow") is not None
|
||||
template_data = node_data["template"]
|
||||
assert any("openai_api_key" in key for key in template_data.keys())
|
||||
# Get the openai_api_key dict
|
||||
openai_api_key = next(
|
||||
(template_data[key] for key in template_data.keys() if "openai_api_key" in key),
|
||||
None,
|
||||
)
|
||||
assert openai_api_key is not None
|
||||
assert openai_api_key["value"] == "test"
|
||||
|
||||
processed_flow = process_flow(grouped_chat_data)
|
||||
assert processed_flow is not None
|
||||
assert isinstance(processed_flow, dict)
|
||||
assert "nodes" in processed_flow
|
||||
assert "edges" in processed_flow
|
||||
|
||||
# Now get the node that has ChatOpenAI in its id
|
||||
chat_openai_node = next((node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None)
|
||||
assert chat_openai_node is not None
|
||||
assert chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] == "test"
|
||||
|
||||
|
||||
def test_process_flow_vector_store_grouped(vector_store_grouped_json_flow):
|
||||
grouped_chat_data = json.loads(vector_store_grouped_json_flow).get("data")
|
||||
nodes = grouped_chat_data["nodes"]
|
||||
assert len(nodes) == 4
|
||||
# There are two group nodes in this flow
|
||||
# One of them is inside the other totalling 7 nodes
|
||||
# 4 nodes grouped, one of these turns into 1 normal node and 1 group node
|
||||
# This group node has 2 nodes inside it
|
||||
|
||||
processed_flow = process_flow(grouped_chat_data)
|
||||
assert processed_flow is not None
|
||||
processed_nodes = processed_flow["nodes"]
|
||||
assert len(processed_nodes) == 7
|
||||
assert isinstance(processed_flow, dict)
|
||||
assert "nodes" in processed_flow
|
||||
assert "edges" in processed_flow
|
||||
edges = processed_flow["edges"]
|
||||
# Expected keywords in source and target fields
|
||||
expected_keywords = [
|
||||
{"source": "VectorStoreInfo", "target": "VectorStoreAgent"},
|
||||
{"source": "ChatOpenAI", "target": "VectorStoreAgent"},
|
||||
{"source": "OpenAIEmbeddings", "target": "Chroma"},
|
||||
{"source": "Chroma", "target": "VectorStoreInfo"},
|
||||
{"source": "WebBaseLoader", "target": "RecursiveCharacterTextSplitter"},
|
||||
{"source": "RecursiveCharacterTextSplitter", "target": "Chroma"},
|
||||
]
|
||||
|
||||
for idx, expected_keyword in enumerate(expected_keywords):
|
||||
for key, value in expected_keyword.items():
|
||||
assert (
|
||||
value in edges[idx][key].split("-")[0]
|
||||
), f"Edge {idx}, key {key} expected to contain {value} but got {edges[idx][key]}"
|
||||
|
||||
|
||||
def test_update_template(sample_template, sample_nodes):
|
||||
# Making a deep copy to keep original sample_nodes unchanged
|
||||
nodes_copy = copy.deepcopy(sample_nodes)
|
||||
update_template(sample_template, nodes_copy)
|
||||
|
||||
# Now, validate the updates.
|
||||
node1_updated = next((n for n in nodes_copy if n["id"] == "node1"), None)
|
||||
node2_updated = next((n for n in nodes_copy if n["id"] == "node2"), None)
|
||||
node3_updated = next((n for n in nodes_copy if n["id"] == "node3"), None)
|
||||
|
||||
assert node1_updated is not None
|
||||
assert node1_updated["data"]["node"]["template"]["some_field"]["show"] is True
|
||||
assert node1_updated["data"]["node"]["template"]["some_field"]["advanced"] is False
|
||||
assert node1_updated["data"]["node"]["template"]["some_field"]["display_name"] == "Name1"
|
||||
|
||||
assert node2_updated is not None
|
||||
assert node2_updated["data"]["node"]["template"]["other_field"]["show"] is False
|
||||
assert node2_updated["data"]["node"]["template"]["other_field"]["advanced"] is True
|
||||
assert node2_updated["data"]["node"]["template"]["other_field"]["display_name"] == "DisplayName2"
|
||||
|
||||
# Ensure node3 remains unchanged
|
||||
assert node3_updated == sample_nodes[2]
|
||||
|
||||
|
||||
# Test `update_target_handle`
|
||||
def test_update_target_handle_proxy():
|
||||
new_edge = {
|
||||
"data": {
|
||||
"targetHandle": {
|
||||
"type": "some_type",
|
||||
"proxy": {"id": "some_id", "field": ""},
|
||||
}
|
||||
}
|
||||
}
|
||||
g_nodes = [{"id": "some_id", "data": {"node": {"flow": None}}}]
|
||||
group_node_id = "group_id"
|
||||
updated_edge = update_target_handle(new_edge, g_nodes, group_node_id)
|
||||
assert updated_edge["data"]["targetHandle"] == new_edge["data"]["targetHandle"]
|
||||
|
||||
|
||||
# Test `set_new_target_handle`
|
||||
def test_set_new_target_handle():
|
||||
proxy_id = "proxy_id"
|
||||
new_edge = {"target": None, "data": {"targetHandle": {}}}
|
||||
target_handle = {"type": "type_1", "proxy": {"field": "field_1"}}
|
||||
node = {
|
||||
"data": {
|
||||
"node": {
|
||||
"flow": True,
|
||||
"template": {"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}},
|
||||
}
|
||||
}
|
||||
}
|
||||
set_new_target_handle(proxy_id, new_edge, target_handle, node)
|
||||
assert new_edge["target"] == "proxy_id"
|
||||
assert new_edge["data"]["targetHandle"]["fieldName"] == "field_1"
|
||||
assert new_edge["data"]["targetHandle"]["proxy"] == {
|
||||
"field": "new_field",
|
||||
"id": "new_id",
|
||||
}
|
||||
|
||||
|
||||
# Test `update_source_handle`
|
||||
def test_update_source_handle():
|
||||
new_edge = {"source": None, "data": {"sourceHandle": {"id": None}}}
|
||||
flow_data = {
|
||||
"nodes": [{"id": "some_node"}, {"id": "last_node"}],
|
||||
"edges": [{"source": "some_node"}],
|
||||
}
|
||||
updated_edge = update_source_handle(new_edge, flow_data["nodes"], flow_data["edges"])
|
||||
assert updated_edge["source"] == "last_node"
|
||||
assert updated_edge["data"]["sourceHandle"]["id"] == "last_node"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pickle_graph(json_vector_store):
|
||||
loaded_json = json.loads(json_vector_store)
|
||||
graph = Graph.from_payload(loaded_json)
|
||||
assert isinstance(graph, Graph)
|
||||
first_result = graph.build()
|
||||
first_result = await graph.build()
|
||||
assert isinstance(first_result, AgentExecutor)
|
||||
pickled = pickle.dumps(graph)
|
||||
assert pickled is not None
|
||||
assert pickled is not UnbuiltObject
|
||||
unpickled = pickle.loads(pickled)
|
||||
assert unpickled is not None
|
||||
result = unpickled.build()
|
||||
assert unpickled is not UnbuiltObject
|
||||
result = await unpickled.build()
|
||||
assert isinstance(result, AgentExecutor)
|
||||
|
||||
|
||||
def test_pickle_each_vertex(json_vector_store):
|
||||
@pytest.mark.asyncio
|
||||
async def test_pickle_each_vertex(json_vector_store):
|
||||
loaded_json = json.loads(json_vector_store)
|
||||
graph = Graph.from_payload(loaded_json)
|
||||
assert isinstance(graph, Graph)
|
||||
for vertex in graph.nodes:
|
||||
vertex.build()
|
||||
for vertex in graph.vertices:
|
||||
await vertex.build()
|
||||
pickled = pickle.dumps(vertex)
|
||||
assert pickled is not None
|
||||
assert pickled is not UnbuiltObject
|
||||
unpickled = pickle.loads(pickled)
|
||||
assert unpickled is not None
|
||||
assert unpickled is not UnbuiltObject
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["verbose"] == {
|
||||
"required": False,
|
||||
|
|
@ -35,6 +36,7 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["client"] == {
|
||||
"required": False,
|
||||
|
|
@ -48,6 +50,7 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["model_name"] == {
|
||||
"required": False,
|
||||
|
|
@ -69,6 +72,7 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": True,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
# Add more assertions for other properties here
|
||||
assert template["temperature"] == {
|
||||
|
|
@ -84,6 +88,8 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["max_tokens"] == {
|
||||
"required": False,
|
||||
|
|
@ -98,6 +104,7 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["top_p"] == {
|
||||
"required": False,
|
||||
|
|
@ -112,6 +119,8 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["frequency_penalty"] == {
|
||||
"required": False,
|
||||
|
|
@ -126,6 +135,8 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["presence_penalty"] == {
|
||||
"required": False,
|
||||
|
|
@ -140,6 +151,8 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["n"] == {
|
||||
"required": False,
|
||||
|
|
@ -154,6 +167,7 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["best_of"] == {
|
||||
"required": False,
|
||||
|
|
@ -168,6 +182,7 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["model_kwargs"] == {
|
||||
"required": False,
|
||||
|
|
@ -181,6 +196,7 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["openai_api_key"] == {
|
||||
"required": False,
|
||||
|
|
@ -196,6 +212,7 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["batch_size"] == {
|
||||
"required": False,
|
||||
|
|
@ -210,6 +227,7 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["request_timeout"] == {
|
||||
"required": False,
|
||||
|
|
@ -223,6 +241,8 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["logit_bias"] == {
|
||||
"required": False,
|
||||
|
|
@ -236,6 +256,7 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["max_retries"] == {
|
||||
"required": False,
|
||||
|
|
@ -243,13 +264,14 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"placeholder": "",
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
"value": 6,
|
||||
"value": 2,
|
||||
"password": False,
|
||||
"name": "max_retries",
|
||||
"type": "int",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["streaming"] == {
|
||||
"required": False,
|
||||
|
|
@ -264,6 +286,7 @@ def test_openai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -289,6 +312,7 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["client"] == {
|
||||
"required": False,
|
||||
|
|
@ -302,6 +326,7 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["model_name"] == {
|
||||
"required": False,
|
||||
|
|
@ -313,6 +338,7 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"password": False,
|
||||
"options": [
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-3.5-turbo",
|
||||
|
|
@ -323,6 +349,7 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": True,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["temperature"] == {
|
||||
"required": False,
|
||||
|
|
@ -337,6 +364,8 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["model_kwargs"] == {
|
||||
"required": False,
|
||||
|
|
@ -350,6 +379,7 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["openai_api_key"] == {
|
||||
"required": False,
|
||||
|
|
@ -365,6 +395,7 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["request_timeout"] == {
|
||||
"required": False,
|
||||
|
|
@ -378,6 +409,8 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["max_retries"] == {
|
||||
"required": False,
|
||||
|
|
@ -385,13 +418,14 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"placeholder": "",
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
"value": 6,
|
||||
"value": 2,
|
||||
"password": False,
|
||||
"name": "max_retries",
|
||||
"type": "int",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["streaming"] == {
|
||||
"required": False,
|
||||
|
|
@ -406,6 +440,7 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["n"] == {
|
||||
"required": False,
|
||||
|
|
@ -420,6 +455,7 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
assert template["max_tokens"] == {
|
||||
|
|
@ -434,6 +470,7 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
assert template["_type"] == "ChatOpenAI"
|
||||
assert (
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from langchain.chains.base import Chain
|
||||
from langflow.processing.process import load_flow_from_json
|
||||
from langflow.graph import Graph
|
||||
from langflow.utils.payload import get_root_node
|
||||
from langflow.processing.process import load_flow_from_json
|
||||
from langflow.utils.payload import get_root_vertex
|
||||
|
||||
|
||||
def test_load_flow_from_json():
|
||||
|
|
@ -22,14 +23,15 @@ def test_load_flow_from_json_with_tweaks():
|
|||
assert loaded.llm.model_name == "test model"
|
||||
|
||||
|
||||
def test_get_root_node():
|
||||
def test_get_root_vertex():
|
||||
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
root = get_root_node(graph)
|
||||
root = get_root_vertex(graph)
|
||||
assert root is not None
|
||||
assert hasattr(root, "id")
|
||||
assert hasattr(root, "data")
|
||||
assert hasattr(root, "data")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service
|
||||
from langflow.services.deps import get_db_service
|
||||
import pytest
|
||||
from langflow.services.database.models.user import User
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
|
|
@ -9,9 +9,7 @@ from langflow.services.auth.utils import get_password_hash
|
|||
def test_user():
|
||||
return User(
|
||||
username="testuser",
|
||||
password=get_password_hash(
|
||||
"testpassword"
|
||||
), # Assuming password needs to be hashed
|
||||
password=get_password_hash("testpassword"), # Assuming password needs to be hashed
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
|
@ -23,17 +21,13 @@ def test_login_successful(client, test_user):
|
|||
session.add(test_user)
|
||||
session.commit()
|
||||
|
||||
response = client.post(
|
||||
"api/v1/login", data={"username": "testuser", "password": "testpassword"}
|
||||
)
|
||||
response = client.post("api/v1/login", data={"username": "testuser", "password": "testpassword"})
|
||||
assert response.status_code == 200
|
||||
assert "access_token" in response.json()
|
||||
|
||||
|
||||
def test_login_unsuccessful_wrong_username(client):
|
||||
response = client.post(
|
||||
"api/v1/login", data={"username": "wrongusername", "password": "testpassword"}
|
||||
)
|
||||
response = client.post("api/v1/login", data={"username": "wrongusername", "password": "testpassword"})
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Incorrect username or password"
|
||||
|
||||
|
|
@ -43,8 +37,6 @@ def test_login_unsuccessful_wrong_password(client, test_user, session):
|
|||
session.add(test_user)
|
||||
session.commit()
|
||||
|
||||
response = client.post(
|
||||
"api/v1/login", data={"username": "testuser", "password": "wrongpassword"}
|
||||
)
|
||||
response = client.post("api/v1/login", data={"username": "testuser", "password": "wrongpassword"})
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Incorrect username or password"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
from langflow.processing.process import process_tweaks
|
||||
from langflow.services.getters import get_session_service
|
||||
from langflow.services.deps import get_session_service
|
||||
|
||||
|
||||
def test_no_tweaks():
|
||||
|
|
@ -197,39 +198,42 @@ def test_tweak_not_in_template():
|
|||
assert result == graph_data
|
||||
|
||||
|
||||
def test_load_langchain_object_with_cached_session(client, basic_graph_data):
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_langchain_object_with_cached_session(client, basic_graph_data):
|
||||
# Provide a non-existent session_id
|
||||
session_service = get_session_service()
|
||||
session_id1 = "non-existent-session-id"
|
||||
graph1, artifacts1 = session_service.load_session(session_id1, basic_graph_data)
|
||||
graph1, artifacts1 = await session_service.load_session(session_id1, basic_graph_data)
|
||||
# Use the new session_id to get the langchain_object again
|
||||
graph2, artifacts2 = session_service.load_session(session_id1, basic_graph_data)
|
||||
graph2, artifacts2 = await session_service.load_session(session_id1, basic_graph_data)
|
||||
|
||||
assert graph1 == graph2
|
||||
assert artifacts1 == artifacts2
|
||||
|
||||
|
||||
def test_load_langchain_object_with_no_cached_session(client, basic_graph_data):
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_langchain_object_with_no_cached_session(client, basic_graph_data):
|
||||
# Provide a non-existent session_id
|
||||
session_service = get_session_service()
|
||||
session_id1 = "non-existent-session-id"
|
||||
session_id = session_service.build_key(session_id1, basic_graph_data)
|
||||
graph1, artifacts1 = session_service.load_session(session_id, basic_graph_data)
|
||||
graph1, artifacts1 = await session_service.load_session(session_id, basic_graph_data)
|
||||
# Clear the cache
|
||||
session_service.clear_session(session_id)
|
||||
# Use the new session_id to get the langchain_object again
|
||||
graph2, artifacts2 = session_service.load_session(session_id, basic_graph_data)
|
||||
graph2, artifacts2 = await session_service.load_session(session_id, basic_graph_data)
|
||||
|
||||
assert id(graph1) != id(graph2)
|
||||
# Since the cache was cleared, objects should be different
|
||||
|
||||
|
||||
def test_load_langchain_object_without_session_id(client, basic_graph_data):
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_langchain_object_without_session_id(client, basic_graph_data):
|
||||
# Provide a non-existent session_id
|
||||
session_service = get_session_service()
|
||||
session_id1 = None
|
||||
graph1, artifacts1 = session_service.load_session(session_id1, basic_graph_data)
|
||||
graph1, artifacts1 = await session_service.load_session(session_id1, basic_graph_data)
|
||||
# Use the new session_id to get the langchain_object again
|
||||
graph2, artifacts2 = session_service.load_session(session_id1, basic_graph_data)
|
||||
graph2, artifacts2 = await session_service.load_session(session_id1, basic_graph_data)
|
||||
|
||||
assert graph1 == graph2
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from fastapi.testclient import TestClient
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
|
||||
def test_prompts_settings(client: TestClient, logged_in_headers):
|
||||
|
|
@ -31,6 +32,7 @@ def test_prompt_template(client: TestClient, logged_in_headers):
|
|||
"list": True,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
assert template["output_parser"] == {
|
||||
|
|
@ -45,6 +47,7 @@ def test_prompt_template(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
assert template["partial_variables"] == {
|
||||
|
|
@ -59,6 +62,7 @@ def test_prompt_template(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
assert template["template"] == {
|
||||
|
|
@ -73,6 +77,23 @@ def test_prompt_template(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
assert template["template_format"] == {
|
||||
"required": False,
|
||||
"dynamic": True,
|
||||
"placeholder": "",
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
"value": "f-string",
|
||||
"password": False,
|
||||
"name": "template_format",
|
||||
"type": "str",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
||||
assert template["validate_template"] == {
|
||||
|
|
@ -88,4 +109,5 @@ def test_prompt_template(client: TestClient, logged_in_headers):
|
|||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
"fileTypes": [],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,17 +1,11 @@
|
|||
from unittest.mock import patch, MagicMock
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.settings.constants import (
|
||||
DEFAULT_SUPERUSER,
|
||||
DEFAULT_SUPERUSER_PASSWORD,
|
||||
)
|
||||
from langflow.services.utils import (
|
||||
teardown_superuser,
|
||||
)
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langflow.services.settings.constants import DEFAULT_SUPERUSER, DEFAULT_SUPERUSER_PASSWORD
|
||||
from langflow.services.utils import teardown_superuser
|
||||
|
||||
# @patch("langflow.services.getters.get_session")
|
||||
# @patch("langflow.services.deps.get_session")
|
||||
# @patch("langflow.services.utils.create_super_user")
|
||||
# @patch("langflow.services.getters.get_settings_service")
|
||||
# @patch("langflow.services.deps.get_settings_service")
|
||||
# # @patch("langflow.services.utils.verify_password")
|
||||
# def test_setup_superuser(
|
||||
# mock_get_session, mock_create_super_user, mock_get_settings_service
|
||||
|
|
@ -92,11 +86,9 @@ from langflow.services.utils import (
|
|||
# assert str(actual_expr) == str(expected_expr)
|
||||
|
||||
|
||||
@patch("langflow.services.getters.get_settings_service")
|
||||
@patch("langflow.services.getters.get_session")
|
||||
def test_teardown_superuser_default_superuser(
|
||||
mock_get_session, mock_get_settings_service
|
||||
):
|
||||
@patch("langflow.services.deps.get_settings_service")
|
||||
@patch("langflow.services.deps.get_session")
|
||||
def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service):
|
||||
mock_settings_service = MagicMock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = True
|
||||
mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
|
||||
|
|
@ -111,20 +103,12 @@ def test_teardown_superuser_default_superuser(
|
|||
|
||||
teardown_superuser(mock_settings_service, mock_session)
|
||||
|
||||
mock_session.query.assert_called_once_with(User)
|
||||
actual_expr = mock_session.query.return_value.filter.call_args[0][0]
|
||||
expected_expr = User.username == DEFAULT_SUPERUSER
|
||||
|
||||
assert str(actual_expr) == str(expected_expr)
|
||||
mock_session.delete.assert_called_once_with(mock_user)
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.query.assert_not_called()
|
||||
|
||||
|
||||
@patch("langflow.services.getters.get_settings_service")
|
||||
@patch("langflow.services.getters.get_session")
|
||||
def test_teardown_superuser_no_default_superuser(
|
||||
mock_get_session, mock_get_settings_service
|
||||
):
|
||||
@patch("langflow.services.deps.get_settings_service")
|
||||
@patch("langflow.services.deps.get_session")
|
||||
def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service):
|
||||
ADMIN_USER_NAME = "admin_user"
|
||||
mock_settings_service = MagicMock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = False
|
||||
|
|
@ -135,11 +119,11 @@ def test_teardown_superuser_no_default_superuser(
|
|||
mock_session = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_user.is_superuser = False
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_session.exec.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_get_session.return_value = [mock_session]
|
||||
|
||||
teardown_superuser(mock_settings_service, mock_session)
|
||||
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.exec.assert_called_once()
|
||||
mock_session.delete.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -65,11 +65,9 @@ def test_build_template_from_function():
|
|||
assert "base_classes" in result
|
||||
|
||||
# Test with add_function=True
|
||||
result_with_function = build_template_from_function(
|
||||
"ExampleClass1", type_to_loader_dict, add_function=True
|
||||
)
|
||||
result_with_function = build_template_from_function("ExampleClass1", type_to_loader_dict, add_function=True)
|
||||
assert result_with_function is not None
|
||||
assert "function" in result_with_function["base_classes"]
|
||||
assert "Callable" in result_with_function["base_classes"]
|
||||
|
||||
# Test with invalid name
|
||||
with pytest.raises(ValueError, match=r".* not found"):
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from datetime import datetime
|
||||
from langflow.services.auth.utils import create_super_user, get_password_hash
|
||||
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service, get_settings_service
|
||||
import pytest
|
||||
|
||||
from langflow.services.auth.utils import create_super_user, get_password_hash
|
||||
from langflow.services.database.models.user import UserUpdate
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -85,15 +86,11 @@ def test_deactivated_user_cannot_access(client, deactivated_user, logged_in_head
|
|||
assert response.json()["detail"] == "The user doesn't have enough privileges"
|
||||
|
||||
|
||||
def test_data_consistency_after_update(
|
||||
client, active_user, logged_in_headers, super_user_headers
|
||||
):
|
||||
def test_data_consistency_after_update(client, active_user, logged_in_headers, super_user_headers):
|
||||
user_id = active_user.id
|
||||
update_data = UserUpdate(is_active=False)
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=super_user_headers
|
||||
)
|
||||
response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=super_user_headers)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
# Fetch the updated user from the database
|
||||
|
|
@ -120,7 +117,7 @@ def test_inactive_user(client):
|
|||
username="inactiveuser",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=False,
|
||||
last_login_at="2023-01-01T00:00:00", # Set to a valid datetime string
|
||||
last_login_at=datetime.now(),
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
|
@ -167,17 +164,13 @@ def test_patch_user(client, active_user, logged_in_headers):
|
|||
username="newname",
|
||||
)
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 200, response.json()
|
||||
update_data = UserUpdate(
|
||||
profile_image="new_image",
|
||||
)
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
|
||||
|
|
@ -189,7 +182,7 @@ def test_patch_reset_password(client, active_user, logged_in_headers):
|
|||
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}/reset-password",
|
||||
json=update_data.dict(),
|
||||
json=update_data.model_dump(),
|
||||
headers=logged_in_headers,
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
|
@ -205,19 +198,13 @@ def test_patch_user_wrong_id(client, active_user, logged_in_headers):
|
|||
username="newname",
|
||||
)
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 422, response.json()
|
||||
assert response.json() == {
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["path", "user_id"],
|
||||
"msg": "value is not a valid uuid",
|
||||
"type": "type_error.uuid",
|
||||
}
|
||||
]
|
||||
}
|
||||
json_response = response.json()
|
||||
detail = json_response["detail"]
|
||||
assert detail[0]["type"] == "uuid_parsing"
|
||||
assert detail[0]["loc"] == ["path", "user_id"]
|
||||
assert detail[0]["input"] == "wrong_id"
|
||||
|
||||
|
||||
def test_delete_user(client, test_user, super_user_headers):
|
||||
|
|
@ -231,15 +218,11 @@ def test_delete_user_wrong_id(client, test_user, super_user_headers):
|
|||
user_id = "wrong_id"
|
||||
response = client.delete(f"/api/v1/users/{user_id}", headers=super_user_headers)
|
||||
assert response.status_code == 422
|
||||
assert response.json() == {
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["path", "user_id"],
|
||||
"msg": "value is not a valid uuid",
|
||||
"type": "type_error.uuid",
|
||||
}
|
||||
]
|
||||
}
|
||||
json_response = response.json()
|
||||
detail = json_response["detail"]
|
||||
assert detail[0]["type"] == "uuid_parsing"
|
||||
assert detail[0]["loc"] == ["path", "user_id"]
|
||||
assert detail[0]["input"] == "wrong_id"
|
||||
|
||||
|
||||
def test_normal_user_cant_delete_user(client, test_user, logged_in_headers):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from fastapi.testclient import TestClient
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
|
||||
# check that all agents are in settings.agents
|
||||
|
|
|
|||
|
|
@ -31,9 +31,7 @@ def test_websocket_endpoint(client: TestClient, active_user, logged_in_headers):
|
|||
# Assuming your websocket_endpoint uses chat_service which caches data from stream_build
|
||||
access_token = logged_in_headers["Authorization"].split(" ")[1]
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
with client.websocket_connect(
|
||||
f"api/v1/chat/non_existing_client_id?token={access_token}"
|
||||
) as websocket:
|
||||
with client.websocket_connect(f"api/v1/chat/non_existing_client_id?token={access_token}") as websocket:
|
||||
websocket.send_json({"type": "test"})
|
||||
data = websocket.receive_json()
|
||||
assert "Please, build the flow before sending messages" in data["message"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue