feat: added tests for building the langchain obj
This commit is contained in:
parent
cccbeeb4d2
commit
8ccc22086b
2 changed files with 100 additions and 2 deletions
|
|
@ -2,6 +2,7 @@ import json
|
|||
from langflow.utils.graph import Edge, Graph, Node
|
||||
import pytest
|
||||
from langflow.utils.payload import build_json, get_root_node
|
||||
from langchain.agents import AgentExecutor
|
||||
|
||||
# Test cases for the graph module
|
||||
|
||||
|
|
@ -183,3 +184,100 @@ def test_validate_edges():
|
|||
assert isinstance(graph, Graph)
|
||||
# all edges should be valid
|
||||
assert all(edge.valid for edge in graph.edges)
|
||||
|
||||
|
||||
def test_matched_type():
|
||||
"""Test matched type attribute in Edge"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
# all edges should be valid
|
||||
assert all(edge.valid for edge in graph.edges)
|
||||
# all edges should have a matched_type attribute
|
||||
assert all(hasattr(edge, "matched_type") for edge in graph.edges)
|
||||
# The matched_type attribute should be in the source_types attr
|
||||
assert all(edge.matched_type in edge.source_types for edge in graph.edges)
|
||||
|
||||
|
||||
def test_build_params():
|
||||
"""Test building params"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
# all edges should be valid
|
||||
assert all(edge.valid for edge in graph.edges)
|
||||
# all edges should have a matched_type attribute
|
||||
assert all(hasattr(edge, "matched_type") for edge in graph.edges)
|
||||
# The matched_type attribute should be in the source_types attr
|
||||
assert all(edge.matched_type in edge.source_types for edge in graph.edges)
|
||||
# Get the root node
|
||||
root = get_root_node(graph)
|
||||
# Root node is a ZeroShotAgent
|
||||
# which requires an llm_chain, allowed_tools and return_values
|
||||
assert isinstance(root.params, dict)
|
||||
assert "llm_chain" in root.params
|
||||
assert "allowed_tools" in root.params
|
||||
assert "return_values" in root.params
|
||||
# The llm_chain should be a Node
|
||||
assert isinstance(root.params["llm_chain"], Node)
|
||||
# The allowed_tools should be a list of Nodes
|
||||
assert isinstance(root.params["allowed_tools"], list)
|
||||
assert all(isinstance(tool, Node) for tool in root.params["allowed_tools"])
|
||||
# The return_values is of type str so it should be a list of strings
|
||||
assert isinstance(root.params["return_values"], list)
|
||||
assert all(isinstance(val, str) for val in root.params["return_values"])
|
||||
# The llm_chain should have a prompt and llm
|
||||
llm_chain_node = root.params["llm_chain"]
|
||||
assert isinstance(llm_chain_node.params, dict)
|
||||
assert "prompt" in llm_chain_node.params
|
||||
assert "llm" in llm_chain_node.params
|
||||
# The prompt should be a Node
|
||||
assert isinstance(llm_chain_node.params["prompt"], Node)
|
||||
# The llm should be a Node
|
||||
assert isinstance(llm_chain_node.params["llm"], Node)
|
||||
# The prompt should have format_insctructions, suffix, prefix
|
||||
prompt_node = llm_chain_node.params["prompt"]
|
||||
assert isinstance(prompt_node.params, dict)
|
||||
assert "format_instructions" in prompt_node.params
|
||||
assert "suffix" in prompt_node.params
|
||||
assert "prefix" in prompt_node.params
|
||||
# All of them should be of type str
|
||||
assert isinstance(prompt_node.params["format_instructions"], str)
|
||||
assert isinstance(prompt_node.params["suffix"], str)
|
||||
assert isinstance(prompt_node.params["prefix"], str)
|
||||
# The llm should have a model
|
||||
llm_node = llm_chain_node.params["llm"]
|
||||
assert isinstance(llm_node.params, dict)
|
||||
assert "model_name" in llm_node.params
|
||||
# The model should be a str
|
||||
assert isinstance(llm_node.params["model_name"], str)
|
||||
|
||||
|
||||
def test_build():
|
||||
"""Test Node's build method"""
|
||||
# def build(self):
|
||||
# # The params dict is used to build the module
|
||||
# # it contains values and keys that point to nodes which
|
||||
# # have their own params dict
|
||||
# # When build is called, we iterate through the params dict
|
||||
# # and if the value is a node, we call build on that node
|
||||
# # and use the output of that build as the value for the param
|
||||
# # if the value is not a node, then we use the value as the param
|
||||
# # and continue
|
||||
# # Another aspect is that the module_type is the class that we need to import
|
||||
# # and instantiate with these built params
|
||||
|
||||
# # Build each node in the params dict
|
||||
# for key, value in self.params.items():
|
||||
# if isinstance(value, Node):
|
||||
# self.params[key] = value.build()
|
||||
|
||||
# # Get the class from LANGCHAIN_TYPES_DICT
|
||||
# # and instantiate it with the params
|
||||
# # and return the instance
|
||||
# return LANGCHAIN_TYPES_DICT[self.module_type](**self.params)
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
# Now we test the build method
|
||||
# Build the Agent
|
||||
agent = graph.build()
|
||||
# The agent should be a AgentExecutor
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
|
|
|
|||
|
|
@ -64,9 +64,9 @@ def test_build_json_missing_child():
|
|||
if isinstance(value, dict) and "required" in value:
|
||||
value["required"] = True
|
||||
|
||||
graph = Graph(nodes, edges)
|
||||
root = get_root_node(graph)
|
||||
with pytest.raises(ValueError):
|
||||
graph = Graph(nodes, edges)
|
||||
root = get_root_node(graph)
|
||||
build_json(root, graph)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue