Add new process method and vertices lists
This commit is contained in:
parent
53af441ec9
commit
e02b477fc9
1 changed files with 68 additions and 18 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
from collections import defaultdict, deque
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union
|
||||
|
||||
|
|
@ -40,8 +41,10 @@ class Graph:
|
|||
self._runs = 0
|
||||
self._updates = 0
|
||||
self.flow_id = flow_id
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self._is_input_vertices = []
|
||||
self._is_output_vertices = []
|
||||
self._has_session_id_vertices = []
|
||||
self._sorted_vertices_layers = []
|
||||
|
||||
self.top_level_vertices = []
|
||||
for vertex in self._vertices:
|
||||
|
|
@ -54,38 +57,37 @@ class Graph:
|
|||
self.inactive_vertices = set()
|
||||
self._build_graph()
|
||||
self.build_graph_maps()
|
||||
self.define_inputs_and_outputs()
|
||||
self.define_vertices_lists()
|
||||
|
||||
def define_inputs_and_outputs(self):
|
||||
@property
|
||||
def sorted_vertices_layers(self):
|
||||
if not self._sorted_vertices_layers:
|
||||
self.sort_vertices()
|
||||
return self._sorted_vertices_layers
|
||||
|
||||
def define_vertices_lists(self):
|
||||
"""
|
||||
Defines the input and output vertices of the graph.
|
||||
Defines the lists of vertices that are inputs, outputs, and have session_id.
|
||||
"""
|
||||
attributes = ["is_input", "is_output", "has_session_id"]
|
||||
for vertex in self.vertices:
|
||||
if vertex.is_input:
|
||||
self._inputs.append(vertex.id)
|
||||
if vertex.is_output:
|
||||
self._outputs.append(vertex.id)
|
||||
for attribute in attributes:
|
||||
if getattr(vertex, attribute):
|
||||
getattr(self, f"_{attribute}_vertices").append(vertex.id)
|
||||
|
||||
def run(self, inputs: Dict[str, str]) -> List["ResultData"]:
|
||||
async def _run(self, inputs: Dict[str, str]) -> List["ResultData"]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
|
||||
# inputs is {"message": "Hello, world!"}
|
||||
# we need to go through self.inputs and update the self._raw_params
|
||||
# of the vertices that are inputs
|
||||
|
||||
for vertex_id in self.inputs:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
vertex.update_raw_params(inputs)
|
||||
try:
|
||||
self.build()
|
||||
await self.process()
|
||||
self.increment_run_count()
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise ValueError(f"Error running graph: {exc}") from exc
|
||||
|
||||
# Now we get the outputs from the self.outputs
|
||||
outputs = []
|
||||
for vertex_id in self.outputs:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
|
|
@ -94,6 +96,23 @@ class Graph:
|
|||
outputs.append(vertex.result)
|
||||
return outputs
|
||||
|
||||
async def run(self, inputs: Dict[str, Union[str, list[str]]]) -> List["ResultData"]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
|
||||
# inputs is {"message": "Hello, world!"}
|
||||
# we need to go through self.inputs and update the self._raw_params
|
||||
# of the vertices that are inputs
|
||||
# if the value is a list, we need to run multiple times
|
||||
outputs = []
|
||||
inputs_values = inputs.get("input_value")
|
||||
if not isinstance(inputs_values, list):
|
||||
inputs_values = [inputs_values]
|
||||
for input_value in inputs_values:
|
||||
run_outputs = await self._run({"input_value": input_value})
|
||||
logger.debug(f"Run outputs: {run_outputs}")
|
||||
outputs.extend(run_outputs)
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return {
|
||||
|
|
@ -404,6 +423,36 @@ class Graph:
|
|||
raise ValueError("No root vertex found")
|
||||
return await root_vertex.build()
|
||||
|
||||
async def process(self) -> "Graph":
|
||||
"""Processes the graph with vertices in each layer run in parallel."""
|
||||
vertices_layers = self.sorted_vertices_layers
|
||||
|
||||
for layer_index, layer in enumerate(vertices_layers):
|
||||
tasks = []
|
||||
for vertex_id in layer:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
task = asyncio.create_task(
|
||||
vertex.build(), name=f"layer-{layer_index}-vertex-{vertex_id}"
|
||||
)
|
||||
tasks.append(task)
|
||||
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
|
||||
await self._execute_tasks(tasks)
|
||||
logger.debug("Graph processing complete")
|
||||
return self
|
||||
|
||||
async def _execute_tasks(self, tasks):
|
||||
"""Executes tasks in parallel, handling exceptions for each task."""
|
||||
results = []
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
result = await task
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
# Log the exception along with the task name for easier debugging
|
||||
task_name = task.get_name()
|
||||
logger.error(f"Task {task_name} failed with exception: {e}")
|
||||
return results
|
||||
|
||||
def topological_sort(self) -> List[Vertex]:
|
||||
"""
|
||||
Performs a topological sort of the vertices in the graph.
|
||||
|
|
@ -671,6 +720,7 @@ class Graph:
|
|||
vertices_layers = self.sort_by_avg_build_time(vertices_layers)
|
||||
vertices_layers = self.sort_chat_inputs_first(vertices_layers)
|
||||
self.increment_run_count()
|
||||
self._sorted_vertices_layers = vertices_layers
|
||||
return vertices_layers
|
||||
|
||||
def sort_interface_components_first(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue