refactor: Prevent infinite loop in get_successors function. (#3332)
* refactor: Prevent infinite loop in get_successors function. * feat(base): improve get_all_successors method to handle cyclic graphs and flat parameter efficiently
This commit is contained in:
parent
15dc4347e0
commit
33e72f162f
2 changed files with 21 additions and 21 deletions
|
|
@ -1412,33 +1412,25 @@ class Graph:
|
|||
"""Returns the predecessors of a vertex."""
|
||||
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
|
||||
|
||||
def get_all_successors(self, vertex: "Vertex", recursive=True, flat=True):
|
||||
# Recursively get the successors of the current vertex
|
||||
# successors = vertex.successors
|
||||
# if not successors:
|
||||
# return []
|
||||
# successors_result = []
|
||||
# for successor in successors:
|
||||
# # Just return a list of successors
|
||||
# if recursive:
|
||||
# next_successors = self.get_all_successors(successor)
|
||||
# successors_result.extend(next_successors)
|
||||
# successors_result.append(successor)
|
||||
# return successors_result
|
||||
# The above is the version without the flat parameter
|
||||
# The below is the version with the flat parameter
|
||||
# the flat parameter will define if each layer of successors
|
||||
# becomes one list or if the result is a list of lists
|
||||
# if flat is True, the result will be a list of vertices
|
||||
# if flat is False, the result will be a list of lists of vertices
|
||||
# each list will represent a layer of successors
|
||||
def get_all_successors(self, vertex: "Vertex", recursive=True, flat=True, visited=None):
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
# Prevent revisiting vertices to avoid infinite loops in cyclic graphs
|
||||
if vertex in visited:
|
||||
return []
|
||||
|
||||
visited.add(vertex)
|
||||
|
||||
successors = vertex.successors
|
||||
if not successors:
|
||||
return []
|
||||
|
||||
successors_result = []
|
||||
|
||||
for successor in successors:
|
||||
if recursive:
|
||||
next_successors = self.get_all_successors(successor)
|
||||
next_successors = self.get_all_successors(successor, recursive=recursive, flat=flat, visited=visited)
|
||||
if flat:
|
||||
successors_result.extend(next_successors)
|
||||
else:
|
||||
|
|
@ -1447,6 +1439,10 @@ class Graph:
|
|||
successors_result.append(successor)
|
||||
else:
|
||||
successors_result.append([successor])
|
||||
|
||||
if not flat and successors_result:
|
||||
return [successors] + successors_result
|
||||
|
||||
return successors_result
|
||||
|
||||
def get_successors(self, vertex: "Vertex") -> List["Vertex"]:
|
||||
|
|
|
|||
|
|
@ -241,8 +241,12 @@ def get_updated_edges(base_flow, g_nodes, g_edges, group_node_id):
|
|||
def get_successors(graph: Dict[str, Dict[str, List[str]]], vertex_id: str) -> List[str]:
|
||||
successors_result = []
|
||||
stack = [vertex_id]
|
||||
visited = set()
|
||||
while stack:
|
||||
current_id = stack.pop()
|
||||
if current_id in visited:
|
||||
continue
|
||||
visited.add(current_id)
|
||||
successors_result.append(current_id)
|
||||
stack.extend(graph[current_id]["successors"])
|
||||
return successors_result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue