diff --git a/tests/test_graph.py b/tests/test_graph.py index c33d387da..e5a72e949 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -2,6 +2,7 @@ import json from langflow.utils.graph import Edge, Graph, Node import pytest from langflow.utils.payload import build_json, get_root_node +from langchain.agents import AgentExecutor # Test cases for the graph module @@ -183,3 +184,100 @@ def test_validate_edges(): assert isinstance(graph, Graph) # all edges should be valid assert all(edge.valid for edge in graph.edges) + + +def test_matched_type(): + """Test matched type attribute in Edge""" + graph = get_graph() + assert isinstance(graph, Graph) + # all edges should be valid + assert all(edge.valid for edge in graph.edges) + # all edges should have a matched_type attribute + assert all(hasattr(edge, "matched_type") for edge in graph.edges) + # The matched_type attribute should be in the source_types attr + assert all(edge.matched_type in edge.source_types for edge in graph.edges) + + +def test_build_params(): + """Test building params""" + graph = get_graph() + assert isinstance(graph, Graph) + # all edges should be valid + assert all(edge.valid for edge in graph.edges) + # all edges should have a matched_type attribute + assert all(hasattr(edge, "matched_type") for edge in graph.edges) + # The matched_type attribute should be in the source_types attr + assert all(edge.matched_type in edge.source_types for edge in graph.edges) + # Get the root node + root = get_root_node(graph) + # Root node is a ZeroShotAgent + # which requires an llm_chain, allowed_tools and return_values + assert isinstance(root.params, dict) + assert "llm_chain" in root.params + assert "allowed_tools" in root.params + assert "return_values" in root.params + # The llm_chain should be a Node + assert isinstance(root.params["llm_chain"], Node) + # The allowed_tools should be a list of Nodes + assert isinstance(root.params["allowed_tools"], list) + assert all(isinstance(tool, Node) for tool in root.params["allowed_tools"]) + # The return_values is of type str so it should be a list of strings + assert isinstance(root.params["return_values"], list) + assert all(isinstance(val, str) for val in root.params["return_values"]) + # The llm_chain should have a prompt and llm + llm_chain_node = root.params["llm_chain"] + assert isinstance(llm_chain_node.params, dict) + assert "prompt" in llm_chain_node.params + assert "llm" in llm_chain_node.params + # The prompt should be a Node + assert isinstance(llm_chain_node.params["prompt"], Node) + # The llm should be a Node + assert isinstance(llm_chain_node.params["llm"], Node) + # The prompt should have format_insctructions, suffix, prefix + prompt_node = llm_chain_node.params["prompt"] + assert isinstance(prompt_node.params, dict) + assert "format_instructions" in prompt_node.params + assert "suffix" in prompt_node.params + assert "prefix" in prompt_node.params + # All of them should be of type str + assert isinstance(prompt_node.params["format_instructions"], str) + assert isinstance(prompt_node.params["suffix"], str) + assert isinstance(prompt_node.params["prefix"], str) + # The llm should have a model + llm_node = llm_chain_node.params["llm"] + assert isinstance(llm_node.params, dict) + assert "model_name" in llm_node.params + # The model should be a str + assert isinstance(llm_node.params["model_name"], str) + + +def test_build(): + """Test Node's build method""" + # def build(self): + # # The params dict is used to build the module + # # it contains values and keys that point to nodes which + # # have their own params dict + # # When build is called, we iterate through the params dict + # # and if the value is a node, we call build on that node + # # and use the output of that build as the value for the param + # # if the value is not a node, then we use the value as the param + # # and continue + # # Another aspect is that the module_type is the class that we need to import + # # and instantiate with these built params + + # # Build each node in the params dict + # for key, value in self.params.items(): + # if isinstance(value, Node): + # self.params[key] = value.build() + + # # Get the class from LANGCHAIN_TYPES_DICT + # # and instantiate it with the params + # # and return the instance + # return LANGCHAIN_TYPES_DICT[self.module_type](**self.params) + graph = get_graph() + assert isinstance(graph, Graph) + # Now we test the build method + # Build the Agent + agent = graph.build() + # The agent should be a AgentExecutor + assert isinstance(agent, AgentExecutor) diff --git a/tests/test_loading.py b/tests/test_loading.py index 5f44e55f8..96a530921 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -64,9 +64,9 @@ def test_build_json_missing_child(): if isinstance(value, dict) and "required" in value: value["required"] = True - graph = Graph(nodes, edges) - root = get_root_node(graph) with pytest.raises(ValueError): + graph = Graph(nodes, edges) + root = get_root_node(graph) build_json(root, graph)