From 47e63d1d026edfdd2bb145db5c287753710606eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Boschi?= Date: Wed, 19 Jun 2024 14:47:39 +0200 Subject: [PATCH] fix: graph could re-run input vertices twice and override the input value (#2219) --- src/backend/base/langflow/api/v1/chat.py | 9 ++- src/backend/base/langflow/graph/graph/base.py | 62 ++++++------------- .../graph/graph/runnable_vertices_manager.py | 23 ++++--- .../base/langflow/initial_setup/setup.py | 2 + 4 files changed, 41 insertions(+), 55 deletions(-) diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index 0264dfca7..df4aeb294 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -1,5 +1,6 @@ import time import uuid +from functools import partial from typing import TYPE_CHECKING, Annotated, Optional from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException @@ -165,8 +166,6 @@ async def build_vertex( try: lock = chat_service._cache_locks[flow_id_str] ( - next_runnable_vertices, - top_level_vertices, result_dict, params, valid, @@ -180,6 +179,11 @@ async def build_vertex( inputs_dict=inputs.model_dump() if inputs else {}, files=files, ) + set_cache_coro = partial(get_chat_service().set_cache, key=flow_id_str) + next_runnable_vertices = await graph.run_manager.get_next_runnable_vertices( + lock, set_cache_coro, graph=graph, vertex=vertex, cache=False + ) + top_level_vertices = graph.run_manager.get_top_level_vertices(graph, next_runnable_vertices) log_obj = Log(message=vertex.artifacts_raw, type=vertex.artifacts_type) result_data_response = ResultDataResponse(**result_dict.model_dump()) @@ -214,7 +218,6 @@ async def build_vertex( result_data_response.duration = duration result_data_response.timedelta = timedelta vertex.add_build_time(timedelta) - inactivated_vertices = None inactivated_vertices = list(graph.inactivated_vertices) graph.reset_inactivated_vertices() graph.reset_activated_vertices() diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index 285f841aa..569b1be4b 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -3,7 +3,7 @@ import uuid from collections import defaultdict, deque from functools import partial from itertools import chain -from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Tuple, Type, Union from langflow.graph.edge.base import ContractEdge from langflow.graph.graph.constants import lazy_load_vertex_dict @@ -258,6 +258,7 @@ class Graph: Returns: List[Optional["ResultData"]]: The outputs of the graph. """ + if input_components and not isinstance(input_components, list): raise ValueError(f"Invalid components value: {input_components}. Expected list") elif input_components is None: @@ -705,14 +706,12 @@ class Graph: async def build_vertex( self, - lock: asyncio.Lock, chat_service: ChatService, vertex_id: str, inputs_dict: Optional[Dict[str, str]] = None, files: Optional[list[str]] = None, user_id: Optional[str] = None, fallback_to_env_vars: bool = False, - cache: bool = True, ): """ Builds a vertex in the graph. @@ -766,43 +765,15 @@ class Graph: artifacts = vertex.artifacts else: raise ValueError(f"No result found for vertex {vertex_id}") - set_cache_coro = partial(chat_service.set_cache, key=self.flow_id) - next_runnable_vertices, top_level_vertices = await self.get_next_and_top_level_vertices( - lock, set_cache_coro, vertex, cache=cache - ) flow_id = self.flow_id log_transaction(flow_id, vertex, status="success") - return next_runnable_vertices, top_level_vertices, result_dict, params, valid, artifacts, vertex + return result_dict, params, valid, artifacts, vertex except Exception as exc: logger.exception(f"Error building Component: {exc}") flow_id = self.flow_id log_transaction(flow_id, vertex, status="failure", error=str(exc)) raise exc - async def get_next_and_top_level_vertices( - self, - lock: asyncio.Lock, - set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine], - vertex: Vertex, - cache: bool = True, - ): - """ - Retrieves the next runnable vertices and the top level vertices for a given vertex. - - Args: - lock (asyncio.Lock): The lock used to synchronize access to the graph. - set_cache_coro (Coroutine): The coroutine used to set the cache for the graph. - vertex (Vertex): The vertex for which to retrieve the next runnable and top level vertices. - - Returns: - Tuple[List[Vertex], List[Vertex]]: A tuple containing the next runnable vertices and the top level vertices. - """ - next_runnable_vertices = await self.run_manager.get_next_runnable_vertices( - lock, set_cache_coro, self, vertex, cache=cache - ) - top_level_vertices = self.run_manager.get_top_level_vertices(self, next_runnable_vertices) - return next_runnable_vertices, top_level_vertices - def get_vertex_edges( self, vertex_id: str, @@ -849,13 +820,11 @@ class Graph: vertex = self.get_vertex(vertex_id) task = asyncio.create_task( self.build_vertex( - lock=lock, chat_service=chat_service, vertex_id=vertex_id, user_id=self.user_id, inputs_dict={}, fallback_to_env_vars=fallback_to_env_vars, - cache=False, ), name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}", ) @@ -864,7 +833,7 @@ class Graph: logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks") try: - next_runnable_vertices = await self._execute_tasks(tasks) + next_runnable_vertices = await self._execute_tasks(tasks, lock=lock) except Exception as e: logger.error(f"Error executing tasks in layer {layer_index}: {e}") break @@ -876,10 +845,11 @@ class Graph: logger.debug("Graph processing complete") return self - async def _execute_tasks(self, tasks: List[asyncio.Task]) -> List[str]: + async def _execute_tasks(self, tasks: List[asyncio.Task], lock: asyncio.Lock) -> List[str]: """Executes tasks in parallel, handling exceptions for each task.""" results = [] completed_tasks = await asyncio.gather(*tasks, return_exceptions=True) + vertices: List[Vertex] = [] for i, result in enumerate(completed_tasks): task_name = tasks[i].get_name() @@ -889,13 +859,23 @@ class Graph: for t in tasks[i + 1 :]: t.cancel() raise result - elif isinstance(result, tuple) and len(result) == 7: - # Get the next runnable vertices - next_runnable_vertices = result[0] - results.extend(next_runnable_vertices) + elif isinstance(result, tuple) and len(result) == 5: + vertices.append(result[4]) else: raise ValueError(f"Invalid result from task {task_name}: {result}") + for v in vertices: + # set all executed vertices as non-runnable to not run them again. + # they could be calculated as predecessor or successors of parallel vertices + # This could usually happen with input vertices like ChatInput + self.run_manager.remove_vertex_from_runnables(v.id) + + set_cache_coro = partial(get_chat_service().set_cache, key=self.flow_id) + for v in vertices: + next_runnable_vertices = await self.run_manager.get_next_runnable_vertices( + lock, set_cache_coro, graph=self, vertex=v, cache=False + ) + results.extend(next_runnable_vertices) return results def topological_sort(self) -> List[Vertex]: @@ -1372,5 +1352,3 @@ class Graph: predecessor_map[edge.target_id].append(edge.source_id) successor_map[edge.source_id].append(edge.target_id) return predecessor_map, successor_map - return predecessor_map, successor_map - return predecessor_map, successor_map diff --git a/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py b/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py index 9dd7c346c..7081cc2ff 100644 --- a/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py +++ b/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py @@ -1,6 +1,6 @@ import asyncio from collections import defaultdict -from typing import TYPE_CHECKING, Awaitable, Callable, List +from typing import TYPE_CHECKING, Callable, List, Coroutine if TYPE_CHECKING: from langflow.graph.graph.base import Graph @@ -56,20 +56,20 @@ class RunnableVerticesManager: async def get_next_runnable_vertices( self, lock: asyncio.Lock, - set_cache_coro: Callable[["Graph", asyncio.Lock], Awaitable[None]], + set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine], graph: "Graph", vertex: "Vertex", cache: bool = True, - ): + ) -> List[str]: """ Retrieves the next runnable vertices in the graph for a given vertex. Args: - graph (Graph): The graph object representing the flow. - vertex (Vertex): The current vertex. - vertex_id (str): The ID of the current vertex. - chat_service (ChatService): The chat service object. - flow_id (str): The ID of the flow. + lock (asyncio.Lock): The lock object to be used for synchronization. + set_cache_coro (Callable): The coroutine function to set the cache. + graph (Graph): The graph object containing the vertices. + vertex (Vertex): The vertex object for which the next runnable vertices are to be retrieved. + cache (bool, optional): A flag to indicate if the cache should be updated. Defaults to True. Returns: list: A list of IDs of the next runnable vertices. @@ -85,12 +85,15 @@ class RunnableVerticesManager: next_runnable_vertices = direct_successors_ready for v_id in set(next_runnable_vertices): # Use set to avoid duplicates - self.update_vertex_run_state(v_id, is_runnable=False) - self.remove_from_predecessors(v_id) + self.remove_vertex_from_runnables(v_id) if cache: await set_cache_coro(data=graph, lock=lock) # type: ignore return next_runnable_vertices + def remove_vertex_from_runnables(self, v_id): + self.update_vertex_run_state(v_id, is_runnable=False) + self.remove_from_predecessors(v_id) + @staticmethod def get_top_level_vertices(graph, vertices_ids): """ diff --git a/src/backend/base/langflow/initial_setup/setup.py b/src/backend/base/langflow/initial_setup/setup.py index 90afead0a..dc66b0371 100644 --- a/src/backend/base/langflow/initial_setup/setup.py +++ b/src/backend/base/langflow/initial_setup/setup.py @@ -341,6 +341,8 @@ def load_flows_from_directory(): if existing: logger.info(f"Updating existing flow: {flow_id} with endpoint name {flow_endpoint_name}") for key, value in flow.items(): + if key == "last_tested_version": + continue setattr(existing, key, value) existing.updated_at = datetime.utcnow() existing.user_id = user_id