From b4963572b0a68184920e7297778b3fd2b90c79d3 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Fri, 24 Mar 2023 11:29:58 -0300 Subject: [PATCH] feat: added more graph tests --- tests/test_graph.py | 69 +++++++++++++++++++++++++++++++++++++++++-- tests/test_loading.py | 14 ++++----- 2 files changed, 74 insertions(+), 9 deletions(-) diff --git a/tests/test_graph.py b/tests/test_graph.py index 94bc82716..c1b6f0c56 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,7 +1,7 @@ import json -from langflow.utils.graph import Graph +from langflow.utils.graph import Edge, Graph, Node import pytest -from langflow.utils.payload import get_root_node +from langflow.utils.payload import build_json, get_root_node # Test cases for the graph module @@ -70,3 +70,68 @@ def test_get_node_neighbors(): for neighbor, val in chain_neighbors.items() if val ) + + +def test_get_node(): + """Test getting a single node""" + graph = get_graph() + node_id = graph.nodes[0].id + node = graph.get_node(node_id) + assert isinstance(node, Node) + assert node.id == node_id + + +def test_build_nodes(): + """Test building nodes""" + graph = get_graph() + assert len(graph.nodes) == len(graph._nodes) + for node in graph.nodes: + assert isinstance(node, Node) + + +def test_build_edges(): + """Test building edges""" + graph = get_graph() + assert len(graph.edges) == len(graph._edges) + for edge in graph.edges: + assert isinstance(edge, Edge) + assert isinstance(edge.source, Node) + assert isinstance(edge.target, Node) + + +def test_get_root_node(): + """Test getting root node""" + graph = get_graph(basic=True) + assert isinstance(graph, Graph) + root = get_root_node(graph) + assert root is not None + assert isinstance(root, Node) + assert root.data["type"] == "ZeroShotAgent" + # For complex example, the root node is a ZeroShotAgent too + graph = get_graph(basic=False) + assert isinstance(graph, Graph) + root = get_root_node(graph) + assert root is not None + assert isinstance(root, Node) + assert root.data["type"] == "ZeroShotAgent" + + +def test_build_json(): + """Test building JSON from graph""" + graph = get_graph() + assert isinstance(graph, Graph) + root = get_root_node(graph) + json_data = build_json(root, graph) + assert isinstance(json_data, dict) + assert json_data["_type"] == "zero-shot-react-description" + assert isinstance(json_data["llm_chain"], dict) + assert json_data["llm_chain"]["_type"] == "llm_chain" + assert json_data["llm_chain"]["memory"] is None + assert json_data["llm_chain"]["verbose"] is True + assert isinstance(json_data["llm_chain"]["prompt"], dict) + assert isinstance(json_data["llm_chain"]["llm"], dict) + assert json_data["llm_chain"]["output_key"] == "text" + assert isinstance(json_data["allowed_tools"], list) + assert all(isinstance(tool, dict) for tool in json_data["allowed_tools"]) + assert isinstance(json_data["return_values"], list) + assert all(isinstance(val, str) for val in json_data["return_values"]) diff --git a/tests/test_loading.py b/tests/test_loading.py index e9bc8d5cb..5f44e55f8 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -11,12 +11,12 @@ from langflow.interface.loading import load_langchain_type_from_config def test_load_flow_from_json(): """Test loading a flow from a json file""" - loaded = load_flow_from_json(pytest.EXAMPLE_JSON_PATH) + loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH) assert loaded is not None def test_extract_json(): - with open(pytest.EXAMPLE_JSON_PATH, "r") as f: + with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: flow_graph = json.load(f) data_graph = flow_graph["data"] extracted = extract_json(data_graph) @@ -25,7 +25,7 @@ def test_extract_json(): def test_get_root_node(): - with open(pytest.EXAMPLE_JSON_PATH, "r") as f: + with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: flow_graph = json.load(f) data_graph = flow_graph["data"] nodes = data_graph["nodes"] @@ -38,7 +38,7 @@ def test_get_root_node(): def test_build_json(): - with open(pytest.EXAMPLE_JSON_PATH, "r") as f: + with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: flow_graph = json.load(f) data_graph = flow_graph["data"] nodes = data_graph["nodes"] @@ -51,7 +51,7 @@ def test_build_json(): def test_build_json_missing_child(): - with open(pytest.EXAMPLE_JSON_PATH, "r") as f: + with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: flow_graph = json.load(f) data_graph = flow_graph["data"] nodes = data_graph["nodes"] @@ -76,7 +76,7 @@ def test_build_json_no_nodes(): def test_build_json_invalid_edge(): - with open(pytest.EXAMPLE_JSON_PATH, "r") as f: + with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: flow_graph = json.load(f) data_graph = flow_graph["data"] nodes = data_graph["nodes"] @@ -92,7 +92,7 @@ def test_build_json_invalid_edge(): def test_load_langchain_type_from_config(): - with open(pytest.EXAMPLE_JSON_PATH, "r") as f: + with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: flow_graph = json.load(f) data_graph = flow_graph["data"] extracted = extract_json(data_graph)