diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 68e16bed8..341a4729c 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -1,3 +1,4 @@ +import asyncio from collections import defaultdict, deque from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union @@ -40,8 +41,10 @@ class Graph: self._runs = 0 self._updates = 0 self.flow_id = flow_id - self._inputs = [] - self._outputs = [] + self._is_input_vertices = [] + self._is_output_vertices = [] + self._has_session_id_vertices = [] + self._sorted_vertices_layers = [] self.top_level_vertices = [] for vertex in self._vertices: @@ -54,38 +57,37 @@ class Graph: self.inactive_vertices = set() self._build_graph() self.build_graph_maps() - self.define_inputs_and_outputs() + self.define_vertices_lists() - def define_inputs_and_outputs(self): + @property + def sorted_vertices_layers(self): + if not self._sorted_vertices_layers: + self.sort_vertices() + return self._sorted_vertices_layers + + def define_vertices_lists(self): """ - Defines the input and output vertices of the graph. + Defines the lists of vertices that are inputs, outputs, and have session_id. """ + attributes = ["is_input", "is_output", "has_session_id"] for vertex in self.vertices: - if vertex.is_input: - self._inputs.append(vertex.id) - if vertex.is_output: - self._outputs.append(vertex.id) + for attribute in attributes: + if getattr(vertex, attribute): + getattr(self, f"_{attribute}_vertices").append(vertex.id) - def run(self, inputs: Dict[str, str]) -> List["ResultData"]: + async def _run(self, inputs: Dict[str, str]) -> List["ResultData"]: """Runs the graph with the given inputs.""" - - # inputs is {"message": "Hello, world!"} - # we need to go through self.inputs and update the self._raw_params - # of the vertices that are inputs - for vertex_id in self.inputs: vertex = self.get_vertex(vertex_id) if vertex is None: raise ValueError(f"Vertex {vertex_id} not found") vertex.update_raw_params(inputs) try: - self.build() + await self.process() self.increment_run_count() except Exception as exc: logger.exception(exc) raise ValueError(f"Error running graph: {exc}") from exc - - # Now we get the outputs from the self.outputs outputs = [] for vertex_id in self.outputs: vertex = self.get_vertex(vertex_id) @@ -94,6 +96,23 @@ class Graph: outputs.append(vertex.result) return outputs + async def run(self, inputs: Dict[str, Union[str, list[str]]]) -> List["ResultData"]: + """Runs the graph with the given inputs.""" + + # inputs is {"message": "Hello, world!"} + # we need to go through self.inputs and update the self._raw_params + # of the vertices that are inputs + # if the value is a list, we need to run multiple times + outputs = [] + inputs_values = inputs.get("input_value") + if not isinstance(inputs_values, list): + inputs_values = [inputs_values] + for input_value in inputs_values: + run_outputs = await self._run({"input_value": input_value}) + logger.debug(f"Run outputs: {run_outputs}") + outputs.extend(run_outputs) + return outputs + @property def metadata(self): return { @@ -404,6 +423,36 @@ class Graph: raise ValueError("No root vertex found") return await root_vertex.build() + async def process(self) -> "Graph": + """Processes the graph with vertices in each layer run in parallel.""" + vertices_layers = self.sorted_vertices_layers + + for layer_index, layer in enumerate(vertices_layers): + tasks = [] + for vertex_id in layer: + vertex = self.get_vertex(vertex_id) + task = asyncio.create_task( + vertex.build(), name=f"layer-{layer_index}-vertex-{vertex_id}" + ) + tasks.append(task) + logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks") + await self._execute_tasks(tasks) + logger.debug("Graph processing complete") + return self + + async def _execute_tasks(self, tasks): + """Executes tasks in parallel, handling exceptions for each task.""" + results = [] + for task in asyncio.as_completed(tasks): + try: + result = await task + results.append(result) + except Exception as e: + # Log the exception along with the task name for easier debugging + task_name = task.get_name() + logger.error(f"Task {task_name} failed with exception: {e}") + return results + def topological_sort(self) -> List[Vertex]: """ Performs a topological sort of the vertices in the graph. @@ -671,6 +720,7 @@ class Graph: vertices_layers = self.sort_by_avg_build_time(vertices_layers) vertices_layers = self.sort_chat_inputs_first(vertices_layers) self.increment_run_count() + self._sorted_vertices_layers = vertices_layers return vertices_layers def sort_interface_components_first(