refactor(tests): remove unused imports and variables, fix typos and update node types

This commit is contained in:
Gabriel Almeida 2023-05-06 10:26:42 -03:00 committed by Gabriel Luiz Freitas Almeida
commit 3b733ada01
4 changed files with 17 additions and 91 deletions

View file

@ -1,7 +1,7 @@
from typing import Type, Union
import pytest
from langchain.agents import AgentExecutor
from langchain.chains.base import Chain
from langchain.llms.fake import FakeListLLM
from langflow.graph import Edge, Graph, Node
from langflow.graph.nodes import (
@ -102,32 +102,13 @@ def test_get_node_neighbors_basic(basic_graph):
# We need to check if there is a Chain in the one of the neighbors'
# data attribute in the type key
assert any(
"Chain" in neighbor.data["type"] for neighbor, val in neighbors.items() if val
)
# assert Serper Search is in the neighbors
assert any(
"Serper" in neighbor.data["type"] for neighbor, val in neighbors.items() if val
)
# Now on to the Chain's neighbors
chain = next(
neighbor
"ConversationBufferMemory" in neighbor.data["type"]
for neighbor, val in neighbors.items()
if "Chain" in neighbor.data["type"] and val
)
chain_neighbors = basic_graph.get_node_neighbors(chain)
assert chain_neighbors is not None
assert isinstance(chain_neighbors, dict)
# Check if there is a LLM in the chain's neighbors
assert any(
"OpenAI" in neighbor.data["type"]
for neighbor, val in chain_neighbors.items()
if val
)
# Chain should have a Prompt as a neighbor
assert any(
"Prompt" in neighbor.data["type"]
for neighbor, val in chain_neighbors.items()
if val
"OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val
)
@ -209,7 +190,7 @@ def test_get_root_node(basic_graph, complex_graph):
root = get_root_node(basic_graph)
assert root is not None
assert isinstance(root, Node)
assert root.data["type"] == "ZeroShotAgent"
assert root.data["type"] == "TimeTravelGuideChain"
# For complex example, the root node is a ZeroShotAgent too
assert isinstance(complex_graph, Graph)
root = get_root_node(complex_graph)
@ -218,26 +199,6 @@ def test_get_root_node(basic_graph, complex_graph):
assert root.data["type"] == "ZeroShotAgent"
def test_build_json(basic_graph):
"""Test building JSON from graph"""
assert isinstance(basic_graph, Graph)
root = get_root_node(basic_graph)
json_data = build_json(root, basic_graph)
assert isinstance(json_data, dict)
assert json_data["_type"] == "zero-shot-react-description"
assert isinstance(json_data["llm_chain"], dict)
assert json_data["llm_chain"]["_type"] == "llm_chain"
assert json_data["llm_chain"]["memory"] is None
assert json_data["llm_chain"]["verbose"] is False
assert isinstance(json_data["llm_chain"]["prompt"], dict)
assert isinstance(json_data["llm_chain"]["llm"], dict)
assert json_data["llm_chain"]["output_key"] == "text"
assert isinstance(json_data["allowed_tools"], list)
assert all(isinstance(tool, dict) for tool in json_data["allowed_tools"])
assert isinstance(json_data["return_values"], list)
assert all(isinstance(val, str) for val in json_data["return_values"])
def test_validate_edges(basic_graph):
"""Test validating edges"""
@ -269,45 +230,11 @@ def test_build_params(basic_graph):
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
# Get the root node
root = get_root_node(basic_graph)
# Root node is a ZeroShotAgent
# which requires an llm_chain, allowed_tools and return_values
# Root node is a TimeTravelGuideChain
# which requires an llm and memory
assert isinstance(root.params, dict)
assert "llm_chain" in root.params
assert "allowed_tools" in root.params
assert "return_values" in root.params
# The llm_chain should be a Node
assert isinstance(root.params["llm_chain"], Node)
# The allowed_tools should be a list of Nodes
assert isinstance(root.params["allowed_tools"], list)
assert all(isinstance(tool, Node) for tool in root.params["allowed_tools"])
# The return_values is of type str so it should be a list of strings
assert isinstance(root.params["return_values"], list)
assert all(isinstance(val, str) for val in root.params["return_values"])
# The llm_chain should have a prompt and llm
llm_chain_node = root.params["llm_chain"]
assert isinstance(llm_chain_node.params, dict)
assert "prompt" in llm_chain_node.params
assert "llm" in llm_chain_node.params
# The prompt should be a Node
assert isinstance(llm_chain_node.params["prompt"], Node)
# The llm should be a Node
assert isinstance(llm_chain_node.params["llm"], Node)
# The prompt should have format_insctructions, suffix, prefix
prompt_node = llm_chain_node.params["prompt"]
assert isinstance(prompt_node.params, dict)
assert "format_instructions" in prompt_node.params
assert "suffix" in prompt_node.params
assert "prefix" in prompt_node.params
# All of them should be of type str
assert isinstance(prompt_node.params["format_instructions"], str)
assert isinstance(prompt_node.params["suffix"], str)
assert isinstance(prompt_node.params["prefix"], str)
# The llm should have a model
llm_node = llm_chain_node.params["llm"]
assert isinstance(llm_node.params, dict)
assert "model_name" in llm_node.params
# The model should be a str
assert isinstance(llm_node.params["model_name"], str)
assert "llm" in root.params
assert "memory" in root.params
def test_build(basic_graph, complex_graph, openapi_graph):
@ -324,18 +251,18 @@ def assert_agent_was_built(graph):
# Build the Agent
result = graph.build()
# The agent should be a AgentExecutor
assert isinstance(result, AgentExecutor)
assert isinstance(result, Chain)
def test_agent_node_build(basic_graph):
agent_node = get_node_by_type(basic_graph, AgentNode)
def test_agent_node_build(complex_graph):
agent_node = get_node_by_type(complex_graph, AgentNode)
assert agent_node is not None
built_object = agent_node.build()
assert built_object is not None
def test_tool_node_build(basic_graph):
tool_node = get_node_by_type(basic_graph, ToolNode)
def test_tool_node_build(complex_graph):
tool_node = get_node_by_type(complex_graph, ToolNode)
assert tool_node is not None
built_object = tool_node.build()
assert built_object is not None