From 549b420545d5881adf4df3b7941fcd15e14e74dd Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Tue, 18 Jun 2024 14:10:40 -0300 Subject: [PATCH] refactor: Add method to get root of group node in Graph class --- src/backend/base/langflow/graph/graph/base.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index 9f1b2c60e..66e9bfa63 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -748,6 +748,21 @@ class Graph: except KeyError: raise ValueError(f"Vertex {vertex_id} not found") + def get_root_of_group_node(self, vertex_id: str) -> Vertex: + """Returns the root of a group node.""" + if vertex_id in self.top_level_vertices: + # Get all vertices with vertex_id as .parent_node_id + # then get the one at the top + vertices = [vertex for vertex in self.vertices if vertex.parent_node_id == vertex_id] + # Now go through successors of the vertices + # and get the one that none of its successors is in vertices + for vertex in vertices: + successors = self.get_all_successors(vertex, recursive=False) + if not any(successor in vertices for successor in successors): + return vertex + else: + raise ValueError(f"Vertex {vertex_id} is not a top level vertex") + async def build_vertex( self, lock: asyncio.Lock, @@ -1127,7 +1142,6 @@ class Graph: # Initial setup visited = set() # To keep track of visited vertices excluded = set() # To keep track of vertices that should be excluded - stack = [vertex_id] # Use a list as a stack for DFS def get_successors(vertex, recursive=True): # Recursively get the successors of the current vertex @@ -1143,7 +1157,12 @@ class Graph: successors_result.append(successor) return successors_result - stop_or_start_vertex = self.get_vertex(vertex_id) + try: + stop_or_start_vertex = self.get_vertex(vertex_id) + stack = [vertex_id] # Use a list as a stack for DFS + except ValueError: + stop_or_start_vertex = self.get_root_of_group_node(vertex_id) + stack = [stop_or_start_vertex.id] stop_predecessors = [pre.id for pre in stop_or_start_vertex.predecessors] # DFS to collect all vertices that can reach the specified vertex while stack: