Refactor graph activation logic and add successors_ids property to StateVertex

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-03 22:47:34 -03:00
commit 178a98dc8e
3 changed files with 17 additions and 15 deletions

View file

@ -162,7 +162,7 @@ async def build_vertex(
inactivated_vertices = None
inactivated_vertices = list(graph.inactivated_vertices)
graph.reset_inactivated_vertices()
activated_layers = graph.activated_layers
activated_layers = graph.activated_vertices
graph.reset_activated_vertices()
chat_service.set_cache(flow_id, graph)

View file

@ -59,7 +59,7 @@ class Graph:
self._vertices = self._graph_data["nodes"]
self._edges = self._graph_data["edges"]
self.inactivated_vertices: set = set()
self.activated_layers: List[List[str]] = []
self.activated_vertices: List[str] = []
self.vertices_layers = []
self.vertices_to_run = set()
self.stop_vertex = None
@ -90,7 +90,7 @@ class Graph:
self.state_manager.update_state(name, record, run_id=self._run_id)
def activate_state_vertices(self, name: str, caller: str):
layers = []
vertices_ids = []
for vertex_id in self._is_state_vertices:
if vertex_id == caller:
continue
@ -101,15 +101,14 @@ class Graph:
and vertex_id != caller
and isinstance(vertex, StateVertex)
):
layers.append([vertex_id])
successors = self.get_all_successors(vertex, flat=False)
for layer in successors:
layers.append([v.id for v in layer])
self.activated_layers = layers
vertices_ids.append(vertex_id)
successors = self.get_all_successors(vertex, flat=True)
self.vertices_to_run.update(list(map(lambda x: x.id, successors)))
self.activated_vertices = vertices_ids
self.vertices_to_run.update(vertices_ids)
def reset_activated_vertices(self):
self.activated_layers = []
self.activated_vertices = []
def append_state(
self, name: str, record: Union[str, Record], caller: Optional[str] = None
@ -922,15 +921,13 @@ class Graph:
vertices = self.sort_up_to_vertex(stop_component_id)
elif start_component_id:
vertices = self.sort_up_to_vertex(start_component_id, is_start=True)
else:
vertices = self.vertices
# without component_id we are probably running in the chat
# so we want to pick only graphs that start with ChatInput or
# TextInput
vertices_layers = self.layered_topological_sort(
vertices, filter_graphs=True
)
vertices_layers = self.layered_topological_sort(vertices)
vertices_layers = self.sort_by_avg_build_time(vertices_layers)
# vertices_layers = self.sort_chat_inputs_first(vertices_layers)
self.increment_run_count()
@ -938,7 +935,7 @@ class Graph:
# save the only the rest
self.vertices_layers = vertices_layers[1:]
self.vertices_to_run = {
vertex for vertex in chain.from_iterable(vertices_layers)
vertex_id for vertex_id in chain.from_iterable(vertices_layers)
}
# Return just the first layer
return first_layer

View file

@ -500,6 +500,11 @@ class StateVertex(Vertex):
self.steps = [self._build]
self.is_state = True
@property
def successors_ids(self) -> List[str]:
successors = self.graph.successor_map.get(self.id, [])
return successors + self.graph.activated_vertices
def dict_to_codeblock(d: dict) -> str:
serialized = {key: serialize_field(val) for key, val in d.items()}