refactor: Add method to get root of group node in Graph class

This commit is contained in:
ogabrielluiz 2024-06-18 14:10:40 -03:00
commit 549b420545

View file

@ -748,6 +748,21 @@ class Graph:
except KeyError:
raise ValueError(f"Vertex {vertex_id} not found")
def get_root_of_group_node(self, vertex_id: str) -> Vertex:
"""Returns the root of a group node."""
if vertex_id in self.top_level_vertices:
# Get all vertices with vertex_id as .parent_node_id
# then get the one at the top
vertices = [vertex for vertex in self.vertices if vertex.parent_node_id == vertex_id]
# Now go through successors of the vertices
# and get the one that none of its successors is in vertices
for vertex in vertices:
successors = self.get_all_successors(vertex, recursive=False)
if not any(successor in vertices for successor in successors):
return vertex
else:
raise ValueError(f"Vertex {vertex_id} is not a top level vertex")
async def build_vertex(
self,
lock: asyncio.Lock,
@ -1127,7 +1142,6 @@ class Graph:
# Initial setup
visited = set() # To keep track of visited vertices
excluded = set() # To keep track of vertices that should be excluded
stack = [vertex_id] # Use a list as a stack for DFS
def get_successors(vertex, recursive=True):
# Recursively get the successors of the current vertex
@ -1143,7 +1157,12 @@ class Graph:
successors_result.append(successor)
return successors_result
stop_or_start_vertex = self.get_vertex(vertex_id)
try:
stop_or_start_vertex = self.get_vertex(vertex_id)
stack = [vertex_id] # Use a list as a stack for DFS
except ValueError:
stop_or_start_vertex = self.get_root_of_group_node(vertex_id)
stack = [stop_or_start_vertex.id]
stop_predecessors = [pre.id for pre in stop_or_start_vertex.predecessors]
# DFS to collect all vertices that can reach the specified vertex
while stack: