The changes include: - Moved the `Edge` class to a new `edge` package - Moved the `Graph` class to a new `graph` package - Moved the `Node` class to a new `node` package - Moved the `VectorStoreNode` class to the `node/types.py` module - Moved the `Edge`, `Graph`, and `Node` classes to their respective `base.py` modules - Added an `__init__.py` file to each package to allow for importing of classes - Added a `constants.py` module to the `graph` package to store constants used in the `Graph` class - Refactored the `Graph` class to use the new `Node` and `Edge` classes - Refactored the `Graph` class to use a dictionary to map node types to their respective classes - Refactored the `Graph` class to remove invalid nodes from the graph - Refactored the `Graph` class to handle the LLM node within the graph - Refactored the `Graph` class to build the nodes before building the edges - Refactored the `Graph` class to use the `get_node` method to find nodes by id - Refactored the `Graph` class to use the `get_node_neighbors` method to find the neighbors of a node - Refactored the `Graph` class to use the `get_children_by_node_type` method to find the children of a node based on the node type These changes improve the modularity and maintainability of the `langflow` package by separating the classes into their respective packages and modules. The changes also make it easier to add new node types to the `Graph` class by using a dictionary to map node types to their respective classes. 🚀 feat(node): add Node class to represent a node in the graph 🚀 feat(constants.py): add DIRECT_TYPES constant to represent direct types in a node's template The Node class represents a node in the graph and is responsible for parsing the data and building the module. The DIRECT_TYPES constant is a list of direct types in a node's template. 🚧 chore(types.py): add import statements for typing and Node classes This commit adds import statements for the typing module and the Node class to the types.py file. This is necessary for the code to run properly as it uses these classes and modules. 🚧 chore(loading.py): remove unnecessary import statement This commit removes an unnecessary import statement from the loading.py file. The import statement was causing a circular import error and was not needed for the code to run properly. 🚧 chore(run.py): update import statement for Graph class This commit updates the import statement for the Graph class in the run.py file. The import statement was outdated and was causing an import error. 🚧 chore(conftest.py): update import statement for Graph class This commit updates the import statement for the Graph class in the conftest.py file. The import statement was outdated and was causing an import error. 🚧 chore(test_graph.py): update import statements for Node and Edge classes This commit updates the import statements for the Node and Edge classes in the test_graph.py file. The import statements were outdated and were causing import errors.
344 lines
12 KiB
Python
344 lines
12 KiB
Python
from typing import Type, Union
|
|
from langflow.graph.edge.base import Edge
|
|
from langflow.graph.node.base import Node
|
|
|
|
import pytest
|
|
from langchain.chains.base import Chain
|
|
from langchain.llms.fake import FakeListLLM
|
|
from langflow.graph import Graph
|
|
from langflow.graph.node.types import (
|
|
AgentNode,
|
|
ChainNode,
|
|
FileToolNode,
|
|
LLMNode,
|
|
PromptNode,
|
|
ToolkitNode,
|
|
ToolNode,
|
|
WrapperNode,
|
|
)
|
|
from langflow.interface.run import get_result_and_thought
|
|
from langflow.utils.payload import get_root_node
|
|
|
|
# Test cases for the graph module
|
|
|
|
# now we have three types of graph:
|
|
# BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH
|
|
|
|
|
|
def get_node_by_type(graph, node_type: Type[Node]) -> Union[Node, None]:
|
|
"""Get a node by type"""
|
|
return next((node for node in graph.nodes if isinstance(node, node_type)), None)
|
|
|
|
|
|
def test_graph_structure(basic_graph):
|
|
assert isinstance(basic_graph, Graph)
|
|
assert len(basic_graph.nodes) > 0
|
|
assert len(basic_graph.edges) > 0
|
|
for node in basic_graph.nodes:
|
|
assert isinstance(node, Node)
|
|
for edge in basic_graph.edges:
|
|
assert isinstance(edge, Edge)
|
|
assert edge.source in basic_graph.nodes
|
|
assert edge.target in basic_graph.nodes
|
|
|
|
|
|
def test_circular_dependencies(basic_graph):
|
|
assert isinstance(basic_graph, Graph)
|
|
|
|
def check_circular(node, visited):
|
|
visited.add(node)
|
|
neighbors = basic_graph.get_nodes_with_target(node)
|
|
for neighbor in neighbors:
|
|
if neighbor in visited:
|
|
return True
|
|
if check_circular(neighbor, visited.copy()):
|
|
return True
|
|
return False
|
|
|
|
for node in basic_graph.nodes:
|
|
assert not check_circular(node, set())
|
|
|
|
|
|
def test_invalid_node_types():
|
|
graph_data = {
|
|
"nodes": [
|
|
{
|
|
"id": "1",
|
|
"data": {
|
|
"node": {
|
|
"base_classes": ["BaseClass"],
|
|
"template": {
|
|
"_type": "InvalidNodeType",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
],
|
|
"edges": [],
|
|
}
|
|
with pytest.raises(Exception):
|
|
Graph(graph_data["nodes"], graph_data["edges"])
|
|
|
|
|
|
def test_get_nodes_with_target(basic_graph):
|
|
"""Test getting connected nodes"""
|
|
assert isinstance(basic_graph, Graph)
|
|
# Get root node
|
|
root = get_root_node(basic_graph)
|
|
assert root is not None
|
|
connected_nodes = basic_graph.get_nodes_with_target(root)
|
|
assert connected_nodes is not None
|
|
|
|
|
|
def test_get_node_neighbors_basic(basic_graph):
|
|
"""Test getting node neighbors"""
|
|
|
|
assert isinstance(basic_graph, Graph)
|
|
# Get root node
|
|
root = get_root_node(basic_graph)
|
|
assert root is not None
|
|
neighbors = basic_graph.get_node_neighbors(root)
|
|
assert neighbors is not None
|
|
assert isinstance(neighbors, dict)
|
|
# Root Node is an Agent, it requires an LLMChain and tools
|
|
# We need to check if there is a Chain in the one of the neighbors'
|
|
# data attribute in the type key
|
|
assert any(
|
|
"ConversationBufferMemory" in neighbor.data["type"]
|
|
for neighbor, val in neighbors.items()
|
|
if val
|
|
)
|
|
|
|
assert any(
|
|
"OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val
|
|
)
|
|
|
|
|
|
def test_get_node_neighbors_complex(complex_graph):
|
|
"""Test getting node neighbors"""
|
|
assert isinstance(complex_graph, Graph)
|
|
# Get root node
|
|
root = get_root_node(complex_graph)
|
|
assert root is not None
|
|
neighbors = complex_graph.get_nodes_with_target(root)
|
|
assert neighbors is not None
|
|
# Neighbors should be a list of nodes
|
|
assert isinstance(neighbors, list)
|
|
# Root Node is an Agent, it requires an LLMChain and tools
|
|
# We need to check if there is a Chain in the one of the neighbors'
|
|
assert any("Chain" in neighbor.data["type"] for neighbor in neighbors)
|
|
# assert Tool is in the neighbors
|
|
assert any("Tool" in neighbor.data["type"] for neighbor in neighbors)
|
|
# Now on to the Chain's neighbors
|
|
chain = next(neighbor for neighbor in neighbors if "Chain" in neighbor.data["type"])
|
|
chain_neighbors = complex_graph.get_nodes_with_target(chain)
|
|
assert chain_neighbors is not None
|
|
# Check if there is a LLM in the chain's neighbors
|
|
assert any("OpenAI" in neighbor.data["type"] for neighbor in chain_neighbors)
|
|
# Chain should have a Prompt as a neighbor
|
|
assert any("Prompt" in neighbor.data["type"] for neighbor in chain_neighbors)
|
|
# Now on to the Tool's neighbors
|
|
tool = next(neighbor for neighbor in neighbors if "Tool" in neighbor.data["type"])
|
|
tool_neighbors = complex_graph.get_nodes_with_target(tool)
|
|
assert tool_neighbors is not None
|
|
# Check if there is an Agent in the tool's neighbors
|
|
assert any("Agent" in neighbor.data["type"] for neighbor in tool_neighbors)
|
|
# This Agent has a Tool that has a PythonFunction as func
|
|
agent = next(
|
|
neighbor for neighbor in tool_neighbors if "Agent" in neighbor.data["type"]
|
|
)
|
|
agent_neighbors = complex_graph.get_nodes_with_target(agent)
|
|
assert agent_neighbors is not None
|
|
# Check if there is a Tool in the agent's neighbors
|
|
assert any("Tool" in neighbor.data["type"] for neighbor in agent_neighbors)
|
|
# This Tool has a PythonFunction as func
|
|
tool = next(
|
|
neighbor for neighbor in agent_neighbors if "Tool" in neighbor.data["type"]
|
|
)
|
|
tool_neighbors = complex_graph.get_nodes_with_target(tool)
|
|
assert tool_neighbors is not None
|
|
# Check if there is a PythonFunction in the tool's neighbors
|
|
assert any("PythonFunction" in neighbor.data["type"] for neighbor in tool_neighbors)
|
|
|
|
|
|
def test_get_node(basic_graph):
|
|
"""Test getting a single node"""
|
|
node_id = basic_graph.nodes[0].id
|
|
node = basic_graph.get_node(node_id)
|
|
assert isinstance(node, Node)
|
|
assert node.id == node_id
|
|
|
|
|
|
def test_build_nodes(basic_graph):
|
|
"""Test building nodes"""
|
|
|
|
assert len(basic_graph.nodes) == len(basic_graph._nodes)
|
|
for node in basic_graph.nodes:
|
|
assert isinstance(node, Node)
|
|
|
|
|
|
def test_build_edges(basic_graph):
|
|
"""Test building edges"""
|
|
assert len(basic_graph.edges) == len(basic_graph._edges)
|
|
for edge in basic_graph.edges:
|
|
assert isinstance(edge, Edge)
|
|
assert isinstance(edge.source, Node)
|
|
assert isinstance(edge.target, Node)
|
|
|
|
|
|
def test_get_root_node(basic_graph, complex_graph):
|
|
"""Test getting root node"""
|
|
assert isinstance(basic_graph, Graph)
|
|
root = get_root_node(basic_graph)
|
|
assert root is not None
|
|
assert isinstance(root, Node)
|
|
assert root.data["type"] == "TimeTravelGuideChain"
|
|
# For complex example, the root node is a ZeroShotAgent too
|
|
assert isinstance(complex_graph, Graph)
|
|
root = get_root_node(complex_graph)
|
|
assert root is not None
|
|
assert isinstance(root, Node)
|
|
assert root.data["type"] == "ZeroShotAgent"
|
|
|
|
|
|
def test_validate_edges(basic_graph):
|
|
"""Test validating edges"""
|
|
|
|
assert isinstance(basic_graph, Graph)
|
|
# all edges should be valid
|
|
assert all(edge.valid for edge in basic_graph.edges)
|
|
|
|
|
|
def test_matched_type(basic_graph):
|
|
"""Test matched type attribute in Edge"""
|
|
assert isinstance(basic_graph, Graph)
|
|
# all edges should be valid
|
|
assert all(edge.valid for edge in basic_graph.edges)
|
|
# all edges should have a matched_type attribute
|
|
assert all(hasattr(edge, "matched_type") for edge in basic_graph.edges)
|
|
# The matched_type attribute should be in the source_types attr
|
|
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
|
|
|
|
|
|
def test_build_params(basic_graph):
|
|
"""Test building params"""
|
|
|
|
assert isinstance(basic_graph, Graph)
|
|
# all edges should be valid
|
|
assert all(edge.valid for edge in basic_graph.edges)
|
|
# all edges should have a matched_type attribute
|
|
assert all(hasattr(edge, "matched_type") for edge in basic_graph.edges)
|
|
# The matched_type attribute should be in the source_types attr
|
|
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
|
|
# Get the root node
|
|
root = get_root_node(basic_graph)
|
|
# Root node is a TimeTravelGuideChain
|
|
# which requires an llm and memory
|
|
assert isinstance(root.params, dict)
|
|
assert "llm" in root.params
|
|
assert "memory" in root.params
|
|
|
|
|
|
def test_build(basic_graph, complex_graph, openapi_graph):
|
|
"""Test Node's build method"""
|
|
assert_agent_was_built(basic_graph)
|
|
assert_agent_was_built(complex_graph)
|
|
assert_agent_was_built(openapi_graph)
|
|
|
|
|
|
def assert_agent_was_built(graph):
|
|
"""Assert that the agent was built"""
|
|
assert isinstance(graph, Graph)
|
|
# Now we test the build method
|
|
# Build the Agent
|
|
result = graph.build()
|
|
# The agent should be a AgentExecutor
|
|
assert isinstance(result, Chain)
|
|
|
|
|
|
def test_agent_node_build(complex_graph):
|
|
agent_node = get_node_by_type(complex_graph, AgentNode)
|
|
assert agent_node is not None
|
|
built_object = agent_node.build()
|
|
assert built_object is not None
|
|
|
|
|
|
def test_tool_node_build(complex_graph):
|
|
tool_node = get_node_by_type(complex_graph, ToolNode)
|
|
assert tool_node is not None
|
|
built_object = tool_node.build()
|
|
assert built_object is not None
|
|
# Add any further assertions specific to the ToolNode's build() method
|
|
|
|
|
|
def test_chain_node_build(complex_graph):
|
|
chain_node = get_node_by_type(complex_graph, ChainNode)
|
|
assert chain_node is not None
|
|
built_object = chain_node.build()
|
|
assert built_object is not None
|
|
# Add any further assertions specific to the ChainNode's build() method
|
|
|
|
|
|
def test_prompt_node_build(complex_graph):
|
|
prompt_node = get_node_by_type(complex_graph, PromptNode)
|
|
assert prompt_node is not None
|
|
built_object = prompt_node.build()
|
|
assert built_object is not None
|
|
# Add any further assertions specific to the PromptNode's build() method
|
|
|
|
|
|
def test_llm_node_build(basic_graph):
|
|
llm_node = get_node_by_type(basic_graph, LLMNode)
|
|
assert llm_node is not None
|
|
built_object = llm_node.build()
|
|
assert built_object is not None
|
|
# Add any further assertions specific to the LLMNode's build() method
|
|
|
|
|
|
def test_toolkit_node_build(openapi_graph):
|
|
toolkit_node = get_node_by_type(openapi_graph, ToolkitNode)
|
|
assert toolkit_node is not None
|
|
built_object = toolkit_node.build()
|
|
assert built_object is not None
|
|
# Add any further assertions specific to the ToolkitNode's build() method
|
|
|
|
|
|
def test_file_tool_node_build(openapi_graph):
|
|
file_tool_node = get_node_by_type(openapi_graph, FileToolNode)
|
|
assert file_tool_node is not None
|
|
built_object = file_tool_node.build()
|
|
assert built_object is not None
|
|
# Add any further assertions specific to the FileToolNode's build() method
|
|
|
|
|
|
def test_wrapper_node_build(openapi_graph):
|
|
wrapper_node = get_node_by_type(openapi_graph, WrapperNode)
|
|
assert wrapper_node is not None
|
|
built_object = wrapper_node.build()
|
|
assert built_object is not None
|
|
# Add any further assertions specific to the WrapperNode's build() method
|
|
|
|
|
|
def test_get_result_and_thought(basic_graph):
|
|
"""Test the get_result_and_thought method"""
|
|
responses = [
|
|
"Final Answer: I am a response",
|
|
]
|
|
message = "Hello"
|
|
# Find the node that is an LLMNode and change the
|
|
# _built_object to a FakeListLLM
|
|
llm_node = get_node_by_type(basic_graph, LLMNode)
|
|
assert llm_node is not None
|
|
llm_node._built_object = FakeListLLM(responses=responses)
|
|
llm_node._built = True
|
|
langchain_object = basic_graph.build()
|
|
# assert all nodes are built
|
|
assert all(node._built for node in basic_graph.nodes)
|
|
# now build again and check if FakeListLLM was used
|
|
|
|
# Get the result and thought
|
|
result, thought = get_result_and_thought(langchain_object, message)
|
|
# The result should be a str
|
|
assert isinstance(result, str)
|
|
# The thought should be a Thought
|
|
assert isinstance(thought, str)
|