From b7427d08d3cdcda6c0d931b82b12926b05857034 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 16 Feb 2024 13:03:35 -0300 Subject: [PATCH] Refactor graph class and add layer refinement algorithm --- src/backend/langflow/graph/graph/base.py | 82 ++++++++++++++++++++---- 1 file changed, 68 insertions(+), 14 deletions(-) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index f57f01f7d..958c9d319 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -6,7 +6,12 @@ from langflow.graph.edge.base import ContractEdge from langflow.graph.graph.constants import lazy_load_vertex_dict from langflow.graph.graph.utils import process_flow from langflow.graph.vertex.base import Vertex -from langflow.graph.vertex.types import ChatVertex, FileToolVertex, LLMVertex, ToolkitVertex +from langflow.graph.vertex.types import ( + ChatVertex, + FileToolVertex, + LLMVertex, + ToolkitVertex, +) from langflow.interface.tools.constants import FILE_TOOLS from langflow.utils import payload from loguru import logger @@ -127,7 +132,9 @@ class Graph: return for vertex in self.vertices: if not self._validate_vertex(vertex): - raise ValueError(f"{vertex.vertex_type} is not connected to any other components") + raise ValueError( + f"{vertex.vertex_type} is not connected to any other components" + ) def _validate_vertex(self, vertex: Vertex) -> bool: """Validates a vertex.""" @@ -140,7 +147,11 @@ class Graph: def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]: """Returns a list of edges for a given vertex.""" - return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id] + return [ + edge + for edge in self.edges + if edge.source_id == vertex_id or edge.target_id == vertex_id + ] def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]: """Returns the vertices connected to a vertex.""" @@ -178,7 +189,9 @@ class Graph: def dfs(vertex): if state[vertex] == 1: # We have a cycle - raise ValueError("Graph contains a cycle, cannot perform topological sort") + raise ValueError( + "Graph contains a cycle, cannot perform topological sort" + ) if state[vertex] == 0: state[vertex] = 1 for edge in vertex.edges: @@ -237,7 +250,9 @@ class Graph: edges.append(ContractEdge(source, target, edge)) return edges - def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]: + def _get_vertex_class( + self, node_type: str, node_base_type: str, node_id: str + ) -> Type[Vertex]: """Returns the node class based on the node type.""" # First we check for the node_base_type node_name = node_id.split("-")[0] @@ -267,14 +282,18 @@ class Graph: vertex_type: str = vertex_data["type"] # type: ignore vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore - VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"]) + VertexClass = self._get_vertex_class( + vertex_type, vertex_base_type, vertex_data["id"] + ) vertex_instance = VertexClass(vertex, graph=self) vertex_instance.set_top_level(self.top_level_vertices) vertices.append(vertex_instance) return vertices - def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]: + def get_children_by_vertex_type( + self, vertex: Vertex, vertex_type: str + ) -> List[Vertex]: """Returns the children of a vertex based on the vertex type.""" children = [] vertex_types = [vertex.data["type"]] @@ -286,7 +305,9 @@ class Graph: def __repr__(self): vertex_ids = [vertex.id for vertex in self.vertices] - edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]) + edges_repr = "\n".join( + [f"{edge.source_id} --> {edge.target_id}" for edge in self.edges] + ) return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}" def layered_topological_sort(self): @@ -299,7 +320,9 @@ class Graph: in_degree[edge.target_id] += 1 # Queue for vertices with no incoming edges - queue = deque(vertex.id for vertex in self.vertices if in_degree[vertex.id] == 0) + queue = deque( + vertex.id for vertex in self.vertices if in_degree[vertex.id] == 0 + ) layers = [] current_layer = 0 @@ -314,9 +337,40 @@ class Graph: if in_degree[neighbor] == 0: queue.append(neighbor) current_layer += 1 # Next layer + new_layers = self.refine_layers(graph, layers) + return new_layers - return layers - return layers - return layers - return layers - return layers + def refine_layers(self, graph, initial_layers): + # Map each vertex to its current layer + vertex_to_layer = {} + for layer_index, layer in enumerate(initial_layers): + for vertex in layer: + vertex_to_layer[vertex] = layer_index + + # Build the adjacency list for reverse lookup (dependencies) + + refined_layers = [[] for _ in initial_layers] # Start with empty layers + new_layer_index_map = defaultdict( + int + ) # Map each vertex to its highest dependency layer + + for vertex_id, deps in graph.items(): + for dep in deps: + new_layer_index_map[vertex_id] = ( + max(new_layer_index_map[vertex_id], vertex_to_layer[dep]) - 1 + ) + + for layer_index, layer in enumerate(initial_layers): + for vertex_id in layer: + # Place the vertex in the highest possible layer where its dependencies are met + new_layer_index = new_layer_index_map[vertex_id] + if new_layer_index > layer_index: + refined_layers[new_layer_index].append(vertex_id) + vertex_to_layer[vertex_id] = new_layer_index + else: + refined_layers[layer_index].append(vertex_id) + + # Remove empty layers if any + refined_layers = [layer for layer in refined_layers if layer] + + return refined_layers