fix: graph could re-run input vertices twice and override the input value (#2219)
This commit is contained in:
parent
6e49a2ec3b
commit
47e63d1d02
4 changed files with 41 additions and 55 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import time
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
|
||||
|
|
@ -165,8 +166,6 @@ async def build_vertex(
|
|||
try:
|
||||
lock = chat_service._cache_locks[flow_id_str]
|
||||
(
|
||||
next_runnable_vertices,
|
||||
top_level_vertices,
|
||||
result_dict,
|
||||
params,
|
||||
valid,
|
||||
|
|
@ -180,6 +179,11 @@ async def build_vertex(
|
|||
inputs_dict=inputs.model_dump() if inputs else {},
|
||||
files=files,
|
||||
)
|
||||
set_cache_coro = partial(get_chat_service().set_cache, key=flow_id_str)
|
||||
next_runnable_vertices = await graph.run_manager.get_next_runnable_vertices(
|
||||
lock, set_cache_coro, graph=graph, vertex=vertex, cache=False
|
||||
)
|
||||
top_level_vertices = graph.run_manager.get_top_level_vertices(graph, next_runnable_vertices)
|
||||
log_obj = Log(message=vertex.artifacts_raw, type=vertex.artifacts_type)
|
||||
result_data_response = ResultDataResponse(**result_dict.model_dump())
|
||||
|
||||
|
|
@ -214,7 +218,6 @@ async def build_vertex(
|
|||
result_data_response.duration = duration
|
||||
result_data_response.timedelta = timedelta
|
||||
vertex.add_build_time(timedelta)
|
||||
inactivated_vertices = None
|
||||
inactivated_vertices = list(graph.inactivated_vertices)
|
||||
graph.reset_inactivated_vertices()
|
||||
graph.reset_activated_vertices()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import uuid
|
|||
from collections import defaultdict, deque
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Tuple, Type, Union
|
||||
|
||||
from langflow.graph.edge.base import ContractEdge
|
||||
from langflow.graph.graph.constants import lazy_load_vertex_dict
|
||||
|
|
@ -258,6 +258,7 @@ class Graph:
|
|||
Returns:
|
||||
List[Optional["ResultData"]]: The outputs of the graph.
|
||||
"""
|
||||
|
||||
if input_components and not isinstance(input_components, list):
|
||||
raise ValueError(f"Invalid components value: {input_components}. Expected list")
|
||||
elif input_components is None:
|
||||
|
|
@ -705,14 +706,12 @@ class Graph:
|
|||
|
||||
async def build_vertex(
|
||||
self,
|
||||
lock: asyncio.Lock,
|
||||
chat_service: ChatService,
|
||||
vertex_id: str,
|
||||
inputs_dict: Optional[Dict[str, str]] = None,
|
||||
files: Optional[list[str]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
fallback_to_env_vars: bool = False,
|
||||
cache: bool = True,
|
||||
):
|
||||
"""
|
||||
Builds a vertex in the graph.
|
||||
|
|
@ -766,43 +765,15 @@ class Graph:
|
|||
artifacts = vertex.artifacts
|
||||
else:
|
||||
raise ValueError(f"No result found for vertex {vertex_id}")
|
||||
set_cache_coro = partial(chat_service.set_cache, key=self.flow_id)
|
||||
next_runnable_vertices, top_level_vertices = await self.get_next_and_top_level_vertices(
|
||||
lock, set_cache_coro, vertex, cache=cache
|
||||
)
|
||||
flow_id = self.flow_id
|
||||
log_transaction(flow_id, vertex, status="success")
|
||||
return next_runnable_vertices, top_level_vertices, result_dict, params, valid, artifacts, vertex
|
||||
return result_dict, params, valid, artifacts, vertex
|
||||
except Exception as exc:
|
||||
logger.exception(f"Error building Component: {exc}")
|
||||
flow_id = self.flow_id
|
||||
log_transaction(flow_id, vertex, status="failure", error=str(exc))
|
||||
raise exc
|
||||
|
||||
async def get_next_and_top_level_vertices(
|
||||
self,
|
||||
lock: asyncio.Lock,
|
||||
set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine],
|
||||
vertex: Vertex,
|
||||
cache: bool = True,
|
||||
):
|
||||
"""
|
||||
Retrieves the next runnable vertices and the top level vertices for a given vertex.
|
||||
|
||||
Args:
|
||||
lock (asyncio.Lock): The lock used to synchronize access to the graph.
|
||||
set_cache_coro (Coroutine): The coroutine used to set the cache for the graph.
|
||||
vertex (Vertex): The vertex for which to retrieve the next runnable and top level vertices.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Vertex], List[Vertex]]: A tuple containing the next runnable vertices and the top level vertices.
|
||||
"""
|
||||
next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(
|
||||
lock, set_cache_coro, self, vertex, cache=cache
|
||||
)
|
||||
top_level_vertices = self.run_manager.get_top_level_vertices(self, next_runnable_vertices)
|
||||
return next_runnable_vertices, top_level_vertices
|
||||
|
||||
def get_vertex_edges(
|
||||
self,
|
||||
vertex_id: str,
|
||||
|
|
@ -849,13 +820,11 @@ class Graph:
|
|||
vertex = self.get_vertex(vertex_id)
|
||||
task = asyncio.create_task(
|
||||
self.build_vertex(
|
||||
lock=lock,
|
||||
chat_service=chat_service,
|
||||
vertex_id=vertex_id,
|
||||
user_id=self.user_id,
|
||||
inputs_dict={},
|
||||
fallback_to_env_vars=fallback_to_env_vars,
|
||||
cache=False,
|
||||
),
|
||||
name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}",
|
||||
)
|
||||
|
|
@ -864,7 +833,7 @@ class Graph:
|
|||
|
||||
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
|
||||
try:
|
||||
next_runnable_vertices = await self._execute_tasks(tasks)
|
||||
next_runnable_vertices = await self._execute_tasks(tasks, lock=lock)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tasks in layer {layer_index}: {e}")
|
||||
break
|
||||
|
|
@ -876,10 +845,11 @@ class Graph:
|
|||
logger.debug("Graph processing complete")
|
||||
return self
|
||||
|
||||
async def _execute_tasks(self, tasks: List[asyncio.Task]) -> List[str]:
|
||||
async def _execute_tasks(self, tasks: List[asyncio.Task], lock: asyncio.Lock) -> List[str]:
|
||||
"""Executes tasks in parallel, handling exceptions for each task."""
|
||||
results = []
|
||||
completed_tasks = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
vertices: List[Vertex] = []
|
||||
|
||||
for i, result in enumerate(completed_tasks):
|
||||
task_name = tasks[i].get_name()
|
||||
|
|
@ -889,13 +859,23 @@ class Graph:
|
|||
for t in tasks[i + 1 :]:
|
||||
t.cancel()
|
||||
raise result
|
||||
elif isinstance(result, tuple) and len(result) == 7:
|
||||
# Get the next runnable vertices
|
||||
next_runnable_vertices = result[0]
|
||||
results.extend(next_runnable_vertices)
|
||||
elif isinstance(result, tuple) and len(result) == 5:
|
||||
vertices.append(result[4])
|
||||
else:
|
||||
raise ValueError(f"Invalid result from task {task_name}: {result}")
|
||||
|
||||
for v in vertices:
|
||||
# set all executed vertices as non-runnable to not run them again.
|
||||
# they could be calculated as predecessor or successors of parallel vertices
|
||||
# This could usually happen with input vertices like ChatInput
|
||||
self.run_manager.remove_vertex_from_runnables(v.id)
|
||||
|
||||
set_cache_coro = partial(get_chat_service().set_cache, key=self.flow_id)
|
||||
for v in vertices:
|
||||
next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(
|
||||
lock, set_cache_coro, graph=self, vertex=v, cache=False
|
||||
)
|
||||
results.extend(next_runnable_vertices)
|
||||
return results
|
||||
|
||||
def topological_sort(self) -> List[Vertex]:
|
||||
|
|
@ -1372,5 +1352,3 @@ class Graph:
|
|||
predecessor_map[edge.target_id].append(edge.source_id)
|
||||
successor_map[edge.source_id].append(edge.target_id)
|
||||
return predecessor_map, successor_map
|
||||
return predecessor_map, successor_map
|
||||
return predecessor_map, successor_map
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, List
|
||||
from typing import TYPE_CHECKING, Callable, List, Coroutine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.graph.base import Graph
|
||||
|
|
@ -56,20 +56,20 @@ class RunnableVerticesManager:
|
|||
async def get_next_runnable_vertices(
|
||||
self,
|
||||
lock: asyncio.Lock,
|
||||
set_cache_coro: Callable[["Graph", asyncio.Lock], Awaitable[None]],
|
||||
set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine],
|
||||
graph: "Graph",
|
||||
vertex: "Vertex",
|
||||
cache: bool = True,
|
||||
):
|
||||
) -> List[str]:
|
||||
"""
|
||||
Retrieves the next runnable vertices in the graph for a given vertex.
|
||||
|
||||
Args:
|
||||
graph (Graph): The graph object representing the flow.
|
||||
vertex (Vertex): The current vertex.
|
||||
vertex_id (str): The ID of the current vertex.
|
||||
chat_service (ChatService): The chat service object.
|
||||
flow_id (str): The ID of the flow.
|
||||
lock (asyncio.Lock): The lock object to be used for synchronization.
|
||||
set_cache_coro (Callable): The coroutine function to set the cache.
|
||||
graph (Graph): The graph object containing the vertices.
|
||||
vertex (Vertex): The vertex object for which the next runnable vertices are to be retrieved.
|
||||
cache (bool, optional): A flag to indicate if the cache should be updated. Defaults to True.
|
||||
|
||||
Returns:
|
||||
list: A list of IDs of the next runnable vertices.
|
||||
|
|
@ -85,12 +85,15 @@ class RunnableVerticesManager:
|
|||
next_runnable_vertices = direct_successors_ready
|
||||
|
||||
for v_id in set(next_runnable_vertices): # Use set to avoid duplicates
|
||||
self.update_vertex_run_state(v_id, is_runnable=False)
|
||||
self.remove_from_predecessors(v_id)
|
||||
self.remove_vertex_from_runnables(v_id)
|
||||
if cache:
|
||||
await set_cache_coro(data=graph, lock=lock) # type: ignore
|
||||
return next_runnable_vertices
|
||||
|
||||
def remove_vertex_from_runnables(self, v_id):
|
||||
self.update_vertex_run_state(v_id, is_runnable=False)
|
||||
self.remove_from_predecessors(v_id)
|
||||
|
||||
@staticmethod
|
||||
def get_top_level_vertices(graph, vertices_ids):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -341,6 +341,8 @@ def load_flows_from_directory():
|
|||
if existing:
|
||||
logger.info(f"Updating existing flow: {flow_id} with endpoint name {flow_endpoint_name}")
|
||||
for key, value in flow.items():
|
||||
if key == "last_tested_version":
|
||||
continue
|
||||
setattr(existing, key, value)
|
||||
existing.updated_at = datetime.utcnow()
|
||||
existing.user_id = user_id
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue