fix: do not mark self.stop caller (#3335)
* refactor: Refactor mark_branch method in Graph class. * fix: do not mark caller vertex
This commit is contained in:
parent
e5ee0ba946
commit
a90cb1249b
1 changed files with 14 additions and 4 deletions
|
|
@ -782,16 +782,18 @@ class Graph:
|
|||
if state == VertexStates.INACTIVE:
|
||||
self.run_manager.remove_from_predecessors(vertex_id)
|
||||
|
||||
def mark_branch(self, vertex_id: str, state: str, visited: Optional[set] = None, output_name: Optional[str] = None):
|
||||
def _mark_branch(
|
||||
self, vertex_id: str, state: str, visited: Optional[set] = None, output_name: Optional[str] = None
|
||||
):
|
||||
"""Marks a branch of the graph."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
else:
|
||||
self.mark_vertex(vertex_id, state)
|
||||
if vertex_id in visited:
|
||||
return
|
||||
visited.add(vertex_id)
|
||||
|
||||
self.mark_vertex(vertex_id, state)
|
||||
|
||||
for child_id in self.parent_child_map[vertex_id]:
|
||||
# Only child_id that have an edge with the vertex_id through the output_name
|
||||
# should be marked
|
||||
|
|
@ -799,7 +801,15 @@ class Graph:
|
|||
edge = self.get_edge(vertex_id, child_id)
|
||||
if edge and edge.source_handle.name != output_name:
|
||||
continue
|
||||
self.mark_branch(child_id, state)
|
||||
self._mark_branch(child_id, state, visited)
|
||||
|
||||
def mark_branch(self, vertex_id: str, state: str, output_name: Optional[str] = None):
|
||||
self._mark_branch(vertex_id=vertex_id, state=state, output_name=output_name)
|
||||
new_predecessor_map, _ = self.build_adjacency_maps(self.edges)
|
||||
self.run_manager.update_run_state(
|
||||
run_predecessors=new_predecessor_map,
|
||||
vertices_to_run=self.vertices_to_run,
|
||||
)
|
||||
|
||||
def get_edge(self, source_id: str, target_id: str) -> Optional[CycleEdge]:
|
||||
"""Returns the edge between two vertices."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue