From e9772d06fc630654d4ba75e5675afb74781501c6 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sat, 2 Mar 2024 17:38:21 -0300 Subject: [PATCH] Refactor graph state management and add get_state method --- src/backend/langflow/graph/graph/base.py | 49 +++++++++++++++---- .../langflow/graph/graph/state_manager.py | 24 +++++---- .../custom_component/custom_component.py | 2 +- 3 files changed, 56 insertions(+), 19 deletions(-) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 95ab6db6b..df270a201 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -47,7 +47,7 @@ class Graph: self._is_state_vertices: List[str] = [] self._has_session_id_vertices: List[str] = [] self._sorted_vertices_layers: List[List[str]] = [] - self.run_id = None + self._run_id = None self.top_level_vertices = [] for vertex in self._vertices: @@ -66,6 +66,10 @@ class Graph: self.define_vertices_lists() self.state_manager = GraphStateManager() + def get_state(self, name: str) -> Optional[Record]: + """Returns the state of the graph.""" + return self.state_manager.get_state(name, run_id=self._run_id) + def update_state( self, name: str, record: Union[str, Record], caller: Optional[str] = None ) -> None: @@ -77,7 +81,7 @@ class Graph: # This also has to activate their successors self.activate_state_vertices(name, caller) - self.state_manager.update_state(name, record) + self.state_manager.update_state(name, record, run_id=self._run_id) def activate_state_vertices(self, name: str, caller: str): layers = [] @@ -109,15 +113,21 @@ class Graph: self.activate_state_vertices(name, caller) - self.state_manager.append_state(name, record) + self.state_manager.append_state(name, record, run_id=self._run_id) + + @property + def run_id(self): + if not self._run_id: + raise ValueError("Run ID not set") + return self._run_id def set_run_id(self, run_id: str): for vertex in self.vertices: self.state_manager.subscribe(run_id, vertex.update_graph_state) - self.run_id = run_id + self._run_id = run_id def add_state(self, state: str): - self.state_manager.append_state(self.run_id, state) + self.state_manager.append_state(self._run_id, state) @property def sorted_vertices_layers(self) -> List[List[str]]: @@ -760,21 +770,28 @@ class Graph: def layered_topological_sort( self, vertices: List[Vertex], + filter_graphs: bool = False, ) -> List[List[str]]: """Performs a layered topological sort of the vertices in the graph.""" vertices_ids = {vertex.id for vertex in vertices} # Queue for vertices with no incoming edges queue = deque( - vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0 + vertex.id + for vertex in vertices + # if filter_graphs then only vertex.is_input will be considered + if self.in_degree_map[vertex.id] == 0 + and (not filter_graphs or vertex.is_input) ) layers: List[List[str]] = [] - + visited = set(queue) current_layer = 0 while queue: layers.append([]) # Start a new layer layer_size = len(queue) for _ in range(layer_size): vertex_id = queue.popleft() + visited.add(vertex_id) + layers[current_layer].append(vertex_id) for neighbor in self.successor_map[vertex_id]: # only vertices in `vertices_ids` should be considered @@ -785,8 +802,16 @@ class Graph: continue self.in_degree_map[neighbor] -= 1 # 'remove' edge - if self.in_degree_map[neighbor] == 0: + if self.in_degree_map[neighbor] == 0 and neighbor not in visited: queue.append(neighbor) + + # if > 0 it might mean not all predecessors have added to the queue + # so we should process the neighbors predecessors + elif self.in_degree_map[neighbor] > 0: + for predecessor in self.predecessor_map[neighbor]: + if predecessor not in queue and predecessor not in visited: + queue.append(predecessor) + current_layer += 1 # Next layer new_layers = self.refine_layers(layers) return new_layers @@ -851,9 +876,15 @@ class Graph: self.mark_all_vertices("ACTIVE") if component_id: vertices = self.sort_up_to_vertex(component_id) + vertices_layers = self.layered_topological_sort(vertices) else: vertices = self.vertices - vertices_layers = self.layered_topological_sort(vertices) + # without component_id we are probably running in the chat + # so we want to pick only graphs that start with ChatInput or + # TextInput + vertices_layers = self.layered_topological_sort( + vertices, filter_graphs=True + ) vertices_layers = self.sort_by_avg_build_time(vertices_layers) vertices_layers = self.sort_chat_inputs_first(vertices_layers) self.increment_run_count() diff --git a/src/backend/langflow/graph/graph/state_manager.py b/src/backend/langflow/graph/graph/state_manager.py index 3fcbb68a3..ed5844d87 100644 --- a/src/backend/langflow/graph/graph/state_manager.py +++ b/src/backend/langflow/graph/graph/state_manager.py @@ -11,23 +11,29 @@ class GraphStateManager: self.observers = defaultdict(list) self.lock = Lock() - def append_state(self, key, new_state): + def append_state(self, key, new_state, run_id: str): with self.lock: - if key not in self.states: - self.states[key] = [] + if run_id not in self.states: + self.states[run_id] = {} + if key not in self.states[run_id]: + self.states[run_id][key] = [] elif not isinstance(self.states[key], list): - self.states[key] = [self.states[key]] - self.states[key].append(new_state) + self.states[run_id][key] = [self.states[key]] + self.states[run_id][key].append(new_state) self.notify_append_observers(key, new_state) - def update_state(self, key, new_state): + def update_state(self, key, new_state, run_id: str): with self.lock: - self.states[key] = new_state + if run_id not in self.states: + self.states[run_id] = {} + if key not in self.states[run_id]: + self.states[run_id][key] = {} + self.states[run_id][key] = new_state self.notify_observers(key, new_state) - def get_state(self, key): + def get_state(self, key, run_id: str): with self.lock: - return self.states.get(key, "") + return self.states.get(run_id, {}).get(key, "") def subscribe(self, key, observer: Callable): with self.lock: diff --git a/src/backend/langflow/interface/custom/custom_component/custom_component.py b/src/backend/langflow/interface/custom/custom_component/custom_component.py index 11c4c941d..1596533a2 100644 --- a/src/backend/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component/custom_component.py @@ -92,7 +92,7 @@ class CustomComponent(Component): def get_state(self, name: str): try: - return self.vertex.graph.state_manager.get_state(key=name) + return self.vertex.graph.get_state(name=name) except Exception as e: raise ValueError(f"Error getting state: {e}")