diff --git a/src/backend/langflow/api/validate.py b/src/backend/langflow/api/validate.py index 53a7ee350..e90e554f0 100644 --- a/src/backend/langflow/api/validate.py +++ b/src/backend/langflow/api/validate.py @@ -9,7 +9,7 @@ from langflow.api.base import ( PromptValidationResponse, validate_prompt, ) -from langflow.graph.node.types import VectorStoreNode +from langflow.graph.vertex.types import VectorStoreVertex from langflow.interface.run import build_graph from langflow.utils.logger import logger from langflow.utils.validate import validate_code @@ -49,7 +49,7 @@ def post_validate_node(node_id: str, data: dict): node = graph.get_node(node_id) if node is None: raise ValueError(f"Node {node_id} not found") - if not isinstance(node, VectorStoreNode): + if not isinstance(node, VectorStoreVertex): node.build() return json.dumps({"valid": True, "params": str(node._built_object_repr())}) except Exception as e: diff --git a/src/backend/langflow/graph/__init__.py b/src/backend/langflow/graph/__init__.py index 44859da02..a68e844ee 100644 --- a/src/backend/langflow/graph/__init__.py +++ b/src/backend/langflow/graph/__init__.py @@ -1,35 +1,35 @@ from langflow.graph.edge.base import Edge from langflow.graph.graph.base import Graph -from langflow.graph.node.base import Node -from langflow.graph.node.types import ( - AgentNode, - ChainNode, - DocumentLoaderNode, - EmbeddingNode, - LLMNode, - MemoryNode, - PromptNode, - TextSplitterNode, - ToolNode, - ToolkitNode, - VectorStoreNode, - WrapperNode, +from langflow.graph.vertex.base import Vertex +from langflow.graph.vertex.types import ( + AgentVertex, + ChainVertex, + DocumentLoaderVertex, + EmbeddingVertex, + LLMVertex, + MemoryVertex, + PromptVertex, + TextSplitterVertex, + ToolVertex, + ToolkitVertex, + VectorStoreVertex, + WrapperVertex, ) __all__ = [ "Graph", - "Node", + "Vertex", "Edge", - "AgentNode", - "ChainNode", - "DocumentLoaderNode", - "EmbeddingNode", - "LLMNode", - "MemoryNode", - "PromptNode", - "TextSplitterNode", - "ToolNode", - "ToolkitNode", - "VectorStoreNode", - "WrapperNode", + "AgentVertex", + "ChainVertex", + "DocumentLoaderVertex", + "EmbeddingVertex", + "LLMVertex", + "MemoryVertex", + "PromptVertex", + "TextSplitterVertex", + "ToolVertex", + "ToolkitVertex", + "VectorStoreVertex", + "WrapperVertex", ] diff --git a/src/backend/langflow/graph/edge/base.py b/src/backend/langflow/graph/edge/base.py index 2bf5a1ba4..08f084a5c 100644 --- a/src/backend/langflow/graph/edge/base.py +++ b/src/backend/langflow/graph/edge/base.py @@ -2,13 +2,13 @@ from langflow.utils.logger import logger from typing import TYPE_CHECKING if TYPE_CHECKING: - from langflow.graph.node.base import Node + from langflow.graph.vertex.base import Vertex class Edge: - def __init__(self, source: "Node", target: "Node"): - self.source: "Node" = source - self.target: "Node" = target + def __init__(self, source: "Vertex", target: "Vertex"): + self.source: "Vertex" = source + self.target: "Vertex" = target self.validate_edge() def validate_edge(self) -> None: @@ -41,7 +41,7 @@ class Edge: logger.debug(self.target_reqs) if no_matched_type: raise ValueError( - f"Edge between {self.source.node_type} and {self.target.node_type} " + f"Edge between {self.source.vertex_type} and {self.target.vertex_type} " f"has no matched type" ) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 3ba67837f..020f539ec 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -1,12 +1,12 @@ from typing import Dict, List, Type, Union from langflow.graph.edge.base import Edge -from langflow.graph.graph.constants import NODE_TYPE_MAP -from langflow.graph.node.base import Node -from langflow.graph.node.types import ( - FileToolNode, - LLMNode, - ToolkitNode, +from langflow.graph.graph.constants import VERTEX_TYPE_MAP +from langflow.graph.vertex.base import Vertex +from langflow.graph.vertex.types import ( + FileToolVertex, + LLMVertex, + ToolkitVertex, ) from langflow.interface.tools.constants import FILE_TOOLS from langflow.utils import payload @@ -26,7 +26,7 @@ class Graph: def _build_graph(self) -> None: """Builds the graph from the nodes and edges.""" - self.nodes = self._build_nodes() + self.nodes = self._build_vertices() self.edges = self._build_edges() for edge in self.edges: edge.source.add_edge(edge) @@ -43,12 +43,12 @@ class Graph: llm_node = None for node in self.nodes: node._build_params() - if isinstance(node, LLMNode): + if isinstance(node, LLMVertex): llm_node = node if llm_node: for node in self.nodes: - if isinstance(node, ToolkitNode): + if isinstance(node, ToolkitVertex): node.params["llm"] = llm_node def _remove_invalid_nodes(self) -> None: @@ -60,23 +60,23 @@ class Graph: or (len(self.nodes) == 1 and len(self.edges) == 0) ] - def _validate_node(self, node: Node) -> bool: + def _validate_node(self, node: Vertex) -> bool: """Validates a node.""" # All nodes that do not have edges are invalid return len(node.edges) > 0 - def get_node(self, node_id: str) -> Union[None, Node]: + def get_node(self, node_id: str) -> Union[None, Vertex]: """Returns a node by id.""" return next((node for node in self.nodes if node.id == node_id), None) - def get_nodes_with_target(self, node: Node) -> List[Node]: + def get_nodes_with_target(self, node: Vertex) -> List[Vertex]: """Returns the nodes connected to a node.""" - connected_nodes: List[Node] = [ + connected_nodes: List[Vertex] = [ edge.source for edge in self.edges if edge.target == node ] return connected_nodes - def build(self) -> List[Node]: + def build(self) -> List[Vertex]: """Builds the graph.""" # Get root node root_node = payload.get_root_node(self) @@ -84,9 +84,9 @@ class Graph: raise ValueError("No root node found") return root_node.build() - def get_node_neighbors(self, node: Node) -> Dict[Node, int]: + def get_node_neighbors(self, node: Vertex) -> Dict[Vertex, int]: """Returns the neighbors of a node.""" - neighbors: Dict[Node, int] = {} + neighbors: Dict[Vertex, int] = {} for edge in self.edges: if edge.source == node: neighbor = edge.target @@ -117,28 +117,30 @@ class Graph: edges.append(Edge(source, target)) return edges - def _get_node_class(self, node_type: str, node_lc_type: str) -> Type[Node]: + def _get_vertex_class(self, node_type: str, node_lc_type: str) -> Type[Vertex]: """Returns the node class based on the node type.""" if node_type in FILE_TOOLS: - return FileToolNode - if node_type in NODE_TYPE_MAP: - return NODE_TYPE_MAP[node_type] - return NODE_TYPE_MAP[node_lc_type] if node_lc_type in NODE_TYPE_MAP else Node + return FileToolVertex + if node_type in VERTEX_TYPE_MAP: + return VERTEX_TYPE_MAP[node_type] + return ( + VERTEX_TYPE_MAP[node_lc_type] if node_lc_type in VERTEX_TYPE_MAP else Vertex + ) - def _build_nodes(self) -> List[Node]: - """Builds the nodes of the graph.""" - nodes: List[Node] = [] + def _build_vertices(self) -> List[Vertex]: + """Builds the vertices of the graph.""" + nodes: List[Vertex] = [] for node in self._nodes: node_data = node["data"] node_type: str = node_data["type"] # type: ignore node_lc_type: str = node_data["node"]["template"]["_type"] # type: ignore - NodeClass = self._get_node_class(node_type, node_lc_type) - nodes.append(NodeClass(node)) + VertexClass = self._get_vertex_class(node_type, node_lc_type) + nodes.append(VertexClass(node)) return nodes - def get_children_by_node_type(self, node: Node, node_type: str) -> List[Node]: + def get_children_by_node_type(self, node: Vertex, node_type: str) -> List[Vertex]: """Returns the children of a node based on the node type.""" children = [] node_types = [node.data["type"]] diff --git a/src/backend/langflow/graph/graph/constants.py b/src/backend/langflow/graph/graph/constants.py index f5bc9b8e3..ff1317d39 100644 --- a/src/backend/langflow/graph/graph/constants.py +++ b/src/backend/langflow/graph/graph/constants.py @@ -1,17 +1,17 @@ -from langflow.graph.node.base import Node -from langflow.graph.node.types import ( - AgentNode, - ChainNode, - DocumentLoaderNode, - EmbeddingNode, - LLMNode, - MemoryNode, - PromptNode, - TextSplitterNode, - ToolNode, - ToolkitNode, - VectorStoreNode, - WrapperNode, +from langflow.graph.vertex.base import Vertex +from langflow.graph.vertex.types import ( + AgentVertex, + ChainVertex, + DocumentLoaderVertex, + EmbeddingVertex, + LLMVertex, + MemoryVertex, + PromptVertex, + TextSplitterVertex, + ToolVertex, + ToolkitVertex, + VectorStoreVertex, + WrapperVertex, ) from langflow.interface.agents.base import agent_creator from langflow.interface.chains.base import chain_creator @@ -33,17 +33,17 @@ from typing import Dict, Type DIRECT_TYPES = ["str", "bool", "code", "int", "float", "Any", "prompt"] -NODE_TYPE_MAP: Dict[str, Type[Node]] = { - **{t: PromptNode for t in prompt_creator.to_list()}, - **{t: AgentNode for t in agent_creator.to_list()}, - **{t: ChainNode for t in chain_creator.to_list()}, - **{t: ToolNode for t in tool_creator.to_list()}, - **{t: ToolkitNode for t in toolkits_creator.to_list()}, - **{t: WrapperNode for t in wrapper_creator.to_list()}, - **{t: LLMNode for t in llm_creator.to_list()}, - **{t: MemoryNode for t in memory_creator.to_list()}, - **{t: EmbeddingNode for t in embedding_creator.to_list()}, - **{t: VectorStoreNode for t in vectorstore_creator.to_list()}, - **{t: DocumentLoaderNode for t in documentloader_creator.to_list()}, - **{t: TextSplitterNode for t in textsplitter_creator.to_list()}, +VERTEX_TYPE_MAP: Dict[str, Type[Vertex]] = { + **{t: PromptVertex for t in prompt_creator.to_list()}, + **{t: AgentVertex for t in agent_creator.to_list()}, + **{t: ChainVertex for t in chain_creator.to_list()}, + **{t: ToolVertex for t in tool_creator.to_list()}, + **{t: ToolkitVertex for t in toolkits_creator.to_list()}, + **{t: WrapperVertex for t in wrapper_creator.to_list()}, + **{t: LLMVertex for t in llm_creator.to_list()}, + **{t: MemoryVertex for t in memory_creator.to_list()}, + **{t: EmbeddingVertex for t in embedding_creator.to_list()}, + **{t: VectorStoreVertex for t in vectorstore_creator.to_list()}, + **{t: DocumentLoaderVertex for t in documentloader_creator.to_list()}, + **{t: TextSplitterVertex for t in textsplitter_creator.to_list()}, } diff --git a/src/backend/langflow/graph/node/__init__.py b/src/backend/langflow/graph/vertex/__init__.py similarity index 100% rename from src/backend/langflow/graph/node/__init__.py rename to src/backend/langflow/graph/vertex/__init__.py diff --git a/src/backend/langflow/graph/node/base.py b/src/backend/langflow/graph/vertex/base.py similarity index 92% rename from src/backend/langflow/graph/node/base.py rename to src/backend/langflow/graph/vertex/base.py index 5076deb9c..4593e0a40 100644 --- a/src/backend/langflow/graph/node/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -1,5 +1,5 @@ from langflow.cache import base as cache_utils -from langflow.graph.node.constants import DIRECT_TYPES +from langflow.graph.vertex.constants import DIRECT_TYPES from langflow.interface import loading from langflow.interface.listing import ALL_TYPES_DICT from langflow.utils.logger import logger @@ -17,7 +17,7 @@ if TYPE_CHECKING: from langflow.graph.edge.base import Edge -class Node: +class Vertex: def __init__(self, data: Dict, base_type: Optional[str] = None) -> None: self.id: str = data["id"] self._data = data @@ -48,12 +48,12 @@ class Node: ] template_dict = self.data["node"]["template"] - self.node_type = ( + self.vertex_type = ( self.data["type"] if "Tool" not in self.output else template_dict["_type"] ) if self.base_type is None: for base_type, value in ALL_TYPES_DICT.items(): - if self.node_type in value: + if self.vertex_type in value: self.base_type = base_type break @@ -113,7 +113,7 @@ class Node: if value["required"] and not edges: # If a required parameter is not found, raise an error raise ValueError( - f"Required input {key} for module {self.node_type} not found" + f"Required input {key} for module {self.vertex_type} not found" ) elif value["list"]: # If this is a list parameter, append all sources to a list @@ -128,7 +128,7 @@ class Node: # so we need to check if value has value new_value = value.get("value") if new_value is None: - warnings.warn(f"Value for {key} in {self.node_type} is None. ") + warnings.warn(f"Value for {key} in {self.vertex_type} is None. ") if value.get("type") == "int": with contextlib.suppress(TypeError, ValueError): new_value = int(new_value) # type: ignore @@ -148,12 +148,12 @@ class Node: # and continue # Another aspect is that the node_type is the class that we need to import # and instantiate with these built params - logger.debug(f"Building {self.node_type}") + logger.debug(f"Building {self.vertex_type}") # Build each node in the params dict for key, value in self.params.copy().items(): # Check if Node or list of Nodes and not self # to avoid recursion - if isinstance(value, Node): + if isinstance(value, Vertex): if value == self: del self.params[key] continue @@ -177,7 +177,7 @@ class Node: self.params[key] = result elif isinstance(value, list) and all( - isinstance(node, Node) for node in value + isinstance(node, Vertex) for node in value ): self.params[key] = [] for node in value: @@ -193,17 +193,17 @@ class Node: try: self._built_object = loading.instantiate_class( - node_type=self.node_type, + node_type=self.vertex_type, base_type=self.base_type, params=self.params, ) except Exception as exc: raise ValueError( - f"Error building node {self.node_type}: {str(exc)}" + f"Error building node {self.vertex_type}: {str(exc)}" ) from exc if self._built_object is None: - raise ValueError(f"Node type {self.node_type} not found") + raise ValueError(f"Node type {self.vertex_type} not found") self._built = True @@ -220,7 +220,7 @@ class Node: return f"Node(id={self.id}, data={self.data})" def __eq__(self, __o: object) -> bool: - return self.id == __o.id if isinstance(__o, Node) else False + return self.id == __o.id if isinstance(__o, Vertex) else False def __hash__(self) -> int: return id(self) diff --git a/src/backend/langflow/graph/node/constants.py b/src/backend/langflow/graph/vertex/constants.py similarity index 100% rename from src/backend/langflow/graph/node/constants.py rename to src/backend/langflow/graph/vertex/constants.py diff --git a/src/backend/langflow/graph/node/types.py b/src/backend/langflow/graph/vertex/types.py similarity index 78% rename from src/backend/langflow/graph/node/types.py rename to src/backend/langflow/graph/vertex/types.py index 9b25fd6ee..4a3290c13 100644 --- a/src/backend/langflow/graph/node/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -1,22 +1,22 @@ from typing import Any, Dict, List, Optional, Union -from langflow.graph.node.base import Node +from langflow.graph.vertex.base import Vertex from langflow.graph.utils import extract_input_variables_from_prompt -class AgentNode(Node): +class AgentVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="agents") - self.tools: List[ToolNode] = [] - self.chains: List[ChainNode] = [] + self.tools: List[ToolVertex] = [] + self.chains: List[ChainVertex] = [] def _set_tools_and_chains(self) -> None: for edge in self.edges: source_node = edge.source - if isinstance(source_node, ToolNode): + if isinstance(source_node, ToolVertex): self.tools.append(source_node) - elif isinstance(source_node, ChainNode): + elif isinstance(source_node, ChainVertex): self.chains.append(source_node) def build(self, force: bool = False) -> Any: @@ -33,24 +33,28 @@ class AgentNode(Node): self._build() #! Cannot deepcopy VectorStore, VectorStoreRouter, or SQL agents - if self.node_type in ["VectorStoreAgent", "VectorStoreRouterAgent", "SQLAgent"]: + if self.vertex_type in [ + "VectorStoreAgent", + "VectorStoreRouterAgent", + "SQLAgent", + ]: return self._built_object return self._built_object -class ToolNode(Node): +class ToolVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="tools") -class PromptNode(Node): +class PromptVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="prompts") def build( self, force: bool = False, - tools: Optional[Union[List[Node], List[ToolNode]]] = None, + tools: Optional[Union[List[Vertex], List[ToolVertex]]] = None, ) -> Any: if not self._built or force: if ( @@ -59,7 +63,7 @@ class PromptNode(Node): ): self.params["input_variables"] = [] # Check if it is a ZeroShotPrompt and needs a tool - if "ShotPrompt" in self.node_type: + if "ShotPrompt" in self.vertex_type: tools = ( [tool_node.build() for tool_node in tools] if tools is not None @@ -83,31 +87,31 @@ class PromptNode(Node): return self._built_object -class ChainNode(Node): +class ChainVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="chains") def build( self, force: bool = False, - tools: Optional[Union[List[Node], List[ToolNode]]] = None, + tools: Optional[Union[List[Vertex], List[ToolVertex]]] = None, ) -> Any: if not self._built or force: # Check if the chain requires a PromptNode for key, value in self.params.items(): - if isinstance(value, PromptNode): + if isinstance(value, PromptVertex): # Build the PromptNode, passing the tools if available self.params[key] = value.build(tools=tools, force=force) self._build() #! Cannot deepcopy SQLDatabaseChain - if self.node_type in ["SQLDatabaseChain"]: + if self.vertex_type in ["SQLDatabaseChain"]: return self._built_object return self._built_object -class LLMNode(Node): +class LLMVertex(Vertex): built_node_type = None class_built_object = None @@ -117,28 +121,28 @@ class LLMNode(Node): def build(self, force: bool = False) -> Any: # LLM is different because some models might take up too much memory # or time to load. So we only load them when we need them.ß - if self.node_type == self.built_node_type: + if self.vertex_type == self.built_node_type: return self.class_built_object if not self._built or force: self._build() - self.built_node_type = self.node_type + self.built_node_type = self.vertex_type self.class_built_object = self._built_object # Avoid deepcopying the LLM # that are loaded from a file return self._built_object -class ToolkitNode(Node): +class ToolkitVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="toolkits") -class FileToolNode(ToolNode): +class FileToolVertex(ToolVertex): def __init__(self, data: Dict): super().__init__(data) -class WrapperNode(Node): +class WrapperVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="wrappers") @@ -150,7 +154,7 @@ class WrapperNode(Node): return self._built_object -class DocumentLoaderNode(Node): +class DocumentLoaderVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="documentloaders") @@ -158,17 +162,17 @@ class DocumentLoaderNode(Node): # This built_object is a list of documents. Maybe we should # show how many documents are in the list? if self._built_object: - return f"""{self.node_type}({len(self._built_object)} documents) + return f"""{self.vertex_type}({len(self._built_object)} documents) Documents: {self._built_object[:3]}...""" - return f"{self.node_type}()" + return f"{self.vertex_type}()" -class EmbeddingNode(Node): +class EmbeddingVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="embeddings") -class VectorStoreNode(Node): +class VectorStoreVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="vectorstores") @@ -176,12 +180,12 @@ class VectorStoreNode(Node): return "Vector stores can take time to build. It will build on the first query." -class MemoryNode(Node): +class MemoryVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="memory") -class TextSplitterNode(Node): +class TextSplitterVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="textsplitters") @@ -189,5 +193,5 @@ class TextSplitterNode(Node): # This built_object is a list of documents. Maybe we should # show how many documents are in the list? if self._built_object: - return f"""{self.node_type}({len(self._built_object)} documents)\nDocuments: {self._built_object[:3]}...""" - return f"{self.node_type}()" + return f"""{self.vertex_type}({len(self._built_object)} documents)\nDocuments: {self._built_object[:3]}...""" + return f"{self.vertex_type}()" diff --git a/tests/test_graph.py b/tests/test_graph.py index cdbe0ba93..c7b5ddf0c 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,20 +1,20 @@ from typing import Type, Union from langflow.graph.edge.base import Edge -from langflow.graph.node.base import Node +from langflow.graph.vertex.base import Vertex import pytest from langchain.chains.base import Chain from langchain.llms.fake import FakeListLLM from langflow.graph import Graph -from langflow.graph.node.types import ( - AgentNode, - ChainNode, - FileToolNode, - LLMNode, - PromptNode, - ToolkitNode, - ToolNode, - WrapperNode, +from langflow.graph.vertex.types import ( + AgentVertex, + ChainVertex, + FileToolVertex, + LLMVertex, + PromptVertex, + ToolkitVertex, + ToolVertex, + WrapperVertex, ) from langflow.interface.run import get_result_and_thought from langflow.utils.payload import get_root_node @@ -25,7 +25,7 @@ from langflow.utils.payload import get_root_node # BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH -def get_node_by_type(graph, node_type: Type[Node]) -> Union[Node, None]: +def get_node_by_type(graph, node_type: Type[Vertex]) -> Union[Vertex, None]: """Get a node by type""" return next((node for node in graph.nodes if isinstance(node, node_type)), None) @@ -35,7 +35,7 @@ def test_graph_structure(basic_graph): assert len(basic_graph.nodes) > 0 assert len(basic_graph.edges) > 0 for node in basic_graph.nodes: - assert isinstance(node, Node) + assert isinstance(node, Vertex) for edge in basic_graph.edges: assert isinstance(edge, Edge) assert edge.source in basic_graph.nodes @@ -165,7 +165,7 @@ def test_get_node(basic_graph): """Test getting a single node""" node_id = basic_graph.nodes[0].id node = basic_graph.get_node(node_id) - assert isinstance(node, Node) + assert isinstance(node, Vertex) assert node.id == node_id @@ -174,7 +174,7 @@ def test_build_nodes(basic_graph): assert len(basic_graph.nodes) == len(basic_graph._nodes) for node in basic_graph.nodes: - assert isinstance(node, Node) + assert isinstance(node, Vertex) def test_build_edges(basic_graph): @@ -182,8 +182,8 @@ def test_build_edges(basic_graph): assert len(basic_graph.edges) == len(basic_graph._edges) for edge in basic_graph.edges: assert isinstance(edge, Edge) - assert isinstance(edge.source, Node) - assert isinstance(edge.target, Node) + assert isinstance(edge.source, Vertex) + assert isinstance(edge.target, Vertex) def test_get_root_node(basic_graph, complex_graph): @@ -191,13 +191,13 @@ def test_get_root_node(basic_graph, complex_graph): assert isinstance(basic_graph, Graph) root = get_root_node(basic_graph) assert root is not None - assert isinstance(root, Node) + assert isinstance(root, Vertex) assert root.data["type"] == "TimeTravelGuideChain" # For complex example, the root node is a ZeroShotAgent too assert isinstance(complex_graph, Graph) root = get_root_node(complex_graph) assert root is not None - assert isinstance(root, Node) + assert isinstance(root, Vertex) assert root.data["type"] == "ZeroShotAgent" @@ -257,14 +257,14 @@ def assert_agent_was_built(graph): def test_agent_node_build(complex_graph): - agent_node = get_node_by_type(complex_graph, AgentNode) + agent_node = get_node_by_type(complex_graph, AgentVertex) assert agent_node is not None built_object = agent_node.build() assert built_object is not None def test_tool_node_build(complex_graph): - tool_node = get_node_by_type(complex_graph, ToolNode) + tool_node = get_node_by_type(complex_graph, ToolVertex) assert tool_node is not None built_object = tool_node.build() assert built_object is not None @@ -272,7 +272,7 @@ def test_tool_node_build(complex_graph): def test_chain_node_build(complex_graph): - chain_node = get_node_by_type(complex_graph, ChainNode) + chain_node = get_node_by_type(complex_graph, ChainVertex) assert chain_node is not None built_object = chain_node.build() assert built_object is not None @@ -280,7 +280,7 @@ def test_chain_node_build(complex_graph): def test_prompt_node_build(complex_graph): - prompt_node = get_node_by_type(complex_graph, PromptNode) + prompt_node = get_node_by_type(complex_graph, PromptVertex) assert prompt_node is not None built_object = prompt_node.build() assert built_object is not None @@ -288,7 +288,7 @@ def test_prompt_node_build(complex_graph): def test_llm_node_build(basic_graph): - llm_node = get_node_by_type(basic_graph, LLMNode) + llm_node = get_node_by_type(basic_graph, LLMVertex) assert llm_node is not None built_object = llm_node.build() assert built_object is not None @@ -296,7 +296,7 @@ def test_llm_node_build(basic_graph): def test_toolkit_node_build(openapi_graph): - toolkit_node = get_node_by_type(openapi_graph, ToolkitNode) + toolkit_node = get_node_by_type(openapi_graph, ToolkitVertex) assert toolkit_node is not None built_object = toolkit_node.build() assert built_object is not None @@ -304,7 +304,7 @@ def test_toolkit_node_build(openapi_graph): def test_file_tool_node_build(openapi_graph): - file_tool_node = get_node_by_type(openapi_graph, FileToolNode) + file_tool_node = get_node_by_type(openapi_graph, FileToolVertex) assert file_tool_node is not None built_object = file_tool_node.build() assert built_object is not None @@ -312,7 +312,7 @@ def test_file_tool_node_build(openapi_graph): def test_wrapper_node_build(openapi_graph): - wrapper_node = get_node_by_type(openapi_graph, WrapperNode) + wrapper_node = get_node_by_type(openapi_graph, WrapperVertex) assert wrapper_node is not None built_object = wrapper_node.build() assert built_object is not None @@ -327,7 +327,7 @@ def test_get_result_and_thought(basic_graph): message = "Hello" # Find the node that is an LLMNode and change the # _built_object to a FakeListLLM - llm_node = get_node_by_type(basic_graph, LLMNode) + llm_node = get_node_by_type(basic_graph, LLMVertex) assert llm_node is not None llm_node._built_object = FakeListLLM(responses=responses) llm_node._built = True