diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 0f5c8c224..ee26f1403 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -6,13 +6,7 @@ from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException from fastapi.responses import StreamingResponse from loguru import logger -from langflow.api.utils import ( - build_and_cache_graph, - format_elapsed_time, - format_exception_message, - get_next_runnable_vertices, - get_top_level_vertices, -) +from langflow.api.utils import build_and_cache_graph, format_elapsed_time, format_exception_message from langflow.api.v1.schemas import ( InputValueRequest, ResultDataResponse, @@ -147,23 +141,21 @@ async def build_vertex( graph = cache.get("result") result_data_response = ResultDataResponse(results={}) duration = "" - - vertex = graph.get_vertex(vertex_id) try: - if not vertex.frozen or not vertex._built: - inputs_dict = inputs.model_dump() if inputs else {} - await vertex.build(user_id=current_user.id, inputs=inputs_dict) - - if vertex.result is not None: - params = vertex._built_object_repr() - valid = True - result_dict = vertex.result - artifacts = vertex.artifacts - else: - raise ValueError(f"No result found for vertex {vertex_id}") - - next_runnable_vertices = await get_next_runnable_vertices(graph, vertex, vertex_id, chat_service, flow_id) - top_level_vertices = get_top_level_vertices(graph, next_runnable_vertices) + ( + next_runnable_vertices, + top_level_vertices, + result_dict, + params, + valid, + artifacts, + vertex, + ) = await graph.build_vertex( + chat_service=chat_service, + vertex_id=vertex_id, + user_id=current_user.id, + inputs=inputs.model_dump() if inputs else {}, + ) result_data_response = ResultDataResponse(**result_dict.model_dump()) except Exception as exc: diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 2fb1aa71c..cb911bbd5 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -7,6 +7,7 @@ from loguru import logger from langflow.graph.edge.base import ContractEdge from langflow.graph.graph.constants import lazy_load_vertex_dict +from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager from langflow.graph.graph.state_manager import GraphStateManager from langflow.graph.graph.utils import process_flow from langflow.graph.schema import InterfaceComponentTypes, RunOutputs @@ -67,6 +68,7 @@ class Graph: self.inactive_vertices: set = set() self.edges: List[ContractEdge] = [] self.vertices: List[Vertex] = [] + self.run_manager = RunnableVerticesManager() self._build_graph() self.build_graph_maps() self.define_vertices_lists() @@ -427,30 +429,6 @@ class Graph: def __setstate__(self, state): self.__init__(**state) - def build_in_degree(self): - in_degree = defaultdict(int) - for edge in self.edges: - in_degree[edge.target_id] += 1 - return in_degree - - def build_adjacency_maps(self): - """Returns the adjacency maps for the graph.""" - predecessor_map = defaultdict(list) - successor_map = defaultdict(list) - for edge in self.edges: - predecessor_map[edge.target_id].append(edge.source_id) - successor_map[edge.source_id].append(edge.target_id) - return predecessor_map, successor_map - - def build_run_map(self): - run_map = defaultdict(list) - # The run map gets the predecessor_map and maps the info like this: - # {vertex_id: every id that contains the vertex_id in the predecessor_map} - for vertex_id, predecessors in self.predecessor_map.items(): - for predecessor in predecessors: - run_map[predecessor].append(vertex_id) - return run_map - @classmethod def from_payload(cls, payload: Dict, flow_id: Optional[str] = None) -> "Graph": """ @@ -669,6 +647,32 @@ class Graph: except KeyError: raise ValueError(f"Vertex {vertex_id} not found") + async def build_vertex( + self, chat_service, vertex_id: str, inputs: Optional[Dict[str, str]] = None, user_id: Optional[str] = None + ): + vertex = self.get_vertex(vertex_id) + try: + if not vertex.frozen or not vertex._built: + inputs_dict = inputs.model_dump() if inputs else {} + await vertex.build(user_id=user_id, inputs=inputs_dict) + + if vertex.result is not None: + params = vertex._built_object_repr() + valid = True + result_dict = vertex.result + artifacts = vertex.artifacts + else: + raise ValueError(f"No result found for vertex {vertex_id}") + + next_runnable_vertices = await self.run_manager.get_next_runnable_vertices( + self, vertex, vertex_id, chat_service, self.flow_id + ) + top_level_vertices = self.run_manager.get_top_level_vertices(self, next_runnable_vertices) + return next_runnable_vertices, top_level_vertices, result_dict, params, valid, artifacts, vertex + except Exception as exc: + logger.exception(f"Error building vertex: {exc}") + raise exc + def get_vertex_edges( self, vertex_id: str, @@ -1107,41 +1111,10 @@ class Graph: # save the only the rest self.vertices_layers = vertices_layers[1:] self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)} - self.run_map, self.run_predecessors = ( - self.build_run_map(), - self.predecessor_map.copy(), - ) - + self.build_run_map() # Return just the first layer return first_layer - def is_vertex_runnable(self, vertex_id: str) -> bool: - """Returns whether a vertex is runnable.""" - return vertex_id in self.vertices_to_run and not self.run_predecessors.get(vertex_id) - - def find_runnable_predecessors_for_successors(self, vertex_id: str) -> List[str]: - """ - For each successor of the current vertex, find runnable predecessors if any. - This checks the direct predecessors of each successor to identify any that are - immediately runnable, expanding the search to ensure progress can be made. - """ - runnable_vertices = [] - visited = set() - - for successor_id in self.run_map.get(vertex_id, []): - for predecessor_id in self.run_predecessors.get(successor_id, []): - if predecessor_id not in visited and self.is_vertex_runnable(predecessor_id): - runnable_vertices.append(predecessor_id) - visited.add(predecessor_id) - - return runnable_vertices - - def remove_from_predecessors(self, vertex_id: str): - predecessors = self.run_map.get(vertex_id, []) - for predecessor in predecessors: - if vertex_id in self.run_predecessors[predecessor]: - self.run_predecessors[predecessor].remove(vertex_id) - def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: """Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first.""" @@ -1171,3 +1144,45 @@ class Graph: sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers] return sorted_vertices + + def is_vertex_runnable(self, vertex_id: str) -> bool: + """Returns whether a vertex is runnable.""" + return self.run_manager.is_vertex_runnable(vertex_id) + + def build_run_map(self): + """ + Builds the run map for the graph. + + This method is responsible for building the run map for the graph, + which maps each node in the graph to its corresponding run function. + + Returns: + None + """ + self.run_manager.build_run_map(self) + + def find_runnable_predecessors_for_successors(self, vertex_id: str) -> List[str]: + """ + For each successor of the current vertex, find runnable predecessors if any. + This checks the direct predecessors of each successor to identify any that are + immediately runnable, expanding the search to ensure progress can be made. + """ + self.run_manager.find_runnable_predecessors_for_successors(vertex_id) + + def remove_from_predecessors(self, vertex_id: str): + self.run_manager.remove_from_predecessors(vertex_id) + + def build_in_degree(self): + in_degree = defaultdict(int) + for edge in self.edges: + in_degree[edge.target_id] += 1 + return in_degree + + def build_adjacency_maps(self): + """Returns the adjacency maps for the graph.""" + predecessor_map = defaultdict(list) + successor_map = defaultdict(list) + for edge in self.edges: + predecessor_map[edge.target_id].append(edge.source_id) + successor_map[edge.source_id].append(edge.target_id) + return predecessor_map, successor_map diff --git a/src/backend/langflow/graph/graph/runnable_vertices_manager.py b/src/backend/langflow/graph/graph/runnable_vertices_manager.py new file mode 100644 index 000000000..43a9c22a2 --- /dev/null +++ b/src/backend/langflow/graph/graph/runnable_vertices_manager.py @@ -0,0 +1,111 @@ +from collections import defaultdict +from typing import TYPE_CHECKING, List + +if TYPE_CHECKING: + from langflow.graph.graph.base import Graph + from langflow.graph.vertex.base import Vertex + from langflow.services.chat.service import ChatService + + +class RunnableVerticesManager: + def __init__(self): + self.run_map = defaultdict(list) # Tracks successors of each vertex + self.run_predecessors = defaultdict(set) # Tracks predecessors for each vertex + self.vertices_to_run = set() # Set of vertices that are ready to run + + def is_vertex_runnable(self, vertex_id: str) -> bool: + """Determines if a vertex is runnable.""" + return vertex_id in self.vertices_to_run and not self.run_predecessors.get(vertex_id) + + def find_runnable_predecessors_for_successors(self, vertex_id: str) -> List[str]: + """Finds runnable predecessors for the successors of a given vertex.""" + runnable_vertices = [] + visited = set() + + for successor_id in self.run_map.get(vertex_id, []): + for predecessor_id in self.run_predecessors.get(successor_id, []): + if predecessor_id not in visited and self.is_vertex_runnable(predecessor_id): + runnable_vertices.append(predecessor_id) + visited.add(predecessor_id) + return runnable_vertices + + def remove_from_predecessors(self, vertex_id: str): + """Removes a vertex from the predecessor list of its successors.""" + predecessors = self.run_map.get(vertex_id, []) + for predecessor in predecessors: + if vertex_id in self.run_predecessors[predecessor]: + self.run_predecessors[predecessor].remove(vertex_id) + + def build_run_map(self, graph): + """Builds a map of vertices and their runnable successors.""" + self.run_map = defaultdict(list) + for vertex_id, predecessors in graph.predecessor_map.items(): + for predecessor in predecessors: + self.run_map[predecessor].append(vertex_id) + self.run_predecessors = {k: set(v) for k, v in self.run_map.items()} + + def update_vertex_run_state(self, vertex_id: str, is_runnable: bool): + """Updates the runnable state of a vertex.""" + if is_runnable: + self.vertices_to_run.add(vertex_id) + else: + self.vertices_to_run.discard(vertex_id) + + @staticmethod + async def get_next_runnable_vertices( + graph: "Graph", + vertex: "Vertex", + vertex_id: str, + chat_service: "ChatService", + flow_id: str, + ): + """ + Retrieves the next runnable vertices in the graph for a given vertex. + + Args: + graph (Graph): The graph object representing the flow. + vertex (Vertex): The current vertex. + vertex_id (str): The ID of the current vertex. + chat_service (ChatService): The chat service object. + flow_id (str): The ID of the flow. + + Returns: + list: A list of IDs of the next runnable vertices. + + """ + async with chat_service._cache_locks[flow_id] as lock: + graph.remove_from_predecessors(vertex_id) + direct_successors_ready = [v for v in vertex.successors_ids if graph.is_vertex_runnable(v)] + if not direct_successors_ready: + # No direct successors ready, look for runnable predecessors of successors + next_runnable_vertices = graph.find_runnable_predecessors_for_successors(vertex_id) + else: + next_runnable_vertices = direct_successors_ready + + for v_id in set(next_runnable_vertices): # Use set to avoid duplicates + graph.vertices_to_run.remove(v_id) + graph.remove_from_predecessors(v_id) + await chat_service.set_cache(flow_id=flow_id, data=graph, lock=lock) + return next_runnable_vertices + + @staticmethod + def get_top_level_vertices(graph, vertices_ids): + """ + Retrieves the top-level vertices from the given graph based on the provided vertex IDs. + + Args: + graph (Graph): The graph object containing the vertices. + vertices_ids (list): A list of vertex IDs. + + Returns: + list: A list of top-level vertex IDs. + + """ + top_level_vertices = [] + for vertex_id in vertices_ids: + vertex = graph.get_vertex(vertex_id) + if vertex.parent_is_top_level: + top_level_vertices.append(vertex.parent_node_id) + else: + top_level_vertices.append(vertex_id) + return top_level_vertices