From 0d75e2905f6e5048378647669127bfefd450f499 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 15 Apr 2024 08:59:27 -0300 Subject: [PATCH] Refactor process method to align it with endpoint logic (#1700) * Refactor Graph class to improve parallel processing in base.py * Fix type hint for run_id parameter in set_run_id method --- src/backend/base/langflow/graph/graph/base.py | 44 +++++++++++++++---- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index cccb1739e..ec526e414 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -1,5 +1,7 @@ import asyncio +import uuid from collections import defaultdict, deque +from functools import partial from itertools import chain from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Type, Union @@ -16,6 +18,7 @@ from langflow.graph.vertex.types import ChatVertex, FileToolVertex, LLMVertex, R from langflow.interface.tools.constants import FILE_TOOLS from langflow.schema import Record from langflow.schema.schema import INPUT_FIELD_NAME, InputType +from langflow.services.deps import get_chat_service if TYPE_CHECKING: from langflow.graph.schema import ResultData @@ -164,13 +167,14 @@ class Graph: raise ValueError("Run ID not set") return self._run_id - def set_run_id(self, run_id: str): + def set_run_id(self, run_id: str | uuid.UUID): """ Sets the ID of the current run. Args: run_id (str): The run ID. """ + run_id = str(run_id) for vertex in self.vertices: self.state_manager.subscribe(run_id, vertex.update_graph_state) self._run_id = run_id @@ -748,31 +752,53 @@ class Graph: async def process(self, start_component_id: Optional[str] = None) -> "Graph": """Processes the graph with vertices in each layer run in parallel.""" - self.sort_vertices(start_component_id=start_component_id) - vertices_layers = self.sorted_vertices_layers + first_layer = self.sort_vertices(start_component_id=start_component_id) vertex_task_run_count: Dict[str, int] = {} - for layer_index, layer in enumerate(vertices_layers): + to_process = deque(first_layer) + layer_index = 0 + chat_service = get_chat_service() + run_id = uuid.uuid4() + self.set_run_id(run_id) + while to_process: + current_batch = list(to_process) # Copy current deque items to a list + to_process.clear() # Clear the deque for new items tasks = [] - for vertex_id in layer: + for vertex_id in current_batch: vertex = self.get_vertex(vertex_id) + lock = chat_service._cache_locks[self.run_id] + set_cache_coro = partial(chat_service.set_cache, flow_id=self.run_id) task = asyncio.create_task( - vertex.build(), + self.build_vertex( + lock=lock, + set_cache_coro=set_cache_coro, + vertex_id=vertex_id, + user_id=None, + inputs_dict={}, + ), name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}", ) tasks.append(task) vertex_task_run_count[vertex_id] = vertex_task_run_count.get(vertex_id, 0) + 1 + logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks") - await self._execute_tasks(tasks) + next_runnable_vertices = await self._execute_tasks(tasks) + to_process.extend(next_runnable_vertices) + logger.debug("Graph processing complete") return self - async def _execute_tasks(self, tasks): + async def _execute_tasks(self, tasks: List[asyncio.Task]) -> List[str]: """Executes tasks in parallel, handling exceptions for each task.""" results = [] for i, task in enumerate(asyncio.as_completed(tasks)): try: result = await task - results.append(result) + if isinstance(result, tuple) and len(result) == 7: + # Get the next runnable vertices + next_runnable_vertices = result[0] + results.extend(next_runnable_vertices) + else: + raise ValueError(f"Invalid result: {result}") except Exception as e: # Log the exception along with the task name for easier debugging # task_name = task.get_name()