feat: added more tests for nodes
This commit is contained in:
parent
48d2ab27da
commit
0fc454f9b7
3 changed files with 662 additions and 71 deletions
|
|
@ -11,6 +11,9 @@ def pytest_configure():
|
|||
pytest.COMPLEX_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "complex_example.json"
|
||||
)
|
||||
pytest.OPENAPI_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "Openapi.json"
|
||||
)
|
||||
|
||||
pytest.CODE_WITH_SYNTAX_ERROR = """
|
||||
def get_text():
|
||||
|
|
|
|||
445
tests/data/Openapi.json
Normal file
445
tests/data/Openapi.json
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -1,16 +1,36 @@
|
|||
import json
|
||||
from langflow.graph.nodes import (
|
||||
WrapperNode,
|
||||
AgentNode,
|
||||
ToolNode,
|
||||
ChainNode,
|
||||
PromptNode,
|
||||
LLMNode,
|
||||
ToolkitNode,
|
||||
FileToolNode,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from langchain.agents import AgentExecutor
|
||||
from langflow.graph import Edge, Graph, Node
|
||||
from langflow.utils.payload import build_json, get_root_node
|
||||
|
||||
|
||||
# Test cases for the graph module
|
||||
|
||||
# now we have three types of graph:
|
||||
# BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH
|
||||
|
||||
def get_graph(basic=True):
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
path = pytest.BASIC_EXAMPLE_PATH if basic else pytest.COMPLEX_EXAMPLE_PATH
|
||||
if _type == "basic":
|
||||
path = pytest.BASIC_EXAMPLE_PATH
|
||||
elif _type == "complex":
|
||||
path = pytest.COMPLEX_EXAMPLE_PATH
|
||||
elif _type == "openapi":
|
||||
path = pytest.OPENAPI_EXAMPLE_PATH
|
||||
|
||||
with open(path, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
|
|
@ -19,26 +39,94 @@ def get_graph(basic=True):
|
|||
return Graph(nodes, edges)
|
||||
|
||||
|
||||
def test_get_nodes_with_target():
|
||||
@pytest.fixture
|
||||
def basic_graph():
|
||||
return get_graph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_graph():
|
||||
return get_graph("complex")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_graph():
|
||||
return get_graph("openapi")
|
||||
|
||||
|
||||
def get_node_by_type(graph, node_type):
|
||||
"""Get a node by type"""
|
||||
return next((node for node in graph.nodes if isinstance(node, node_type)), None)
|
||||
|
||||
|
||||
def test_graph_structure(basic_graph):
|
||||
assert isinstance(basic_graph, Graph)
|
||||
assert len(basic_graph.nodes) > 0
|
||||
assert len(basic_graph.edges) > 0
|
||||
for node in basic_graph.nodes:
|
||||
assert isinstance(node, Node)
|
||||
for edge in basic_graph.edges:
|
||||
assert isinstance(edge, Edge)
|
||||
assert edge.source in basic_graph.nodes
|
||||
assert edge.target in basic_graph.nodes
|
||||
|
||||
|
||||
def test_circular_dependencies(basic_graph):
|
||||
assert isinstance(basic_graph, Graph)
|
||||
|
||||
def check_circular(node, visited):
|
||||
visited.add(node)
|
||||
neighbors = basic_graph.get_nodes_with_target(node)
|
||||
for neighbor in neighbors:
|
||||
if neighbor in visited:
|
||||
return True
|
||||
if check_circular(neighbor, visited.copy()):
|
||||
return True
|
||||
return False
|
||||
|
||||
for node in basic_graph.nodes:
|
||||
assert not check_circular(node, set())
|
||||
|
||||
|
||||
def test_invalid_node_types():
|
||||
graph_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"data": {
|
||||
"node": {
|
||||
"base_classes": ["BaseClass"],
|
||||
"template": {
|
||||
"_type": "InvalidNodeType",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
with pytest.raises(Exception):
|
||||
Graph(graph_data["nodes"], graph_data["edges"])
|
||||
|
||||
|
||||
def test_get_nodes_with_target(basic_graph):
|
||||
"""Test getting connected nodes"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_node(graph)
|
||||
root = get_root_node(basic_graph)
|
||||
assert root is not None
|
||||
connected_nodes = graph.get_nodes_with_target(root)
|
||||
connected_nodes = basic_graph.get_nodes_with_target(root)
|
||||
assert connected_nodes is not None
|
||||
|
||||
|
||||
def test_get_node_neighbors_basic():
|
||||
def test_get_node_neighbors_basic(basic_graph):
|
||||
"""Test getting node neighbors"""
|
||||
|
||||
graph = get_graph(basic=True)
|
||||
assert isinstance(graph, Graph)
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_node(graph)
|
||||
root = get_root_node(basic_graph)
|
||||
assert root is not None
|
||||
neighbors = graph.get_node_neighbors(root)
|
||||
neighbors = basic_graph.get_node_neighbors(root)
|
||||
assert neighbors is not None
|
||||
assert isinstance(neighbors, dict)
|
||||
# Root Node is an Agent, it requires an LLMChain and tools
|
||||
|
|
@ -57,7 +145,7 @@ def test_get_node_neighbors_basic():
|
|||
for neighbor, val in neighbors.items()
|
||||
if "Chain" in neighbor.data["type"] and val
|
||||
)
|
||||
chain_neighbors = graph.get_node_neighbors(chain)
|
||||
chain_neighbors = basic_graph.get_node_neighbors(chain)
|
||||
assert chain_neighbors is not None
|
||||
assert isinstance(chain_neighbors, dict)
|
||||
# Check if there is a LLM in the chain's neighbors
|
||||
|
|
@ -74,15 +162,13 @@ def test_get_node_neighbors_basic():
|
|||
)
|
||||
|
||||
|
||||
def test_get_node_neighbors_complex():
|
||||
def test_get_node_neighbors_complex(complex_graph):
|
||||
"""Test getting node neighbors"""
|
||||
|
||||
graph = get_graph(basic=False)
|
||||
assert isinstance(graph, Graph)
|
||||
assert isinstance(complex_graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_node(graph)
|
||||
root = get_root_node(complex_graph)
|
||||
assert root is not None
|
||||
neighbors = graph.get_nodes_with_target(root)
|
||||
neighbors = complex_graph.get_nodes_with_target(root)
|
||||
assert neighbors is not None
|
||||
# Neighbors should be a list of nodes
|
||||
assert isinstance(neighbors, list)
|
||||
|
|
@ -93,7 +179,7 @@ def test_get_node_neighbors_complex():
|
|||
assert any("Tool" in neighbor.data["type"] for neighbor in neighbors)
|
||||
# Now on to the Chain's neighbors
|
||||
chain = next(neighbor for neighbor in neighbors if "Chain" in neighbor.data["type"])
|
||||
chain_neighbors = graph.get_nodes_with_target(chain)
|
||||
chain_neighbors = complex_graph.get_nodes_with_target(chain)
|
||||
assert chain_neighbors is not None
|
||||
# Check if there is a LLM in the chain's neighbors
|
||||
assert any("OpenAI" in neighbor.data["type"] for neighbor in chain_neighbors)
|
||||
|
|
@ -101,7 +187,7 @@ def test_get_node_neighbors_complex():
|
|||
assert any("Prompt" in neighbor.data["type"] for neighbor in chain_neighbors)
|
||||
# Now on to the Tool's neighbors
|
||||
tool = next(neighbor for neighbor in neighbors if "Tool" in neighbor.data["type"])
|
||||
tool_neighbors = graph.get_nodes_with_target(tool)
|
||||
tool_neighbors = complex_graph.get_nodes_with_target(tool)
|
||||
assert tool_neighbors is not None
|
||||
# Check if there is an Agent in the tool's neighbors
|
||||
assert any("Agent" in neighbor.data["type"] for neighbor in tool_neighbors)
|
||||
|
|
@ -109,7 +195,7 @@ def test_get_node_neighbors_complex():
|
|||
agent = next(
|
||||
neighbor for neighbor in tool_neighbors if "Agent" in neighbor.data["type"]
|
||||
)
|
||||
agent_neighbors = graph.get_nodes_with_target(agent)
|
||||
agent_neighbors = complex_graph.get_nodes_with_target(agent)
|
||||
assert agent_neighbors is not None
|
||||
# Check if there is a Tool in the agent's neighbors
|
||||
assert any("Tool" in neighbor.data["type"] for neighbor in agent_neighbors)
|
||||
|
|
@ -117,62 +203,57 @@ def test_get_node_neighbors_complex():
|
|||
tool = next(
|
||||
neighbor for neighbor in agent_neighbors if "Tool" in neighbor.data["type"]
|
||||
)
|
||||
tool_neighbors = graph.get_nodes_with_target(tool)
|
||||
tool_neighbors = complex_graph.get_nodes_with_target(tool)
|
||||
assert tool_neighbors is not None
|
||||
# Check if there is a PythonFunction in the tool's neighbors
|
||||
assert any("PythonFunction" in neighbor.data["type"] for neighbor in tool_neighbors)
|
||||
|
||||
|
||||
def test_get_node():
|
||||
def test_get_node(basic_graph):
|
||||
"""Test getting a single node"""
|
||||
graph = get_graph()
|
||||
node_id = graph.nodes[0].id
|
||||
node = graph.get_node(node_id)
|
||||
node_id = basic_graph.nodes[0].id
|
||||
node = basic_graph.get_node(node_id)
|
||||
assert isinstance(node, Node)
|
||||
assert node.id == node_id
|
||||
|
||||
|
||||
def test_build_nodes():
|
||||
def test_build_nodes(basic_graph):
|
||||
"""Test building nodes"""
|
||||
graph = get_graph()
|
||||
assert len(graph.nodes) == len(graph._nodes)
|
||||
for node in graph.nodes:
|
||||
|
||||
assert len(basic_graph.nodes) == len(basic_graph._nodes)
|
||||
for node in basic_graph.nodes:
|
||||
assert isinstance(node, Node)
|
||||
|
||||
|
||||
def test_build_edges():
|
||||
def test_build_edges(basic_graph):
|
||||
"""Test building edges"""
|
||||
graph = get_graph()
|
||||
assert len(graph.edges) == len(graph._edges)
|
||||
for edge in graph.edges:
|
||||
assert len(basic_graph.edges) == len(basic_graph._edges)
|
||||
for edge in basic_graph.edges:
|
||||
assert isinstance(edge, Edge)
|
||||
assert isinstance(edge.source, Node)
|
||||
assert isinstance(edge.target, Node)
|
||||
|
||||
|
||||
def test_get_root_node():
|
||||
def test_get_root_node(basic_graph, complex_graph):
|
||||
"""Test getting root node"""
|
||||
graph = get_graph(basic=True)
|
||||
assert isinstance(graph, Graph)
|
||||
root = get_root_node(graph)
|
||||
assert isinstance(basic_graph, Graph)
|
||||
root = get_root_node(basic_graph)
|
||||
assert root is not None
|
||||
assert isinstance(root, Node)
|
||||
assert root.data["type"] == "ZeroShotAgent"
|
||||
# For complex example, the root node is a ZeroShotAgent too
|
||||
graph = get_graph(basic=False)
|
||||
assert isinstance(graph, Graph)
|
||||
root = get_root_node(graph)
|
||||
assert isinstance(complex_graph, Graph)
|
||||
root = get_root_node(complex_graph)
|
||||
assert root is not None
|
||||
assert isinstance(root, Node)
|
||||
assert root.data["type"] == "ZeroShotAgent"
|
||||
|
||||
|
||||
def test_build_json():
|
||||
def test_build_json(basic_graph):
|
||||
"""Test building JSON from graph"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
root = get_root_node(graph)
|
||||
json_data = build_json(root, graph)
|
||||
assert isinstance(basic_graph, Graph)
|
||||
root = get_root_node(basic_graph)
|
||||
json_data = build_json(root, basic_graph)
|
||||
assert isinstance(json_data, dict)
|
||||
assert json_data["_type"] == "zero-shot-react-description"
|
||||
assert isinstance(json_data["llm_chain"], dict)
|
||||
|
|
@ -188,38 +269,37 @@ def test_build_json():
|
|||
assert all(isinstance(val, str) for val in json_data["return_values"])
|
||||
|
||||
|
||||
def test_validate_edges():
|
||||
def test_validate_edges(basic_graph):
|
||||
"""Test validating edges"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# all edges should be valid
|
||||
assert all(edge.valid for edge in graph.edges)
|
||||
assert all(edge.valid for edge in basic_graph.edges)
|
||||
|
||||
|
||||
def test_matched_type():
|
||||
def test_matched_type(basic_graph):
|
||||
"""Test matched type attribute in Edge"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# all edges should be valid
|
||||
assert all(edge.valid for edge in graph.edges)
|
||||
assert all(edge.valid for edge in basic_graph.edges)
|
||||
# all edges should have a matched_type attribute
|
||||
assert all(hasattr(edge, "matched_type") for edge in graph.edges)
|
||||
assert all(hasattr(edge, "matched_type") for edge in basic_graph.edges)
|
||||
# The matched_type attribute should be in the source_types attr
|
||||
assert all(edge.matched_type in edge.source_types for edge in graph.edges)
|
||||
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
|
||||
|
||||
|
||||
def test_build_params():
|
||||
def test_build_params(basic_graph):
|
||||
"""Test building params"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# all edges should be valid
|
||||
assert all(edge.valid for edge in graph.edges)
|
||||
assert all(edge.valid for edge in basic_graph.edges)
|
||||
# all edges should have a matched_type attribute
|
||||
assert all(hasattr(edge, "matched_type") for edge in graph.edges)
|
||||
assert all(hasattr(edge, "matched_type") for edge in basic_graph.edges)
|
||||
# The matched_type attribute should be in the source_types attr
|
||||
assert all(edge.matched_type in edge.source_types for edge in graph.edges)
|
||||
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
|
||||
# Get the root node
|
||||
root = get_root_node(graph)
|
||||
root = get_root_node(basic_graph)
|
||||
# Root node is a ZeroShotAgent
|
||||
# which requires an llm_chain, allowed_tools and return_values
|
||||
assert isinstance(root.params, dict)
|
||||
|
|
@ -261,7 +341,7 @@ def test_build_params():
|
|||
assert isinstance(llm_node.params["model_name"], str)
|
||||
|
||||
|
||||
def test_build():
|
||||
def test_build(basic_graph, complex_graph):
|
||||
"""Test Node's build method"""
|
||||
# def build(self):
|
||||
# # The params dict is used to build the module
|
||||
|
|
@ -284,18 +364,81 @@ def test_build():
|
|||
# # and instantiate it with the params
|
||||
# # and return the instance
|
||||
# return LANGCHAIN_TYPES_DICT[self.node_type](**self.params)
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# Now we test the build method
|
||||
# Build the Agent
|
||||
agent = graph.build()
|
||||
agent = basic_graph.build()
|
||||
# The agent should be a AgentExecutor
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
|
||||
# Now we test the complex example
|
||||
graph = get_graph(basic=False)
|
||||
assert isinstance(graph, Graph)
|
||||
assert isinstance(complex_graph, Graph)
|
||||
# Now we test the build method
|
||||
agent = graph.build()
|
||||
agent = complex_graph.build()
|
||||
# The agent should be a AgentExecutor
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
|
||||
|
||||
def test_agent_node_build(basic_graph):
|
||||
agent_node = get_node_by_type(basic_graph, AgentNode)
|
||||
assert agent_node is not None
|
||||
built_object = agent_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the AgentNode's build() method
|
||||
|
||||
|
||||
def test_tool_node_build(basic_graph):
|
||||
tool_node = get_node_by_type(basic_graph, ToolNode)
|
||||
assert tool_node is not None
|
||||
built_object = tool_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the ToolNode's build() method
|
||||
|
||||
|
||||
def test_chain_node_build(complex_graph):
|
||||
chain_node = get_node_by_type(complex_graph, ChainNode)
|
||||
assert chain_node is not None
|
||||
built_object = chain_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the ChainNode's build() method
|
||||
|
||||
|
||||
def test_prompt_node_build(complex_graph):
|
||||
prompt_node = get_node_by_type(complex_graph, PromptNode)
|
||||
assert prompt_node is not None
|
||||
built_object = prompt_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the PromptNode's build() method
|
||||
|
||||
|
||||
def test_llm_node_build(basic_graph):
|
||||
llm_node = get_node_by_type(basic_graph, LLMNode)
|
||||
assert llm_node is not None
|
||||
built_object = llm_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the LLMNode's build() method
|
||||
|
||||
|
||||
def test_toolkit_node_build(openapi_graph):
|
||||
toolkit_node = get_node_by_type(openapi_graph, ToolkitNode)
|
||||
assert toolkit_node is not None
|
||||
built_object = toolkit_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the ToolkitNode's build() method
|
||||
|
||||
|
||||
def test_file_tool_node_build(openapi_graph):
|
||||
file_tool_node = get_node_by_type(openapi_graph, FileToolNode)
|
||||
assert file_tool_node is not None
|
||||
built_object = file_tool_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the FileToolNode's build() method
|
||||
|
||||
|
||||
def test_wrapper_node_build(openapi_graph):
|
||||
wrapper_node = get_node_by_type(openapi_graph, WrapperNode)
|
||||
assert wrapper_node is not None
|
||||
built_object = wrapper_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the WrapperNode's build() method
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue