refactor: runnable_vertices_manager.py (#2646)

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
Ítalo Johnny 2024-07-15 14:57:00 -03:00 committed by GitHub
commit d93382e90a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 290 additions and 156 deletions

View file

@ -12,7 +12,6 @@ from langflow.services.store.schema import StoreComponentCreate
from langflow.services.store.utils import get_lf_version_from_pypi
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
from langflow.services.database.models.flow.model import Flow
@ -179,43 +178,6 @@ def format_exception_message(exc: Exception) -> str:
return str(exc)
async def get_next_runnable_vertices(
graph: Graph,
vertex: "Vertex",
vertex_id: str,
chat_service: ChatService,
flow_id: str,
):
"""
Retrieves the next runnable vertices in the graph for a given vertex.
Args:
graph (Graph): The graph object representing the flow.
vertex (Vertex): The current vertex.
vertex_id (str): The ID of the current vertex.
chat_service (ChatService): The chat service object.
flow_id (str): The ID of the flow.
Returns:
list: A list of IDs of the next runnable vertices.
"""
async with chat_service._async_cache_locks[flow_id] as lock:
graph.remove_from_predecessors(vertex_id)
direct_successors_ready = [v for v in vertex.successors_ids if graph.is_vertex_runnable(v)]
if not direct_successors_ready:
# No direct successors ready, look for runnable predecessors of successors
next_runnable_vertices = graph.find_runnable_predecessors_for_successors(vertex_id)
else:
next_runnable_vertices = direct_successors_ready
for v_id in set(next_runnable_vertices): # Use set to avoid duplicates
graph.vertices_to_run.remove(v_id)
graph.remove_from_predecessors(v_id)
await chat_service.set_cache(key=flow_id, data=graph, lock=lock)
return next_runnable_vertices
def get_top_level_vertices(graph, vertices_ids):
"""
Retrieves the top-level vertices from the given graph based on the provided vertex IDs.

View file

@ -1,7 +1,6 @@
import time
import traceback
import uuid
from functools import partial
from typing import TYPE_CHECKING, Annotated, Optional
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
@ -202,11 +201,8 @@ async def build_vertex(
inputs_dict=inputs.model_dump() if inputs else {},
files=files,
)
set_cache_coro = partial(get_chat_service().set_cache, key=flow_id_str)
next_runnable_vertices = await graph.run_manager.get_next_runnable_vertices(
lock, set_cache_coro, graph=graph, vertex=vertex, cache=False
)
top_level_vertices = graph.run_manager.get_top_level_vertices(graph, next_runnable_vertices)
next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False)
top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices)
result_data_response = ResultDataResponse.model_validate(result_dict, from_attributes=True)
except Exception as exc:
@ -292,7 +288,7 @@ async def build_vertex(
componentErrorMessage=str(exc),
),
)
logger.error(f"Error building Component:\n\n{exc}")
logger.error(f"Error building Component: \n\n{exc}")
logger.exception(exc)
message = parse_exception(exc)
raise HTTPException(status_code=500, detail=message) from exc

View file

@ -209,7 +209,8 @@ class Graph:
for successor in successors:
if successor.params.get("stream") or successor.params.get("streaming"):
raise ValueError(
f"Components {vertex.display_name} and {successor.display_name} are connected and both have stream or streaming set to True"
f"Components {vertex.display_name} and {successor.display_name} "
"are connected and both have stream or streaming set to True"
)
@property
@ -438,7 +439,7 @@ class Graph:
Args:
inputs (list[Dict[str, str]]): The input values for the graph.
inputs_components (Optional[list[list[str]]], optional): The components to run for the inputs. Defaults to None.
inputs_components (Optional[list[list[str]]], optional): Components to run for the inputs. Defaults to None.
outputs (Optional[list[str]], optional): The outputs to retrieve from the graph. Defaults to None.
session_id (Optional[str], optional): The session ID for the graph. Defaults to None.
stream (bool, optional): Whether to stream the results or not. Defaults to False.
@ -909,7 +910,7 @@ class Graph:
return result_dict, params, valid, artifacts, vertex
except Exception as exc:
if not isinstance(exc, ComponentBuildException):
logger.exception(f"Error building Component:\n\n{exc}")
logger.exception(f"Error building Component: \n\n{exc}")
flow_id = self.flow_id
log_transaction(flow_id, vertex, status="failure", error=str(exc))
raise exc
@ -987,6 +988,29 @@ class Graph:
logger.debug("Graph processing complete")
return self
def find_next_runnable_vertices(self, vertex_id: str, vertex_successors_ids: List[str]) -> List[str]:
direct_successors_ready = [v_id for v_id in vertex_successors_ids if self.is_vertex_runnable(v_id)]
if not direct_successors_ready:
return self.find_runnable_predecessors_for_successors(vertex_id)
return direct_successors_ready
async def get_next_runnable_vertices(self, lock: asyncio.Lock, vertex: "Vertex", cache: bool = True) -> List[str]:
v_id = vertex.id
v_successors_ids = vertex.successors_ids
async with lock:
self.run_manager.remove_vertex_from_runnables(v_id)
next_runnable_vertices = self.find_next_runnable_vertices(v_id, v_successors_ids)
for i in set(next_runnable_vertices): # Use set to avoid duplicates
if i == v_id:
next_runnable_vertices.remove(v_id)
else:
self.run_manager.add_to_vertices_being_run(v_id)
if cache:
set_cache_coro = partial(get_chat_service().set_cache, key=self.flow_id)
await set_cache_coro(self, lock)
return next_runnable_vertices
async def _execute_tasks(self, tasks: List[asyncio.Task], lock: asyncio.Lock) -> List[str]:
"""Executes tasks in parallel, handling exceptions for each task."""
results = []
@ -1012,11 +1036,8 @@ class Graph:
# This could usually happen with input vertices like ChatInput
self.run_manager.remove_vertex_from_runnables(v.id)
set_cache_coro = partial(get_chat_service().set_cache, key=self.flow_id)
for v in vertices:
next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(
lock, set_cache_coro, graph=self, vertex=v, cache=False
)
next_runnable_vertices = await self.get_next_runnable_vertices(lock, vertex=v, cache=False)
results.extend(next_runnable_vertices)
return results
@ -1197,7 +1218,7 @@ class Graph:
def __repr__(self):
vertex_ids = [vertex.id for vertex in self.vertices]
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
return f"Graph: \nNodes: {vertex_ids}\nConnections: \n{edges_repr}"
def layered_topological_sort(
self,
@ -1412,7 +1433,8 @@ class Graph:
def is_vertex_runnable(self, vertex_id: str) -> bool:
"""Returns whether a vertex is runnable."""
return self.run_manager.is_vertex_runnable(self.get_vertex(vertex_id))
is_active = self.get_vertex(vertex_id).is_active()
return self.run_manager.is_vertex_runnable(vertex_id, is_active)
def build_run_map(self):
"""
@ -1432,7 +1454,26 @@ class Graph:
This checks the direct predecessors of each successor to identify any that are
immediately runnable, expanding the search to ensure progress can be made.
"""
return self.run_manager.find_runnable_predecessors_for_successors(self.get_vertex(vertex_id))
runnable_vertices = []
visited = set()
def find_runnable_predecessors(predecessor: "Vertex"):
predecessor_id = predecessor.id
if predecessor_id in visited:
return
visited.add(predecessor_id)
is_active = self.get_vertex(predecessor_id).is_active()
if self.run_manager.is_vertex_runnable(predecessor_id, is_active):
runnable_vertices.append(predecessor_id)
else:
for pred_pred_id in self.run_manager.run_predecessors.get(predecessor_id, []):
find_runnable_predecessors(self.get_vertex(pred_pred_id))
for successor_id in self.run_manager.run_map.get(vertex_id, []):
for predecessor_id in self.run_manager.run_predecessors.get(successor_id, []):
find_runnable_predecessors(self.get_vertex(predecessor_id))
return runnable_vertices
def remove_from_predecessors(self, vertex_id: str):
self.run_manager.remove_from_predecessors(vertex_id)
@ -1440,6 +1481,26 @@ class Graph:
def remove_vertex_from_runnables(self, vertex_id: str):
self.run_manager.remove_vertex_from_runnables(vertex_id)
def get_top_level_vertices(self, vertices_ids):
"""
Retrieves the top-level vertices from the given graph based on the provided vertex IDs.
Args:
vertices_ids (list): A list of vertex IDs.
Returns:
list: A list of top-level vertex IDs.
"""
top_level_vertices = []
for vertex_id in vertices_ids:
vertex = self.get_vertex(vertex_id)
if vertex.parent_is_top_level:
top_level_vertices.append(vertex.parent_node_id)
else:
top_level_vertices.append(vertex_id)
return top_level_vertices
def build_in_degree(self, edges: List[ContractEdge]) -> Dict[str, int]:
in_degree: Dict[str, int] = defaultdict(int)
for edge in edges:

View file

@ -1,10 +1,4 @@
import asyncio
from collections import defaultdict
from typing import TYPE_CHECKING, Callable, Coroutine, List
if TYPE_CHECKING:
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
class RunnableVerticesManager:
@ -19,6 +13,7 @@ class RunnableVerticesManager:
"run_map": self.run_map,
"run_predecessors": self.run_predecessors,
"vertices_to_run": self.vertices_to_run,
"vertices_being_run": self.vertices_being_run,
}
@classmethod
@ -27,6 +22,7 @@ class RunnableVerticesManager:
instance.run_map = data["run_map"]
instance.run_predecessors = data["run_predecessors"]
instance.vertices_to_run = data["vertices_to_run"]
instance.vertices_being_run = data["vertices_being_run"]
return instance
def __getstate__(self) -> object:
@ -34,12 +30,14 @@ class RunnableVerticesManager:
"run_map": self.run_map,
"run_predecessors": self.run_predecessors,
"vertices_to_run": self.vertices_to_run,
"vertices_being_run": self.vertices_being_run,
}
def __setstate__(self, state: dict) -> None:
self.run_map = state["run_map"]
self.run_predecessors = state["run_predecessors"]
self.vertices_to_run = state["vertices_to_run"]
self.vertices_being_run = state["vertices_being_run"]
def all_predecessors_are_fulfilled(self) -> bool:
return all(not value for value in self.run_predecessors.values())
@ -49,42 +47,21 @@ class RunnableVerticesManager:
self.vertices_to_run.update(vertices_to_run)
self.build_run_map(self.run_predecessors, self.vertices_to_run)
def is_vertex_runnable(self, vertex: "Vertex") -> bool:
def is_vertex_runnable(self, vertex_id: str, is_active: bool) -> bool:
"""Determines if a vertex is runnable."""
return (
vertex.is_active()
and self.are_all_predecessors_fulfilled(vertex.id)
and vertex.id in self.vertices_to_run
and vertex.id not in self.vertices_being_run
)
if not is_active:
return False
if vertex_id in self.vertices_being_run:
return False
if vertex_id not in self.vertices_to_run:
return False
if not self.are_all_predecessors_fulfilled(vertex_id):
return False
return True
def are_all_predecessors_fulfilled(self, vertex_id: str) -> bool:
return not any(self.run_predecessors.get(vertex_id, []))
def find_runnable_predecessors_for_successors(self, vertex: "Vertex") -> List[str]:
"""Finds runnable predecessors for the successors of a given vertex."""
runnable_vertices = []
visited = set()
get_vertex = vertex.graph.get_vertex
def find_runnable_predecessors(predecessor: "Vertex"):
predecessor_id = predecessor.id
if predecessor_id in visited:
return
visited.add(predecessor_id)
if self.is_vertex_runnable(predecessor):
runnable_vertices.append(predecessor_id)
else:
for pred_pred_id in self.run_predecessors.get(predecessor_id, []):
find_runnable_predecessors(get_vertex(pred_pred_id))
for successor_id in self.run_map.get(vertex.id, []):
for predecessor_id in self.run_predecessors.get(successor_id, []):
find_runnable_predecessors(get_vertex(predecessor_id))
return runnable_vertices
def remove_from_predecessors(self, vertex_id: str):
"""Removes a vertex from the predecessor list of its successors."""
predecessors = self.run_map.get(vertex_id, [])
@ -108,71 +85,9 @@ class RunnableVerticesManager:
else:
self.vertices_being_run.discard(vertex_id)
async def get_next_runnable_vertices(
self,
lock: asyncio.Lock,
set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine],
graph: "Graph",
vertex: "Vertex",
cache: bool = True,
) -> List[str]:
"""
Retrieves the next runnable vertices in the graph for a given vertex.
Args:
lock (asyncio.Lock): The lock object to be used for synchronization.
set_cache_coro (Callable): The coroutine function to set the cache.
graph (Graph): The graph object containing the vertices.
vertex (Vertex): The vertex object for which the next runnable vertices are to be retrieved.
cache (bool, optional): A flag to indicate if the cache should be updated. Defaults to True.
Returns:
list: A list of IDs of the next runnable vertices.
"""
async with lock:
self.remove_vertex_from_runnables(vertex.id)
direct_successors_ready = [v for v in vertex.successors_ids if self.is_vertex_runnable(graph.get_vertex(v))]
if not direct_successors_ready:
# No direct successors ready, look for runnable predecessors of successors
next_runnable_vertices = self.find_runnable_predecessors_for_successors(vertex)
else:
next_runnable_vertices = direct_successors_ready
for v_id in set(next_runnable_vertices): # Use set to avoid duplicates
if vertex.id == v_id:
next_runnable_vertices.remove(v_id)
else:
self.add_to_vertices_being_run(v_id)
if cache:
await set_cache_coro(data=graph, lock=lock) # type: ignore
return next_runnable_vertices
def remove_vertex_from_runnables(self, v_id):
self.update_vertex_run_state(v_id, is_runnable=False)
self.remove_from_predecessors(v_id)
def add_to_vertices_being_run(self, v_id):
self.vertices_being_run.add(v_id)
@staticmethod
def get_top_level_vertices(graph, vertices_ids):
"""
Retrieves the top-level vertices from the given graph based on the provided vertex IDs.
Args:
graph (Graph): The graph object containing the vertices.
vertices_ids (list): A list of vertex IDs.
Returns:
list: A list of top-level vertex IDs.
"""
top_level_vertices = []
for vertex_id in vertices_ids:
vertex = graph.get_vertex(vertex_id)
if vertex.parent_is_top_level:
top_level_vertices.append(vertex.parent_node_id)
else:
top_level_vertices.append(vertex_id)
return top_level_vertices

View file

@ -0,0 +1,200 @@
import pickle
from collections import defaultdict
import pytest
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
@pytest.fixture
def data():
run_map: defaultdict(list) = {"A": ["B", "C"], "B": ["D"], "C": ["D"], "D": []}
run_predecessors: defaultdict(set) = {"A": set(), "B": {"A"}, "C": {"A"}, "D": {"B", "C"}}
vertices_to_run: set = {"A", "B", "C"}
vertices_being_run = {"A"}
return {
"run_map": run_map,
"run_predecessors": run_predecessors,
"vertices_to_run": vertices_to_run,
"vertices_being_run": vertices_being_run,
}
def test_to_dict(data):
result = RunnableVerticesManager.from_dict(data).to_dict()
assert all(key in result.keys() for key in data.keys())
def test_from_dict(data):
result = RunnableVerticesManager.from_dict(data)
assert isinstance(result, RunnableVerticesManager)
def test_from_dict_without_run_map__bad_case(data):
data.pop("run_map")
with pytest.raises(KeyError):
RunnableVerticesManager.from_dict(data)
def test_from_dict_without_run_predecessors__bad_case(data):
data.pop("run_predecessors")
with pytest.raises(KeyError):
RunnableVerticesManager.from_dict(data)
def test_from_dict_without_vertices_to_run__bad_case(data):
data.pop("vertices_to_run")
with pytest.raises(KeyError):
RunnableVerticesManager.from_dict(data)
def test_from_dict_without_vertices_being_run__bad_case(data):
data.pop("vertices_being_run")
with pytest.raises(KeyError):
RunnableVerticesManager.from_dict(data)
def test_pickle(data):
manager = RunnableVerticesManager.from_dict(data)
binary = pickle.dumps(manager)
result = pickle.loads(binary)
assert result.run_map == manager.run_map
assert result.run_predecessors == manager.run_predecessors
assert result.vertices_to_run == manager.vertices_to_run
assert result.vertices_being_run == manager.vertices_being_run
def test_update_run_state(data):
manager = RunnableVerticesManager.from_dict(data)
run_predecessors = {"E": {"D"}}
vertices_to_run = {"D"}
manager.update_run_state(run_predecessors, vertices_to_run)
assert "D" in manager.run_map
assert "D" in manager.vertices_to_run
assert "D" in manager.run_predecessors["E"]
def test_is_vertex_runnable(data):
manager = RunnableVerticesManager.from_dict(data)
vertex_id = "A"
is_active = True
result = manager.is_vertex_runnable(vertex_id, is_active)
assert result is False
def test_is_vertex_runnable__wrong_is_active(data):
manager = RunnableVerticesManager.from_dict(data)
vertex_id = "A"
is_active = False
result = manager.is_vertex_runnable(vertex_id, is_active)
assert result is False
def test_is_vertex_runnable__wrong_vertices_to_run(data):
manager = RunnableVerticesManager.from_dict(data)
vertex_id = "D"
is_active = True
result = manager.is_vertex_runnable(vertex_id, is_active)
assert result is False
def test_is_vertex_runnable__wrong_run_predecessors(data):
manager = RunnableVerticesManager.from_dict(data)
vertex_id = "C"
is_active = True
result = manager.is_vertex_runnable(vertex_id, is_active)
assert result is False
def test_are_all_predecessors_fulfilled(data):
manager = RunnableVerticesManager.from_dict(data)
vertex_id = "A"
result = manager.are_all_predecessors_fulfilled(vertex_id)
assert result is True
def test_are_all_predecessors_fulfilled__wrong(data):
manager = RunnableVerticesManager.from_dict(data)
vertex_id = "D"
result = manager.are_all_predecessors_fulfilled(vertex_id)
assert result is False
def test_remove_from_predecessors(data):
manager = RunnableVerticesManager.from_dict(data)
vertex_id = "A"
manager.remove_from_predecessors(vertex_id)
assert all(vertex_id not in predecessors for predecessors in manager.run_predecessors.values())
def test_build_run_map(data):
manager = RunnableVerticesManager.from_dict(data)
vertices_to_run = {}
predecessor_map = {"Z": set(), "X": {"Z"}, "Y": {"Z"}, "W": {"X", "Y"}}
manager.build_run_map(predecessor_map, vertices_to_run)
assert all(v in manager.run_map.keys() for v in ["Z", "X", "Y"])
assert "W" not in manager.run_map.keys()
def test_update_vertex_run_state(data):
manager = RunnableVerticesManager.from_dict(data)
vertex_id = "C"
is_runnable = True
manager.update_vertex_run_state(vertex_id, is_runnable)
assert vertex_id in manager.vertices_to_run
def test_update_vertex_run_state__bad_case(data):
manager = RunnableVerticesManager.from_dict(data)
vertex_id = "C"
is_runnable = False
manager.update_vertex_run_state(vertex_id, is_runnable)
assert vertex_id not in manager.vertices_being_run
def test_remove_vertex_from_runnables(data):
manager = RunnableVerticesManager.from_dict(data)
vertex_id = "C"
manager.remove_vertex_from_runnables(vertex_id)
assert vertex_id not in manager.vertices_being_run
def test_add_to_vertices_being_run(data):
manager = RunnableVerticesManager.from_dict(data)
vertex_id = "C"
manager.add_to_vertices_being_run(vertex_id)
assert vertex_id in manager.vertices_being_run

View file

@ -184,10 +184,10 @@ def test_directory_without_mocks():
# check if the directory component can load them
# just check if the number of results is the same as the number of files
directory_component = data.DirectoryComponent()
docs_path = Path(__file__).parent.parent.parent / "docs" / "docs" / "components"
docs_path = Path(__file__).parent.parent.parent / "docs" / "docs" / "Components"
directory_component.set_attributes({"path": str(docs_path), "use_multithreading": False})
results = directory_component.load_directory()
docs_files = list(docs_path.glob("*.mdx"))
docs_files = list(docs_path.glob("*.md")) + list(docs_path.glob("*.json"))
assert len(results) == len(docs_files)