From 515fe5638fcc817e5f25aa3178c7751b7d4e2287 Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Tue, 18 Jun 2024 00:13:43 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9D=20(base.py):=20refactor=20Graph=20?= =?UTF-8?q?class=20serialization=20to=20include=20necessary=20attributes?= =?UTF-8?q?=20and=20handle=20run=5Fmanager=20properly?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 📝 (utils.py): rename build_and_cache_graph_from_db to build_graph_from_db for clarity and consistency 📝 (chat.py): update function calls to use the new build_graph_from_db function name 📝 (runnable_vertices_manager.py): add serialization methods to RunnableVerticesManager class 📝 (vertex.py): improve Vertex class serialization for better handling of internal attributes --- src/backend/base/langflow/api/utils.py | 2 +- src/backend/base/langflow/api/v1/chat.py | 7 +-- src/backend/base/langflow/graph/graph/base.py | 35 +++++++++++-- .../graph/graph/runnable_vertices_manager.py | 27 ++++++++++ .../base/langflow/graph/vertex/base.py | 52 ++++--------------- 5 files changed, 73 insertions(+), 50 deletions(-) diff --git a/src/backend/base/langflow/api/utils.py b/src/backend/base/langflow/api/utils.py index 4cc4aa7f2..099df66c5 100644 --- a/src/backend/base/langflow/api/utils.py +++ b/src/backend/base/langflow/api/utils.py @@ -205,7 +205,7 @@ def format_elapsed_time(elapsed_time: float) -> str: return f"{minutes} {minutes_unit}, {seconds} {seconds_unit}" -async def build_and_cache_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"): +async def build_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"): """Build and cache the graph.""" flow: Optional[Flow] = session.get(Flow, flow_id) if not flow or not flow.data: diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index 1affd42b0..2e53818ef 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -8,7 +8,7 @@ from loguru import logger from langflow.api.utils import ( build_and_cache_graph_from_data, - build_and_cache_graph_from_db, + build_graph_from_db, format_elapsed_time, format_exception_message, get_top_level_vertices, @@ -81,7 +81,7 @@ async def retrieve_vertices_order( flow_id_str = str(flow_id) # First, we need to check if the flow_id is in the cache if not data: - graph = await build_and_cache_graph_from_db(flow_id=flow_id_str, session=session, chat_service=chat_service) + graph = await build_graph_from_db(flow_id=flow_id_str, session=session, chat_service=chat_service) else: graph = await build_and_cache_graph_from_data( flow_id=flow_id_str, graph_data=data.model_dump(), chat_service=chat_service @@ -108,6 +108,7 @@ async def retrieve_vertices_order( run_id = uuid.uuid4() graph.set_run_id(run_id) vertices_to_run = list(graph.vertices_to_run) + get_top_level_vertices(graph, graph.vertices_to_run) + await chat_service.set_cache(flow_id, graph) return VerticesOrderResponse(ids=first_layer, run_id=run_id, vertices_to_run=vertices_to_run) except Exception as exc: @@ -155,7 +156,7 @@ async def build_vertex( if not cache: # If there's no cache logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}") - graph = await build_and_cache_graph_from_db( + graph = await build_graph_from_db( flow_id=flow_id_str, session=next(get_session()), chat_service=chat_service ) else: diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index f541bb3e4..9f1b2c60e 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -5,6 +5,8 @@ from functools import partial from itertools import chain from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Tuple, Type, Union +from loguru import logger + from langflow.graph.edge.base import ContractEdge from langflow.graph.graph.constants import lazy_load_vertex_dict from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager @@ -19,7 +21,6 @@ from langflow.services.cache.utils import CacheMiss from langflow.services.chat.service import ChatService from langflow.services.deps import get_chat_service from langflow.services.monitor.utils import log_transaction -from loguru import logger if TYPE_CHECKING: from langflow.graph.schema import ResultData @@ -511,10 +512,38 @@ class Graph: self._updates += 1 def __getstate__(self): - return self.raw_graph_data + # Get all attributes that are useful in runs. + # We don't need to save the state_manager because it is + # a singleton and it is not necessary to save it + return { + "vertices": self.vertices, + "edges": self.edges, + "flow_id": self.flow_id, + "user_id": self.user_id, + "raw_graph_data": self.raw_graph_data, + "top_level_vertices": self.top_level_vertices, + "inactivated_vertices": self.inactivated_vertices, + "run_manager": self.run_manager.to_dict(), + "_run_id": self._run_id, + "in_degree_map": self.in_degree_map, + "parent_child_map": self.parent_child_map, + "predecessor_map": self.predecessor_map, + "successor_map": self.successor_map, + "activated_vertices": self.activated_vertices, + "vertices_layers": self.vertices_layers, + "vertices_to_run": self.vertices_to_run, + "stop_vertex": self.stop_vertex, + "vertex_map": self.vertex_map, + } def __setstate__(self, state): - self.__init__(**state) + run_manager = state["run_manager"] + if isinstance(run_manager, RunnableVerticesManager): + state["run_manager"] = run_manager + else: + state["run_manager"] = RunnableVerticesManager.from_dict(run_manager) + self.__dict__.update(state) + self.state_manager = GraphStateManager() @classmethod def from_payload(cls, payload: Dict, flow_id: Optional[str] = None, user_id: Optional[str] = None) -> "Graph": diff --git a/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py b/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py index 9dd7c346c..4951453cf 100644 --- a/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py +++ b/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py @@ -13,6 +13,33 @@ class RunnableVerticesManager: self.run_predecessors = defaultdict(set) # Tracks predecessors for each vertex self.vertices_to_run = set() # Set of vertices that are ready to run + def to_dict(self) -> dict: + return { + "run_map": self.run_map, + "run_predecessors": self.run_predecessors, + "vertices_to_run": self.vertices_to_run, + } + + @classmethod + def from_dict(cls, data: dict) -> "RunnableVerticesManager": + instance = cls() + instance.run_map = data["run_map"] + instance.run_predecessors = data["run_predecessors"] + instance.vertices_to_run = data["vertices_to_run"] + return instance + + def __getstate__(self) -> object: + return { + "run_map": self.run_map, + "run_predecessors": self.run_predecessors, + "vertices_to_run": self.vertices_to_run, + } + + def __setstate__(self, state: dict) -> None: + self.run_map = state["run_map"] + self.run_predecessors = state["run_predecessors"] + self.vertices_to_run = state["vertices_to_run"] + def is_vertex_runnable(self, vertex_id: str) -> bool: """Determines if a vertex is runnable.""" diff --git a/src/backend/base/langflow/graph/vertex/base.py b/src/backend/base/langflow/graph/vertex/base.py index 736fc8f37..8e19184a2 100644 --- a/src/backend/base/langflow/graph/vertex/base.py +++ b/src/backend/base/langflow/graph/vertex/base.py @@ -165,51 +165,17 @@ class Vertex: return self.graph.successor_map.get(self.id, []) def __getstate__(self): - return { - "_data": self._data, - "params": {}, - "base_type": self.base_type, - "base_name": self.base_name, - "is_task": self.is_task, - "id": self.id, - "_built_object": UnbuiltObject(), - "_built": False, - "parent_node_id": self.parent_node_id, - "parent_is_top_level": self.parent_is_top_level, - "load_from_db_fields": self.load_from_db_fields, - "is_input": self.is_input, - "is_output": self.is_output, - } + state = self.__dict__.copy() + state["_lock"] = None # Locks are not serializable + state["_built_object"] = None if isinstance(self._built_object, UnbuiltObject) else self._built_object + state["_built_result"] = None if isinstance(self._built_result, UnbuiltResult) else self._built_result + return state def __setstate__(self, state): - self._lock = asyncio.Lock() - self._data = state["_data"] - self.params = state["params"] - self.base_type = state["base_type"] - self.is_task = state["is_task"] - self.id = state["id"] - self.frozen = state.get("frozen", False) - self.is_input = state.get("is_input", False) - self.is_output = state.get("is_output", False) - self.base_name = state["base_name"] - self._parse_data() - if "_built_object" in state: - self._built_object = state["_built_object"] - self._built = state["_built"] - else: - self._built_object = UnbuiltObject() - self._built = False - if "_built_result" in state: - self._built_result = state["_built_result"] - else: - self._built_result = UnbuiltResult() - self.artifacts: Dict[str, Any] = {} - self.task_id: Optional[str] = None - self.parent_node_id = state["parent_node_id"] - self.parent_is_top_level = state["parent_is_top_level"] - self.load_from_db_fields = state["load_from_db_fields"] - self.layer = state.get("layer") - self.steps = state.get("steps", [self._build]) + self.__dict__.update(state) + self._lock = asyncio.Lock() # Reinitialize the lock + self._built_object = state.get("_built_object") or UnbuiltObject() + self._built_result = state.get("_built_result") or UnbuiltResult() def set_top_level(self, top_level_vertices: List[str]) -> None: self.parent_is_top_level = self.parent_node_id in top_level_vertices