Refactor graph state management and add get_state method
This commit is contained in:
parent
0fe8d1116d
commit
e9772d06fc
3 changed files with 56 additions and 19 deletions
|
|
@ -47,7 +47,7 @@ class Graph:
|
|||
self._is_state_vertices: List[str] = []
|
||||
self._has_session_id_vertices: List[str] = []
|
||||
self._sorted_vertices_layers: List[List[str]] = []
|
||||
self.run_id = None
|
||||
self._run_id = None
|
||||
|
||||
self.top_level_vertices = []
|
||||
for vertex in self._vertices:
|
||||
|
|
@ -66,6 +66,10 @@ class Graph:
|
|||
self.define_vertices_lists()
|
||||
self.state_manager = GraphStateManager()
|
||||
|
||||
def get_state(self, name: str) -> Optional[Record]:
|
||||
"""Returns the state of the graph."""
|
||||
return self.state_manager.get_state(name, run_id=self._run_id)
|
||||
|
||||
def update_state(
|
||||
self, name: str, record: Union[str, Record], caller: Optional[str] = None
|
||||
) -> None:
|
||||
|
|
@ -77,7 +81,7 @@ class Graph:
|
|||
# This also has to activate their successors
|
||||
self.activate_state_vertices(name, caller)
|
||||
|
||||
self.state_manager.update_state(name, record)
|
||||
self.state_manager.update_state(name, record, run_id=self._run_id)
|
||||
|
||||
def activate_state_vertices(self, name: str, caller: str):
|
||||
layers = []
|
||||
|
|
@ -109,15 +113,21 @@ class Graph:
|
|||
|
||||
self.activate_state_vertices(name, caller)
|
||||
|
||||
self.state_manager.append_state(name, record)
|
||||
self.state_manager.append_state(name, record, run_id=self._run_id)
|
||||
|
||||
@property
|
||||
def run_id(self):
|
||||
if not self._run_id:
|
||||
raise ValueError("Run ID not set")
|
||||
return self._run_id
|
||||
|
||||
def set_run_id(self, run_id: str):
|
||||
for vertex in self.vertices:
|
||||
self.state_manager.subscribe(run_id, vertex.update_graph_state)
|
||||
self.run_id = run_id
|
||||
self._run_id = run_id
|
||||
|
||||
def add_state(self, state: str):
|
||||
self.state_manager.append_state(self.run_id, state)
|
||||
self.state_manager.append_state(self._run_id, state)
|
||||
|
||||
@property
|
||||
def sorted_vertices_layers(self) -> List[List[str]]:
|
||||
|
|
@ -760,21 +770,28 @@ class Graph:
|
|||
def layered_topological_sort(
|
||||
self,
|
||||
vertices: List[Vertex],
|
||||
filter_graphs: bool = False,
|
||||
) -> List[List[str]]:
|
||||
"""Performs a layered topological sort of the vertices in the graph."""
|
||||
vertices_ids = {vertex.id for vertex in vertices}
|
||||
# Queue for vertices with no incoming edges
|
||||
queue = deque(
|
||||
vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0
|
||||
vertex.id
|
||||
for vertex in vertices
|
||||
# if filter_graphs then only vertex.is_input will be considered
|
||||
if self.in_degree_map[vertex.id] == 0
|
||||
and (not filter_graphs or vertex.is_input)
|
||||
)
|
||||
layers: List[List[str]] = []
|
||||
|
||||
visited = set(queue)
|
||||
current_layer = 0
|
||||
while queue:
|
||||
layers.append([]) # Start a new layer
|
||||
layer_size = len(queue)
|
||||
for _ in range(layer_size):
|
||||
vertex_id = queue.popleft()
|
||||
visited.add(vertex_id)
|
||||
|
||||
layers[current_layer].append(vertex_id)
|
||||
for neighbor in self.successor_map[vertex_id]:
|
||||
# only vertices in `vertices_ids` should be considered
|
||||
|
|
@ -785,8 +802,16 @@ class Graph:
|
|||
continue
|
||||
|
||||
self.in_degree_map[neighbor] -= 1 # 'remove' edge
|
||||
if self.in_degree_map[neighbor] == 0:
|
||||
if self.in_degree_map[neighbor] == 0 and neighbor not in visited:
|
||||
queue.append(neighbor)
|
||||
|
||||
# if > 0 it might mean not all predecessors have added to the queue
|
||||
# so we should process the neighbors predecessors
|
||||
elif self.in_degree_map[neighbor] > 0:
|
||||
for predecessor in self.predecessor_map[neighbor]:
|
||||
if predecessor not in queue and predecessor not in visited:
|
||||
queue.append(predecessor)
|
||||
|
||||
current_layer += 1 # Next layer
|
||||
new_layers = self.refine_layers(layers)
|
||||
return new_layers
|
||||
|
|
@ -851,9 +876,15 @@ class Graph:
|
|||
self.mark_all_vertices("ACTIVE")
|
||||
if component_id:
|
||||
vertices = self.sort_up_to_vertex(component_id)
|
||||
vertices_layers = self.layered_topological_sort(vertices)
|
||||
else:
|
||||
vertices = self.vertices
|
||||
vertices_layers = self.layered_topological_sort(vertices)
|
||||
# without component_id we are probably running in the chat
|
||||
# so we want to pick only graphs that start with ChatInput or
|
||||
# TextInput
|
||||
vertices_layers = self.layered_topological_sort(
|
||||
vertices, filter_graphs=True
|
||||
)
|
||||
vertices_layers = self.sort_by_avg_build_time(vertices_layers)
|
||||
vertices_layers = self.sort_chat_inputs_first(vertices_layers)
|
||||
self.increment_run_count()
|
||||
|
|
|
|||
|
|
@ -11,23 +11,29 @@ class GraphStateManager:
|
|||
self.observers = defaultdict(list)
|
||||
self.lock = Lock()
|
||||
|
||||
def append_state(self, key, new_state):
|
||||
def append_state(self, key, new_state, run_id: str):
|
||||
with self.lock:
|
||||
if key not in self.states:
|
||||
self.states[key] = []
|
||||
if run_id not in self.states:
|
||||
self.states[run_id] = {}
|
||||
if key not in self.states[run_id]:
|
||||
self.states[run_id][key] = []
|
||||
elif not isinstance(self.states[key], list):
|
||||
self.states[key] = [self.states[key]]
|
||||
self.states[key].append(new_state)
|
||||
self.states[run_id][key] = [self.states[key]]
|
||||
self.states[run_id][key].append(new_state)
|
||||
self.notify_append_observers(key, new_state)
|
||||
|
||||
def update_state(self, key, new_state):
|
||||
def update_state(self, key, new_state, run_id: str):
|
||||
with self.lock:
|
||||
self.states[key] = new_state
|
||||
if run_id not in self.states:
|
||||
self.states[run_id] = {}
|
||||
if key not in self.states[run_id]:
|
||||
self.states[run_id][key] = {}
|
||||
self.states[run_id][key] = new_state
|
||||
self.notify_observers(key, new_state)
|
||||
|
||||
def get_state(self, key):
|
||||
def get_state(self, key, run_id: str):
|
||||
with self.lock:
|
||||
return self.states.get(key, "")
|
||||
return self.states.get(run_id, {}).get(key, "")
|
||||
|
||||
def subscribe(self, key, observer: Callable):
|
||||
with self.lock:
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class CustomComponent(Component):
|
|||
|
||||
def get_state(self, name: str):
|
||||
try:
|
||||
return self.vertex.graph.state_manager.get_state(key=name)
|
||||
return self.vertex.graph.get_state(name=name)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error getting state: {e}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue