Refactor process method to align it with endpoint logic (#1700)
* Refactor Graph class to improve parallel processing in base.py * Fix type hint for run_id parameter in set_run_id method
This commit is contained in:
parent
05104117ba
commit
0d75e2905f
1 changed files with 35 additions and 9 deletions
|
|
@ -1,5 +1,7 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Type, Union
|
||||
|
||||
|
|
@ -16,6 +18,7 @@ from langflow.graph.vertex.types import ChatVertex, FileToolVertex, LLMVertex, R
|
|||
from langflow.interface.tools.constants import FILE_TOOLS
|
||||
from langflow.schema import Record
|
||||
from langflow.schema.schema import INPUT_FIELD_NAME, InputType
|
||||
from langflow.services.deps import get_chat_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.schema import ResultData
|
||||
|
|
@ -164,13 +167,14 @@ class Graph:
|
|||
raise ValueError("Run ID not set")
|
||||
return self._run_id
|
||||
|
||||
def set_run_id(self, run_id: str):
|
||||
def set_run_id(self, run_id: str | uuid.UUID):
|
||||
"""
|
||||
Sets the ID of the current run.
|
||||
|
||||
Args:
|
||||
run_id (str): The run ID.
|
||||
"""
|
||||
run_id = str(run_id)
|
||||
for vertex in self.vertices:
|
||||
self.state_manager.subscribe(run_id, vertex.update_graph_state)
|
||||
self._run_id = run_id
|
||||
|
|
@ -748,31 +752,53 @@ class Graph:
|
|||
async def process(self, start_component_id: Optional[str] = None) -> "Graph":
|
||||
"""Processes the graph with vertices in each layer run in parallel."""
|
||||
|
||||
self.sort_vertices(start_component_id=start_component_id)
|
||||
vertices_layers = self.sorted_vertices_layers
|
||||
first_layer = self.sort_vertices(start_component_id=start_component_id)
|
||||
vertex_task_run_count: Dict[str, int] = {}
|
||||
for layer_index, layer in enumerate(vertices_layers):
|
||||
to_process = deque(first_layer)
|
||||
layer_index = 0
|
||||
chat_service = get_chat_service()
|
||||
run_id = uuid.uuid4()
|
||||
self.set_run_id(run_id)
|
||||
while to_process:
|
||||
current_batch = list(to_process) # Copy current deque items to a list
|
||||
to_process.clear() # Clear the deque for new items
|
||||
tasks = []
|
||||
for vertex_id in layer:
|
||||
for vertex_id in current_batch:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
lock = chat_service._cache_locks[self.run_id]
|
||||
set_cache_coro = partial(chat_service.set_cache, flow_id=self.run_id)
|
||||
task = asyncio.create_task(
|
||||
vertex.build(),
|
||||
self.build_vertex(
|
||||
lock=lock,
|
||||
set_cache_coro=set_cache_coro,
|
||||
vertex_id=vertex_id,
|
||||
user_id=None,
|
||||
inputs_dict={},
|
||||
),
|
||||
name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}",
|
||||
)
|
||||
tasks.append(task)
|
||||
vertex_task_run_count[vertex_id] = vertex_task_run_count.get(vertex_id, 0) + 1
|
||||
|
||||
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
|
||||
await self._execute_tasks(tasks)
|
||||
next_runnable_vertices = await self._execute_tasks(tasks)
|
||||
to_process.extend(next_runnable_vertices)
|
||||
|
||||
logger.debug("Graph processing complete")
|
||||
return self
|
||||
|
||||
async def _execute_tasks(self, tasks):
|
||||
async def _execute_tasks(self, tasks: List[asyncio.Task]) -> List[str]:
|
||||
"""Executes tasks in parallel, handling exceptions for each task."""
|
||||
results = []
|
||||
for i, task in enumerate(asyncio.as_completed(tasks)):
|
||||
try:
|
||||
result = await task
|
||||
results.append(result)
|
||||
if isinstance(result, tuple) and len(result) == 7:
|
||||
# Get the next runnable vertices
|
||||
next_runnable_vertices = result[0]
|
||||
results.extend(next_runnable_vertices)
|
||||
else:
|
||||
raise ValueError(f"Invalid result: {result}")
|
||||
except Exception as e:
|
||||
# Log the exception along with the task name for easier debugging
|
||||
# task_name = task.get_name()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue