diff --git a/tests/test_agents_template.py b/tests/test_agents_template.py index db746a424..750050c25 100644 --- a/tests/test_agents_template.py +++ b/tests/test_agents_template.py @@ -48,7 +48,6 @@ def test_zero_shot_agent(client: TestClient): "type": "Tool", "list": True, "advanced": False, - "value": [], } diff --git a/tests/test_chains_template.py b/tests/test_chains_template.py index 33af32e57..f1154a556 100644 --- a/tests/test_chains_template.py +++ b/tests/test_chains_template.py @@ -191,7 +191,7 @@ def test_llm_checker_chain(client: TestClient): "multiline": False, "password": False, "name": "llm", - "type": "BaseLLM", + "type": "BaseLanguageModel", "list": False, "advanced": False, } diff --git a/tests/test_graph.py b/tests/test_graph.py index e109850e3..07c4630d6 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,7 +1,7 @@ from typing import Type, Union import pytest -from langchain.agents import AgentExecutor +from langchain.chains.base import Chain from langchain.llms.fake import FakeListLLM from langflow.graph import Edge, Graph, Node from langflow.graph.nodes import ( @@ -102,32 +102,13 @@ def test_get_node_neighbors_basic(basic_graph): # We need to check if there is a Chain in the one of the neighbors' # data attribute in the type key assert any( - "Chain" in neighbor.data["type"] for neighbor, val in neighbors.items() if val - ) - # assert Serper Search is in the neighbors - assert any( - "Serper" in neighbor.data["type"] for neighbor, val in neighbors.items() if val - ) - # Now on to the Chain's neighbors - chain = next( - neighbor + "ConversationBufferMemory" in neighbor.data["type"] for neighbor, val in neighbors.items() - if "Chain" in neighbor.data["type"] and val - ) - chain_neighbors = basic_graph.get_node_neighbors(chain) - assert chain_neighbors is not None - assert isinstance(chain_neighbors, dict) - # Check if there is a LLM in the chain's neighbors - assert any( - "OpenAI" in neighbor.data["type"] - for neighbor, val in chain_neighbors.items() if val ) - # Chain should have a Prompt as a neighbor + assert any( - "Prompt" in neighbor.data["type"] - for neighbor, val in chain_neighbors.items() - if val + "OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val ) @@ -209,7 +190,7 @@ def test_get_root_node(basic_graph, complex_graph): root = get_root_node(basic_graph) assert root is not None assert isinstance(root, Node) - assert root.data["type"] == "ZeroShotAgent" + assert root.data["type"] == "TimeTravelGuideChain" # For complex example, the root node is a ZeroShotAgent too assert isinstance(complex_graph, Graph) root = get_root_node(complex_graph) @@ -218,26 +199,6 @@ def test_get_root_node(basic_graph, complex_graph): assert root.data["type"] == "ZeroShotAgent" -def test_build_json(basic_graph): - """Test building JSON from graph""" - assert isinstance(basic_graph, Graph) - root = get_root_node(basic_graph) - json_data = build_json(root, basic_graph) - assert isinstance(json_data, dict) - assert json_data["_type"] == "zero-shot-react-description" - assert isinstance(json_data["llm_chain"], dict) - assert json_data["llm_chain"]["_type"] == "llm_chain" - assert json_data["llm_chain"]["memory"] is None - assert json_data["llm_chain"]["verbose"] is False - assert isinstance(json_data["llm_chain"]["prompt"], dict) - assert isinstance(json_data["llm_chain"]["llm"], dict) - assert json_data["llm_chain"]["output_key"] == "text" - assert isinstance(json_data["allowed_tools"], list) - assert all(isinstance(tool, dict) for tool in json_data["allowed_tools"]) - assert isinstance(json_data["return_values"], list) - assert all(isinstance(val, str) for val in json_data["return_values"]) - - def test_validate_edges(basic_graph): """Test validating edges""" @@ -269,45 +230,11 @@ def test_build_params(basic_graph): assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges) # Get the root node root = get_root_node(basic_graph) - # Root node is a ZeroShotAgent - # which requires an llm_chain, allowed_tools and return_values + # Root node is a TimeTravelGuideChain + # which requires an llm and memory assert isinstance(root.params, dict) - assert "llm_chain" in root.params - assert "allowed_tools" in root.params - assert "return_values" in root.params - # The llm_chain should be a Node - assert isinstance(root.params["llm_chain"], Node) - # The allowed_tools should be a list of Nodes - assert isinstance(root.params["allowed_tools"], list) - assert all(isinstance(tool, Node) for tool in root.params["allowed_tools"]) - # The return_values is of type str so it should be a list of strings - assert isinstance(root.params["return_values"], list) - assert all(isinstance(val, str) for val in root.params["return_values"]) - # The llm_chain should have a prompt and llm - llm_chain_node = root.params["llm_chain"] - assert isinstance(llm_chain_node.params, dict) - assert "prompt" in llm_chain_node.params - assert "llm" in llm_chain_node.params - # The prompt should be a Node - assert isinstance(llm_chain_node.params["prompt"], Node) - # The llm should be a Node - assert isinstance(llm_chain_node.params["llm"], Node) - # The prompt should have format_insctructions, suffix, prefix - prompt_node = llm_chain_node.params["prompt"] - assert isinstance(prompt_node.params, dict) - assert "format_instructions" in prompt_node.params - assert "suffix" in prompt_node.params - assert "prefix" in prompt_node.params - # All of them should be of type str - assert isinstance(prompt_node.params["format_instructions"], str) - assert isinstance(prompt_node.params["suffix"], str) - assert isinstance(prompt_node.params["prefix"], str) - # The llm should have a model - llm_node = llm_chain_node.params["llm"] - assert isinstance(llm_node.params, dict) - assert "model_name" in llm_node.params - # The model should be a str - assert isinstance(llm_node.params["model_name"], str) + assert "llm" in root.params + assert "memory" in root.params def test_build(basic_graph, complex_graph, openapi_graph): @@ -324,18 +251,18 @@ def assert_agent_was_built(graph): # Build the Agent result = graph.build() # The agent should be a AgentExecutor - assert isinstance(result, AgentExecutor) + assert isinstance(result, Chain) -def test_agent_node_build(basic_graph): - agent_node = get_node_by_type(basic_graph, AgentNode) +def test_agent_node_build(complex_graph): + agent_node = get_node_by_type(complex_graph, AgentNode) assert agent_node is not None built_object = agent_node.build() assert built_object is not None -def test_tool_node_build(basic_graph): - tool_node = get_node_by_type(basic_graph, ToolNode) +def test_tool_node_build(complex_graph): + tool_node = get_node_by_type(complex_graph, ToolNode) assert tool_node is not None built_object = tool_node.build() assert built_object is not None diff --git a/tests/test_loading.py b/tests/test_loading.py index 444c85fd9..872314699 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -1,7 +1,7 @@ import json import pytest -from langchain.agents import AgentExecutor +from langchain.chains.base import Chain from langflow import load_flow_from_json from langflow.graph import Graph from langflow.utils.payload import get_root_node @@ -11,7 +11,7 @@ def test_load_flow_from_json(): """Test loading a flow from a json file""" loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH) assert loaded is not None - assert isinstance(loaded, AgentExecutor) + assert isinstance(loaded, Chain) def test_get_root_node():