Add functions for retrieving next runnable vertices and top-level vertices in the graph

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-20 16:11:12 -03:00
commit c4b2ad763f

View file

@ -13,6 +13,7 @@ from langflow.services.store.schema import StoreComponentCreate
from langflow.services.store.utils import get_lf_version_from_pypi
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
from langflow.services.database.models.flow.model import Flow
@ -238,3 +239,62 @@ def format_exception_message(exc: Exception) -> str:
if isinstance(causing_exception, SyntaxError):
return format_syntax_error_message(causing_exception)
return str(exc)
async def get_next_runnable_vertices(
graph: Graph,
vertex: "Vertex",
vertex_id: str,
chat_service: ChatService,
flow_id: str,
):
"""
Retrieves the next runnable vertices in the graph for a given vertex.
Args:
graph (Graph): The graph object representing the flow.
vertex (Vertex): The current vertex.
vertex_id (str): The ID of the current vertex.
chat_service (ChatService): The chat service object.
flow_id (str): The ID of the flow.
Returns:
list: A list of IDs of the next runnable vertices.
"""
async with chat_service._cache_locks[flow_id] as lock:
graph.remove_from_predecessors(vertex_id)
direct_successors_ready = [v for v in vertex.successors_ids if graph.is_vertex_runnable(v)]
if not direct_successors_ready:
# No direct successors ready, look for runnable predecessors of successors
next_runnable_vertices = graph.find_runnable_predecessors_for_successors(vertex_id)
else:
next_runnable_vertices = direct_successors_ready
for v_id in set(next_runnable_vertices): # Use set to avoid duplicates
graph.vertices_to_run.remove(v_id)
graph.remove_from_predecessors(v_id)
await chat_service.set_cache(flow_id=flow_id, data=graph, lock=lock)
return next_runnable_vertices
def get_top_level_vertices(graph, vertices_ids):
"""
Retrieves the top-level vertices from the given graph based on the provided vertex IDs.
Args:
graph (Graph): The graph object containing the vertices.
vertices_ids (list): A list of vertex IDs.
Returns:
list: A list of top-level vertex IDs.
"""
top_level_vertices = []
for vertex_id in vertices_ids:
vertex = graph.get_vertex(vertex_id)
if vertex.parent_is_top_level:
top_level_vertices.append(vertex.parent_node_id)
else:
top_level_vertices.append(vertex_id)
return top_level_vertices