Fix: update initial graph sort for disconnected graph (#5867)
* refactor: add test for sorting vertices in unconnected graph Add a new test case to the `test_get_sorted_vertices_with_unconnected_graph` function in the `test_utils.py` file. This test verifies the correct sorting of vertices in an unconnected graph. The test defines a graph structure and checks that the first layer contains the input vertices and the remaining layers contain the rest of the vertices in the correct order. Refactor the code to improve test coverage and ensure the correctness of the sorting algorithm. * refactor: improve handling of unconnected vertices in graph sorting * [autofix.ci] apply automated fixes * Refactor: Update start_component_id in test_get_sorted_vertices_with_unconnected_graph The start_component_id parameter in the test_get_sorted_vertices_with_unconnected_graph function was updated to "A" to improve the handling of unconnected vertices in graph sorting. * Refactor: Improve handling of unconnected vertices in graph sorting * [autofix.ci] apply automated fixes * Refactor: Add test_filter_vertices_from_vertex function to test_utils.py * Refactor: Add error handling for missing graph information in filter_vertices_up_to_vertex and filter_vertices_from_vertex functions * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
33ba516e48
commit
1df53a87f4
2 changed files with 208 additions and 1 deletions
|
|
@ -841,6 +841,31 @@ def get_sorted_vertices(
|
|||
)
|
||||
vertices_ids = list(filtered_vertices)
|
||||
|
||||
# If we have a start component, we need to filter out unconnected vertices
|
||||
# but keep vertices that are connected to the graph even if not reachable from start
|
||||
if start_component_id is not None:
|
||||
# First get all vertices reachable from start
|
||||
reachable_vertices = filter_vertices_from_vertex(
|
||||
vertices_ids,
|
||||
start_component_id,
|
||||
get_vertex_predecessors=get_vertex_predecessors,
|
||||
get_vertex_successors=get_vertex_successors,
|
||||
graph_dict=graph_dict,
|
||||
)
|
||||
# Then get all vertices that can reach any reachable vertex
|
||||
connected_vertices = set()
|
||||
for vertex in reachable_vertices:
|
||||
connected_vertices.update(
|
||||
filter_vertices_up_to_vertex(
|
||||
vertices_ids,
|
||||
vertex,
|
||||
get_vertex_predecessors=get_vertex_predecessors,
|
||||
get_vertex_successors=get_vertex_successors,
|
||||
graph_dict=graph_dict,
|
||||
)
|
||||
)
|
||||
vertices_ids = list(connected_vertices)
|
||||
|
||||
# Get the layers
|
||||
layers = layered_topological_sort(
|
||||
vertices_ids=set(vertices_ids),
|
||||
|
|
@ -905,7 +930,8 @@ def filter_vertices_up_to_vertex(
|
|||
# Build predecessor map if not provided
|
||||
if get_vertex_predecessors is None:
|
||||
if graph_dict is None:
|
||||
return set()
|
||||
msg = "Either get_vertex_predecessors or graph_dict must be provided"
|
||||
raise ValueError(msg)
|
||||
|
||||
def get_vertex_predecessors(v):
|
||||
return graph_dict[v]["predecessors"]
|
||||
|
|
@ -931,3 +957,58 @@ def filter_vertices_up_to_vertex(
|
|||
queue.append(predecessor)
|
||||
|
||||
return filtered_vertices
|
||||
|
||||
|
||||
def filter_vertices_from_vertex(
|
||||
vertices_ids: list[str],
|
||||
vertex_id: str,
|
||||
get_vertex_predecessors: Callable[[str], list[str]] | None = None,
|
||||
get_vertex_successors: Callable[[str], list[str]] | None = None,
|
||||
graph_dict: dict[str, Any] | None = None,
|
||||
) -> set[str]:
|
||||
"""Filter vertices starting from a given vertex.
|
||||
|
||||
Args:
|
||||
vertices_ids: List of vertex IDs to filter
|
||||
vertex_id: ID of the vertex to start filtering from
|
||||
get_vertex_predecessors: Function to get predecessors of a vertex
|
||||
get_vertex_successors: Function to get successors of a vertex
|
||||
graph_dict: Dictionary containing graph information
|
||||
|
||||
Returns:
|
||||
Set of vertex IDs that are successors of the given vertex
|
||||
"""
|
||||
vertices_set = set(vertices_ids)
|
||||
if vertex_id not in vertices_set:
|
||||
return set()
|
||||
|
||||
# Build predecessor map if not provided
|
||||
if get_vertex_predecessors is None:
|
||||
if graph_dict is None:
|
||||
msg = "Either get_vertex_predecessors or graph_dict must be provided"
|
||||
raise ValueError(msg)
|
||||
|
||||
def get_vertex_predecessors(v):
|
||||
return graph_dict[v]["predecessors"]
|
||||
|
||||
# Build successor map if not provided
|
||||
if get_vertex_successors is None:
|
||||
if graph_dict is None:
|
||||
return set()
|
||||
|
||||
def get_vertex_successors(v):
|
||||
return graph_dict[v]["successors"]
|
||||
|
||||
# Start with the target vertex
|
||||
filtered_vertices = {vertex_id}
|
||||
queue = deque([vertex_id])
|
||||
|
||||
# Process vertices in breadth-first order
|
||||
while queue:
|
||||
current_vertex = queue.popleft()
|
||||
for successor in get_vertex_successors(current_vertex):
|
||||
if successor in vertices_set and successor not in filtered_vertices:
|
||||
filtered_vertices.add(successor)
|
||||
queue.append(successor)
|
||||
|
||||
return filtered_vertices
|
||||
|
|
|
|||
|
|
@ -851,3 +851,129 @@ def test_get_sorted_vertices_exact_sequence(graph_with_loop):
|
|||
assert len(sequence) == len(expected_sequence), (
|
||||
f"Expected sequence length {len(expected_sequence)}, but got {len(sequence)}"
|
||||
)
|
||||
|
||||
|
||||
def test_get_sorted_vertices_with_unconnected_graph():
|
||||
# Define a graph with the specified structure
|
||||
vertices_ids = ["A", "B", "C", "D", "K"]
|
||||
cycle_vertices = set()
|
||||
graph_dict = {
|
||||
"A": {"successors": ["B"], "predecessors": []},
|
||||
"C": {"successors": ["B"], "predecessors": []},
|
||||
"B": {"successors": ["D"], "predecessors": ["A", "C"]},
|
||||
"D": {"successors": [], "predecessors": ["B"]},
|
||||
"K": {"successors": [], "predecessors": []},
|
||||
}
|
||||
in_degree_map = {vertex: len(data["predecessors"]) for vertex, data in graph_dict.items()}
|
||||
successor_map = {vertex: data["successors"] for vertex, data in graph_dict.items()}
|
||||
predecessor_map = {vertex: data["predecessors"] for vertex, data in graph_dict.items()}
|
||||
|
||||
def is_input_vertex(vertex_id: str) -> bool:
|
||||
return vertex_id in ["A"]
|
||||
|
||||
def get_vertex_predecessors(vertex_id: str) -> list[str]:
|
||||
return predecessor_map[vertex_id]
|
||||
|
||||
def get_vertex_successors(vertex_id: str) -> list[str]:
|
||||
return successor_map[vertex_id]
|
||||
|
||||
first_layer, remaining_layers = utils.get_sorted_vertices(
|
||||
vertices_ids=vertices_ids,
|
||||
cycle_vertices=cycle_vertices,
|
||||
stop_component_id=None,
|
||||
start_component_id="A",
|
||||
graph_dict=graph_dict,
|
||||
in_degree_map=in_degree_map,
|
||||
successor_map=successor_map,
|
||||
predecessor_map=predecessor_map,
|
||||
is_input_vertex=is_input_vertex,
|
||||
get_vertex_predecessors=get_vertex_predecessors,
|
||||
get_vertex_successors=get_vertex_successors,
|
||||
is_cyclic=False,
|
||||
)
|
||||
|
||||
# Verify the first layer contains all input vertices
|
||||
assert set(first_layer) == {"A", "C"}
|
||||
|
||||
# Verify the remaining layers contain the rest of the vertices in the correct order
|
||||
assert len(remaining_layers) == 2
|
||||
assert remaining_layers[0] == ["B"]
|
||||
assert remaining_layers[1] == ["D"]
|
||||
|
||||
|
||||
def test_filter_vertices_from_vertex():
|
||||
# Test case 1: Simple linear graph
|
||||
vertices_ids = ["A", "B", "C", "D"]
|
||||
graph_dict = {
|
||||
"A": {"successors": ["B"], "predecessors": []},
|
||||
"B": {"successors": ["C"], "predecessors": ["A"]},
|
||||
"C": {"successors": ["D"], "predecessors": ["B"]},
|
||||
"D": {"successors": [], "predecessors": ["C"]},
|
||||
}
|
||||
|
||||
# Starting from A should return all vertices
|
||||
result = utils.filter_vertices_from_vertex(vertices_ids, "A", graph_dict=graph_dict)
|
||||
assert result == {"A", "B", "C", "D"}
|
||||
|
||||
# Starting from B should return B, C, D
|
||||
result = utils.filter_vertices_from_vertex(vertices_ids, "B", graph_dict=graph_dict)
|
||||
assert result == {"B", "C", "D"}
|
||||
|
||||
# Starting from D should return only D
|
||||
result = utils.filter_vertices_from_vertex(vertices_ids, "D", graph_dict=graph_dict)
|
||||
assert result == {"D"}
|
||||
|
||||
# Test case 2: Graph with branches
|
||||
vertices_ids = ["A", "B", "C", "D", "E"]
|
||||
graph_dict = {
|
||||
"A": {"successors": ["B", "C"], "predecessors": []},
|
||||
"B": {"successors": ["D"], "predecessors": ["A"]},
|
||||
"C": {"successors": ["E"], "predecessors": ["A"]},
|
||||
"D": {"successors": [], "predecessors": ["B"]},
|
||||
"E": {"successors": [], "predecessors": ["C"]},
|
||||
}
|
||||
|
||||
# Starting from A should return all vertices
|
||||
result = utils.filter_vertices_from_vertex(vertices_ids, "A", graph_dict=graph_dict)
|
||||
assert result == {"A", "B", "C", "D", "E"}
|
||||
|
||||
# Starting from B should return B and D
|
||||
result = utils.filter_vertices_from_vertex(vertices_ids, "B", graph_dict=graph_dict)
|
||||
assert result == {"B", "D"}
|
||||
|
||||
# Test case 3: Graph with unconnected vertices
|
||||
vertices_ids = ["A", "B", "C", "X", "Y"]
|
||||
graph_dict = {
|
||||
"A": {"successors": ["B"], "predecessors": []},
|
||||
"B": {"successors": ["C"], "predecessors": ["A"]},
|
||||
"C": {"successors": [], "predecessors": ["B"]},
|
||||
"X": {"successors": ["Y"], "predecessors": []},
|
||||
"Y": {"successors": [], "predecessors": ["X"]},
|
||||
}
|
||||
|
||||
# Starting from A should return only A, B, C
|
||||
result = utils.filter_vertices_from_vertex(vertices_ids, "A", graph_dict=graph_dict)
|
||||
assert result == {"A", "B", "C"}
|
||||
|
||||
# Starting from X should return only X, Y
|
||||
result = utils.filter_vertices_from_vertex(vertices_ids, "X", graph_dict=graph_dict)
|
||||
assert result == {"X", "Y"}
|
||||
|
||||
# Test case 4: Invalid vertex
|
||||
result = utils.filter_vertices_from_vertex(vertices_ids, "Z", graph_dict=graph_dict)
|
||||
assert result == set()
|
||||
|
||||
# Test case 5: Using callback functions instead of graph_dict
|
||||
def get_successors(v: str) -> list[str]:
|
||||
return graph_dict[v]["successors"]
|
||||
|
||||
def get_predecessors(v: str) -> list[str]:
|
||||
return graph_dict[v]["predecessors"]
|
||||
|
||||
result = utils.filter_vertices_from_vertex(
|
||||
vertices_ids,
|
||||
"A",
|
||||
get_vertex_predecessors=get_predecessors,
|
||||
get_vertex_successors=get_successors,
|
||||
)
|
||||
assert result == {"A", "B", "C"}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue