feat: Add maximum iterations limit in Graph start method. (#3336)

* feat: Add maximum iterations limit in Graph start method.

* feat: Add OutputConfigDict and StartConfigDict to schema.py.

* feat: Add ability to apply configuration before starting graph.

* feat: Add max_iterations parameter to async_start method and update schema imports.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-08-14 15:06:25 -03:00 committed by GitHub
commit 7ce6a9ee03
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 36 additions and 5 deletions

View file

@ -17,10 +17,10 @@ from langflow.graph.edge.base import CycleEdge
from langflow.graph.edge.schema import EdgeData
from langflow.graph.graph.constants import Finish, lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.schema import GraphData, GraphDump, VertexBuildResult
from langflow.graph.graph.schema import GraphData, GraphDump, StartConfigDict, VertexBuildResult
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.state_model import create_state_model_from_graph
from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex
from langflow.graph.graph.utils import find_start_component_id, process_flow, should_continue, sort_up_to_vertex
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex, VertexStates
from langflow.graph.vertex.schema import NodeData
@ -258,17 +258,40 @@ class Graph:
for key, value in _input.items():
vertex = self.get_vertex(key)
vertex.set_input_value(key, value)
while True:
# I want to keep a counter of how many tyimes result.vertex.id
# has been yielded
yielded_counts: dict[str, int] = defaultdict(int)
while should_continue(yielded_counts, max_iterations):
result = await self.astep()
yield result
if hasattr(result, "vertex"):
yielded_counts[result.vertex.id] += 1
if isinstance(result, Finish):
return
def start(self, inputs: Optional[List[dict]] = None) -> Generator:
raise ValueError("Max iterations reached")
def __apply_config(self, config: StartConfigDict):
for vertex in self.vertices:
if vertex._custom_component is None:
continue
for output in vertex._custom_component.outputs:
for key, value in config["output"].items():
setattr(output, key, value)
def start(
self,
inputs: Optional[List[dict]] = None,
max_iterations: Optional[int] = None,
config: Optional[StartConfigDict] = None,
) -> Generator:
if config is not None:
self.__apply_config(config)
#! Change this ASAP
nest_asyncio.apply()
loop = asyncio.get_event_loop()
async_gen = self.async_start(inputs)
async_gen = self.async_start(inputs, max_iterations)
async_gen_task = asyncio.ensure_future(async_gen.__anext__())
while True:

View file

@ -36,3 +36,11 @@ class VertexBuildResult(NamedTuple):
valid: bool
artifacts: dict
vertex: "Vertex"
class OutputConfigDict(TypedDict):
cache: bool
class StartConfigDict(TypedDict):
output: OutputConfigDict