From 5ad879bd9b0c5fdaea7a32fbb3007175b8ef90b2 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 27 Nov 2023 21:51:31 -0300 Subject: [PATCH] Fix inconsistencies in test cases --- tests/test_cache.py | 4 +-- tests/test_custom_component.py | 7 ++--- tests/test_graph.py | 55 ++++++++++++++++++---------------- 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index c2c706ee9..925402769 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,7 +1,7 @@ import json -from langflow.graph import Graph import pytest +from langflow.graph import Graph def get_graph(_type="basic"): @@ -41,5 +41,5 @@ def langchain_objects_are_equal(obj1, obj2): def test_build_graph(client, basic_data_graph): graph = Graph.from_payload(basic_data_graph) assert graph is not None - assert len(graph.nodes) == len(basic_data_graph["nodes"]) + assert len(graph.vertices) == len(basic_data_graph["nodes"]) assert len(graph.edges) == len(basic_data_graph["edges"]) diff --git a/tests/test_custom_component.py b/tests/test_custom_component.py index b07753b8d..636bd63b1 100644 --- a/tests/test_custom_component.py +++ b/tests/test_custom_component.py @@ -7,10 +7,7 @@ from fastapi import HTTPException from langflow.field_typing.constants import Data from langflow.interface.custom.base import CustomComponent from langflow.interface.custom.code_parser import CodeParser, CodeSyntaxError -from langflow.interface.custom.component import ( - Component, - ComponentCodeNullError, -) +from langflow.interface.custom.component import Component, ComponentCodeNullError from langflow.services.database.models.flow import Flow, FlowCreate code_default = """ @@ -445,7 +442,7 @@ def test_custom_component_build_not_implemented(): def test_build_config_no_code(): component = CustomComponent(code=None) - assert component.get_function_entrypoint_args == "" + assert component.get_function_entrypoint_args == [] assert component.get_function_entrypoint_return_type == [] diff --git a/tests/test_graph.py b/tests/test_graph.py index cb69d79d5..020642798 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -24,7 +24,7 @@ from langflow.graph.utils import UnbuiltObject from langflow.graph.vertex.base import Vertex from langflow.graph.vertex.types import FileToolVertex, LLMVertex, ToolkitVertex from langflow.processing.process import get_result_and_thought -from langflow.utils.payload import get_root_node +from langflow.utils.payload import get_root_vertex # Test cases for the graph module @@ -70,19 +70,19 @@ def sample_nodes(): def get_node_by_type(graph, node_type: Type[Vertex]) -> Union[Vertex, None]: """Get a node by type""" - return next((node for node in graph.nodes if isinstance(node, node_type)), None) + return next((node for node in graph.vertices if isinstance(node, node_type)), None) def test_graph_structure(basic_graph): assert isinstance(basic_graph, Graph) - assert len(basic_graph.nodes) > 0 + assert len(basic_graph.vertices) > 0 assert len(basic_graph.edges) > 0 - for node in basic_graph.nodes: + for node in basic_graph.vertices: assert isinstance(node, Vertex) for edge in basic_graph.edges: assert isinstance(edge, Edge) - assert edge.source in basic_graph.nodes - assert edge.target in basic_graph.nodes + assert edge.source_id in basic_graph.vertex_ids + assert edge.target_id in basic_graph.vertex_ids def test_circular_dependencies(basic_graph): @@ -90,7 +90,7 @@ def test_circular_dependencies(basic_graph): def check_circular(node, visited): visited.add(node) - neighbors = basic_graph.get_nodes_with_target(node) + neighbors = basic_graph.get_vertices_with_target(node) for neighbor in neighbors: if neighbor in visited: return True @@ -98,7 +98,7 @@ def test_circular_dependencies(basic_graph): return True return False - for node in basic_graph.nodes: + for node in basic_graph.vertices: assert not check_circular(node, set()) @@ -123,13 +123,13 @@ def test_invalid_node_types(): Graph(graph_data["nodes"], graph_data["edges"]) -def test_get_nodes_with_target(basic_graph): +def test_get_vertices_with_target(basic_graph): """Test getting connected nodes""" assert isinstance(basic_graph, Graph) # Get root node - root = get_root_node(basic_graph) + root = get_root_vertex(basic_graph) assert root is not None - connected_nodes = basic_graph.get_nodes_with_target(root) + connected_nodes = basic_graph.get_vertices_with_target(root.id) assert connected_nodes is not None @@ -138,9 +138,9 @@ def test_get_node_neighbors_basic(basic_graph): assert isinstance(basic_graph, Graph) # Get root node - root = get_root_node(basic_graph) + root = get_root_vertex(basic_graph) assert root is not None - neighbors = basic_graph.get_node_neighbors(root) + neighbors = basic_graph.get_vertex_neighbors(root) assert neighbors is not None assert isinstance(neighbors, dict) # Root Node is an Agent, it requires an LLMChain and tools @@ -153,8 +153,8 @@ def test_get_node_neighbors_basic(basic_graph): def test_get_node(basic_graph): """Test getting a single node""" - node_id = basic_graph.nodes[0].id - node = basic_graph.get_node(node_id) + node_id = basic_graph.vertices[0].id + node = basic_graph.get_vertex(node_id) assert isinstance(node, Vertex) assert node.id == node_id @@ -162,8 +162,8 @@ def test_get_node(basic_graph): def test_build_nodes(basic_graph): """Test building nodes""" - assert len(basic_graph.nodes) == len(basic_graph._nodes) - for node in basic_graph.nodes: + assert len(basic_graph.vertices) == len(basic_graph._vertices) + for node in basic_graph.vertices: assert isinstance(node, Vertex) @@ -172,20 +172,21 @@ def test_build_edges(basic_graph): assert len(basic_graph.edges) == len(basic_graph._edges) for edge in basic_graph.edges: assert isinstance(edge, Edge) - assert isinstance(edge.source, Vertex) - assert isinstance(edge.target, Vertex) + + assert isinstance(edge.source_id, str) + assert isinstance(edge.target_id, str) -def test_get_root_node(client, basic_graph, complex_graph): +def test_get_root_vertex(client, basic_graph, complex_graph): """Test getting root node""" assert isinstance(basic_graph, Graph) - root = get_root_node(basic_graph) + root = get_root_vertex(basic_graph) assert root is not None assert isinstance(root, Vertex) assert root.data["type"] == "TimeTravelGuideChain" # For complex example, the root node is a ZeroShotAgent too assert isinstance(complex_graph, Graph) - root = get_root_node(complex_graph) + root = get_root_vertex(complex_graph) assert root is not None assert isinstance(root, Vertex) assert root.data["type"] == "ZeroShotAgent" @@ -221,7 +222,7 @@ def test_build_params(basic_graph): # The matched_type attribute should be in the source_types attr assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges) # Get the root node - root = get_root_node(basic_graph) + root = get_root_vertex(basic_graph) # Root node is a TimeTravelGuideChain # which requires an llm and memory assert root is not None @@ -278,7 +279,7 @@ async def test_file_tool_node_build(client, openapi_graph): assert Path(file_path).exists() file_tool_node = get_node_by_type(openapi_graph, FileToolVertex) - assert file_tool_node is not UnbuiltObject + assert file_tool_node is not UnbuiltObject and file_tool_node is not None built_object = await file_tool_node.build() assert built_object is not UnbuiltObject # Remove the file @@ -301,7 +302,7 @@ async def test_get_result_and_thought(basic_graph): llm_node._built = True langchain_object = await basic_graph.build() # assert all nodes are built - assert all(node._built for node in basic_graph.nodes) + assert all(node._built for node in basic_graph.vertices) # now build again and check if FakeListLLM was used # Get the result and thought @@ -420,10 +421,12 @@ def test_update_template(sample_template, sample_nodes): node2_updated = next((n for n in nodes_copy if n["id"] == "node2"), None) node3_updated = next((n for n in nodes_copy if n["id"] == "node3"), None) + assert node1_updated is not None assert node1_updated["data"]["node"]["template"]["some_field"]["show"] is True assert node1_updated["data"]["node"]["template"]["some_field"]["advanced"] is False assert node1_updated["data"]["node"]["template"]["some_field"]["display_name"] == "Name1" + assert node2_updated is not None assert node2_updated["data"]["node"]["template"]["other_field"]["show"] is False assert node2_updated["data"]["node"]["template"]["other_field"]["advanced"] is True assert node2_updated["data"]["node"]["template"]["other_field"]["display_name"] == "DisplayName2" @@ -502,7 +505,7 @@ async def test_pickle_each_vertex(json_vector_store): loaded_json = json.loads(json_vector_store) graph = Graph.from_payload(loaded_json) assert isinstance(graph, Graph) - for vertex in graph.nodes: + for vertex in graph.vertices: await vertex.build() pickled = pickle.dumps(vertex) assert pickled is not UnbuiltObject