From 563356515637633cc90ae10e6435a766fd7ee73d Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 4 Mar 2024 15:19:22 -0300 Subject: [PATCH] Refactor graph processing and error handling --- src/backend/langflow/graph/graph/base.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 43d32c176..daff60b6d 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -111,7 +111,12 @@ class Graph: vertex = self.get_vertex(vertex_id) if vertex is None: raise ValueError(f"Vertex {vertex_id} not found") - if not stream and hasattr(vertex, "consume_async_generator"): + + if ( + not vertex.result + and not stream + and hasattr(vertex, "consume_async_generator") + ): await vertex.consume_async_generator() outputs.append(vertex.result) return outputs @@ -472,15 +477,19 @@ class Graph: async def process(self) -> "Graph": """Processes the graph with vertices in each layer run in parallel.""" vertices_layers = self.sorted_vertices_layers - + vertex_task_run_count = {} for layer_index, layer in enumerate(vertices_layers): tasks = [] for vertex_id in layer: vertex = self.get_vertex(vertex_id) task = asyncio.create_task( - vertex.build(), name=f"layer-{layer_index}-vertex-{vertex_id}" + vertex.build(), + name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}", ) tasks.append(task) + vertex_task_run_count[vertex_id] = ( + vertex_task_run_count.get(vertex_id, 0) + 1 + ) logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks") await self._execute_tasks(tasks) logger.debug("Graph processing complete") @@ -499,6 +508,10 @@ class Graph: # coroutine has not attribute get_name task_name = tasks[i].get_name() logger.error(f"Task {task_name} failed with exception: {e}") + # Cancel all remaining tasks + for t in tasks[i:]: + t.cancel() + raise e return results def topological_sort(self) -> List[Vertex]: @@ -815,6 +828,7 @@ class Graph: vertices_layers = self.sort_by_avg_build_time(vertices_layers) # vertices_layers = self.sort_chat_inputs_first(vertices_layers) self.increment_run_count() + self._sorted_vertices_layers = vertices_layers first_layer = vertices_layers[0] # save the only the rest self.vertices_layers = vertices_layers[1:]