diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index c8df5cf4d..258277bca 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -73,6 +73,7 @@ async def get_vertices( # We need to get the id of each vertex # and return the same structure but only with the ids run_id = uuid.uuid4() + graph.set_run_id(run_id) return VerticesOrderResponse(ids=vertices, run_id=run_id) except Exception as exc: @@ -98,8 +99,12 @@ async def build_vertex( cache = chat_service.get_cache(flow_id) if not cache: # If there's no cache - logger.warning(f"No cache found for {flow_id}. Building graph starting at {vertex_id}") - graph = build_and_cache_graph(flow_id=flow_id, session=next(get_session()), chat_service=chat_service) + logger.warning( + f"No cache found for {flow_id}. Building graph starting at {vertex_id}" + ) + graph = build_and_cache_graph( + flow_id=flow_id, session=next(get_session()), chat_service=chat_service + ) else: graph = cache.get("result") result_data_response = ResultDataResponse(results={}) @@ -154,13 +159,14 @@ async def build_vertex( graph.reset_inactive_vertices() chat_service.set_cache(flow_id, graph) - return VertexBuildResponse( + build_response = VertexBuildResponse( inactive_vertices=inactive_vertices, valid=valid, params=params, id=vertex.id, data=result_data_response, ) + return build_response except Exception as exc: logger.error(f"Error building vertex: {exc}") logger.exception(exc) @@ -191,7 +197,9 @@ async def build_vertex_stream( else: graph = cache.get("result") else: - session_data = await session_service.load_session(session_id, flow_id=flow_id) + session_data = await session_service.load_session( + session_id, flow_id=flow_id + ) graph, artifacts = session_data if session_data else (None, None) if not graph: raise ValueError(f"No graph found for {flow_id}.") diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 0e1da7e30..966111d5f 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -7,6 +7,7 @@ from loguru import logger from langflow.graph.edge.base import ContractEdge from langflow.graph.graph.constants import lazy_load_vertex_dict +from langflow.graph.graph.state_manager import GraphStateManager from langflow.graph.graph.utils import process_flow from langflow.graph.schema import INPUT_FIELD_NAME, InterfaceComponentTypes from langflow.graph.vertex.base import Vertex @@ -43,6 +44,7 @@ class Graph: self._is_output_vertices: List[str] = [] self._has_session_id_vertices: List[str] = [] self._sorted_vertices_layers: List[List[str]] = [] + self.run_id = None self.top_level_vertices = [] for vertex in self._vertices: @@ -58,9 +60,18 @@ class Graph: self._build_graph() self.build_graph_maps() self.define_vertices_lists() + self.state_manager = GraphStateManager() + + def set_run_id(self, run_id: str): + for vertex in self.vertices: + self.state_manager.subscribe(run_id, vertex.update_graph_state) + self.run_id = run_id + + def add_state(self, state: str): + self.state_manager.append_state(self.run_id, state) @property - def sorted_vertices_layers(self): + def sorted_vertices_layers(self) -> List[List[str]]: if not self._sorted_vertices_layers: self.sort_vertices() return self._sorted_vertices_layers @@ -75,7 +86,9 @@ class Graph: if getattr(vertex, attribute): getattr(self, f"_{attribute}_vertices").append(vertex.id) - async def _run(self, inputs: Dict[str, str], stream: bool) -> List[Optional["ResultData"]]: + async def _run( + self, inputs: Dict[str, str], stream: bool + ) -> List[Optional["ResultData"]]: """Runs the graph with the given inputs.""" for vertex_id in self._is_input_vertices: vertex = self.get_vertex(vertex_id) @@ -98,7 +111,9 @@ class Graph: outputs.append(vertex.result) return outputs - async def run(self, inputs: Dict[str, Union[str, list[str]]], stream: bool) -> List[Optional["ResultData"]]: + async def run( + self, inputs: Dict[str, Union[str, list[str]]], stream: bool + ) -> List[Optional["ResultData"]]: """Runs the graph with the given inputs.""" # inputs is {"message": "Hello, world!"} @@ -110,7 +125,9 @@ class Graph: if not isinstance(inputs_values, list): inputs_values = [inputs_values] for input_value in inputs_values: - run_outputs = await self._run({INPUT_FIELD_NAME: input_value}, stream=stream) + run_outputs = await self._run( + {INPUT_FIELD_NAME: input_value}, stream=stream + ) logger.debug(f"Run outputs: {run_outputs}") outputs.extend(run_outputs) return outputs @@ -150,7 +167,9 @@ class Graph: def build_parent_child_map(self): parent_child_map = defaultdict(list) for vertex in self.vertices: - parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)] + parent_child_map[vertex.id] = [ + child.id for child in self.get_successors(vertex) + ] return parent_child_map def increment_run_count(self): @@ -325,7 +344,11 @@ class Graph: return self.vertices.remove(vertex) self.vertex_map.pop(vertex_id) - self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id] + self.edges = [ + edge + for edge in self.edges + if edge.source_id != vertex_id and edge.target_id != vertex_id + ] def _build_vertex_params(self) -> None: """Identifies and handles the LLM vertex within the graph.""" @@ -346,7 +369,9 @@ class Graph: return for vertex in self.vertices: if not self._validate_vertex(vertex): - raise ValueError(f"{vertex.display_name} is not connected to any other components") + raise ValueError( + f"{vertex.display_name} is not connected to any other components" + ) def _validate_vertex(self, vertex: Vertex) -> bool: """Validates a vertex.""" @@ -403,7 +428,9 @@ class Graph: tasks = [] for vertex_id in layer: vertex = self.get_vertex(vertex_id) - task = asyncio.create_task(vertex.build(), name=f"layer-{layer_index}-vertex-{vertex_id}") + task = asyncio.create_task( + vertex.build(), name=f"layer-{layer_index}-vertex-{vertex_id}" + ) tasks.append(task) logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks") await self._execute_tasks(tasks) @@ -442,7 +469,9 @@ class Graph: def dfs(vertex): if state[vertex] == 1: # We have a cycle - raise ValueError("Graph contains a cycle, cannot perform topological sort") + raise ValueError( + "Graph contains a cycle, cannot perform topological sort" + ) if state[vertex] == 0: state[vertex] = 1 for edge in vertex.edges: @@ -466,11 +495,17 @@ class Graph: def get_predecessors(self, vertex): """Returns the predecessors of a vertex.""" - return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])] + return [ + self.get_vertex(source_id) + for source_id in self.predecessor_map.get(vertex.id, []) + ] def get_successors(self, vertex): """Returns the successors of a vertex.""" - return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])] + return [ + self.get_vertex(target_id) + for target_id in self.successor_map.get(vertex.id, []) + ] def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]: """Returns the neighbors of a vertex.""" @@ -509,7 +544,9 @@ class Graph: edges.append(ContractEdge(source, target, edge)) return edges - def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]: + def _get_vertex_class( + self, node_type: str, node_base_type: str, node_id: str + ) -> Type[Vertex]: """Returns the node class based on the node type.""" # First we check for the node_base_type node_name = node_id.split("-")[0] @@ -540,14 +577,18 @@ class Graph: vertex_type: str = vertex_data["type"] # type: ignore vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore - VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"]) + VertexClass = self._get_vertex_class( + vertex_type, vertex_base_type, vertex_data["id"] + ) vertex_instance = VertexClass(vertex, graph=self) vertex_instance.set_top_level(self.top_level_vertices) vertices.append(vertex_instance) return vertices - def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]: + def get_children_by_vertex_type( + self, vertex: Vertex, vertex_type: str + ) -> List[Vertex]: """Returns the children of a vertex based on the vertex type.""" children = [] vertex_types = [vertex.data["type"]] @@ -559,7 +600,9 @@ class Graph: def __repr__(self): vertex_ids = [vertex.id for vertex in self.vertices] - edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]) + edges_repr = "\n".join( + [f"{edge.source_id} --> {edge.target_id}" for edge in self.edges] + ) return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}" def sort_up_to_vertex(self, vertex_id: str) -> List[Vertex]: @@ -590,7 +633,9 @@ class Graph: """Performs a layered topological sort of the vertices in the graph.""" # Queue for vertices with no incoming edges - queue = deque(vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0) + queue = deque( + vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0 + ) layers: List[List[str]] = [] current_layer = 0 @@ -646,7 +691,9 @@ class Graph: return refined_layers - def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: + def sort_chat_inputs_first( + self, vertices_layers: List[List[str]] + ) -> List[List[str]]: chat_inputs_first = [] for layer in vertices_layers: for vertex_id in layer: @@ -675,11 +722,15 @@ class Graph: self._sorted_vertices_layers = vertices_layers return vertices_layers - def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: + def sort_interface_components_first( + self, vertices_layers: List[List[str]] + ) -> List[List[str]]: """Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first.""" def contains_interface_component(vertex): - return any(component.value in vertex for component in InterfaceComponentTypes) + return any( + component.value in vertex for component in InterfaceComponentTypes + ) # Sort each inner list so that vertices containing ChatInput or ChatOutput come first sorted_vertices = [ @@ -691,16 +742,22 @@ class Graph: ] return sorted_vertices - def sort_by_avg_build_time(self, vertices_layers: List[List[str]]) -> List[List[str]]: + def sort_by_avg_build_time( + self, vertices_layers: List[List[str]] + ) -> List[List[str]]: """Sorts the vertices in the graph so that vertices with the lowest average build time come first.""" def sort_layer_by_avg_build_time(vertices_ids: List[str]) -> List[str]: """Sorts the vertices in the graph so that vertices with the lowest average build time come first.""" if len(vertices_ids) == 1: return vertices_ids - vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time) + vertices_ids.sort( + key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time + ) return vertices_ids - sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers] + sorted_vertices = [ + sort_layer_by_avg_build_time(layer) for layer in vertices_layers + ] return sorted_vertices diff --git a/src/backend/langflow/graph/graph/state_manager.py b/src/backend/langflow/graph/graph/state_manager.py new file mode 100644 index 000000000..64476011b --- /dev/null +++ b/src/backend/langflow/graph/graph/state_manager.py @@ -0,0 +1,39 @@ +from collections import defaultdict +from threading import Lock +from typing import Callable + + +class GraphStateManager: + def __init__(self): + self.states = {} + self.observers = defaultdict(list) + self.lock = Lock() + + def append_state(self, key, new_state): + with self.lock: + if key not in self.states: + self.states[key] = [] + self.states[key].append(new_state) + self.notify_append_observers(key, new_state) + + def update_state(self, key, new_state): + with self.lock: + self.states[key] = new_state + self.notify_observers(key, new_state) + + def get_state(self, key): + with self.lock: + return self.states.get(key, None) + + def subscribe(self, key, observer: Callable): + with self.lock: + if observer not in self.observers[key]: + self.observers[key].append(observer) + + def notify_observers(self, key, new_state): + for callback in self.observers[key]: + callback(key, new_state, append=False) + + def notify_append_observers(self, key, new_state): + for callback in self.observers[key]: + callback(key, new_state, append=True) diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 26230c648..6425f5e08 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -47,8 +47,13 @@ class Vertex: self.will_stream = False self.updated_raw_params = False self.id: str = data["id"] - self.is_input = any(input_component_name in self.id for input_component_name in INPUT_COMPONENTS) - self.is_output = any(output_component_name in self.id for output_component_name in OUTPUT_COMPONENTS) + self.is_input = any( + input_component_name in self.id for input_component_name in INPUT_COMPONENTS + ) + self.is_output = any( + output_component_name in self.id + for output_component_name in OUTPUT_COMPONENTS + ) self.has_session_id = None self._custom_component = None self.has_external_input = False @@ -79,14 +84,30 @@ class Vertex: self.use_result = False self.build_times: List[float] = [] self.state = VertexStates.ACTIVE + self.graph_state = {} + + def update_graph_state(self, key, new_state, append: bool): + if append: + if key in self.graph_state: + self.graph_state[key].append(new_state) + else: + self.graph_state[key] = [new_state] + else: + self.graph_state[key] = new_state def set_state(self, state: str): self.state = VertexStates[state] - if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] < 2: + if ( + self.state == VertexStates.INACTIVE + and self.graph.in_degree_map[self.id] < 2 + ): # If the vertex is inactive and has only one in degree # it means that it is not a merge point in the graph self.graph.inactive_vertices.add(self.id) - elif self.state == VertexStates.ACTIVE and self.id in self.graph.inactive_vertices: + elif ( + self.state == VertexStates.ACTIVE + and self.id in self.graph.inactive_vertices + ): self.graph.inactive_vertices.remove(self.id) @property @@ -103,7 +124,9 @@ class Vertex: # If the Vertex.type is a power component # then we need to return the built object # instead of the result dict - if self.is_interface_component and not isinstance(self._built_object, UnbuiltObject): + if self.is_interface_component and not isinstance( + self._built_object, UnbuiltObject + ): result = self._built_object # if it is not a dict or a string and hasattr model_dump then # return the model_dump @@ -113,7 +136,11 @@ class Vertex: if isinstance(self._built_result, UnbuiltResult): return {} - return self._built_result if isinstance(self._built_result, dict) else {"result": self._built_result} + return ( + self._built_result + if isinstance(self._built_result, dict) + else {"result": self._built_result} + ) def set_artifacts(self) -> None: pass @@ -179,19 +206,31 @@ class Vertex: self.selected_output_type = self.data["node"].get("selected_output_type") self.is_input = self.data["node"].get("is_input") or self.is_input self.is_output = self.data["node"].get("is_output") or self.is_output - template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} + template_dicts = { + key: value + for key, value in self.data["node"]["template"].items() + if isinstance(value, dict) + } self.has_session_id = "session_id" in template_dicts self.required_inputs = [ - template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"] + template_dicts[key]["type"] + for key, value in template_dicts.items() + if value["required"] ] self.optional_inputs = [ - template_dicts[key]["type"] for key, value in template_dicts.items() if not value["required"] + template_dicts[key]["type"] + for key, value in template_dicts.items() + if not value["required"] ] # Add the template_dicts[key]["input_types"] to the optional_inputs self.optional_inputs.extend( - [input_type for value in template_dicts.values() for input_type in value.get("input_types", [])] + [ + input_type + for value in template_dicts.values() + for input_type in value.get("input_types", []) + ] ) template_dict = self.data["node"]["template"] @@ -238,7 +277,11 @@ class Vertex: self.updated_raw_params = False return - template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} + template_dict = { + key: value + for key, value in self.data["node"]["template"].items() + if isinstance(value, dict) + } params = {} for edge in self.edges: @@ -289,7 +332,11 @@ class Vertex: # list of dicts, so we need to convert it to a dict # before passing it to the build method if isinstance(val, list): - params[key] = {k: v for item in value.get("value", []) for k, v in item.items()} + params[key] = { + k: v + for item in value.get("value", []) + for k, v in item.items() + } elif isinstance(val, dict): params[key] = val elif value.get("type") == "int" and val is not None: @@ -382,7 +429,9 @@ class Vertex: if isinstance(self._built_object, str): self._built_result = self._built_object - result = await generate_result(self._built_object, inputs, self.has_external_output, session_id) + result = await generate_result( + self._built_object, inputs, self.has_external_output, session_id + ) self._built_result = result async def _build_each_node_in_params_dict(self, user_id=None): @@ -412,7 +461,9 @@ class Vertex: """ return all(self._is_node(node) for node in value) - async def get_result(self, requester: Optional["Vertex"] = None, user_id=None, timeout=None) -> Any: + async def get_result( + self, requester: Optional["Vertex"] = None, user_id=None, timeout=None + ) -> Any: # PLEASE REVIEW THIS IF STATEMENT # Check if the Vertex was built already if self._built: @@ -446,7 +497,9 @@ class Vertex: self._extend_params_list_with_result(key, result) self.params[key] = result - async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None): + async def _build_list_of_nodes_and_update_params( + self, key, nodes: List["Vertex"], user_id=None + ): """ Iterates over a list of nodes, builds each and updates the params dictionary. """ @@ -500,7 +553,9 @@ class Vertex: except Exception as exc: logger.exception(exc) - raise ValueError(f"Error building node {self.display_name}: {str(exc)}") from exc + raise ValueError( + f"Error building node {self.display_name}: {str(exc)}" + ) from exc def _update_built_object_and_artifacts(self, result): """ @@ -580,16 +635,24 @@ class Vertex: return self._built_object # Get the requester edge - requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None) + requester_edge = next( + (edge for edge in self.edges if edge.target_id == requester.id), None + ) # Return the result of the requester edge - return None if requester_edge is None else await requester_edge.get_result(source=self, target=requester) + return ( + None + if requester_edge is None + else await requester_edge.get_result(source=self, target=requester) + ) def add_edge(self, edge: "ContractEdge") -> None: if edge not in self.edges: self.edges.append(edge) def __repr__(self) -> str: - return f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})" + return ( + f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})" + ) def __eq__(self, __o: object) -> bool: try: @@ -602,7 +665,11 @@ class Vertex: def _built_object_repr(self): # Add a message with an emoji, stars for sucess, - return "Built sucessfully ✨" if self._built_object is not None else "Failed to build 😵‍💫" + return ( + "Built sucessfully ✨" + if self._built_object is not None + else "Failed to build 😵‍💫" + ) class StatefulVertex(Vertex):