Refactor graph state management and add get_state method

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-02 17:38:21 -03:00
commit e9772d06fc
3 changed files with 56 additions and 19 deletions

View file

@ -47,7 +47,7 @@ class Graph:
self._is_state_vertices: List[str] = []
self._has_session_id_vertices: List[str] = []
self._sorted_vertices_layers: List[List[str]] = []
self.run_id = None
self._run_id = None
self.top_level_vertices = []
for vertex in self._vertices:
@ -66,6 +66,10 @@ class Graph:
self.define_vertices_lists()
self.state_manager = GraphStateManager()
def get_state(self, name: str) -> Optional[Record]:
"""Returns the state of the graph."""
return self.state_manager.get_state(name, run_id=self._run_id)
def update_state(
self, name: str, record: Union[str, Record], caller: Optional[str] = None
) -> None:
@ -77,7 +81,7 @@ class Graph:
# This also has to activate their successors
self.activate_state_vertices(name, caller)
self.state_manager.update_state(name, record)
self.state_manager.update_state(name, record, run_id=self._run_id)
def activate_state_vertices(self, name: str, caller: str):
layers = []
@ -109,15 +113,21 @@ class Graph:
self.activate_state_vertices(name, caller)
self.state_manager.append_state(name, record)
self.state_manager.append_state(name, record, run_id=self._run_id)
@property
def run_id(self):
if not self._run_id:
raise ValueError("Run ID not set")
return self._run_id
def set_run_id(self, run_id: str):
for vertex in self.vertices:
self.state_manager.subscribe(run_id, vertex.update_graph_state)
self.run_id = run_id
self._run_id = run_id
def add_state(self, state: str):
self.state_manager.append_state(self.run_id, state)
self.state_manager.append_state(self._run_id, state)
@property
def sorted_vertices_layers(self) -> List[List[str]]:
@ -760,21 +770,28 @@ class Graph:
def layered_topological_sort(
self,
vertices: List[Vertex],
filter_graphs: bool = False,
) -> List[List[str]]:
"""Performs a layered topological sort of the vertices in the graph."""
vertices_ids = {vertex.id for vertex in vertices}
# Queue for vertices with no incoming edges
queue = deque(
vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0
vertex.id
for vertex in vertices
# if filter_graphs then only vertex.is_input will be considered
if self.in_degree_map[vertex.id] == 0
and (not filter_graphs or vertex.is_input)
)
layers: List[List[str]] = []
visited = set(queue)
current_layer = 0
while queue:
layers.append([]) # Start a new layer
layer_size = len(queue)
for _ in range(layer_size):
vertex_id = queue.popleft()
visited.add(vertex_id)
layers[current_layer].append(vertex_id)
for neighbor in self.successor_map[vertex_id]:
# only vertices in `vertices_ids` should be considered
@ -785,8 +802,16 @@ class Graph:
continue
self.in_degree_map[neighbor] -= 1 # 'remove' edge
if self.in_degree_map[neighbor] == 0:
if self.in_degree_map[neighbor] == 0 and neighbor not in visited:
queue.append(neighbor)
# if > 0 it might mean not all predecessors have added to the queue
# so we should process the neighbors predecessors
elif self.in_degree_map[neighbor] > 0:
for predecessor in self.predecessor_map[neighbor]:
if predecessor not in queue and predecessor not in visited:
queue.append(predecessor)
current_layer += 1 # Next layer
new_layers = self.refine_layers(layers)
return new_layers
@ -851,9 +876,15 @@ class Graph:
self.mark_all_vertices("ACTIVE")
if component_id:
vertices = self.sort_up_to_vertex(component_id)
vertices_layers = self.layered_topological_sort(vertices)
else:
vertices = self.vertices
vertices_layers = self.layered_topological_sort(vertices)
# without component_id we are probably running in the chat
# so we want to pick only graphs that start with ChatInput or
# TextInput
vertices_layers = self.layered_topological_sort(
vertices, filter_graphs=True
)
vertices_layers = self.sort_by_avg_build_time(vertices_layers)
vertices_layers = self.sort_chat_inputs_first(vertices_layers)
self.increment_run_count()

View file

@ -11,23 +11,29 @@ class GraphStateManager:
self.observers = defaultdict(list)
self.lock = Lock()
def append_state(self, key, new_state):
def append_state(self, key, new_state, run_id: str):
with self.lock:
if key not in self.states:
self.states[key] = []
if run_id not in self.states:
self.states[run_id] = {}
if key not in self.states[run_id]:
self.states[run_id][key] = []
elif not isinstance(self.states[key], list):
self.states[key] = [self.states[key]]
self.states[key].append(new_state)
self.states[run_id][key] = [self.states[key]]
self.states[run_id][key].append(new_state)
self.notify_append_observers(key, new_state)
def update_state(self, key, new_state):
def update_state(self, key, new_state, run_id: str):
with self.lock:
self.states[key] = new_state
if run_id not in self.states:
self.states[run_id] = {}
if key not in self.states[run_id]:
self.states[run_id][key] = {}
self.states[run_id][key] = new_state
self.notify_observers(key, new_state)
def get_state(self, key):
def get_state(self, key, run_id: str):
with self.lock:
return self.states.get(key, "")
return self.states.get(run_id, {}).get(key, "")
def subscribe(self, key, observer: Callable):
with self.lock:

View file

@ -92,7 +92,7 @@ class CustomComponent(Component):
def get_state(self, name: str):
try:
return self.vertex.graph.state_manager.get_state(key=name)
return self.vertex.graph.get_state(name=name)
except Exception as e:
raise ValueError(f"Error getting state: {e}")