Refactor build_vertex method and add RunnableVerticesManager class

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-25 10:56:24 -03:00
commit 42710a246b
3 changed files with 197 additions and 79 deletions

View file

@ -6,13 +6,7 @@ from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from langflow.api.utils import (
build_and_cache_graph,
format_elapsed_time,
format_exception_message,
get_next_runnable_vertices,
get_top_level_vertices,
)
from langflow.api.utils import build_and_cache_graph, format_elapsed_time, format_exception_message
from langflow.api.v1.schemas import (
InputValueRequest,
ResultDataResponse,
@ -147,23 +141,21 @@ async def build_vertex(
graph = cache.get("result")
result_data_response = ResultDataResponse(results={})
duration = ""
vertex = graph.get_vertex(vertex_id)
try:
if not vertex.frozen or not vertex._built:
inputs_dict = inputs.model_dump() if inputs else {}
await vertex.build(user_id=current_user.id, inputs=inputs_dict)
if vertex.result is not None:
params = vertex._built_object_repr()
valid = True
result_dict = vertex.result
artifacts = vertex.artifacts
else:
raise ValueError(f"No result found for vertex {vertex_id}")
next_runnable_vertices = await get_next_runnable_vertices(graph, vertex, vertex_id, chat_service, flow_id)
top_level_vertices = get_top_level_vertices(graph, next_runnable_vertices)
(
next_runnable_vertices,
top_level_vertices,
result_dict,
params,
valid,
artifacts,
vertex,
) = await graph.build_vertex(
chat_service=chat_service,
vertex_id=vertex_id,
user_id=current_user.id,
inputs=inputs.model_dump() if inputs else {},
)
result_data_response = ResultDataResponse(**result_dict.model_dump())
except Exception as exc:

View file

@ -7,6 +7,7 @@ from loguru import logger
from langflow.graph.edge.base import ContractEdge
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.utils import process_flow
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
@ -67,6 +68,7 @@ class Graph:
self.inactive_vertices: set = set()
self.edges: List[ContractEdge] = []
self.vertices: List[Vertex] = []
self.run_manager = RunnableVerticesManager()
self._build_graph()
self.build_graph_maps()
self.define_vertices_lists()
@ -427,30 +429,6 @@ class Graph:
def __setstate__(self, state):
self.__init__(**state)
def build_in_degree(self):
in_degree = defaultdict(int)
for edge in self.edges:
in_degree[edge.target_id] += 1
return in_degree
def build_adjacency_maps(self):
"""Returns the adjacency maps for the graph."""
predecessor_map = defaultdict(list)
successor_map = defaultdict(list)
for edge in self.edges:
predecessor_map[edge.target_id].append(edge.source_id)
successor_map[edge.source_id].append(edge.target_id)
return predecessor_map, successor_map
def build_run_map(self):
run_map = defaultdict(list)
# The run map gets the predecessor_map and maps the info like this:
# {vertex_id: every id that contains the vertex_id in the predecessor_map}
for vertex_id, predecessors in self.predecessor_map.items():
for predecessor in predecessors:
run_map[predecessor].append(vertex_id)
return run_map
@classmethod
def from_payload(cls, payload: Dict, flow_id: Optional[str] = None) -> "Graph":
"""
@ -669,6 +647,32 @@ class Graph:
except KeyError:
raise ValueError(f"Vertex {vertex_id} not found")
async def build_vertex(
self, chat_service, vertex_id: str, inputs: Optional[Dict[str, str]] = None, user_id: Optional[str] = None
):
vertex = self.get_vertex(vertex_id)
try:
if not vertex.frozen or not vertex._built:
inputs_dict = inputs.model_dump() if inputs else {}
await vertex.build(user_id=user_id, inputs=inputs_dict)
if vertex.result is not None:
params = vertex._built_object_repr()
valid = True
result_dict = vertex.result
artifacts = vertex.artifacts
else:
raise ValueError(f"No result found for vertex {vertex_id}")
next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(
self, vertex, vertex_id, chat_service, self.flow_id
)
top_level_vertices = self.run_manager.get_top_level_vertices(self, next_runnable_vertices)
return next_runnable_vertices, top_level_vertices, result_dict, params, valid, artifacts, vertex
except Exception as exc:
logger.exception(f"Error building vertex: {exc}")
raise exc
def get_vertex_edges(
self,
vertex_id: str,
@ -1107,41 +1111,10 @@ class Graph:
# save the only the rest
self.vertices_layers = vertices_layers[1:]
self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)}
self.run_map, self.run_predecessors = (
self.build_run_map(),
self.predecessor_map.copy(),
)
self.build_run_map()
# Return just the first layer
return first_layer
def is_vertex_runnable(self, vertex_id: str) -> bool:
"""Returns whether a vertex is runnable."""
return vertex_id in self.vertices_to_run and not self.run_predecessors.get(vertex_id)
def find_runnable_predecessors_for_successors(self, vertex_id: str) -> List[str]:
"""
For each successor of the current vertex, find runnable predecessors if any.
This checks the direct predecessors of each successor to identify any that are
immediately runnable, expanding the search to ensure progress can be made.
"""
runnable_vertices = []
visited = set()
for successor_id in self.run_map.get(vertex_id, []):
for predecessor_id in self.run_predecessors.get(successor_id, []):
if predecessor_id not in visited and self.is_vertex_runnable(predecessor_id):
runnable_vertices.append(predecessor_id)
visited.add(predecessor_id)
return runnable_vertices
def remove_from_predecessors(self, vertex_id: str):
predecessors = self.run_map.get(vertex_id, [])
for predecessor in predecessors:
if vertex_id in self.run_predecessors[predecessor]:
self.run_predecessors[predecessor].remove(vertex_id)
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
@ -1171,3 +1144,45 @@ class Graph:
sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers]
return sorted_vertices
def is_vertex_runnable(self, vertex_id: str) -> bool:
"""Returns whether a vertex is runnable."""
return self.run_manager.is_vertex_runnable(vertex_id)
def build_run_map(self):
"""
Builds the run map for the graph.
This method is responsible for building the run map for the graph,
which maps each node in the graph to its corresponding run function.
Returns:
None
"""
self.run_manager.build_run_map(self)
def find_runnable_predecessors_for_successors(self, vertex_id: str) -> List[str]:
"""
For each successor of the current vertex, find runnable predecessors if any.
This checks the direct predecessors of each successor to identify any that are
immediately runnable, expanding the search to ensure progress can be made.
"""
self.run_manager.find_runnable_predecessors_for_successors(vertex_id)
def remove_from_predecessors(self, vertex_id: str):
self.run_manager.remove_from_predecessors(vertex_id)
def build_in_degree(self):
in_degree = defaultdict(int)
for edge in self.edges:
in_degree[edge.target_id] += 1
return in_degree
def build_adjacency_maps(self):
"""Returns the adjacency maps for the graph."""
predecessor_map = defaultdict(list)
successor_map = defaultdict(list)
for edge in self.edges:
predecessor_map[edge.target_id].append(edge.source_id)
successor_map[edge.source_id].append(edge.target_id)
return predecessor_map, successor_map

View file

@ -0,0 +1,111 @@
from collections import defaultdict
from typing import TYPE_CHECKING, List
if TYPE_CHECKING:
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
from langflow.services.chat.service import ChatService
class RunnableVerticesManager:
def __init__(self):
self.run_map = defaultdict(list) # Tracks successors of each vertex
self.run_predecessors = defaultdict(set) # Tracks predecessors for each vertex
self.vertices_to_run = set() # Set of vertices that are ready to run
def is_vertex_runnable(self, vertex_id: str) -> bool:
"""Determines if a vertex is runnable."""
return vertex_id in self.vertices_to_run and not self.run_predecessors.get(vertex_id)
def find_runnable_predecessors_for_successors(self, vertex_id: str) -> List[str]:
"""Finds runnable predecessors for the successors of a given vertex."""
runnable_vertices = []
visited = set()
for successor_id in self.run_map.get(vertex_id, []):
for predecessor_id in self.run_predecessors.get(successor_id, []):
if predecessor_id not in visited and self.is_vertex_runnable(predecessor_id):
runnable_vertices.append(predecessor_id)
visited.add(predecessor_id)
return runnable_vertices
def remove_from_predecessors(self, vertex_id: str):
"""Removes a vertex from the predecessor list of its successors."""
predecessors = self.run_map.get(vertex_id, [])
for predecessor in predecessors:
if vertex_id in self.run_predecessors[predecessor]:
self.run_predecessors[predecessor].remove(vertex_id)
def build_run_map(self, graph):
"""Builds a map of vertices and their runnable successors."""
self.run_map = defaultdict(list)
for vertex_id, predecessors in graph.predecessor_map.items():
for predecessor in predecessors:
self.run_map[predecessor].append(vertex_id)
self.run_predecessors = {k: set(v) for k, v in self.run_map.items()}
def update_vertex_run_state(self, vertex_id: str, is_runnable: bool):
"""Updates the runnable state of a vertex."""
if is_runnable:
self.vertices_to_run.add(vertex_id)
else:
self.vertices_to_run.discard(vertex_id)
@staticmethod
async def get_next_runnable_vertices(
graph: "Graph",
vertex: "Vertex",
vertex_id: str,
chat_service: "ChatService",
flow_id: str,
):
"""
Retrieves the next runnable vertices in the graph for a given vertex.
Args:
graph (Graph): The graph object representing the flow.
vertex (Vertex): The current vertex.
vertex_id (str): The ID of the current vertex.
chat_service (ChatService): The chat service object.
flow_id (str): The ID of the flow.
Returns:
list: A list of IDs of the next runnable vertices.
"""
async with chat_service._cache_locks[flow_id] as lock:
graph.remove_from_predecessors(vertex_id)
direct_successors_ready = [v for v in vertex.successors_ids if graph.is_vertex_runnable(v)]
if not direct_successors_ready:
# No direct successors ready, look for runnable predecessors of successors
next_runnable_vertices = graph.find_runnable_predecessors_for_successors(vertex_id)
else:
next_runnable_vertices = direct_successors_ready
for v_id in set(next_runnable_vertices): # Use set to avoid duplicates
graph.vertices_to_run.remove(v_id)
graph.remove_from_predecessors(v_id)
await chat_service.set_cache(flow_id=flow_id, data=graph, lock=lock)
return next_runnable_vertices
@staticmethod
def get_top_level_vertices(graph, vertices_ids):
"""
Retrieves the top-level vertices from the given graph based on the provided vertex IDs.
Args:
graph (Graph): The graph object containing the vertices.
vertices_ids (list): A list of vertex IDs.
Returns:
list: A list of top-level vertex IDs.
"""
top_level_vertices = []
for vertex_id in vertices_ids:
vertex = graph.get_vertex(vertex_id)
if vertex.parent_is_top_level:
top_level_vertices.append(vertex.parent_node_id)
else:
top_level_vertices.append(vertex_id)
return top_level_vertices