feat: added edge validation
This commit is contained in:
parent
230f0d95e9
commit
cccbeeb4d2
4 changed files with 173 additions and 119 deletions
|
|
@ -28,7 +28,7 @@ def test_get_nodes_with_target():
|
|||
assert connected_nodes is not None
|
||||
|
||||
|
||||
def test_get_node_neighbors():
|
||||
def test_get_node_neighbors_basic():
|
||||
"""Test getting node neighbors"""
|
||||
|
||||
graph = get_graph(basic=True)
|
||||
|
|
@ -72,6 +72,46 @@ def test_get_node_neighbors():
|
|||
)
|
||||
|
||||
|
||||
def test_get_node_neighbors_complex():
|
||||
"""Test getting node neighbors"""
|
||||
|
||||
graph = get_graph(basic=False)
|
||||
assert isinstance(graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_node(graph)
|
||||
assert root is not None
|
||||
neighbors = graph.get_node_neighbors(root)
|
||||
assert neighbors is not None
|
||||
assert isinstance(neighbors, dict)
|
||||
# Root Node is an Agent, it requires an LLMChain and tools
|
||||
# We need to check if there is a Chain in the one of the neighbors'
|
||||
# data attribute in the type key
|
||||
assert any(
|
||||
"Chain" in neighbor.data["type"] for neighbor, val in neighbors.items() if val
|
||||
)
|
||||
# assert BaseTool is in the neighbors
|
||||
assert any(
|
||||
"BaseTool" in neighbor.data["type"]
|
||||
for neighbor, val in neighbors.items()
|
||||
if val
|
||||
)
|
||||
# Now on to the BaseTool's neighbors
|
||||
base_tool = next(
|
||||
neighbor
|
||||
for neighbor, val in neighbors.items()
|
||||
if "BaseTool" in neighbor.data["type"] and val
|
||||
)
|
||||
base_tool_neighbors = graph.get_node_neighbors(base_tool)
|
||||
assert base_tool_neighbors is not None
|
||||
assert isinstance(base_tool_neighbors, dict)
|
||||
# Check if there is an ZeroShotAgent in the base_tool's neighbors
|
||||
assert any(
|
||||
"ZeroShotAgent" in neighbor.data["type"]
|
||||
for neighbor, val in base_tool_neighbors.items()
|
||||
if val
|
||||
)
|
||||
|
||||
|
||||
def test_get_node():
|
||||
"""Test getting a single node"""
|
||||
graph = get_graph()
|
||||
|
|
@ -127,7 +167,7 @@ def test_build_json():
|
|||
assert isinstance(json_data["llm_chain"], dict)
|
||||
assert json_data["llm_chain"]["_type"] == "llm_chain"
|
||||
assert json_data["llm_chain"]["memory"] is None
|
||||
assert json_data["llm_chain"]["verbose"] is True
|
||||
assert json_data["llm_chain"]["verbose"] is False
|
||||
assert isinstance(json_data["llm_chain"]["prompt"], dict)
|
||||
assert isinstance(json_data["llm_chain"]["llm"], dict)
|
||||
assert json_data["llm_chain"]["output_key"] == "text"
|
||||
|
|
@ -135,3 +175,11 @@ def test_build_json():
|
|||
assert all(isinstance(tool, dict) for tool in json_data["allowed_tools"])
|
||||
assert isinstance(json_data["return_values"], list)
|
||||
assert all(isinstance(val, str) for val in json_data["return_values"])
|
||||
|
||||
|
||||
def test_validate_edges():
|
||||
"""Test validating edges"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
# all edges should be valid
|
||||
assert all(edge.valid for edge in graph.edges)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue