refactor: Prevent infinite loop in get_successors function. (#3332)

* refactor: Prevent infinite loop in get_successors function.

* feat(base): improve get_all_successors method to handle cyclic graphs and flat parameter efficiently
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-08-14 14:04:51 -03:00 committed by GitHub
commit 33e72f162f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 21 additions and 21 deletions

View file

@ -1412,33 +1412,25 @@ class Graph:
"""Returns the predecessors of a vertex."""
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
def get_all_successors(self, vertex: "Vertex", recursive=True, flat=True):
# Recursively get the successors of the current vertex
# successors = vertex.successors
# if not successors:
# return []
# successors_result = []
# for successor in successors:
# # Just return a list of successors
# if recursive:
# next_successors = self.get_all_successors(successor)
# successors_result.extend(next_successors)
# successors_result.append(successor)
# return successors_result
# The above is the version without the flat parameter
# The below is the version with the flat parameter
# the flat parameter will define if each layer of successors
# becomes one list or if the result is a list of lists
# if flat is True, the result will be a list of vertices
# if flat is False, the result will be a list of lists of vertices
# each list will represent a layer of successors
def get_all_successors(self, vertex: "Vertex", recursive=True, flat=True, visited=None):
if visited is None:
visited = set()
# Prevent revisiting vertices to avoid infinite loops in cyclic graphs
if vertex in visited:
return []
visited.add(vertex)
successors = vertex.successors
if not successors:
return []
successors_result = []
for successor in successors:
if recursive:
next_successors = self.get_all_successors(successor)
next_successors = self.get_all_successors(successor, recursive=recursive, flat=flat, visited=visited)
if flat:
successors_result.extend(next_successors)
else:
@ -1447,6 +1439,10 @@ class Graph:
successors_result.append(successor)
else:
successors_result.append([successor])
if not flat and successors_result:
return [successors] + successors_result
return successors_result
def get_successors(self, vertex: "Vertex") -> List["Vertex"]:

View file

@ -241,8 +241,12 @@ def get_updated_edges(base_flow, g_nodes, g_edges, group_node_id):
def get_successors(graph: Dict[str, Dict[str, List[str]]], vertex_id: str) -> List[str]:
successors_result = []
stack = [vertex_id]
visited = set()
while stack:
current_id = stack.pop()
if current_id in visited:
continue
visited.add(current_id)
successors_result.append(current_id)
stack.extend(graph[current_id]["successors"])
return successors_result