📝 (base.py): refactor Graph class serialization to include necessary attributes and handle run_manager properly

📝 (utils.py): rename build_and_cache_graph_from_db to build_graph_from_db for clarity and consistency
📝 (chat.py): update function calls to use the new build_graph_from_db function name
📝 (runnable_vertices_manager.py): add serialization methods to RunnableVerticesManager class
📝 (vertex.py): improve Vertex class serialization for better handling of internal attributes
This commit is contained in:
ogabrielluiz 2024-06-18 00:13:43 -03:00
commit 515fe5638f
5 changed files with 73 additions and 50 deletions

View file

@ -205,7 +205,7 @@ def format_elapsed_time(elapsed_time: float) -> str:
return f"{minutes} {minutes_unit}, {seconds} {seconds_unit}"
async def build_and_cache_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
async def build_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
"""Build and cache the graph."""
flow: Optional[Flow] = session.get(Flow, flow_id)
if not flow or not flow.data:

View file

@ -8,7 +8,7 @@ from loguru import logger
from langflow.api.utils import (
build_and_cache_graph_from_data,
build_and_cache_graph_from_db,
build_graph_from_db,
format_elapsed_time,
format_exception_message,
get_top_level_vertices,
@ -81,7 +81,7 @@ async def retrieve_vertices_order(
flow_id_str = str(flow_id)
# First, we need to check if the flow_id is in the cache
if not data:
graph = await build_and_cache_graph_from_db(flow_id=flow_id_str, session=session, chat_service=chat_service)
graph = await build_graph_from_db(flow_id=flow_id_str, session=session, chat_service=chat_service)
else:
graph = await build_and_cache_graph_from_data(
flow_id=flow_id_str, graph_data=data.model_dump(), chat_service=chat_service
@ -108,6 +108,7 @@ async def retrieve_vertices_order(
run_id = uuid.uuid4()
graph.set_run_id(run_id)
vertices_to_run = list(graph.vertices_to_run) + get_top_level_vertices(graph, graph.vertices_to_run)
await chat_service.set_cache(flow_id, graph)
return VerticesOrderResponse(ids=first_layer, run_id=run_id, vertices_to_run=vertices_to_run)
except Exception as exc:
@ -155,7 +156,7 @@ async def build_vertex(
if not cache:
# If there's no cache
logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}")
graph = await build_and_cache_graph_from_db(
graph = await build_graph_from_db(
flow_id=flow_id_str, session=next(get_session()), chat_service=chat_service
)
else:

View file

@ -5,6 +5,8 @@ from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Tuple, Type, Union
from loguru import logger
from langflow.graph.edge.base import ContractEdge
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
@ -19,7 +21,6 @@ from langflow.services.cache.utils import CacheMiss
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_chat_service
from langflow.services.monitor.utils import log_transaction
from loguru import logger
if TYPE_CHECKING:
from langflow.graph.schema import ResultData
@ -511,10 +512,38 @@ class Graph:
self._updates += 1
def __getstate__(self):
return self.raw_graph_data
# Get all attributes that are useful in runs.
# We don't need to save the state_manager because it is
# a singleton and it is not necessary to save it
return {
"vertices": self.vertices,
"edges": self.edges,
"flow_id": self.flow_id,
"user_id": self.user_id,
"raw_graph_data": self.raw_graph_data,
"top_level_vertices": self.top_level_vertices,
"inactivated_vertices": self.inactivated_vertices,
"run_manager": self.run_manager.to_dict(),
"_run_id": self._run_id,
"in_degree_map": self.in_degree_map,
"parent_child_map": self.parent_child_map,
"predecessor_map": self.predecessor_map,
"successor_map": self.successor_map,
"activated_vertices": self.activated_vertices,
"vertices_layers": self.vertices_layers,
"vertices_to_run": self.vertices_to_run,
"stop_vertex": self.stop_vertex,
"vertex_map": self.vertex_map,
}
def __setstate__(self, state):
self.__init__(**state)
run_manager = state["run_manager"]
if isinstance(run_manager, RunnableVerticesManager):
state["run_manager"] = run_manager
else:
state["run_manager"] = RunnableVerticesManager.from_dict(run_manager)
self.__dict__.update(state)
self.state_manager = GraphStateManager()
@classmethod
def from_payload(cls, payload: Dict, flow_id: Optional[str] = None, user_id: Optional[str] = None) -> "Graph":

View file

@ -13,6 +13,33 @@ class RunnableVerticesManager:
self.run_predecessors = defaultdict(set) # Tracks predecessors for each vertex
self.vertices_to_run = set() # Set of vertices that are ready to run
def to_dict(self) -> dict:
return {
"run_map": self.run_map,
"run_predecessors": self.run_predecessors,
"vertices_to_run": self.vertices_to_run,
}
@classmethod
def from_dict(cls, data: dict) -> "RunnableVerticesManager":
instance = cls()
instance.run_map = data["run_map"]
instance.run_predecessors = data["run_predecessors"]
instance.vertices_to_run = data["vertices_to_run"]
return instance
def __getstate__(self) -> object:
return {
"run_map": self.run_map,
"run_predecessors": self.run_predecessors,
"vertices_to_run": self.vertices_to_run,
}
def __setstate__(self, state: dict) -> None:
self.run_map = state["run_map"]
self.run_predecessors = state["run_predecessors"]
self.vertices_to_run = state["vertices_to_run"]
def is_vertex_runnable(self, vertex_id: str) -> bool:
"""Determines if a vertex is runnable."""

View file

@ -165,51 +165,17 @@ class Vertex:
return self.graph.successor_map.get(self.id, [])
def __getstate__(self):
return {
"_data": self._data,
"params": {},
"base_type": self.base_type,
"base_name": self.base_name,
"is_task": self.is_task,
"id": self.id,
"_built_object": UnbuiltObject(),
"_built": False,
"parent_node_id": self.parent_node_id,
"parent_is_top_level": self.parent_is_top_level,
"load_from_db_fields": self.load_from_db_fields,
"is_input": self.is_input,
"is_output": self.is_output,
}
state = self.__dict__.copy()
state["_lock"] = None # Locks are not serializable
state["_built_object"] = None if isinstance(self._built_object, UnbuiltObject) else self._built_object
state["_built_result"] = None if isinstance(self._built_result, UnbuiltResult) else self._built_result
return state
def __setstate__(self, state):
self._lock = asyncio.Lock()
self._data = state["_data"]
self.params = state["params"]
self.base_type = state["base_type"]
self.is_task = state["is_task"]
self.id = state["id"]
self.frozen = state.get("frozen", False)
self.is_input = state.get("is_input", False)
self.is_output = state.get("is_output", False)
self.base_name = state["base_name"]
self._parse_data()
if "_built_object" in state:
self._built_object = state["_built_object"]
self._built = state["_built"]
else:
self._built_object = UnbuiltObject()
self._built = False
if "_built_result" in state:
self._built_result = state["_built_result"]
else:
self._built_result = UnbuiltResult()
self.artifacts: Dict[str, Any] = {}
self.task_id: Optional[str] = None
self.parent_node_id = state["parent_node_id"]
self.parent_is_top_level = state["parent_is_top_level"]
self.load_from_db_fields = state["load_from_db_fields"]
self.layer = state.get("layer")
self.steps = state.get("steps", [self._build])
self.__dict__.update(state)
self._lock = asyncio.Lock() # Reinitialize the lock
self._built_object = state.get("_built_object") or UnbuiltObject()
self._built_result = state.get("_built_result") or UnbuiltResult()
def set_top_level(self, top_level_vertices: List[str]) -> None:
self.parent_is_top_level = self.parent_node_id in top_level_vertices