Refactor graph processing and error handling
This commit is contained in:
parent
a31f824457
commit
815e9cfb59
1 changed files with 17 additions and 3 deletions
|
|
@ -170,7 +170,12 @@ class Graph:
|
|||
vertex = self.get_vertex(vertex_id)
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
if not stream and hasattr(vertex, "consume_async_generator"):
|
||||
|
||||
if (
|
||||
not vertex.result
|
||||
and not stream
|
||||
and hasattr(vertex, "consume_async_generator")
|
||||
):
|
||||
await vertex.consume_async_generator()
|
||||
outputs.append(vertex.result)
|
||||
return outputs
|
||||
|
|
@ -524,15 +529,19 @@ class Graph:
|
|||
async def process(self) -> "Graph":
|
||||
"""Processes the graph with vertices in each layer run in parallel."""
|
||||
vertices_layers = self.sorted_vertices_layers
|
||||
|
||||
vertex_task_run_count = {}
|
||||
for layer_index, layer in enumerate(vertices_layers):
|
||||
tasks = []
|
||||
for vertex_id in layer:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
task = asyncio.create_task(
|
||||
vertex.build(), name=f"layer-{layer_index}-vertex-{vertex_id}"
|
||||
vertex.build(),
|
||||
name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}",
|
||||
)
|
||||
tasks.append(task)
|
||||
vertex_task_run_count[vertex_id] = (
|
||||
vertex_task_run_count.get(vertex_id, 0) + 1
|
||||
)
|
||||
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
|
||||
await self._execute_tasks(tasks)
|
||||
logger.debug("Graph processing complete")
|
||||
|
|
@ -551,6 +560,10 @@ class Graph:
|
|||
# coroutine has not attribute get_name
|
||||
task_name = tasks[i].get_name()
|
||||
logger.error(f"Task {task_name} failed with exception: {e}")
|
||||
# Cancel all remaining tasks
|
||||
for t in tasks[i:]:
|
||||
t.cancel()
|
||||
raise e
|
||||
return results
|
||||
|
||||
def topological_sort(self) -> List[Vertex]:
|
||||
|
|
@ -931,6 +944,7 @@ class Graph:
|
|||
vertices_layers = self.sort_by_avg_build_time(vertices_layers)
|
||||
# vertices_layers = self.sort_chat_inputs_first(vertices_layers)
|
||||
self.increment_run_count()
|
||||
self._sorted_vertices_layers = vertices_layers
|
||||
first_layer = vertices_layers[0]
|
||||
# save the only the rest
|
||||
self.vertices_layers = vertices_layers[1:]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue