feat: Add cache parameter to RunnableVerticesManager (#2181)
This commit adds a new optional `cache` parameter to the `RunnableVerticesManager` class in the `runnable_vertices_manager.py` file. The `cache` parameter allows controlling whether the graph data should be cached or not when retrieving the next runnable vertices for a given vertex. If `cache` is set to `True`, the graph data will be cached using the `set_cache_coro` function. If `cache` is set to `False`, the graph data will not be cached. This change was made to provide more flexibility in managing the caching behavior of the graph data. Ref: #2180
This commit is contained in:
parent
c7a15de0a9
commit
81bb749ea8
2 changed files with 41 additions and 25 deletions
|
|
@ -5,8 +5,6 @@ from functools import partial
|
|||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Tuple, Type, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from langflow.graph.edge.base import ContractEdge
|
||||
from langflow.graph.graph.constants import lazy_load_vertex_dict
|
||||
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
|
||||
|
|
@ -21,6 +19,7 @@ from langflow.services.cache.utils import CacheMiss
|
|||
from langflow.services.chat.service import ChatService
|
||||
from langflow.services.deps import get_chat_service
|
||||
from langflow.services.monitor.utils import log_transaction
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.schema import ResultData
|
||||
|
|
@ -713,6 +712,7 @@ class Graph:
|
|||
files: Optional[list[str]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
fallback_to_env_vars: bool = False,
|
||||
cache: bool = True,
|
||||
):
|
||||
"""
|
||||
Builds a vertex in the graph.
|
||||
|
|
@ -768,7 +768,7 @@ class Graph:
|
|||
raise ValueError(f"No result found for vertex {vertex_id}")
|
||||
set_cache_coro = partial(chat_service.set_cache, key=self.flow_id)
|
||||
next_runnable_vertices, top_level_vertices = await self.get_next_and_top_level_vertices(
|
||||
lock, set_cache_coro, vertex
|
||||
lock, set_cache_coro, vertex, cache=cache
|
||||
)
|
||||
flow_id = self.flow_id
|
||||
log_transaction(flow_id, vertex, status="success")
|
||||
|
|
@ -780,7 +780,11 @@ class Graph:
|
|||
raise exc
|
||||
|
||||
async def get_next_and_top_level_vertices(
|
||||
self, lock: asyncio.Lock, set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine], vertex: Vertex
|
||||
self,
|
||||
lock: asyncio.Lock,
|
||||
set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine],
|
||||
vertex: Vertex,
|
||||
cache: bool = True,
|
||||
):
|
||||
"""
|
||||
Retrieves the next runnable vertices and the top level vertices for a given vertex.
|
||||
|
|
@ -793,7 +797,9 @@ class Graph:
|
|||
Returns:
|
||||
Tuple[List[Vertex], List[Vertex]]: A tuple containing the next runnable vertices and the top level vertices.
|
||||
"""
|
||||
next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(lock, set_cache_coro, self, vertex)
|
||||
next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(
|
||||
lock, set_cache_coro, self, vertex, cache=cache
|
||||
)
|
||||
top_level_vertices = self.run_manager.get_top_level_vertices(self, next_runnable_vertices)
|
||||
return next_runnable_vertices, top_level_vertices
|
||||
|
||||
|
|
@ -834,13 +840,13 @@ class Graph:
|
|||
chat_service = get_chat_service()
|
||||
run_id = uuid.uuid4()
|
||||
self.set_run_id(run_id)
|
||||
lock = chat_service._cache_locks[self.run_id]
|
||||
while to_process:
|
||||
current_batch = list(to_process) # Copy current deque items to a list
|
||||
to_process.clear() # Clear the deque for new items
|
||||
tasks = []
|
||||
for vertex_id in current_batch:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
lock = chat_service._cache_locks[self.run_id]
|
||||
task = asyncio.create_task(
|
||||
self.build_vertex(
|
||||
lock=lock,
|
||||
|
|
@ -849,6 +855,7 @@ class Graph:
|
|||
user_id=self.user_id,
|
||||
inputs_dict={},
|
||||
fallback_to_env_vars=fallback_to_env_vars,
|
||||
cache=False,
|
||||
),
|
||||
name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}",
|
||||
)
|
||||
|
|
@ -856,8 +863,15 @@ class Graph:
|
|||
vertex_task_run_count[vertex_id] = vertex_task_run_count.get(vertex_id, 0) + 1
|
||||
|
||||
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
|
||||
next_runnable_vertices = await self._execute_tasks(tasks)
|
||||
try:
|
||||
next_runnable_vertices = await self._execute_tasks(tasks)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tasks in layer {layer_index}: {e}")
|
||||
break
|
||||
if not next_runnable_vertices:
|
||||
break
|
||||
to_process.extend(next_runnable_vertices)
|
||||
layer_index += 1
|
||||
|
||||
logger.debug("Graph processing complete")
|
||||
return self
|
||||
|
|
@ -865,25 +879,23 @@ class Graph:
|
|||
async def _execute_tasks(self, tasks: List[asyncio.Task]) -> List[str]:
|
||||
"""Executes tasks in parallel, handling exceptions for each task."""
|
||||
results = []
|
||||
for i, task in enumerate(asyncio.as_completed(tasks)):
|
||||
try:
|
||||
result = await task
|
||||
if isinstance(result, tuple) and len(result) == 7:
|
||||
# Get the next runnable vertices
|
||||
next_runnable_vertices = result[0]
|
||||
results.extend(next_runnable_vertices)
|
||||
else:
|
||||
raise ValueError(f"Invalid result: {result}")
|
||||
except Exception as e:
|
||||
# Log the exception along with the task name for easier debugging
|
||||
# task_name = task.get_name()
|
||||
# coroutine has not attribute get_name
|
||||
task_name = tasks[i].get_name()
|
||||
logger.error(f"Task {task_name} failed with exception: {e}")
|
||||
completed_tasks = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for i, result in enumerate(completed_tasks):
|
||||
task_name = tasks[i].get_name()
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Task {task_name} failed with exception: {result}")
|
||||
# Cancel all remaining tasks
|
||||
for t in tasks[i:]:
|
||||
for t in tasks[i + 1 :]:
|
||||
t.cancel()
|
||||
raise e
|
||||
raise result
|
||||
elif isinstance(result, tuple) and len(result) == 7:
|
||||
# Get the next runnable vertices
|
||||
next_runnable_vertices = result[0]
|
||||
results.extend(next_runnable_vertices)
|
||||
else:
|
||||
raise ValueError(f"Invalid result from task {task_name}: {result}")
|
||||
|
||||
return results
|
||||
|
||||
def topological_sort(self) -> List[Vertex]:
|
||||
|
|
@ -1360,3 +1372,5 @@ class Graph:
|
|||
predecessor_map[edge.target_id].append(edge.source_id)
|
||||
successor_map[edge.source_id].append(edge.target_id)
|
||||
return predecessor_map, successor_map
|
||||
return predecessor_map, successor_map
|
||||
return predecessor_map, successor_map
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ class RunnableVerticesManager:
|
|||
set_cache_coro: Callable[["Graph", asyncio.Lock], Awaitable[None]],
|
||||
graph: "Graph",
|
||||
vertex: "Vertex",
|
||||
cache: bool = True,
|
||||
):
|
||||
"""
|
||||
Retrieves the next runnable vertices in the graph for a given vertex.
|
||||
|
|
@ -86,7 +87,8 @@ class RunnableVerticesManager:
|
|||
for v_id in set(next_runnable_vertices): # Use set to avoid duplicates
|
||||
self.update_vertex_run_state(v_id, is_runnable=False)
|
||||
self.remove_from_predecessors(v_id)
|
||||
await set_cache_coro(data=graph, lock=lock) # type: ignore
|
||||
if cache:
|
||||
await set_cache_coro(data=graph, lock=lock) # type: ignore
|
||||
return next_runnable_vertices
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue