Fix: update initial graph sort for disconnected graph (#5867)

* refactor: add test for sorting vertices in unconnected graph

Add a new test case to the `test_get_sorted_vertices_with_unconnected_graph` function in the `test_utils.py` file. This test verifies the correct sorting of vertices in an unconnected graph. The test defines a graph structure and checks that the first layer contains the input vertices and the remaining layers contain the rest of the vertices in the correct order.

Refactor the code to improve test coverage and ensure the correctness of the sorting algorithm.

* refactor: improve handling of unconnected vertices in graph sorting

* [autofix.ci] apply automated fixes

* Refactor: Update start_component_id in test_get_sorted_vertices_with_unconnected_graph

The start_component_id parameter in the test_get_sorted_vertices_with_unconnected_graph function was updated to "A" to improve the handling of unconnected vertices in graph sorting.

* Refactor: Improve handling of unconnected vertices in graph sorting

* [autofix.ci] apply automated fixes

* Refactor: Add test_filter_vertices_from_vertex function to test_utils.py

* Refactor: Add error handling for missing graph information in filter_vertices_up_to_vertex and filter_vertices_from_vertex functions

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
anovazzi1 2025-01-22 17:27:23 -03:00 committed by GitHub
commit 1df53a87f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 208 additions and 1 deletions

View file

@ -841,6 +841,31 @@ def get_sorted_vertices(
)
vertices_ids = list(filtered_vertices)
# If we have a start component, we need to filter out unconnected vertices
# but keep vertices that are connected to the graph even if not reachable from start
if start_component_id is not None:
# First get all vertices reachable from start
reachable_vertices = filter_vertices_from_vertex(
vertices_ids,
start_component_id,
get_vertex_predecessors=get_vertex_predecessors,
get_vertex_successors=get_vertex_successors,
graph_dict=graph_dict,
)
# Then get all vertices that can reach any reachable vertex
connected_vertices = set()
for vertex in reachable_vertices:
connected_vertices.update(
filter_vertices_up_to_vertex(
vertices_ids,
vertex,
get_vertex_predecessors=get_vertex_predecessors,
get_vertex_successors=get_vertex_successors,
graph_dict=graph_dict,
)
)
vertices_ids = list(connected_vertices)
# Get the layers
layers = layered_topological_sort(
vertices_ids=set(vertices_ids),
@ -905,7 +930,8 @@ def filter_vertices_up_to_vertex(
# Build predecessor map if not provided
if get_vertex_predecessors is None:
if graph_dict is None:
return set()
msg = "Either get_vertex_predecessors or graph_dict must be provided"
raise ValueError(msg)
def get_vertex_predecessors(v):
return graph_dict[v]["predecessors"]
@ -931,3 +957,58 @@ def filter_vertices_up_to_vertex(
queue.append(predecessor)
return filtered_vertices
def filter_vertices_from_vertex(
vertices_ids: list[str],
vertex_id: str,
get_vertex_predecessors: Callable[[str], list[str]] | None = None,
get_vertex_successors: Callable[[str], list[str]] | None = None,
graph_dict: dict[str, Any] | None = None,
) -> set[str]:
"""Filter vertices starting from a given vertex.
Args:
vertices_ids: List of vertex IDs to filter
vertex_id: ID of the vertex to start filtering from
get_vertex_predecessors: Function to get predecessors of a vertex
get_vertex_successors: Function to get successors of a vertex
graph_dict: Dictionary containing graph information
Returns:
Set of vertex IDs that are successors of the given vertex
"""
vertices_set = set(vertices_ids)
if vertex_id not in vertices_set:
return set()
# Build predecessor map if not provided
if get_vertex_predecessors is None:
if graph_dict is None:
msg = "Either get_vertex_predecessors or graph_dict must be provided"
raise ValueError(msg)
def get_vertex_predecessors(v):
return graph_dict[v]["predecessors"]
# Build successor map if not provided
if get_vertex_successors is None:
if graph_dict is None:
return set()
def get_vertex_successors(v):
return graph_dict[v]["successors"]
# Start with the target vertex
filtered_vertices = {vertex_id}
queue = deque([vertex_id])
# Process vertices in breadth-first order
while queue:
current_vertex = queue.popleft()
for successor in get_vertex_successors(current_vertex):
if successor in vertices_set and successor not in filtered_vertices:
filtered_vertices.add(successor)
queue.append(successor)
return filtered_vertices

View file

@ -851,3 +851,129 @@ def test_get_sorted_vertices_exact_sequence(graph_with_loop):
assert len(sequence) == len(expected_sequence), (
f"Expected sequence length {len(expected_sequence)}, but got {len(sequence)}"
)
def test_get_sorted_vertices_with_unconnected_graph():
# Define a graph with the specified structure
vertices_ids = ["A", "B", "C", "D", "K"]
cycle_vertices = set()
graph_dict = {
"A": {"successors": ["B"], "predecessors": []},
"C": {"successors": ["B"], "predecessors": []},
"B": {"successors": ["D"], "predecessors": ["A", "C"]},
"D": {"successors": [], "predecessors": ["B"]},
"K": {"successors": [], "predecessors": []},
}
in_degree_map = {vertex: len(data["predecessors"]) for vertex, data in graph_dict.items()}
successor_map = {vertex: data["successors"] for vertex, data in graph_dict.items()}
predecessor_map = {vertex: data["predecessors"] for vertex, data in graph_dict.items()}
def is_input_vertex(vertex_id: str) -> bool:
return vertex_id in ["A"]
def get_vertex_predecessors(vertex_id: str) -> list[str]:
return predecessor_map[vertex_id]
def get_vertex_successors(vertex_id: str) -> list[str]:
return successor_map[vertex_id]
first_layer, remaining_layers = utils.get_sorted_vertices(
vertices_ids=vertices_ids,
cycle_vertices=cycle_vertices,
stop_component_id=None,
start_component_id="A",
graph_dict=graph_dict,
in_degree_map=in_degree_map,
successor_map=successor_map,
predecessor_map=predecessor_map,
is_input_vertex=is_input_vertex,
get_vertex_predecessors=get_vertex_predecessors,
get_vertex_successors=get_vertex_successors,
is_cyclic=False,
)
# Verify the first layer contains all input vertices
assert set(first_layer) == {"A", "C"}
# Verify the remaining layers contain the rest of the vertices in the correct order
assert len(remaining_layers) == 2
assert remaining_layers[0] == ["B"]
assert remaining_layers[1] == ["D"]
def test_filter_vertices_from_vertex():
# Test case 1: Simple linear graph
vertices_ids = ["A", "B", "C", "D"]
graph_dict = {
"A": {"successors": ["B"], "predecessors": []},
"B": {"successors": ["C"], "predecessors": ["A"]},
"C": {"successors": ["D"], "predecessors": ["B"]},
"D": {"successors": [], "predecessors": ["C"]},
}
# Starting from A should return all vertices
result = utils.filter_vertices_from_vertex(vertices_ids, "A", graph_dict=graph_dict)
assert result == {"A", "B", "C", "D"}
# Starting from B should return B, C, D
result = utils.filter_vertices_from_vertex(vertices_ids, "B", graph_dict=graph_dict)
assert result == {"B", "C", "D"}
# Starting from D should return only D
result = utils.filter_vertices_from_vertex(vertices_ids, "D", graph_dict=graph_dict)
assert result == {"D"}
# Test case 2: Graph with branches
vertices_ids = ["A", "B", "C", "D", "E"]
graph_dict = {
"A": {"successors": ["B", "C"], "predecessors": []},
"B": {"successors": ["D"], "predecessors": ["A"]},
"C": {"successors": ["E"], "predecessors": ["A"]},
"D": {"successors": [], "predecessors": ["B"]},
"E": {"successors": [], "predecessors": ["C"]},
}
# Starting from A should return all vertices
result = utils.filter_vertices_from_vertex(vertices_ids, "A", graph_dict=graph_dict)
assert result == {"A", "B", "C", "D", "E"}
# Starting from B should return B and D
result = utils.filter_vertices_from_vertex(vertices_ids, "B", graph_dict=graph_dict)
assert result == {"B", "D"}
# Test case 3: Graph with unconnected vertices
vertices_ids = ["A", "B", "C", "X", "Y"]
graph_dict = {
"A": {"successors": ["B"], "predecessors": []},
"B": {"successors": ["C"], "predecessors": ["A"]},
"C": {"successors": [], "predecessors": ["B"]},
"X": {"successors": ["Y"], "predecessors": []},
"Y": {"successors": [], "predecessors": ["X"]},
}
# Starting from A should return only A, B, C
result = utils.filter_vertices_from_vertex(vertices_ids, "A", graph_dict=graph_dict)
assert result == {"A", "B", "C"}
# Starting from X should return only X, Y
result = utils.filter_vertices_from_vertex(vertices_ids, "X", graph_dict=graph_dict)
assert result == {"X", "Y"}
# Test case 4: Invalid vertex
result = utils.filter_vertices_from_vertex(vertices_ids, "Z", graph_dict=graph_dict)
assert result == set()
# Test case 5: Using callback functions instead of graph_dict
def get_successors(v: str) -> list[str]:
return graph_dict[v]["successors"]
def get_predecessors(v: str) -> list[str]:
return graph_dict[v]["predecessors"]
result = utils.filter_vertices_from_vertex(
vertices_ids,
"A",
get_vertex_predecessors=get_predecessors,
get_vertex_successors=get_successors,
)
assert result == {"A", "B", "C"}