Refactor build_vertex method and add RunnableVerticesManager class
This commit is contained in:
parent
0f9dccc4fd
commit
42710a246b
3 changed files with 197 additions and 79 deletions
|
|
@ -6,13 +6,7 @@ from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
|
|||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from langflow.api.utils import (
|
||||
build_and_cache_graph,
|
||||
format_elapsed_time,
|
||||
format_exception_message,
|
||||
get_next_runnable_vertices,
|
||||
get_top_level_vertices,
|
||||
)
|
||||
from langflow.api.utils import build_and_cache_graph, format_elapsed_time, format_exception_message
|
||||
from langflow.api.v1.schemas import (
|
||||
InputValueRequest,
|
||||
ResultDataResponse,
|
||||
|
|
@ -147,23 +141,21 @@ async def build_vertex(
|
|||
graph = cache.get("result")
|
||||
result_data_response = ResultDataResponse(results={})
|
||||
duration = ""
|
||||
|
||||
vertex = graph.get_vertex(vertex_id)
|
||||
try:
|
||||
if not vertex.frozen or not vertex._built:
|
||||
inputs_dict = inputs.model_dump() if inputs else {}
|
||||
await vertex.build(user_id=current_user.id, inputs=inputs_dict)
|
||||
|
||||
if vertex.result is not None:
|
||||
params = vertex._built_object_repr()
|
||||
valid = True
|
||||
result_dict = vertex.result
|
||||
artifacts = vertex.artifacts
|
||||
else:
|
||||
raise ValueError(f"No result found for vertex {vertex_id}")
|
||||
|
||||
next_runnable_vertices = await get_next_runnable_vertices(graph, vertex, vertex_id, chat_service, flow_id)
|
||||
top_level_vertices = get_top_level_vertices(graph, next_runnable_vertices)
|
||||
(
|
||||
next_runnable_vertices,
|
||||
top_level_vertices,
|
||||
result_dict,
|
||||
params,
|
||||
valid,
|
||||
artifacts,
|
||||
vertex,
|
||||
) = await graph.build_vertex(
|
||||
chat_service=chat_service,
|
||||
vertex_id=vertex_id,
|
||||
user_id=current_user.id,
|
||||
inputs=inputs.model_dump() if inputs else {},
|
||||
)
|
||||
result_data_response = ResultDataResponse(**result_dict.model_dump())
|
||||
|
||||
except Exception as exc:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from loguru import logger
|
|||
|
||||
from langflow.graph.edge.base import ContractEdge
|
||||
from langflow.graph.graph.constants import lazy_load_vertex_dict
|
||||
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
|
||||
from langflow.graph.graph.state_manager import GraphStateManager
|
||||
from langflow.graph.graph.utils import process_flow
|
||||
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
|
||||
|
|
@ -67,6 +68,7 @@ class Graph:
|
|||
self.inactive_vertices: set = set()
|
||||
self.edges: List[ContractEdge] = []
|
||||
self.vertices: List[Vertex] = []
|
||||
self.run_manager = RunnableVerticesManager()
|
||||
self._build_graph()
|
||||
self.build_graph_maps()
|
||||
self.define_vertices_lists()
|
||||
|
|
@ -427,30 +429,6 @@ class Graph:
|
|||
def __setstate__(self, state):
|
||||
self.__init__(**state)
|
||||
|
||||
def build_in_degree(self):
|
||||
in_degree = defaultdict(int)
|
||||
for edge in self.edges:
|
||||
in_degree[edge.target_id] += 1
|
||||
return in_degree
|
||||
|
||||
def build_adjacency_maps(self):
|
||||
"""Returns the adjacency maps for the graph."""
|
||||
predecessor_map = defaultdict(list)
|
||||
successor_map = defaultdict(list)
|
||||
for edge in self.edges:
|
||||
predecessor_map[edge.target_id].append(edge.source_id)
|
||||
successor_map[edge.source_id].append(edge.target_id)
|
||||
return predecessor_map, successor_map
|
||||
|
||||
def build_run_map(self):
|
||||
run_map = defaultdict(list)
|
||||
# The run map gets the predecessor_map and maps the info like this:
|
||||
# {vertex_id: every id that contains the vertex_id in the predecessor_map}
|
||||
for vertex_id, predecessors in self.predecessor_map.items():
|
||||
for predecessor in predecessors:
|
||||
run_map[predecessor].append(vertex_id)
|
||||
return run_map
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: Dict, flow_id: Optional[str] = None) -> "Graph":
|
||||
"""
|
||||
|
|
@ -669,6 +647,32 @@ class Graph:
|
|||
except KeyError:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
|
||||
async def build_vertex(
|
||||
self, chat_service, vertex_id: str, inputs: Optional[Dict[str, str]] = None, user_id: Optional[str] = None
|
||||
):
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
try:
|
||||
if not vertex.frozen or not vertex._built:
|
||||
inputs_dict = inputs.model_dump() if inputs else {}
|
||||
await vertex.build(user_id=user_id, inputs=inputs_dict)
|
||||
|
||||
if vertex.result is not None:
|
||||
params = vertex._built_object_repr()
|
||||
valid = True
|
||||
result_dict = vertex.result
|
||||
artifacts = vertex.artifacts
|
||||
else:
|
||||
raise ValueError(f"No result found for vertex {vertex_id}")
|
||||
|
||||
next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(
|
||||
self, vertex, vertex_id, chat_service, self.flow_id
|
||||
)
|
||||
top_level_vertices = self.run_manager.get_top_level_vertices(self, next_runnable_vertices)
|
||||
return next_runnable_vertices, top_level_vertices, result_dict, params, valid, artifacts, vertex
|
||||
except Exception as exc:
|
||||
logger.exception(f"Error building vertex: {exc}")
|
||||
raise exc
|
||||
|
||||
def get_vertex_edges(
|
||||
self,
|
||||
vertex_id: str,
|
||||
|
|
@ -1107,41 +1111,10 @@ class Graph:
|
|||
# save the only the rest
|
||||
self.vertices_layers = vertices_layers[1:]
|
||||
self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)}
|
||||
self.run_map, self.run_predecessors = (
|
||||
self.build_run_map(),
|
||||
self.predecessor_map.copy(),
|
||||
)
|
||||
|
||||
self.build_run_map()
|
||||
# Return just the first layer
|
||||
return first_layer
|
||||
|
||||
def is_vertex_runnable(self, vertex_id: str) -> bool:
|
||||
"""Returns whether a vertex is runnable."""
|
||||
return vertex_id in self.vertices_to_run and not self.run_predecessors.get(vertex_id)
|
||||
|
||||
def find_runnable_predecessors_for_successors(self, vertex_id: str) -> List[str]:
|
||||
"""
|
||||
For each successor of the current vertex, find runnable predecessors if any.
|
||||
This checks the direct predecessors of each successor to identify any that are
|
||||
immediately runnable, expanding the search to ensure progress can be made.
|
||||
"""
|
||||
runnable_vertices = []
|
||||
visited = set()
|
||||
|
||||
for successor_id in self.run_map.get(vertex_id, []):
|
||||
for predecessor_id in self.run_predecessors.get(successor_id, []):
|
||||
if predecessor_id not in visited and self.is_vertex_runnable(predecessor_id):
|
||||
runnable_vertices.append(predecessor_id)
|
||||
visited.add(predecessor_id)
|
||||
|
||||
return runnable_vertices
|
||||
|
||||
def remove_from_predecessors(self, vertex_id: str):
|
||||
predecessors = self.run_map.get(vertex_id, [])
|
||||
for predecessor in predecessors:
|
||||
if vertex_id in self.run_predecessors[predecessor]:
|
||||
self.run_predecessors[predecessor].remove(vertex_id)
|
||||
|
||||
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
|
||||
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
|
||||
|
||||
|
|
@ -1171,3 +1144,45 @@ class Graph:
|
|||
|
||||
sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers]
|
||||
return sorted_vertices
|
||||
|
||||
def is_vertex_runnable(self, vertex_id: str) -> bool:
|
||||
"""Returns whether a vertex is runnable."""
|
||||
return self.run_manager.is_vertex_runnable(vertex_id)
|
||||
|
||||
def build_run_map(self):
|
||||
"""
|
||||
Builds the run map for the graph.
|
||||
|
||||
This method is responsible for building the run map for the graph,
|
||||
which maps each node in the graph to its corresponding run function.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.run_manager.build_run_map(self)
|
||||
|
||||
def find_runnable_predecessors_for_successors(self, vertex_id: str) -> List[str]:
|
||||
"""
|
||||
For each successor of the current vertex, find runnable predecessors if any.
|
||||
This checks the direct predecessors of each successor to identify any that are
|
||||
immediately runnable, expanding the search to ensure progress can be made.
|
||||
"""
|
||||
self.run_manager.find_runnable_predecessors_for_successors(vertex_id)
|
||||
|
||||
def remove_from_predecessors(self, vertex_id: str):
|
||||
self.run_manager.remove_from_predecessors(vertex_id)
|
||||
|
||||
def build_in_degree(self):
|
||||
in_degree = defaultdict(int)
|
||||
for edge in self.edges:
|
||||
in_degree[edge.target_id] += 1
|
||||
return in_degree
|
||||
|
||||
def build_adjacency_maps(self):
|
||||
"""Returns the adjacency maps for the graph."""
|
||||
predecessor_map = defaultdict(list)
|
||||
successor_map = defaultdict(list)
|
||||
for edge in self.edges:
|
||||
predecessor_map[edge.target_id].append(edge.source_id)
|
||||
successor_map[edge.source_id].append(edge.target_id)
|
||||
return predecessor_map, successor_map
|
||||
|
|
|
|||
111
src/backend/langflow/graph/graph/runnable_vertices_manager.py
Normal file
111
src/backend/langflow/graph/graph/runnable_vertices_manager.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.services.chat.service import ChatService
|
||||
|
||||
|
||||
class RunnableVerticesManager:
|
||||
def __init__(self):
|
||||
self.run_map = defaultdict(list) # Tracks successors of each vertex
|
||||
self.run_predecessors = defaultdict(set) # Tracks predecessors for each vertex
|
||||
self.vertices_to_run = set() # Set of vertices that are ready to run
|
||||
|
||||
def is_vertex_runnable(self, vertex_id: str) -> bool:
|
||||
"""Determines if a vertex is runnable."""
|
||||
return vertex_id in self.vertices_to_run and not self.run_predecessors.get(vertex_id)
|
||||
|
||||
def find_runnable_predecessors_for_successors(self, vertex_id: str) -> List[str]:
|
||||
"""Finds runnable predecessors for the successors of a given vertex."""
|
||||
runnable_vertices = []
|
||||
visited = set()
|
||||
|
||||
for successor_id in self.run_map.get(vertex_id, []):
|
||||
for predecessor_id in self.run_predecessors.get(successor_id, []):
|
||||
if predecessor_id not in visited and self.is_vertex_runnable(predecessor_id):
|
||||
runnable_vertices.append(predecessor_id)
|
||||
visited.add(predecessor_id)
|
||||
return runnable_vertices
|
||||
|
||||
def remove_from_predecessors(self, vertex_id: str):
|
||||
"""Removes a vertex from the predecessor list of its successors."""
|
||||
predecessors = self.run_map.get(vertex_id, [])
|
||||
for predecessor in predecessors:
|
||||
if vertex_id in self.run_predecessors[predecessor]:
|
||||
self.run_predecessors[predecessor].remove(vertex_id)
|
||||
|
||||
def build_run_map(self, graph):
|
||||
"""Builds a map of vertices and their runnable successors."""
|
||||
self.run_map = defaultdict(list)
|
||||
for vertex_id, predecessors in graph.predecessor_map.items():
|
||||
for predecessor in predecessors:
|
||||
self.run_map[predecessor].append(vertex_id)
|
||||
self.run_predecessors = {k: set(v) for k, v in self.run_map.items()}
|
||||
|
||||
def update_vertex_run_state(self, vertex_id: str, is_runnable: bool):
|
||||
"""Updates the runnable state of a vertex."""
|
||||
if is_runnable:
|
||||
self.vertices_to_run.add(vertex_id)
|
||||
else:
|
||||
self.vertices_to_run.discard(vertex_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_next_runnable_vertices(
|
||||
graph: "Graph",
|
||||
vertex: "Vertex",
|
||||
vertex_id: str,
|
||||
chat_service: "ChatService",
|
||||
flow_id: str,
|
||||
):
|
||||
"""
|
||||
Retrieves the next runnable vertices in the graph for a given vertex.
|
||||
|
||||
Args:
|
||||
graph (Graph): The graph object representing the flow.
|
||||
vertex (Vertex): The current vertex.
|
||||
vertex_id (str): The ID of the current vertex.
|
||||
chat_service (ChatService): The chat service object.
|
||||
flow_id (str): The ID of the flow.
|
||||
|
||||
Returns:
|
||||
list: A list of IDs of the next runnable vertices.
|
||||
|
||||
"""
|
||||
async with chat_service._cache_locks[flow_id] as lock:
|
||||
graph.remove_from_predecessors(vertex_id)
|
||||
direct_successors_ready = [v for v in vertex.successors_ids if graph.is_vertex_runnable(v)]
|
||||
if not direct_successors_ready:
|
||||
# No direct successors ready, look for runnable predecessors of successors
|
||||
next_runnable_vertices = graph.find_runnable_predecessors_for_successors(vertex_id)
|
||||
else:
|
||||
next_runnable_vertices = direct_successors_ready
|
||||
|
||||
for v_id in set(next_runnable_vertices): # Use set to avoid duplicates
|
||||
graph.vertices_to_run.remove(v_id)
|
||||
graph.remove_from_predecessors(v_id)
|
||||
await chat_service.set_cache(flow_id=flow_id, data=graph, lock=lock)
|
||||
return next_runnable_vertices
|
||||
|
||||
@staticmethod
|
||||
def get_top_level_vertices(graph, vertices_ids):
|
||||
"""
|
||||
Retrieves the top-level vertices from the given graph based on the provided vertex IDs.
|
||||
|
||||
Args:
|
||||
graph (Graph): The graph object containing the vertices.
|
||||
vertices_ids (list): A list of vertex IDs.
|
||||
|
||||
Returns:
|
||||
list: A list of top-level vertex IDs.
|
||||
|
||||
"""
|
||||
top_level_vertices = []
|
||||
for vertex_id in vertices_ids:
|
||||
vertex = graph.get_vertex(vertex_id)
|
||||
if vertex.parent_is_top_level:
|
||||
top_level_vertices.append(vertex.parent_node_id)
|
||||
else:
|
||||
top_level_vertices.append(vertex_id)
|
||||
return top_level_vertices
|
||||
Loading…
Add table
Add a link
Reference in a new issue