feat: Add cache parameter to RunnableVerticesManager (#2181)

This commit adds a new optional `cache` parameter to the `RunnableVerticesManager` class in the `runnable_vertices_manager.py` file. The `cache` parameter allows controlling whether the graph data should be cached or not when retrieving the next runnable vertices for a given vertex. If `cache` is set to `True`, the graph data will be cached using the `set_cache_coro` function. If `cache` is set to `False`, the graph data will not be cached.

This change was made to provide more flexibility in managing the caching behavior of the graph data.

Ref: #2180
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-15 07:04:00 -07:00 committed by GitHub
commit 81bb749ea8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 41 additions and 25 deletions

View file

@ -5,8 +5,6 @@ from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Tuple, Type, Union
from loguru import logger
from langflow.graph.edge.base import ContractEdge
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
@ -21,6 +19,7 @@ from langflow.services.cache.utils import CacheMiss
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_chat_service
from langflow.services.monitor.utils import log_transaction
from loguru import logger
if TYPE_CHECKING:
from langflow.graph.schema import ResultData
@ -713,6 +712,7 @@ class Graph:
files: Optional[list[str]] = None,
user_id: Optional[str] = None,
fallback_to_env_vars: bool = False,
cache: bool = True,
):
"""
Builds a vertex in the graph.
@ -768,7 +768,7 @@ class Graph:
raise ValueError(f"No result found for vertex {vertex_id}")
set_cache_coro = partial(chat_service.set_cache, key=self.flow_id)
next_runnable_vertices, top_level_vertices = await self.get_next_and_top_level_vertices(
lock, set_cache_coro, vertex
lock, set_cache_coro, vertex, cache=cache
)
flow_id = self.flow_id
log_transaction(flow_id, vertex, status="success")
@ -780,7 +780,11 @@ class Graph:
raise exc
async def get_next_and_top_level_vertices(
self, lock: asyncio.Lock, set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine], vertex: Vertex
self,
lock: asyncio.Lock,
set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine],
vertex: Vertex,
cache: bool = True,
):
"""
Retrieves the next runnable vertices and the top level vertices for a given vertex.
@ -793,7 +797,9 @@ class Graph:
Returns:
Tuple[List[Vertex], List[Vertex]]: A tuple containing the next runnable vertices and the top level vertices.
"""
next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(lock, set_cache_coro, self, vertex)
next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(
lock, set_cache_coro, self, vertex, cache=cache
)
top_level_vertices = self.run_manager.get_top_level_vertices(self, next_runnable_vertices)
return next_runnable_vertices, top_level_vertices
@ -834,13 +840,13 @@ class Graph:
chat_service = get_chat_service()
run_id = uuid.uuid4()
self.set_run_id(run_id)
lock = chat_service._cache_locks[self.run_id]
while to_process:
current_batch = list(to_process) # Copy current deque items to a list
to_process.clear() # Clear the deque for new items
tasks = []
for vertex_id in current_batch:
vertex = self.get_vertex(vertex_id)
lock = chat_service._cache_locks[self.run_id]
task = asyncio.create_task(
self.build_vertex(
lock=lock,
@ -849,6 +855,7 @@ class Graph:
user_id=self.user_id,
inputs_dict={},
fallback_to_env_vars=fallback_to_env_vars,
cache=False,
),
name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}",
)
@ -856,8 +863,15 @@ class Graph:
vertex_task_run_count[vertex_id] = vertex_task_run_count.get(vertex_id, 0) + 1
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
next_runnable_vertices = await self._execute_tasks(tasks)
try:
next_runnable_vertices = await self._execute_tasks(tasks)
except Exception as e:
logger.error(f"Error executing tasks in layer {layer_index}: {e}")
break
if not next_runnable_vertices:
break
to_process.extend(next_runnable_vertices)
layer_index += 1
logger.debug("Graph processing complete")
return self
@ -865,25 +879,23 @@ class Graph:
async def _execute_tasks(self, tasks: List[asyncio.Task]) -> List[str]:
"""Executes tasks in parallel, handling exceptions for each task."""
results = []
for i, task in enumerate(asyncio.as_completed(tasks)):
try:
result = await task
if isinstance(result, tuple) and len(result) == 7:
# Get the next runnable vertices
next_runnable_vertices = result[0]
results.extend(next_runnable_vertices)
else:
raise ValueError(f"Invalid result: {result}")
except Exception as e:
# Log the exception along with the task name for easier debugging
# task_name = task.get_name()
# coroutine has not attribute get_name
task_name = tasks[i].get_name()
logger.error(f"Task {task_name} failed with exception: {e}")
completed_tasks = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(completed_tasks):
task_name = tasks[i].get_name()
if isinstance(result, Exception):
logger.error(f"Task {task_name} failed with exception: {result}")
# Cancel all remaining tasks
for t in tasks[i:]:
for t in tasks[i + 1 :]:
t.cancel()
raise e
raise result
elif isinstance(result, tuple) and len(result) == 7:
# Get the next runnable vertices
next_runnable_vertices = result[0]
results.extend(next_runnable_vertices)
else:
raise ValueError(f"Invalid result from task {task_name}: {result}")
return results
def topological_sort(self) -> List[Vertex]:
@ -1360,3 +1372,5 @@ class Graph:
predecessor_map[edge.target_id].append(edge.source_id)
successor_map[edge.source_id].append(edge.target_id)
return predecessor_map, successor_map
return predecessor_map, successor_map
return predecessor_map, successor_map

View file

@ -59,6 +59,7 @@ class RunnableVerticesManager:
set_cache_coro: Callable[["Graph", asyncio.Lock], Awaitable[None]],
graph: "Graph",
vertex: "Vertex",
cache: bool = True,
):
"""
Retrieves the next runnable vertices in the graph for a given vertex.
@ -86,7 +87,8 @@ class RunnableVerticesManager:
for v_id in set(next_runnable_vertices): # Use set to avoid duplicates
self.update_vertex_run_state(v_id, is_runnable=False)
self.remove_from_predecessors(v_id)
await set_cache_coro(data=graph, lock=lock) # type: ignore
if cache:
await set_cache_coro(data=graph, lock=lock) # type: ignore
return next_runnable_vertices
@staticmethod