📝 (base.py): refactor Graph class serialization to include necessary attributes and handle run_manager properly
📝 (utils.py): rename build_and_cache_graph_from_db to build_graph_from_db for clarity and consistency 📝 (chat.py): update function calls to use the new build_graph_from_db function name 📝 (runnable_vertices_manager.py): add serialization methods to RunnableVerticesManager class 📝 (vertex.py): improve Vertex class serialization for better handling of internal attributes
This commit is contained in:
parent
f49ebe5f9b
commit
515fe5638f
5 changed files with 73 additions and 50 deletions
|
|
@ -205,7 +205,7 @@ def format_elapsed_time(elapsed_time: float) -> str:
|
|||
return f"{minutes} {minutes_unit}, {seconds} {seconds_unit}"
|
||||
|
||||
|
||||
async def build_and_cache_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
|
||||
async def build_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
|
||||
"""Build and cache the graph."""
|
||||
flow: Optional[Flow] = session.get(Flow, flow_id)
|
||||
if not flow or not flow.data:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from loguru import logger
|
|||
|
||||
from langflow.api.utils import (
|
||||
build_and_cache_graph_from_data,
|
||||
build_and_cache_graph_from_db,
|
||||
build_graph_from_db,
|
||||
format_elapsed_time,
|
||||
format_exception_message,
|
||||
get_top_level_vertices,
|
||||
|
|
@ -81,7 +81,7 @@ async def retrieve_vertices_order(
|
|||
flow_id_str = str(flow_id)
|
||||
# First, we need to check if the flow_id is in the cache
|
||||
if not data:
|
||||
graph = await build_and_cache_graph_from_db(flow_id=flow_id_str, session=session, chat_service=chat_service)
|
||||
graph = await build_graph_from_db(flow_id=flow_id_str, session=session, chat_service=chat_service)
|
||||
else:
|
||||
graph = await build_and_cache_graph_from_data(
|
||||
flow_id=flow_id_str, graph_data=data.model_dump(), chat_service=chat_service
|
||||
|
|
@ -108,6 +108,7 @@ async def retrieve_vertices_order(
|
|||
run_id = uuid.uuid4()
|
||||
graph.set_run_id(run_id)
|
||||
vertices_to_run = list(graph.vertices_to_run) + get_top_level_vertices(graph, graph.vertices_to_run)
|
||||
await chat_service.set_cache(flow_id, graph)
|
||||
return VerticesOrderResponse(ids=first_layer, run_id=run_id, vertices_to_run=vertices_to_run)
|
||||
|
||||
except Exception as exc:
|
||||
|
|
@ -155,7 +156,7 @@ async def build_vertex(
|
|||
if not cache:
|
||||
# If there's no cache
|
||||
logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}")
|
||||
graph = await build_and_cache_graph_from_db(
|
||||
graph = await build_graph_from_db(
|
||||
flow_id=flow_id_str, session=next(get_session()), chat_service=chat_service
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ from functools import partial
|
|||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Tuple, Type, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from langflow.graph.edge.base import ContractEdge
|
||||
from langflow.graph.graph.constants import lazy_load_vertex_dict
|
||||
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
|
||||
|
|
@ -19,7 +21,6 @@ from langflow.services.cache.utils import CacheMiss
|
|||
from langflow.services.chat.service import ChatService
|
||||
from langflow.services.deps import get_chat_service
|
||||
from langflow.services.monitor.utils import log_transaction
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.schema import ResultData
|
||||
|
|
@ -511,10 +512,38 @@ class Graph:
|
|||
self._updates += 1
|
||||
|
||||
def __getstate__(self):
|
||||
return self.raw_graph_data
|
||||
# Get all attributes that are useful in runs.
|
||||
# We don't need to save the state_manager because it is
|
||||
# a singleton and it is not necessary to save it
|
||||
return {
|
||||
"vertices": self.vertices,
|
||||
"edges": self.edges,
|
||||
"flow_id": self.flow_id,
|
||||
"user_id": self.user_id,
|
||||
"raw_graph_data": self.raw_graph_data,
|
||||
"top_level_vertices": self.top_level_vertices,
|
||||
"inactivated_vertices": self.inactivated_vertices,
|
||||
"run_manager": self.run_manager.to_dict(),
|
||||
"_run_id": self._run_id,
|
||||
"in_degree_map": self.in_degree_map,
|
||||
"parent_child_map": self.parent_child_map,
|
||||
"predecessor_map": self.predecessor_map,
|
||||
"successor_map": self.successor_map,
|
||||
"activated_vertices": self.activated_vertices,
|
||||
"vertices_layers": self.vertices_layers,
|
||||
"vertices_to_run": self.vertices_to_run,
|
||||
"stop_vertex": self.stop_vertex,
|
||||
"vertex_map": self.vertex_map,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__init__(**state)
|
||||
run_manager = state["run_manager"]
|
||||
if isinstance(run_manager, RunnableVerticesManager):
|
||||
state["run_manager"] = run_manager
|
||||
else:
|
||||
state["run_manager"] = RunnableVerticesManager.from_dict(run_manager)
|
||||
self.__dict__.update(state)
|
||||
self.state_manager = GraphStateManager()
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: Dict, flow_id: Optional[str] = None, user_id: Optional[str] = None) -> "Graph":
|
||||
|
|
|
|||
|
|
@ -13,6 +13,33 @@ class RunnableVerticesManager:
|
|||
self.run_predecessors = defaultdict(set) # Tracks predecessors for each vertex
|
||||
self.vertices_to_run = set() # Set of vertices that are ready to run
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"run_map": self.run_map,
|
||||
"run_predecessors": self.run_predecessors,
|
||||
"vertices_to_run": self.vertices_to_run,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "RunnableVerticesManager":
|
||||
instance = cls()
|
||||
instance.run_map = data["run_map"]
|
||||
instance.run_predecessors = data["run_predecessors"]
|
||||
instance.vertices_to_run = data["vertices_to_run"]
|
||||
return instance
|
||||
|
||||
def __getstate__(self) -> object:
|
||||
return {
|
||||
"run_map": self.run_map,
|
||||
"run_predecessors": self.run_predecessors,
|
||||
"vertices_to_run": self.vertices_to_run,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
self.run_map = state["run_map"]
|
||||
self.run_predecessors = state["run_predecessors"]
|
||||
self.vertices_to_run = state["vertices_to_run"]
|
||||
|
||||
def is_vertex_runnable(self, vertex_id: str) -> bool:
|
||||
"""Determines if a vertex is runnable."""
|
||||
|
||||
|
|
|
|||
|
|
@ -165,51 +165,17 @@ class Vertex:
|
|||
return self.graph.successor_map.get(self.id, [])
|
||||
|
||||
def __getstate__(self):
|
||||
return {
|
||||
"_data": self._data,
|
||||
"params": {},
|
||||
"base_type": self.base_type,
|
||||
"base_name": self.base_name,
|
||||
"is_task": self.is_task,
|
||||
"id": self.id,
|
||||
"_built_object": UnbuiltObject(),
|
||||
"_built": False,
|
||||
"parent_node_id": self.parent_node_id,
|
||||
"parent_is_top_level": self.parent_is_top_level,
|
||||
"load_from_db_fields": self.load_from_db_fields,
|
||||
"is_input": self.is_input,
|
||||
"is_output": self.is_output,
|
||||
}
|
||||
state = self.__dict__.copy()
|
||||
state["_lock"] = None # Locks are not serializable
|
||||
state["_built_object"] = None if isinstance(self._built_object, UnbuiltObject) else self._built_object
|
||||
state["_built_result"] = None if isinstance(self._built_result, UnbuiltResult) else self._built_result
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._lock = asyncio.Lock()
|
||||
self._data = state["_data"]
|
||||
self.params = state["params"]
|
||||
self.base_type = state["base_type"]
|
||||
self.is_task = state["is_task"]
|
||||
self.id = state["id"]
|
||||
self.frozen = state.get("frozen", False)
|
||||
self.is_input = state.get("is_input", False)
|
||||
self.is_output = state.get("is_output", False)
|
||||
self.base_name = state["base_name"]
|
||||
self._parse_data()
|
||||
if "_built_object" in state:
|
||||
self._built_object = state["_built_object"]
|
||||
self._built = state["_built"]
|
||||
else:
|
||||
self._built_object = UnbuiltObject()
|
||||
self._built = False
|
||||
if "_built_result" in state:
|
||||
self._built_result = state["_built_result"]
|
||||
else:
|
||||
self._built_result = UnbuiltResult()
|
||||
self.artifacts: Dict[str, Any] = {}
|
||||
self.task_id: Optional[str] = None
|
||||
self.parent_node_id = state["parent_node_id"]
|
||||
self.parent_is_top_level = state["parent_is_top_level"]
|
||||
self.load_from_db_fields = state["load_from_db_fields"]
|
||||
self.layer = state.get("layer")
|
||||
self.steps = state.get("steps", [self._build])
|
||||
self.__dict__.update(state)
|
||||
self._lock = asyncio.Lock() # Reinitialize the lock
|
||||
self._built_object = state.get("_built_object") or UnbuiltObject()
|
||||
self._built_result = state.get("_built_result") or UnbuiltResult()
|
||||
|
||||
def set_top_level(self, top_level_vertices: List[str]) -> None:
|
||||
self.parent_is_top_level = self.parent_node_id in top_level_vertices
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue