Refactor graph processing and error handling

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-04 15:19:22 -03:00
commit 5633565156

View file

@ -111,7 +111,12 @@ class Graph:
vertex = self.get_vertex(vertex_id)
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
if not stream and hasattr(vertex, "consume_async_generator"):
if (
not vertex.result
and not stream
and hasattr(vertex, "consume_async_generator")
):
await vertex.consume_async_generator()
outputs.append(vertex.result)
return outputs
@ -472,15 +477,19 @@ class Graph:
async def process(self) -> "Graph":
"""Processes the graph with vertices in each layer run in parallel."""
vertices_layers = self.sorted_vertices_layers
vertex_task_run_count = {}
for layer_index, layer in enumerate(vertices_layers):
tasks = []
for vertex_id in layer:
vertex = self.get_vertex(vertex_id)
task = asyncio.create_task(
vertex.build(), name=f"layer-{layer_index}-vertex-{vertex_id}"
vertex.build(),
name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}",
)
tasks.append(task)
vertex_task_run_count[vertex_id] = (
vertex_task_run_count.get(vertex_id, 0) + 1
)
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
await self._execute_tasks(tasks)
logger.debug("Graph processing complete")
@ -499,6 +508,10 @@ class Graph:
# coroutine has not attribute get_name
task_name = tasks[i].get_name()
logger.error(f"Task {task_name} failed with exception: {e}")
# Cancel all remaining tasks
for t in tasks[i:]:
t.cancel()
raise e
return results
def topological_sort(self) -> List[Vertex]:
@ -815,6 +828,7 @@ class Graph:
vertices_layers = self.sort_by_avg_build_time(vertices_layers)
# vertices_layers = self.sort_chat_inputs_first(vertices_layers)
self.increment_run_count()
self._sorted_vertices_layers = vertices_layers
first_layer = vertices_layers[0]
# save the only the rest
self.vertices_layers = vertices_layers[1:]