Add new process method and vertices lists

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-27 13:44:30 -03:00
commit e02b477fc9

View file

@ -1,3 +1,4 @@
import asyncio
from collections import defaultdict, deque
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union
@ -40,8 +41,10 @@ class Graph:
self._runs = 0
self._updates = 0
self.flow_id = flow_id
self._inputs = []
self._outputs = []
self._is_input_vertices = []
self._is_output_vertices = []
self._has_session_id_vertices = []
self._sorted_vertices_layers = []
self.top_level_vertices = []
for vertex in self._vertices:
@ -54,38 +57,37 @@ class Graph:
self.inactive_vertices = set()
self._build_graph()
self.build_graph_maps()
self.define_inputs_and_outputs()
self.define_vertices_lists()
def define_inputs_and_outputs(self):
@property
def sorted_vertices_layers(self):
if not self._sorted_vertices_layers:
self.sort_vertices()
return self._sorted_vertices_layers
def define_vertices_lists(self):
"""
Defines the input and output vertices of the graph.
Defines the lists of vertices that are inputs, outputs, and have session_id.
"""
attributes = ["is_input", "is_output", "has_session_id"]
for vertex in self.vertices:
if vertex.is_input:
self._inputs.append(vertex.id)
if vertex.is_output:
self._outputs.append(vertex.id)
for attribute in attributes:
if getattr(vertex, attribute):
getattr(self, f"_{attribute}_vertices").append(vertex.id)
def run(self, inputs: Dict[str, str]) -> List["ResultData"]:
async def _run(self, inputs: Dict[str, str]) -> List["ResultData"]:
"""Runs the graph with the given inputs."""
# inputs is {"message": "Hello, world!"}
# we need to go through self.inputs and update the self._raw_params
# of the vertices that are inputs
for vertex_id in self.inputs:
vertex = self.get_vertex(vertex_id)
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
vertex.update_raw_params(inputs)
try:
self.build()
await self.process()
self.increment_run_count()
except Exception as exc:
logger.exception(exc)
raise ValueError(f"Error running graph: {exc}") from exc
# Now we get the outputs from the self.outputs
outputs = []
for vertex_id in self.outputs:
vertex = self.get_vertex(vertex_id)
@ -94,6 +96,23 @@ class Graph:
outputs.append(vertex.result)
return outputs
async def run(self, inputs: Dict[str, Union[str, list[str]]]) -> List["ResultData"]:
"""Runs the graph with the given inputs."""
# inputs is {"message": "Hello, world!"}
# we need to go through self.inputs and update the self._raw_params
# of the vertices that are inputs
# if the value is a list, we need to run multiple times
outputs = []
inputs_values = inputs.get("input_value")
if not isinstance(inputs_values, list):
inputs_values = [inputs_values]
for input_value in inputs_values:
run_outputs = await self._run({"input_value": input_value})
logger.debug(f"Run outputs: {run_outputs}")
outputs.extend(run_outputs)
return outputs
@property
def metadata(self):
return {
@ -404,6 +423,36 @@ class Graph:
raise ValueError("No root vertex found")
return await root_vertex.build()
async def process(self) -> "Graph":
"""Processes the graph with vertices in each layer run in parallel."""
vertices_layers = self.sorted_vertices_layers
for layer_index, layer in enumerate(vertices_layers):
tasks = []
for vertex_id in layer:
vertex = self.get_vertex(vertex_id)
task = asyncio.create_task(
vertex.build(), name=f"layer-{layer_index}-vertex-{vertex_id}"
)
tasks.append(task)
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
await self._execute_tasks(tasks)
logger.debug("Graph processing complete")
return self
async def _execute_tasks(self, tasks):
"""Executes tasks in parallel, handling exceptions for each task."""
results = []
for task in asyncio.as_completed(tasks):
try:
result = await task
results.append(result)
except Exception as e:
# Log the exception along with the task name for easier debugging
task_name = task.get_name()
logger.error(f"Task {task_name} failed with exception: {e}")
return results
def topological_sort(self) -> List[Vertex]:
"""
Performs a topological sort of the vertices in the graph.
@ -671,6 +720,7 @@ class Graph:
vertices_layers = self.sort_by_avg_build_time(vertices_layers)
vertices_layers = self.sort_chat_inputs_first(vertices_layers)
self.increment_run_count()
self._sorted_vertices_layers = vertices_layers
return vertices_layers
def sort_interface_components_first(