Refactor process method to align it with endpoint logic (#1700)

* Refactor Graph class to improve parallel processing in base.py

* Fix type hint for run_id parameter in set_run_id method
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-04-15 08:59:27 -03:00 committed by GitHub
commit 0d75e2905f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,5 +1,7 @@
import asyncio
import uuid
from collections import defaultdict, deque
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Type, Union
@ -16,6 +18,7 @@ from langflow.graph.vertex.types import ChatVertex, FileToolVertex, LLMVertex, R
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.schema import Record
from langflow.schema.schema import INPUT_FIELD_NAME, InputType
from langflow.services.deps import get_chat_service
if TYPE_CHECKING:
from langflow.graph.schema import ResultData
@ -164,13 +167,14 @@ class Graph:
raise ValueError("Run ID not set")
return self._run_id
def set_run_id(self, run_id: str):
def set_run_id(self, run_id: str | uuid.UUID):
"""
Sets the ID of the current run.
Args:
run_id (str): The run ID.
"""
run_id = str(run_id)
for vertex in self.vertices:
self.state_manager.subscribe(run_id, vertex.update_graph_state)
self._run_id = run_id
@ -748,31 +752,53 @@ class Graph:
async def process(self, start_component_id: Optional[str] = None) -> "Graph":
"""Processes the graph with vertices in each layer run in parallel."""
self.sort_vertices(start_component_id=start_component_id)
vertices_layers = self.sorted_vertices_layers
first_layer = self.sort_vertices(start_component_id=start_component_id)
vertex_task_run_count: Dict[str, int] = {}
for layer_index, layer in enumerate(vertices_layers):
to_process = deque(first_layer)
layer_index = 0
chat_service = get_chat_service()
run_id = uuid.uuid4()
self.set_run_id(run_id)
while to_process:
current_batch = list(to_process) # Copy current deque items to a list
to_process.clear() # Clear the deque for new items
tasks = []
for vertex_id in layer:
for vertex_id in current_batch:
vertex = self.get_vertex(vertex_id)
lock = chat_service._cache_locks[self.run_id]
set_cache_coro = partial(chat_service.set_cache, flow_id=self.run_id)
task = asyncio.create_task(
vertex.build(),
self.build_vertex(
lock=lock,
set_cache_coro=set_cache_coro,
vertex_id=vertex_id,
user_id=None,
inputs_dict={},
),
name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}",
)
tasks.append(task)
vertex_task_run_count[vertex_id] = vertex_task_run_count.get(vertex_id, 0) + 1
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
await self._execute_tasks(tasks)
next_runnable_vertices = await self._execute_tasks(tasks)
to_process.extend(next_runnable_vertices)
logger.debug("Graph processing complete")
return self
async def _execute_tasks(self, tasks):
async def _execute_tasks(self, tasks: List[asyncio.Task]) -> List[str]:
"""Executes tasks in parallel, handling exceptions for each task."""
results = []
for i, task in enumerate(asyncio.as_completed(tasks)):
try:
result = await task
results.append(result)
if isinstance(result, tuple) and len(result) == 7:
# Get the next runnable vertices
next_runnable_vertices = result[0]
results.extend(next_runnable_vertices)
else:
raise ValueError(f"Invalid result: {result}")
except Exception as e:
# Log the exception along with the task name for easier debugging
# task_name = task.get_name()