Refactor vertex sorting in Graph class to use avg_build_time

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-21 12:11:29 -03:00
commit 05916ab03c

View file

@ -504,7 +504,7 @@ class Graph:
# If a vertex has dependencies, it will be placed in the lowest layer index of its dependencies
# minus 1
for vertex_id, deps in self.successor_map.items():
indexes = [vertex_to_layer[dep] for dep in deps]
indexes = [vertex_to_layer[dep] for dep in deps if dep in vertex_to_layer]
new_layer_index = max(min(indexes, default=0) - 1, 0)
new_layer_index_map[vertex_id] = new_layer_index
@ -523,21 +523,23 @@ class Graph:
return refined_layers
def sort_chat_inputs_first(self, vertices: List[List[str]]) -> List[List[str]]:
def sort_chat_inputs_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
chat_inputs_first = []
for layer in vertices:
for layer in vertices_layers:
for vertex_id in layer:
if "ChatInput" in vertex_id:
# Remove the ChatInput from the layer
layer.remove(vertex_id)
chat_inputs_first.append(vertex_id)
if not chat_inputs_first:
return vertices
return vertices_layers
vertices = [chat_inputs_first] + vertices
vertices_layers = [chat_inputs_first] + vertices_layers
return vertices
return vertices_layers
def sort_vertices(self, component_id: Optional[str] = None) -> List[List[str]]:
"""Sorts the vertices in the graph."""
@ -545,13 +547,15 @@ class Graph:
vertices = self.sort_up_to_vertex(component_id)
else:
vertices = self.vertices
vertices = self.layered_topological_sort(vertices)
vertices = self.sort_interface_components_first(vertices)
vertices = self.sort_chat_inputs_first(vertices)
vertices_layers = self.layered_topological_sort(vertices)
vertices_layers = self.sort_by_avg_build_time(vertices_layers)
vertices_layers = self.sort_chat_inputs_first(vertices_layers)
self.increment_run_count()
return vertices
return vertices_layers
def sort_interface_components_first(self, vertices: List[Vertex]) -> List[Vertex]:
def sort_interface_components_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
def contains_interface_component(vertex):
@ -565,6 +569,24 @@ class Graph:
inner_list,
key=lambda vertex: not contains_interface_component(vertex),
)
for inner_list in vertices
for inner_list in vertices_layers
]
return sorted_vertices
def sort_by_avg_build_time(self, vertices_layers: List[str]) -> List[str]:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
def sort_layer_by_avg_build_time(vertices_ids: List[str]) -> List[str]:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
if len(vertices_ids) == 1:
return vertices_ids
vertices_ids.sort(
key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time
)
return vertices_ids
sorted_vertices = [
sort_layer_by_avg_build_time(layer) for layer in vertices_layers
]
return sorted_vertices