diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index d82ccfeb8..90a6e4f9a 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -3,6 +3,8 @@ from collections import defaultdict, deque from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union from langchain.chains.base import Chain +from loguru import logger + from langflow.graph.edge.base import ContractEdge from langflow.graph.graph.constants import lazy_load_vertex_dict from langflow.graph.graph.state_manager import GraphStateManager @@ -20,7 +22,6 @@ from langflow.graph.vertex.types import ( from langflow.interface.tools.constants import FILE_TOOLS from langflow.schema import Record from langflow.utils import payload -from loguru import logger if TYPE_CHECKING: from langflow.graph.schema import ResultData @@ -43,6 +44,7 @@ class Graph: self.flow_id = flow_id self._is_input_vertices: List[str] = [] self._is_output_vertices: List[str] = [] + self._is_state_vertices: List[str] = [] self._has_session_id_vertices: List[str] = [] self._sorted_vertices_layers: List[List[str]] = [] self.run_id = None @@ -73,16 +75,23 @@ class Graph: # all StateVertex in self.vertices that are not the caller # essentially notifying all the other vertices that the state has changed # This also has to activate their successors - caller_vertex = self.get_vertex(caller) - for vertex in self.vertices: - if vertex.id != caller and isinstance(vertex, StateVertex): - successors = self.get_all_successors(vertex) - self.activated_vertices.add(vertex.id) - for successor in successors: - self.activated_vertices.add(successor.id) + self.activate_state_vertices(name, caller) self.state_manager.update_state(name, record) + def activate_state_vertices(self, name: str, caller: str): + for vertex_id in self._is_state_vertices: + vertex = self.get_vertex(vertex_id) + if ( + name in vertex._raw_params["name"] + and vertex_id != caller + and isinstance(vertex, StateVertex) + ): + successors = self.get_all_successors(vertex) + self.activated_vertices.add(vertex_id) + for successor in successors: + self.activated_vertices.add(successor.id) + def reset_activated_vertices(self): self.activated_vertices = set() @@ -91,7 +100,8 @@ class Graph: ) -> None: """Appends the state of the graph.""" if caller: - self.state_manager.subscribe(name, caller) + + self.activate_state_vertices(name, caller) self.state_manager.append_state(name, record) @@ -113,7 +123,7 @@ class Graph: """ Defines the lists of vertices that are inputs, outputs, and have session_id. """ - attributes = ["is_input", "is_output", "has_session_id"] + attributes = ["is_input", "is_output", "has_session_id", "is_state"] for vertex in self.vertices: for attribute in attributes: if getattr(vertex, attribute):