From 81bb749ea8458b0f458150ac673ea1a44c6eff55 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sat, 15 Jun 2024 07:04:00 -0700 Subject: [PATCH] feat: Add cache parameter to RunnableVerticesManager (#2181) This commit adds a new optional `cache` parameter to the `RunnableVerticesManager` class in the `runnable_vertices_manager.py` file. The `cache` parameter allows controlling whether the graph data should be cached or not when retrieving the next runnable vertices for a given vertex. If `cache` is set to `True`, the graph data will be cached using the `set_cache_coro` function. If `cache` is set to `False`, the graph data will not be cached. This change was made to provide more flexibility in managing the caching behavior of the graph data. Ref: #2180 --- src/backend/base/langflow/graph/graph/base.py | 62 ++++++++++++------- .../graph/graph/runnable_vertices_manager.py | 4 +- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index 56b5e8bfe..b47b718b3 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -5,8 +5,6 @@ from functools import partial from itertools import chain from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Tuple, Type, Union -from loguru import logger - from langflow.graph.edge.base import ContractEdge from langflow.graph.graph.constants import lazy_load_vertex_dict from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager @@ -21,6 +19,7 @@ from langflow.services.cache.utils import CacheMiss from langflow.services.chat.service import ChatService from langflow.services.deps import get_chat_service from langflow.services.monitor.utils import log_transaction +from loguru import logger if TYPE_CHECKING: from langflow.graph.schema import ResultData @@ -713,6 +712,7 @@ class Graph: files: Optional[list[str]] = None, user_id: Optional[str] = None, fallback_to_env_vars: bool = False, + cache: bool = True, ): """ Builds a vertex in the graph. @@ -768,7 +768,7 @@ class Graph: raise ValueError(f"No result found for vertex {vertex_id}") set_cache_coro = partial(chat_service.set_cache, key=self.flow_id) next_runnable_vertices, top_level_vertices = await self.get_next_and_top_level_vertices( - lock, set_cache_coro, vertex + lock, set_cache_coro, vertex, cache=cache ) flow_id = self.flow_id log_transaction(flow_id, vertex, status="success") @@ -780,7 +780,11 @@ class Graph: raise exc async def get_next_and_top_level_vertices( - self, lock: asyncio.Lock, set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine], vertex: Vertex + self, + lock: asyncio.Lock, + set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine], + vertex: Vertex, + cache: bool = True, ): """ Retrieves the next runnable vertices and the top level vertices for a given vertex. @@ -793,7 +797,9 @@ class Graph: Returns: Tuple[List[Vertex], List[Vertex]]: A tuple containing the next runnable vertices and the top level vertices. """ - next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(lock, set_cache_coro, self, vertex) + next_runnable_vertices = await self.run_manager.get_next_runnable_vertices( + lock, set_cache_coro, self, vertex, cache=cache + ) top_level_vertices = self.run_manager.get_top_level_vertices(self, next_runnable_vertices) return next_runnable_vertices, top_level_vertices @@ -834,13 +840,13 @@ class Graph: chat_service = get_chat_service() run_id = uuid.uuid4() self.set_run_id(run_id) + lock = chat_service._cache_locks[self.run_id] while to_process: current_batch = list(to_process) # Copy current deque items to a list to_process.clear() # Clear the deque for new items tasks = [] for vertex_id in current_batch: vertex = self.get_vertex(vertex_id) - lock = chat_service._cache_locks[self.run_id] task = asyncio.create_task( self.build_vertex( lock=lock, @@ -849,6 +855,7 @@ class Graph: user_id=self.user_id, inputs_dict={}, fallback_to_env_vars=fallback_to_env_vars, + cache=False, ), name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}", ) @@ -856,8 +863,15 @@ class Graph: vertex_task_run_count[vertex_id] = vertex_task_run_count.get(vertex_id, 0) + 1 logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks") - next_runnable_vertices = await self._execute_tasks(tasks) + try: + next_runnable_vertices = await self._execute_tasks(tasks) + except Exception as e: + logger.error(f"Error executing tasks in layer {layer_index}: {e}") + break + if not next_runnable_vertices: + break to_process.extend(next_runnable_vertices) + layer_index += 1 logger.debug("Graph processing complete") return self @@ -865,25 +879,23 @@ class Graph: async def _execute_tasks(self, tasks: List[asyncio.Task]) -> List[str]: """Executes tasks in parallel, handling exceptions for each task.""" results = [] - for i, task in enumerate(asyncio.as_completed(tasks)): - try: - result = await task - if isinstance(result, tuple) and len(result) == 7: - # Get the next runnable vertices - next_runnable_vertices = result[0] - results.extend(next_runnable_vertices) - else: - raise ValueError(f"Invalid result: {result}") - except Exception as e: - # Log the exception along with the task name for easier debugging - # task_name = task.get_name() - # coroutine has not attribute get_name - task_name = tasks[i].get_name() - logger.error(f"Task {task_name} failed with exception: {e}") + completed_tasks = await asyncio.gather(*tasks, return_exceptions=True) + + for i, result in enumerate(completed_tasks): + task_name = tasks[i].get_name() + if isinstance(result, Exception): + logger.error(f"Task {task_name} failed with exception: {result}") # Cancel all remaining tasks - for t in tasks[i:]: + for t in tasks[i + 1 :]: t.cancel() - raise e + raise result + elif isinstance(result, tuple) and len(result) == 7: + # Get the next runnable vertices + next_runnable_vertices = result[0] + results.extend(next_runnable_vertices) + else: + raise ValueError(f"Invalid result from task {task_name}: {result}") + return results def topological_sort(self) -> List[Vertex]: @@ -1360,3 +1372,5 @@ class Graph: predecessor_map[edge.target_id].append(edge.source_id) successor_map[edge.source_id].append(edge.target_id) return predecessor_map, successor_map + return predecessor_map, successor_map + return predecessor_map, successor_map diff --git a/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py b/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py index 713aead65..9dd7c346c 100644 --- a/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py +++ b/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py @@ -59,6 +59,7 @@ class RunnableVerticesManager: set_cache_coro: Callable[["Graph", asyncio.Lock], Awaitable[None]], graph: "Graph", vertex: "Vertex", + cache: bool = True, ): """ Retrieves the next runnable vertices in the graph for a given vertex. @@ -86,7 +87,8 @@ class RunnableVerticesManager: for v_id in set(next_runnable_vertices): # Use set to avoid duplicates self.update_vertex_run_state(v_id, is_runnable=False) self.remove_from_predecessors(v_id) - await set_cache_coro(data=graph, lock=lock) # type: ignore + if cache: + await set_cache_coro(data=graph, lock=lock) # type: ignore return next_runnable_vertices @staticmethod