diff --git a/src/backend/langflow/graph/edge/base.py b/src/backend/langflow/graph/edge/base.py index f9d77741b..c9c1077f6 100644 --- a/src/backend/langflow/graph/edge/base.py +++ b/src/backend/langflow/graph/edge/base.py @@ -1,7 +1,7 @@ +from typing import TYPE_CHECKING, List, Optional + from loguru import logger -from typing import TYPE_CHECKING from pydantic import BaseModel, Field -from typing import List, Optional if TYPE_CHECKING: from langflow.graph.vertex.base import Vertex @@ -22,8 +22,8 @@ class TargetHandle(BaseModel): class Edge: def __init__(self, source: "Vertex", target: "Vertex", edge: dict): - self.source: "Vertex" = source - self.target: "Vertex" = target + self.source_id: str = source.id + self.target_id: str = target.id if data := edge.get("data", {}): self._source_handle = data.get("sourceHandle", {}) self._target_handle = data.get("targetHandle", {}) @@ -31,7 +31,7 @@ class Edge: self.target_handle: TargetHandle = TargetHandle(**self._target_handle) self.target_param = self.target_handle.fieldName # validate handles - self.validate_handles() + self.validate_handles(source, target) else: # Logging here because this is a breaking change logger.error("Edge data is empty") @@ -41,9 +41,9 @@ class Edge: # target_param is documents self.target_param = self._target_handle.split("|")[1] # Validate in __init__ to fail fast - self.validate_edge() + self.validate_edge(source, target) - def validate_handles(self) -> None: + def validate_handles(self, source, target) -> None: if self.target_handle.inputTypes is None: self.valid_handles = self.target_handle.type in self.source_handle.baseClasses else: @@ -54,26 +54,20 @@ class Edge: if not self.valid_handles: logger.debug(self.source_handle) logger.debug(self.target_handle) - raise ValueError( - f"Edge between {self.source.vertex_type} and {self.target.vertex_type} " f"has invalid handles" - ) + raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles") def __setstate__(self, state): - self.source = state["source"] - self.target = state["target"] + self.source_id = state["source_id"] + self.target_id = state["target_id"] self.target_param = state["target_param"] self.source_handle = state.get("source_handle") self.target_handle = state.get("target_handle") - def reset(self) -> None: - self.source._build_params() - self.target._build_params() - - def validate_edge(self) -> None: + def validate_edge(self, source, target) -> None: # Validate that the outputs of the source node are valid inputs # for the target node - self.source_types = self.source.output - self.target_reqs = self.target.required_inputs + self.target.optional_inputs + self.source_types = source.output + self.target_reqs = target.required_inputs + target.optional_inputs # Both lists contain strings and sometimes a string contains the value we are # looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"] # so we need to check if any of the strings in source_types is in target_reqs @@ -88,13 +82,11 @@ class Edge: if no_matched_type: logger.debug(self.source_types) logger.debug(self.target_reqs) - raise ValueError( - f"Edge between {self.source.vertex_type} and {self.target.vertex_type} " f"has no matched type" - ) + raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type") def __repr__(self) -> str: return ( - f"Edge(source={self.source.id}, target={self.target.id}, target_param={self.target_param}" + f"Edge(source={self.source_id}, target={self.target_id}, target_param={self.target_param}" f", matched_type={self.matched_type})" ) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 48e2bac30..9f80907a8 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -13,24 +13,24 @@ from langflow.utils import payload class Graph: - """A class representing a graph of nodes and edges.""" + """A class representing a graph of vertices and edges.""" def __init__( self, nodes: List[Dict], edges: List[Dict[str, str]], ) -> None: - self._nodes = nodes + self._vertices = nodes self._edges = edges self.raw_graph_data = {"nodes": nodes, "edges": edges} - self.top_level_nodes = [] - for node in self._nodes: - if node_id := node.get("id"): - self.top_level_nodes.append(node_id) + self.top_level_vertices = [] + for vertex in self._vertices: + if vertex_id := vertex.get("id"): + self.top_level_vertices.append(vertex_id) self._graph_data = process_flow(self.raw_graph_data) - self._nodes = self._graph_data["nodes"] + self._vertices = self._graph_data["nodes"] self._edges = self._graph_data["edges"] self._build_graph() @@ -54,9 +54,9 @@ class Graph: if "data" in payload: payload = payload["data"] try: - nodes = payload["nodes"] + vertices = payload["nodes"] edges = payload["edges"] - return cls(nodes, edges) + return cls(vertices, edges) except KeyError as exc: logger.exception(exc) raise ValueError( @@ -69,61 +69,69 @@ class Graph: return self.__repr__() == other.__repr__() def _build_graph(self) -> None: - """Builds the graph from the nodes and edges.""" - self.nodes = self._build_vertices() + """Builds the graph from the vertices and edges.""" + self.vertices = self._build_vertices() + self.vertex_ids = [vertex.id for vertex in self.vertices] self.edges = self._build_edges() - for edge in self.edges: - edge.source.add_edge(edge) - edge.target.add_edge(edge) - # This is a hack to make sure that the LLM node is sent to - # the toolkit node - self._build_node_params() - # remove invalid nodes - self._validate_nodes() + # This is a hack to make sure that the LLM vertex is sent to + # the toolkit vertex + self._build_vertex_params() + # remove invalid vertices + self._validate_vertices() - def _build_node_params(self) -> None: - """Identifies and handles the LLM node within the graph.""" - llm_node = None - for node in self.nodes: - node._build_params() - if isinstance(node, LLMVertex): - llm_node = node + def _build_vertex_params(self) -> None: + """Identifies and handles the LLM vertex within the graph.""" + llm_vertex = None + for vertex in self.vertices: + vertex._build_params() + if isinstance(vertex, LLMVertex): + llm_vertex = vertex - if llm_node: - for node in self.nodes: - if isinstance(node, ToolkitVertex): - node.params["llm"] = llm_node + if llm_vertex: + for vertex in self.vertices: + if isinstance(vertex, ToolkitVertex): + vertex.params["llm"] = llm_vertex - def _validate_nodes(self) -> None: - """Check that all nodes have edges""" - if len(self.nodes) == 1: + def _validate_vertices(self) -> None: + """Check that all vertices have edges""" + if len(self.vertices) == 1: return - for node in self.nodes: - if not self._validate_node(node): - raise ValueError(f"{node.vertex_type} is not connected to any other components") + for vertex in self.vertices: + if not self._validate_vertex(vertex): + raise ValueError(f"{vertex.vertex_type} is not connected to any other components") - def _validate_node(self, node: Vertex) -> bool: - """Validates a node.""" - # All nodes that do not have edges are invalid - return len(node.edges) > 0 + def _validate_vertex(self, vertex: Vertex) -> bool: + """Validates a vertex.""" + # All vertices that do not have edges are invalid + return len(self.get_vertex_edges(vertex.id)) > 0 - def get_node(self, node_id: str) -> Union[None, Vertex]: - """Returns a node by id.""" - return next((node for node in self.nodes if node.id == node_id), None) + def get_vertex(self, vertex_id: str) -> Union[None, Vertex]: + """Returns a vertex by id.""" + return next((vertex for vertex in self.vertices if vertex.id == vertex_id), None) - def get_nodes_with_target(self, node: Vertex) -> List[Vertex]: - """Returns the nodes connected to a node.""" - connected_nodes: List[Vertex] = [edge.source for edge in self.edges if edge.target == node] - return connected_nodes + def get_vertex_edges(self, vertex_id: str) -> List[Edge]: + """Returns a list of edges for a given vertex.""" + return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id] + + def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]: + """Returns the vertices connected to a vertex.""" + vertices: List[Vertex] = [] + for edge in self.edges: + if edge.target_id == vertex_id: + vertex = self.get_vertex(edge.source_id) + if vertex is None: + continue + vertices.append(vertex) + return vertices async def build(self) -> Chain: """Builds the graph.""" - # Get root node - root_node = payload.get_root_node(self) - if root_node is None: - raise ValueError("No root node found") - return await root_node.build() + # Get root vertex + root_vertex = payload.get_root_vertex(self) + if root_vertex is None: + raise ValueError("No root vertex found") + return await root_vertex.build() def topological_sort(self) -> List[Vertex]: """ @@ -136,25 +144,25 @@ class Graph: ValueError: If the graph contains a cycle. """ # States: 0 = unvisited, 1 = visiting, 2 = visited - state = {node: 0 for node in self.nodes} + state = {vertex: 0 for vertex in self.vertices} sorted_vertices = [] - def dfs(node): - if state[node] == 1: + def dfs(vertex): + if state[vertex] == 1: # We have a cycle raise ValueError("Graph contains a cycle, cannot perform topological sort") - if state[node] == 0: - state[node] = 1 - for edge in node.edges: - if edge.source == node: + if state[vertex] == 0: + state[vertex] = 1 + for edge in vertex.edges: + if edge.source == vertex: dfs(edge.target) - state[node] = 2 - sorted_vertices.append(node) + state[vertex] = 2 + sorted_vertices.append(vertex) - # Visit each node - for node in self.nodes: - if state[node] == 0: - dfs(node) + # Visit each vertex + for vertex in self.vertices: + if state[vertex] == 0: + dfs(vertex) return list(reversed(sorted_vertices)) @@ -164,17 +172,21 @@ class Graph: logger.debug("There are %s vertices in the graph", len(sorted_vertices)) yield from sorted_vertices - def get_node_neighbors(self, node: Vertex) -> Dict[Vertex, int]: - """Returns the neighbors of a node.""" + def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]: + """Returns the neighbors of a vertex.""" neighbors: Dict[Vertex, int] = {} for edge in self.edges: - if edge.source == node: - neighbor = edge.target + if edge.source_id == vertex.id: + neighbor = self.get_vertex(edge.target_id) + if neighbor is None: + continue if neighbor not in neighbors: neighbors[neighbor] = 0 neighbors[neighbor] += 1 - elif edge.target == node: - neighbor = edge.source + elif edge.target_id == vertex.id: + neighbor = self.get_vertex(edge.source_id) + if neighbor is None: + continue if neighbor not in neighbors: neighbors[neighbor] = 0 neighbors[neighbor] += 1 @@ -182,59 +194,59 @@ class Graph: def _build_edges(self) -> List[Edge]: """Builds the edges of the graph.""" - # Edge takes two nodes as arguments, so we need to build the nodes first + # Edge takes two vertices as arguments, so we need to build the vertices first # and then build the edges - # if we can't find a node, we raise an error + # if we can't find a vertex, we raise an error edges: List[Edge] = [] for edge in self._edges: - source = self.get_node(edge["source"]) - target = self.get_node(edge["target"]) + source = self.get_vertex(edge["source"]) + target = self.get_vertex(edge["target"]) if source is None: - raise ValueError(f"Source node {edge['source']} not found") + raise ValueError(f"Source vertex {edge['source']} not found") if target is None: - raise ValueError(f"Target node {edge['target']} not found") + raise ValueError(f"Target vertex {edge['target']} not found") edges.append(Edge(source, target, edge)) return edges - def _get_vertex_class(self, node_type: str, node_lc_type: str) -> Type[Vertex]: - """Returns the node class based on the node type.""" - if node_type in FILE_TOOLS: + def _get_vertex_class(self, vertex_type: str, vertex_lc_type: str) -> Type[Vertex]: + """Returns the vertex class based on the vertex type.""" + if vertex_type in FILE_TOOLS: return FileToolVertex - if node_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP: - return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_type] + if vertex_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP: + return lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_type] return ( - lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_lc_type] - if node_lc_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP + lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_lc_type] + if vertex_lc_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP else Vertex ) def _build_vertices(self) -> List[Vertex]: """Builds the vertices of the graph.""" - nodes: List[Vertex] = [] - for node in self._nodes: - node_data = node["data"] - node_type: str = node_data["type"] # type: ignore - node_lc_type: str = node_data["node"]["template"]["_type"] # type: ignore + vertices: List[Vertex] = [] + for vertex in self._vertices: + vertex_data = vertex["data"] + vertex_type: str = vertex_data["type"] # type: ignore + vertex_lc_type: str = vertex_data["node"]["template"]["_type"] # type: ignore - VertexClass = self._get_vertex_class(node_type, node_lc_type) - vertex = VertexClass(node) - vertex.set_top_level(self.top_level_nodes) - nodes.append(vertex) + VertexClass = self._get_vertex_class(vertex_type, vertex_lc_type) + vertex = VertexClass(vertex, graph=self) + vertex.set_top_level(self.top_level_vertices) + vertices.append(vertex) - return nodes + return vertices - def get_children_by_node_type(self, node: Vertex, node_type: str) -> List[Vertex]: - """Returns the children of a node based on the node type.""" + def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]: + """Returns the children of a vertex based on the vertex type.""" children = [] - node_types = [node.data["type"]] - if "node" in node.data: - node_types += node.data["node"]["base_classes"] - if node_type in node_types: - children.append(node) + vertex_types = [vertex.data["type"]] + if "node" in vertex.data: + vertex_types += vertex.data["node"]["base_classes"] + if vertex_type in vertex_types: + children.append(vertex) return children def __repr__(self): - node_ids = [node.id for node in self.nodes] - edges_repr = "\n".join([f"{edge.source.id} --> {edge.target.id}" for edge in self.edges]) - return f"Graph:\nNodes: {node_ids}\nConnections:\n{edges_repr}" + vertex_ids = [vertex.id for vertex in self.vertices] + edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]) + return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}" diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 5ea645980..dc9e76f8a 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -1,33 +1,32 @@ import ast import inspect -import pickle import types from typing import TYPE_CHECKING, Any, Dict, List, Optional -from loguru import logger - from langflow.graph.utils import UnbuiltObject -from langflow.graph.vertex.utils import is_basic_type from langflow.interface.initialize import loading from langflow.interface.listing import lazy_load_dict from langflow.utils.constants import DIRECT_TYPES from langflow.utils.util import sync_to_async +from loguru import logger if TYPE_CHECKING: from langflow.graph.edge.base import Edge + from langflow.graph.graph.base import Graph class Vertex: def __init__( self, data: Dict, + graph: "Graph", base_type: Optional[str] = None, is_task: bool = False, params: Optional[Dict] = None, ) -> None: + self.graph = graph self.id: str = data["id"] self._data = data - self.edges: List["Edge"] = [] self.base_type: Optional[str] = base_type self._parse_data() self._built_object = UnbuiltObject() @@ -39,43 +38,28 @@ class Vertex: self.parent_node_id: Optional[str] = self._data.get("parent_node_id") self.parent_is_top_level = False - def reset_params(self): - for edge in self.edges: - if edge.source != self: - target_param = edge.target_param - if target_param in ["document", "texts"]: - # this means they got data and have already ingested it - # so we continue after removing the param - self.params.pop(target_param, None) - continue - - if target_param in self.params and not is_basic_type(self.params[target_param]): - # edge.source.params = {} - edge.source._build_params() - edge.source._built_object = UnbuiltObject() - edge.source._built = False - - self.params[target_param] = edge.source + @property + def edges(self) -> List["Edge"]: + return self.graph.get_vertex_edges(self.id) def __getstate__(self): - state_dict = self.__dict__.copy() - try: - # try pickling the built object - # if it fails, then we need to delete it - # and build it again - pickle.dumps(state_dict["_built_object"]) - except Exception: - self.reset_params() - del state_dict["_built_object"] - del state_dict["_built"] - return state_dict + return { + "_data": self._data, + "params": {}, + "base_type": self.base_type, + "is_task": self.is_task, + "id": self.id, + "_built_object": UnbuiltObject(), + "_built": False, + "parent_node_id": self.parent_node_id, + "parent_is_top_level": self.parent_is_top_level, + } def __setstate__(self, state): self._data = state["_data"] self.params = state["params"] self.base_type = state["base_type"] self.is_task = state["is_task"] - self.edges = state["edges"] self.id = state["id"] self._parse_data() if "_built_object" in state: @@ -144,6 +128,10 @@ class Vertex: # and use that as the value for the param # If the type is "str", then we need to get the value of the "value" key # and use that as the value for the param + + if self.graph is None: + raise ValueError("Graph not found") + template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} params = self.params.copy() if self.params else {} @@ -155,9 +143,9 @@ class Vertex: if template_dict[param_key]["list"]: if param_key not in params: params[param_key] = [] - params[param_key].append(edge.source) - elif edge.target.id == self.id: - params[param_key] = edge.source + params[param_key].append(self.graph.get_vertex(edge.source_id)) + elif edge.target_id == self.id: + params[param_key] = self.graph.get_vertex(edge.source_id) for key, value in template_dict.items(): if key in params: @@ -177,33 +165,33 @@ class Vertex: else: raise ValueError(f"File path not found for {self.vertex_type}") elif value.get("type") in DIRECT_TYPES and params.get(key) is None: + val = value.get("value") if value.get("type") == "code": try: - params[key] = ast.literal_eval(value.get("value")) + params[key] = ast.literal_eval(val) if val else None except Exception as exc: logger.debug(f"Error parsing code: {exc}") - params[key] = value.get("value") + params[key] = val elif value.get("type") in ["dict", "NestedDict"]: # When dict comes from the frontend it comes as a # list of dicts, so we need to convert it to a dict # before passing it to the build method - _value = value.get("value") - if isinstance(_value, list): + if isinstance(val, list): params[key] = {k: v for item in value.get("value", []) for k, v in item.items()} - elif isinstance(_value, dict): - params[key] = _value - elif value.get("type") == "int" and value.get("value") is not None: + elif isinstance(val, dict): + params[key] = val + elif value.get("type") == "int" and val is not None: try: - params[key] = int(value.get("value")) + params[key] = int(val) except ValueError: - params[key] = value.get("value") - elif value.get("type") == "float" and value.get("value") is not None: + params[key] = val + elif value.get("type") == "float" and val is not None: try: - params[key] = float(value.get("value")) + params[key] = float(val) except ValueError: - params[key] = value.get("value") + params[key] = val else: - params[key] = value.get("value") + params[key] = val if not value.get("required") and params.get(key) is None: if value.get("default"): @@ -266,7 +254,7 @@ class Vertex: pass # If there's no task_id, build the vertex locally - await self.build(user_id) + await self.build(user_id=user_id) return self._built_object async def _build_node_and_update_params(self, key, node, user_id=None):