From 7ce6a9ee032278fb8d0ca96de6e00fbe5e00ecf0 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 14 Aug 2024 15:06:25 -0300 Subject: [PATCH] feat: Add maximum iterations limit in Graph start method. (#3336) * feat: Add maximum iterations limit in Graph start method. * feat: Add OutputConfigDict and StartConfigDict to schema.py. * feat: Add ability to apply configuration before starting graph. * feat: Add max_iterations parameter to async_start method and update schema imports. --- src/backend/base/langflow/graph/graph/base.py | 33 ++++++++++++++++--- .../base/langflow/graph/graph/schema.py | 8 +++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index e817138c7..7121d051d 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -17,10 +17,10 @@ from langflow.graph.edge.base import CycleEdge from langflow.graph.edge.schema import EdgeData from langflow.graph.graph.constants import Finish, lazy_load_vertex_dict from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager -from langflow.graph.graph.schema import GraphData, GraphDump, VertexBuildResult +from langflow.graph.graph.schema import GraphData, GraphDump, StartConfigDict, VertexBuildResult from langflow.graph.graph.state_manager import GraphStateManager from langflow.graph.graph.state_model import create_state_model_from_graph -from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex +from langflow.graph.graph.utils import find_start_component_id, process_flow, should_continue, sort_up_to_vertex from langflow.graph.schema import InterfaceComponentTypes, RunOutputs from langflow.graph.vertex.base import Vertex, VertexStates from langflow.graph.vertex.schema import NodeData @@ -258,17 +258,40 @@ class Graph: for key, value in _input.items(): vertex = self.get_vertex(key) vertex.set_input_value(key, value) - while True: + # I want to keep a counter of how many tyimes result.vertex.id + # has been yielded + yielded_counts: dict[str, int] = defaultdict(int) + + while should_continue(yielded_counts, max_iterations): result = await self.astep() yield result + if hasattr(result, "vertex"): + yielded_counts[result.vertex.id] += 1 if isinstance(result, Finish): return - def start(self, inputs: Optional[List[dict]] = None) -> Generator: + raise ValueError("Max iterations reached") + + def __apply_config(self, config: StartConfigDict): + for vertex in self.vertices: + if vertex._custom_component is None: + continue + for output in vertex._custom_component.outputs: + for key, value in config["output"].items(): + setattr(output, key, value) + + def start( + self, + inputs: Optional[List[dict]] = None, + max_iterations: Optional[int] = None, + config: Optional[StartConfigDict] = None, + ) -> Generator: + if config is not None: + self.__apply_config(config) #! Change this ASAP nest_asyncio.apply() loop = asyncio.get_event_loop() - async_gen = self.async_start(inputs) + async_gen = self.async_start(inputs, max_iterations) async_gen_task = asyncio.ensure_future(async_gen.__anext__()) while True: diff --git a/src/backend/base/langflow/graph/graph/schema.py b/src/backend/base/langflow/graph/graph/schema.py index 306ea7ba6..4a1dcc3a3 100644 --- a/src/backend/base/langflow/graph/graph/schema.py +++ b/src/backend/base/langflow/graph/graph/schema.py @@ -36,3 +36,11 @@ class VertexBuildResult(NamedTuple): valid: bool artifacts: dict vertex: "Vertex" + + +class OutputConfigDict(TypedDict): + cache: bool + + +class StartConfigDict(TypedDict): + output: OutputConfigDict