refactor(tests): remove unused imports and variables, fix typos and update node types
This commit is contained in:
parent
7ae42bf7d2
commit
3b733ada01
4 changed files with 17 additions and 91 deletions
|
|
@ -48,7 +48,6 @@ def test_zero_shot_agent(client: TestClient):
|
|||
"type": "Tool",
|
||||
"list": True,
|
||||
"advanced": False,
|
||||
"value": [],
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ def test_llm_checker_chain(client: TestClient):
|
|||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "llm",
|
||||
"type": "BaseLLM",
|
||||
"type": "BaseLanguageModel",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Type, Union
|
||||
|
||||
import pytest
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langflow.graph import Edge, Graph, Node
|
||||
from langflow.graph.nodes import (
|
||||
|
|
@ -102,32 +102,13 @@ def test_get_node_neighbors_basic(basic_graph):
|
|||
# We need to check if there is a Chain in the one of the neighbors'
|
||||
# data attribute in the type key
|
||||
assert any(
|
||||
"Chain" in neighbor.data["type"] for neighbor, val in neighbors.items() if val
|
||||
)
|
||||
# assert Serper Search is in the neighbors
|
||||
assert any(
|
||||
"Serper" in neighbor.data["type"] for neighbor, val in neighbors.items() if val
|
||||
)
|
||||
# Now on to the Chain's neighbors
|
||||
chain = next(
|
||||
neighbor
|
||||
"ConversationBufferMemory" in neighbor.data["type"]
|
||||
for neighbor, val in neighbors.items()
|
||||
if "Chain" in neighbor.data["type"] and val
|
||||
)
|
||||
chain_neighbors = basic_graph.get_node_neighbors(chain)
|
||||
assert chain_neighbors is not None
|
||||
assert isinstance(chain_neighbors, dict)
|
||||
# Check if there is a LLM in the chain's neighbors
|
||||
assert any(
|
||||
"OpenAI" in neighbor.data["type"]
|
||||
for neighbor, val in chain_neighbors.items()
|
||||
if val
|
||||
)
|
||||
# Chain should have a Prompt as a neighbor
|
||||
|
||||
assert any(
|
||||
"Prompt" in neighbor.data["type"]
|
||||
for neighbor, val in chain_neighbors.items()
|
||||
if val
|
||||
"OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -209,7 +190,7 @@ def test_get_root_node(basic_graph, complex_graph):
|
|||
root = get_root_node(basic_graph)
|
||||
assert root is not None
|
||||
assert isinstance(root, Node)
|
||||
assert root.data["type"] == "ZeroShotAgent"
|
||||
assert root.data["type"] == "TimeTravelGuideChain"
|
||||
# For complex example, the root node is a ZeroShotAgent too
|
||||
assert isinstance(complex_graph, Graph)
|
||||
root = get_root_node(complex_graph)
|
||||
|
|
@ -218,26 +199,6 @@ def test_get_root_node(basic_graph, complex_graph):
|
|||
assert root.data["type"] == "ZeroShotAgent"
|
||||
|
||||
|
||||
def test_build_json(basic_graph):
|
||||
"""Test building JSON from graph"""
|
||||
assert isinstance(basic_graph, Graph)
|
||||
root = get_root_node(basic_graph)
|
||||
json_data = build_json(root, basic_graph)
|
||||
assert isinstance(json_data, dict)
|
||||
assert json_data["_type"] == "zero-shot-react-description"
|
||||
assert isinstance(json_data["llm_chain"], dict)
|
||||
assert json_data["llm_chain"]["_type"] == "llm_chain"
|
||||
assert json_data["llm_chain"]["memory"] is None
|
||||
assert json_data["llm_chain"]["verbose"] is False
|
||||
assert isinstance(json_data["llm_chain"]["prompt"], dict)
|
||||
assert isinstance(json_data["llm_chain"]["llm"], dict)
|
||||
assert json_data["llm_chain"]["output_key"] == "text"
|
||||
assert isinstance(json_data["allowed_tools"], list)
|
||||
assert all(isinstance(tool, dict) for tool in json_data["allowed_tools"])
|
||||
assert isinstance(json_data["return_values"], list)
|
||||
assert all(isinstance(val, str) for val in json_data["return_values"])
|
||||
|
||||
|
||||
def test_validate_edges(basic_graph):
|
||||
"""Test validating edges"""
|
||||
|
||||
|
|
@ -269,45 +230,11 @@ def test_build_params(basic_graph):
|
|||
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
|
||||
# Get the root node
|
||||
root = get_root_node(basic_graph)
|
||||
# Root node is a ZeroShotAgent
|
||||
# which requires an llm_chain, allowed_tools and return_values
|
||||
# Root node is a TimeTravelGuideChain
|
||||
# which requires an llm and memory
|
||||
assert isinstance(root.params, dict)
|
||||
assert "llm_chain" in root.params
|
||||
assert "allowed_tools" in root.params
|
||||
assert "return_values" in root.params
|
||||
# The llm_chain should be a Node
|
||||
assert isinstance(root.params["llm_chain"], Node)
|
||||
# The allowed_tools should be a list of Nodes
|
||||
assert isinstance(root.params["allowed_tools"], list)
|
||||
assert all(isinstance(tool, Node) for tool in root.params["allowed_tools"])
|
||||
# The return_values is of type str so it should be a list of strings
|
||||
assert isinstance(root.params["return_values"], list)
|
||||
assert all(isinstance(val, str) for val in root.params["return_values"])
|
||||
# The llm_chain should have a prompt and llm
|
||||
llm_chain_node = root.params["llm_chain"]
|
||||
assert isinstance(llm_chain_node.params, dict)
|
||||
assert "prompt" in llm_chain_node.params
|
||||
assert "llm" in llm_chain_node.params
|
||||
# The prompt should be a Node
|
||||
assert isinstance(llm_chain_node.params["prompt"], Node)
|
||||
# The llm should be a Node
|
||||
assert isinstance(llm_chain_node.params["llm"], Node)
|
||||
# The prompt should have format_insctructions, suffix, prefix
|
||||
prompt_node = llm_chain_node.params["prompt"]
|
||||
assert isinstance(prompt_node.params, dict)
|
||||
assert "format_instructions" in prompt_node.params
|
||||
assert "suffix" in prompt_node.params
|
||||
assert "prefix" in prompt_node.params
|
||||
# All of them should be of type str
|
||||
assert isinstance(prompt_node.params["format_instructions"], str)
|
||||
assert isinstance(prompt_node.params["suffix"], str)
|
||||
assert isinstance(prompt_node.params["prefix"], str)
|
||||
# The llm should have a model
|
||||
llm_node = llm_chain_node.params["llm"]
|
||||
assert isinstance(llm_node.params, dict)
|
||||
assert "model_name" in llm_node.params
|
||||
# The model should be a str
|
||||
assert isinstance(llm_node.params["model_name"], str)
|
||||
assert "llm" in root.params
|
||||
assert "memory" in root.params
|
||||
|
||||
|
||||
def test_build(basic_graph, complex_graph, openapi_graph):
|
||||
|
|
@ -324,18 +251,18 @@ def assert_agent_was_built(graph):
|
|||
# Build the Agent
|
||||
result = graph.build()
|
||||
# The agent should be a AgentExecutor
|
||||
assert isinstance(result, AgentExecutor)
|
||||
assert isinstance(result, Chain)
|
||||
|
||||
|
||||
def test_agent_node_build(basic_graph):
|
||||
agent_node = get_node_by_type(basic_graph, AgentNode)
|
||||
def test_agent_node_build(complex_graph):
|
||||
agent_node = get_node_by_type(complex_graph, AgentNode)
|
||||
assert agent_node is not None
|
||||
built_object = agent_node.build()
|
||||
assert built_object is not None
|
||||
|
||||
|
||||
def test_tool_node_build(basic_graph):
|
||||
tool_node = get_node_by_type(basic_graph, ToolNode)
|
||||
def test_tool_node_build(complex_graph):
|
||||
tool_node = get_node_by_type(complex_graph, ToolNode)
|
||||
assert tool_node is not None
|
||||
built_object = tool_node.build()
|
||||
assert built_object is not None
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.chains.base import Chain
|
||||
from langflow import load_flow_from_json
|
||||
from langflow.graph import Graph
|
||||
from langflow.utils.payload import get_root_node
|
||||
|
|
@ -11,7 +11,7 @@ def test_load_flow_from_json():
|
|||
"""Test loading a flow from a json file"""
|
||||
loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH)
|
||||
assert loaded is not None
|
||||
assert isinstance(loaded, AgentExecutor)
|
||||
assert isinstance(loaded, Chain)
|
||||
|
||||
|
||||
def test_get_root_node():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue