diff --git a/src/backend/langflow/api/validate.py b/src/backend/langflow/api/validate.py index 0e2a7752c..53a7ee350 100644 --- a/src/backend/langflow/api/validate.py +++ b/src/backend/langflow/api/validate.py @@ -9,7 +9,7 @@ from langflow.api.base import ( PromptValidationResponse, validate_prompt, ) -from langflow.graph.nodes import VectorStoreNode +from langflow.graph.node.types import VectorStoreNode from langflow.interface.run import build_graph from langflow.utils.logger import logger from langflow.utils.validate import validate_code diff --git a/src/backend/langflow/graph/__init__.py b/src/backend/langflow/graph/__init__.py index 097b7a695..44859da02 100644 --- a/src/backend/langflow/graph/__init__.py +++ b/src/backend/langflow/graph/__init__.py @@ -1,4 +1,35 @@ -from langflow.graph.base import Edge, Node -from langflow.graph.graph import Graph +from langflow.graph.edge.base import Edge +from langflow.graph.graph.base import Graph +from langflow.graph.node.base import Node +from langflow.graph.node.types import ( + AgentNode, + ChainNode, + DocumentLoaderNode, + EmbeddingNode, + LLMNode, + MemoryNode, + PromptNode, + TextSplitterNode, + ToolNode, + ToolkitNode, + VectorStoreNode, + WrapperNode, +) -__all__ = ["Graph", "Node", "Edge"] +__all__ = [ + "Graph", + "Node", + "Edge", + "AgentNode", + "ChainNode", + "DocumentLoaderNode", + "EmbeddingNode", + "LLMNode", + "MemoryNode", + "PromptNode", + "TextSplitterNode", + "ToolNode", + "ToolkitNode", + "VectorStoreNode", + "WrapperNode", +] diff --git a/src/backend/langflow/graph/edge/__init__.py b/src/backend/langflow/graph/edge/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/langflow/graph/edge/base.py b/src/backend/langflow/graph/edge/base.py new file mode 100644 index 000000000..2bf5a1ba4 --- /dev/null +++ b/src/backend/langflow/graph/edge/base.py @@ -0,0 +1,52 @@ +from langflow.utils.logger import logger +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from langflow.graph.node.base import Node + + +class Edge: + def __init__(self, source: "Node", target: "Node"): + self.source: "Node" = source + self.target: "Node" = target + self.validate_edge() + + def validate_edge(self) -> None: + # Validate that the outputs of the source node are valid inputs + # for the target node + self.source_types = self.source.output + self.target_reqs = self.target.required_inputs + self.target.optional_inputs + # Both lists contain strings and sometimes a string contains the value we are + # looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"] + # so we need to check if any of the strings in source_types is in target_reqs + self.valid = any( + output in target_req + for output in self.source_types + for target_req in self.target_reqs + ) + # Get what type of input the target node is expecting + + self.matched_type = next( + ( + output + for output in self.source_types + for target_req in self.target_reqs + if output in target_req + ), + None, + ) + no_matched_type = self.matched_type is None + if no_matched_type: + logger.debug(self.source_types) + logger.debug(self.target_reqs) + if no_matched_type: + raise ValueError( + f"Edge between {self.source.node_type} and {self.target.node_type} " + f"has no matched type" + ) + + def __repr__(self) -> str: + return ( + f"Edge(source={self.source.id}, target={self.target.id}, valid={self.valid}" + f", matched_type={self.matched_type})" + ) diff --git a/src/backend/langflow/graph/graph/__init__.py b/src/backend/langflow/graph/graph/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/langflow/graph/graph.py b/src/backend/langflow/graph/graph/base.py similarity index 64% rename from src/backend/langflow/graph/graph.py rename to src/backend/langflow/graph/graph/base.py index b289d5c31..3ba67837f 100644 --- a/src/backend/langflow/graph/graph.py +++ b/src/backend/langflow/graph/graph/base.py @@ -1,38 +1,20 @@ from typing import Dict, List, Type, Union -from langflow.graph.base import Edge, Node -from langflow.graph.nodes import ( - AgentNode, - ChainNode, - DocumentLoaderNode, - EmbeddingNode, +from langflow.graph.edge.base import Edge +from langflow.graph.graph.constants import NODE_TYPE_MAP +from langflow.graph.node.base import Node +from langflow.graph.node.types import ( FileToolNode, LLMNode, - MemoryNode, - PromptNode, - TextSplitterNode, ToolkitNode, - ToolNode, - VectorStoreNode, - WrapperNode, ) -from langflow.interface.agents.base import agent_creator -from langflow.interface.chains.base import chain_creator -from langflow.interface.document_loaders.base import documentloader_creator -from langflow.interface.embeddings.base import embedding_creator -from langflow.interface.llms.base import llm_creator -from langflow.interface.memories.base import memory_creator -from langflow.interface.prompts.base import prompt_creator -from langflow.interface.text_splitters.base import textsplitter_creator -from langflow.interface.toolkits.base import toolkits_creator -from langflow.interface.tools.base import tool_creator from langflow.interface.tools.constants import FILE_TOOLS -from langflow.interface.vector_store.base import vectorstore_creator -from langflow.interface.wrappers.base import wrapper_creator from langflow.utils import payload class Graph: + """A class representing a graph of nodes and edges.""" + def __init__( self, nodes: List[Dict[str, Union[str, Dict[str, Union[str, List[str]]]]]], @@ -43,6 +25,7 @@ class Graph: self._build_graph() def _build_graph(self) -> None: + """Builds the graph from the nodes and edges.""" self.nodes = self._build_nodes() self.edges = self._build_edges() for edge in self.edges: @@ -51,17 +34,25 @@ class Graph: # This is a hack to make sure that the LLM node is sent to # the toolkit node + self._build_node_params() + # remove invalid nodes + self._remove_invalid_nodes() + + def _build_node_params(self) -> None: + """Identifies and handles the LLM node within the graph.""" llm_node = None for node in self.nodes: node._build_params() - if isinstance(node, LLMNode): llm_node = node - for node in self.nodes: - if isinstance(node, ToolkitNode): - node.params["llm"] = llm_node - # remove invalid nodes + if llm_node: + for node in self.nodes: + if isinstance(node, ToolkitNode): + node.params["llm"] = llm_node + + def _remove_invalid_nodes(self) -> None: + """Removes invalid nodes from the graph.""" self.nodes = [ node for node in self.nodes @@ -70,19 +61,23 @@ class Graph: ] def _validate_node(self, node: Node) -> bool: + """Validates a node.""" # All nodes that do not have edges are invalid return len(node.edges) > 0 def get_node(self, node_id: str) -> Union[None, Node]: + """Returns a node by id.""" return next((node for node in self.nodes if node.id == node_id), None) def get_nodes_with_target(self, node: Node) -> List[Node]: + """Returns the nodes connected to a node.""" connected_nodes: List[Node] = [ edge.source for edge in self.edges if edge.target == node ] return connected_nodes def build(self) -> List[Node]: + """Builds the graph.""" # Get root node root_node = payload.get_root_node(self) if root_node is None: @@ -90,6 +85,7 @@ class Graph: return root_node.build() def get_node_neighbors(self, node: Node) -> Dict[Node, int]: + """Returns the neighbors of a node.""" neighbors: Dict[Node, int] = {} for edge in self.edges: if edge.source == node: @@ -105,6 +101,7 @@ class Graph: return neighbors def _build_edges(self) -> List[Edge]: + """Builds the edges of the graph.""" # Edge takes two nodes as arguments, so we need to build the nodes first # and then build the edges # if we can't find a node, we raise an error @@ -121,30 +118,15 @@ class Graph: return edges def _get_node_class(self, node_type: str, node_lc_type: str) -> Type[Node]: - node_type_map: Dict[str, Type[Node]] = { - **{t: PromptNode for t in prompt_creator.to_list()}, - **{t: AgentNode for t in agent_creator.to_list()}, - **{t: ChainNode for t in chain_creator.to_list()}, - **{t: ToolNode for t in tool_creator.to_list()}, - **{t: ToolkitNode for t in toolkits_creator.to_list()}, - **{t: WrapperNode for t in wrapper_creator.to_list()}, - **{t: LLMNode for t in llm_creator.to_list()}, - **{t: MemoryNode for t in memory_creator.to_list()}, - **{t: EmbeddingNode for t in embedding_creator.to_list()}, - **{t: VectorStoreNode for t in vectorstore_creator.to_list()}, - **{t: DocumentLoaderNode for t in documentloader_creator.to_list()}, - **{t: TextSplitterNode for t in textsplitter_creator.to_list()}, - } - + """Returns the node class based on the node type.""" if node_type in FILE_TOOLS: return FileToolNode - if node_type in node_type_map: - return node_type_map[node_type] - if node_lc_type in node_type_map: - return node_type_map[node_lc_type] - return Node + if node_type in NODE_TYPE_MAP: + return NODE_TYPE_MAP[node_type] + return NODE_TYPE_MAP[node_lc_type] if node_lc_type in NODE_TYPE_MAP else Node def _build_nodes(self) -> List[Node]: + """Builds the nodes of the graph.""" nodes: List[Node] = [] for node in self._nodes: node_data = node["data"] @@ -157,6 +139,7 @@ class Graph: return nodes def get_children_by_node_type(self, node: Node, node_type: str) -> List[Node]: + """Returns the children of a node based on the node type.""" children = [] node_types = [node.data["type"]] if "node" in node.data: diff --git a/src/backend/langflow/graph/graph/constants.py b/src/backend/langflow/graph/graph/constants.py new file mode 100644 index 000000000..f5bc9b8e3 --- /dev/null +++ b/src/backend/langflow/graph/graph/constants.py @@ -0,0 +1,49 @@ +from langflow.graph.node.base import Node +from langflow.graph.node.types import ( + AgentNode, + ChainNode, + DocumentLoaderNode, + EmbeddingNode, + LLMNode, + MemoryNode, + PromptNode, + TextSplitterNode, + ToolNode, + ToolkitNode, + VectorStoreNode, + WrapperNode, +) +from langflow.interface.agents.base import agent_creator +from langflow.interface.chains.base import chain_creator +from langflow.interface.document_loaders.base import documentloader_creator +from langflow.interface.embeddings.base import embedding_creator +from langflow.interface.llms.base import llm_creator +from langflow.interface.memories.base import memory_creator +from langflow.interface.prompts.base import prompt_creator +from langflow.interface.text_splitters.base import textsplitter_creator +from langflow.interface.toolkits.base import toolkits_creator +from langflow.interface.tools.base import tool_creator +from langflow.interface.vector_store.base import vectorstore_creator +from langflow.interface.wrappers.base import wrapper_creator + + +from typing import Dict, Type + + +DIRECT_TYPES = ["str", "bool", "code", "int", "float", "Any", "prompt"] + + +NODE_TYPE_MAP: Dict[str, Type[Node]] = { + **{t: PromptNode for t in prompt_creator.to_list()}, + **{t: AgentNode for t in agent_creator.to_list()}, + **{t: ChainNode for t in chain_creator.to_list()}, + **{t: ToolNode for t in tool_creator.to_list()}, + **{t: ToolkitNode for t in toolkits_creator.to_list()}, + **{t: WrapperNode for t in wrapper_creator.to_list()}, + **{t: LLMNode for t in llm_creator.to_list()}, + **{t: MemoryNode for t in memory_creator.to_list()}, + **{t: EmbeddingNode for t in embedding_creator.to_list()}, + **{t: VectorStoreNode for t in vectorstore_creator.to_list()}, + **{t: DocumentLoaderNode for t in documentloader_creator.to_list()}, + **{t: TextSplitterNode for t in textsplitter_creator.to_list()}, +} diff --git a/src/backend/langflow/graph/node/__init__.py b/src/backend/langflow/graph/node/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/langflow/graph/base.py b/src/backend/langflow/graph/node/base.py similarity index 81% rename from src/backend/langflow/graph/base.py rename to src/backend/langflow/graph/node/base.py index cc5e2902b..5076deb9c 100644 --- a/src/backend/langflow/graph/base.py +++ b/src/backend/langflow/graph/node/base.py @@ -1,27 +1,27 @@ -# Description: Graph class for building a graph of nodes and edges -# Insights: -# - Defer prompts building to the last moment or when they have all the tools -# - Build each inner agent first, then build the outer agent - -import contextlib -import inspect -import types -import warnings -from typing import Any, Dict, List, Optional - from langflow.cache import base as cache_utils -from langflow.graph.constants import DIRECT_TYPES +from langflow.graph.node.constants import DIRECT_TYPES from langflow.interface import loading from langflow.interface.listing import ALL_TYPES_DICT from langflow.utils.logger import logger from langflow.utils.util import sync_to_async +import contextlib +import inspect +import types +import warnings +from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from langflow.graph.edge.base import Edge + + class Node: def __init__(self, data: Dict, base_type: Optional[str] = None) -> None: self.id: str = data["id"] self._data = data - self.edges: List[Edge] = [] + self.edges: List["Edge"] = [] self.base_type: Optional[str] = base_type self._parse_data() self._built_object = None @@ -227,50 +227,3 @@ class Node: def _built_object_repr(self): return repr(self._built_object) - - -class Edge: - def __init__(self, source: "Node", target: "Node"): - self.source: "Node" = source - self.target: "Node" = target - self.validate_edge() - - def validate_edge(self) -> None: - # Validate that the outputs of the source node are valid inputs - # for the target node - self.source_types = self.source.output - self.target_reqs = self.target.required_inputs + self.target.optional_inputs - # Both lists contain strings and sometimes a string contains the value we are - # looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"] - # so we need to check if any of the strings in source_types is in target_reqs - self.valid = any( - output in target_req - for output in self.source_types - for target_req in self.target_reqs - ) - # Get what type of input the target node is expecting - - self.matched_type = next( - ( - output - for output in self.source_types - for target_req in self.target_reqs - if output in target_req - ), - None, - ) - no_matched_type = self.matched_type is None - if no_matched_type: - logger.debug(self.source_types) - logger.debug(self.target_reqs) - if no_matched_type: - raise ValueError( - f"Edge between {self.source.node_type} and {self.target.node_type} " - f"has no matched type" - ) - - def __repr__(self) -> str: - return ( - f"Edge(source={self.source.id}, target={self.target.id}, valid={self.valid}" - f", matched_type={self.matched_type})" - ) diff --git a/src/backend/langflow/graph/constants.py b/src/backend/langflow/graph/node/constants.py similarity index 100% rename from src/backend/langflow/graph/constants.py rename to src/backend/langflow/graph/node/constants.py diff --git a/src/backend/langflow/graph/nodes.py b/src/backend/langflow/graph/node/types.py similarity index 99% rename from src/backend/langflow/graph/nodes.py rename to src/backend/langflow/graph/node/types.py index 189e40b5c..9b25fd6ee 100644 --- a/src/backend/langflow/graph/nodes.py +++ b/src/backend/langflow/graph/node/types.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional, Union -from langflow.graph.base import Node +from langflow.graph.node.base import Node from langflow.graph.utils import extract_input_variables_from_prompt diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index 69c697823..a3799be16 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -12,6 +12,7 @@ from langchain.agents.load_tools import ( _LLM_TOOLS, ) from langchain.agents.loading import load_agent_from_config +from langflow.graph import Graph from langchain.agents.tools import Tool from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager @@ -164,7 +165,6 @@ def instantiate_utility(node_type, class_object, params): def load_flow_from_json(path: str, build=True): """Load flow from json file""" # This is done to avoid circular imports - from langflow.graph import Graph with open(path, "r", encoding="utf-8") as f: flow_graph = json.load(f) diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index d24b6a0dc..c2483416f 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -6,7 +6,7 @@ from langchain.schema import AgentAction from langflow.api.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler # type: ignore from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict -from langflow.graph.graph import Graph +from langflow.graph import Graph from langflow.utils.logger import logger diff --git a/tests/conftest.py b/tests/conftest.py index 870c48a32..d0af2ad84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import json from pathlib import Path from typing import AsyncGenerator +from langflow.graph.graph.base import Graph import pytest from fastapi.testclient import TestClient from httpx import AsyncClient @@ -46,7 +47,6 @@ def client(): def get_graph(_type="basic"): """Get a graph from a json file""" - from langflow.graph.graph import Graph if _type == "basic": path = pytest.BASIC_EXAMPLE_PATH diff --git a/tests/test_graph.py b/tests/test_graph.py index a0f5945fc..cdbe0ba93 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,10 +1,12 @@ from typing import Type, Union +from langflow.graph.edge.base import Edge +from langflow.graph.node.base import Node import pytest from langchain.chains.base import Chain from langchain.llms.fake import FakeListLLM -from langflow.graph import Edge, Graph, Node -from langflow.graph.nodes import ( +from langflow.graph import Graph +from langflow.graph.node.types import ( AgentNode, ChainNode, FileToolNode,