From 33e72f162f712d6cedc4e4c5e3061e2934fe6ce6 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 14 Aug 2024 14:04:51 -0300 Subject: [PATCH] refactor: Prevent infinite loop in get_successors function. (#3332) * refactor: Prevent infinite loop in get_successors function. * feat(base): improve get_all_successors method to handle cyclic graphs and flat parameter efficiently --- src/backend/base/langflow/graph/graph/base.py | 38 +++++++++---------- .../base/langflow/graph/graph/utils.py | 4 ++ 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index 9a93b6053..f3d522901 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -1412,33 +1412,25 @@ class Graph: """Returns the predecessors of a vertex.""" return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])] - def get_all_successors(self, vertex: "Vertex", recursive=True, flat=True): - # Recursively get the successors of the current vertex - # successors = vertex.successors - # if not successors: - # return [] - # successors_result = [] - # for successor in successors: - # # Just return a list of successors - # if recursive: - # next_successors = self.get_all_successors(successor) - # successors_result.extend(next_successors) - # successors_result.append(successor) - # return successors_result - # The above is the version without the flat parameter - # The below is the version with the flat parameter - # the flat parameter will define if each layer of successors - # becomes one list or if the result is a list of lists - # if flat is True, the result will be a list of vertices - # if flat is False, the result will be a list of lists of vertices - # each list will represent a layer of successors + def get_all_successors(self, vertex: "Vertex", recursive=True, flat=True, visited=None): + if visited is None: + visited = set() + + # Prevent revisiting vertices to avoid infinite loops in cyclic graphs + if vertex in visited: + return [] + + visited.add(vertex) + successors = vertex.successors if not successors: return [] + successors_result = [] + for successor in successors: if recursive: - next_successors = self.get_all_successors(successor) + next_successors = self.get_all_successors(successor, recursive=recursive, flat=flat, visited=visited) if flat: successors_result.extend(next_successors) else: @@ -1447,6 +1439,10 @@ class Graph: successors_result.append(successor) else: successors_result.append([successor]) + + if not flat and successors_result: + return [successors] + successors_result + return successors_result def get_successors(self, vertex: "Vertex") -> List["Vertex"]: diff --git a/src/backend/base/langflow/graph/graph/utils.py b/src/backend/base/langflow/graph/graph/utils.py index ccc3c40f7..21e00f696 100644 --- a/src/backend/base/langflow/graph/graph/utils.py +++ b/src/backend/base/langflow/graph/graph/utils.py @@ -241,8 +241,12 @@ def get_updated_edges(base_flow, g_nodes, g_edges, group_node_id): def get_successors(graph: Dict[str, Dict[str, List[str]]], vertex_id: str) -> List[str]: successors_result = [] stack = [vertex_id] + visited = set() while stack: current_id = stack.pop() + if current_id in visited: + continue + visited.add(current_id) successors_result.append(current_id) stack.extend(graph[current_id]["successors"]) return successors_result