Refactor graph activation logic and add successors_ids property to StateVertex
This commit is contained in:
parent
e8d047239e
commit
178a98dc8e
3 changed files with 17 additions and 15 deletions
|
|
@ -162,7 +162,7 @@ async def build_vertex(
|
|||
inactivated_vertices = None
|
||||
inactivated_vertices = list(graph.inactivated_vertices)
|
||||
graph.reset_inactivated_vertices()
|
||||
activated_layers = graph.activated_layers
|
||||
activated_layers = graph.activated_vertices
|
||||
graph.reset_activated_vertices()
|
||||
chat_service.set_cache(flow_id, graph)
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class Graph:
|
|||
self._vertices = self._graph_data["nodes"]
|
||||
self._edges = self._graph_data["edges"]
|
||||
self.inactivated_vertices: set = set()
|
||||
self.activated_layers: List[List[str]] = []
|
||||
self.activated_vertices: List[str] = []
|
||||
self.vertices_layers = []
|
||||
self.vertices_to_run = set()
|
||||
self.stop_vertex = None
|
||||
|
|
@ -90,7 +90,7 @@ class Graph:
|
|||
self.state_manager.update_state(name, record, run_id=self._run_id)
|
||||
|
||||
def activate_state_vertices(self, name: str, caller: str):
|
||||
layers = []
|
||||
vertices_ids = []
|
||||
for vertex_id in self._is_state_vertices:
|
||||
if vertex_id == caller:
|
||||
continue
|
||||
|
|
@ -101,15 +101,14 @@ class Graph:
|
|||
and vertex_id != caller
|
||||
and isinstance(vertex, StateVertex)
|
||||
):
|
||||
layers.append([vertex_id])
|
||||
successors = self.get_all_successors(vertex, flat=False)
|
||||
for layer in successors:
|
||||
|
||||
layers.append([v.id for v in layer])
|
||||
self.activated_layers = layers
|
||||
vertices_ids.append(vertex_id)
|
||||
successors = self.get_all_successors(vertex, flat=True)
|
||||
self.vertices_to_run.update(list(map(lambda x: x.id, successors)))
|
||||
self.activated_vertices = vertices_ids
|
||||
self.vertices_to_run.update(vertices_ids)
|
||||
|
||||
def reset_activated_vertices(self):
|
||||
self.activated_layers = []
|
||||
self.activated_vertices = []
|
||||
|
||||
def append_state(
|
||||
self, name: str, record: Union[str, Record], caller: Optional[str] = None
|
||||
|
|
@ -922,15 +921,13 @@ class Graph:
|
|||
vertices = self.sort_up_to_vertex(stop_component_id)
|
||||
elif start_component_id:
|
||||
vertices = self.sort_up_to_vertex(start_component_id, is_start=True)
|
||||
|
||||
else:
|
||||
vertices = self.vertices
|
||||
# without component_id we are probably running in the chat
|
||||
# so we want to pick only graphs that start with ChatInput or
|
||||
# TextInput
|
||||
vertices_layers = self.layered_topological_sort(
|
||||
vertices, filter_graphs=True
|
||||
)
|
||||
|
||||
vertices_layers = self.layered_topological_sort(vertices)
|
||||
vertices_layers = self.sort_by_avg_build_time(vertices_layers)
|
||||
# vertices_layers = self.sort_chat_inputs_first(vertices_layers)
|
||||
self.increment_run_count()
|
||||
|
|
@ -938,7 +935,7 @@ class Graph:
|
|||
# save the only the rest
|
||||
self.vertices_layers = vertices_layers[1:]
|
||||
self.vertices_to_run = {
|
||||
vertex for vertex in chain.from_iterable(vertices_layers)
|
||||
vertex_id for vertex_id in chain.from_iterable(vertices_layers)
|
||||
}
|
||||
# Return just the first layer
|
||||
return first_layer
|
||||
|
|
|
|||
|
|
@ -500,6 +500,11 @@ class StateVertex(Vertex):
|
|||
self.steps = [self._build]
|
||||
self.is_state = True
|
||||
|
||||
@property
|
||||
def successors_ids(self) -> List[str]:
|
||||
successors = self.graph.successor_map.get(self.id, [])
|
||||
return successors + self.graph.activated_vertices
|
||||
|
||||
|
||||
def dict_to_codeblock(d: dict) -> str:
|
||||
serialized = {key: serialize_field(val) for key, val in d.items()}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue