Add run and update count properties to Graph class
This commit is contained in:
parent
ccd2355ca1
commit
f58245fab6
1 changed files with 27 additions and 23 deletions
|
|
@ -31,6 +31,8 @@ class Graph:
|
|||
self._vertices = nodes
|
||||
self._edges = edges
|
||||
self.raw_graph_data = {"nodes": nodes, "edges": edges}
|
||||
self._runs = 0
|
||||
self._updates = 0
|
||||
|
||||
self.top_level_vertices = []
|
||||
for vertex in self._vertices:
|
||||
|
|
@ -44,10 +46,23 @@ class Graph:
|
|||
self._build_graph()
|
||||
self.build_graph_maps()
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return {
|
||||
"runs": self._runs,
|
||||
"updates": self._updates,
|
||||
}
|
||||
|
||||
def build_graph_maps(self):
|
||||
self.predecessor_map, self.successor_map = self.build_adjacency_maps()
|
||||
self.in_degree_map = self.build_in_degree()
|
||||
|
||||
def increment_run_count(self):
|
||||
self._runs += 1
|
||||
|
||||
def increment_update_count(self):
|
||||
self._updates += 1
|
||||
|
||||
def __getstate__(self):
|
||||
return self.raw_graph_data
|
||||
|
||||
|
|
@ -137,6 +152,7 @@ class Graph:
|
|||
self._add_vertex(new_vertex)
|
||||
|
||||
self.build_graph_maps()
|
||||
self.increment_update_count()
|
||||
return self
|
||||
|
||||
def reset_all_edges_of_vertex(self, vertex: Vertex) -> None:
|
||||
|
|
@ -439,24 +455,14 @@ class Graph:
|
|||
|
||||
# Filter the original graph's vertices and edges to keep only those in `visited`
|
||||
vertices_to_keep = [self.get_vertex(vid) for vid in visited]
|
||||
edges_to_keep = [
|
||||
e for e in self.edges if e.target_id in visited and e.source_id in visited
|
||||
]
|
||||
|
||||
vertices = self.layered_topological_sort(vertices_to_keep, edges_to_keep)
|
||||
vertices = self.sort_interface_components_first(vertices)
|
||||
return self.sort_chat_inputs_first(vertices)
|
||||
return vertices_to_keep
|
||||
|
||||
def layered_topological_sort(
|
||||
self,
|
||||
vertices: Optional[List[Vertex]] = None,
|
||||
edges: Optional[List[ContractEdge]] = None,
|
||||
vertices: List[Vertex],
|
||||
) -> List[List[str]]:
|
||||
"""Performs a layered topological sort of the vertices in the graph."""
|
||||
if vertices is None:
|
||||
vertices = self.vertices
|
||||
if edges is None:
|
||||
edges = self.edges
|
||||
|
||||
# Queue for vertices with no incoming edges
|
||||
queue = deque(
|
||||
|
|
@ -533,19 +539,17 @@ class Graph:
|
|||
|
||||
return vertices
|
||||
|
||||
def sort_vertices(self) -> List[List[str]]:
|
||||
def sort_vertices(self, component_id: Optional[str] = None) -> List[List[str]]:
|
||||
"""Sorts the vertices in the graph."""
|
||||
vertices = self.layered_topological_sort()
|
||||
# Sort each layer to have ChatInput or ChatOutput first
|
||||
# each layer consists of a list of vertex ids
|
||||
# formatted as ComponentName-5letters
|
||||
# e.g. ChatInput-abcde
|
||||
# we just need to check if the vertex id contains ChatInput or ChatOutput
|
||||
# and sort the layers accordingly
|
||||
# InterfaceComponentTypes is an enum
|
||||
# check all values of the enum and sort the layers
|
||||
if component_id:
|
||||
vertices = self.sort_up_to_vertex(component_id)
|
||||
else:
|
||||
vertices = self.vertices
|
||||
vertices = self.layered_topological_sort(vertices)
|
||||
vertices = self.sort_interface_components_first(vertices)
|
||||
return self.sort_chat_inputs_first(vertices)
|
||||
vertices = self.sort_chat_inputs_first(vertices)
|
||||
self.increment_run_count()
|
||||
return vertices
|
||||
|
||||
def sort_interface_components_first(self, vertices: List[Vertex]) -> List[Vertex]:
|
||||
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue