diff --git a/src/backend/langflow/graph/edge/base.py b/src/backend/langflow/graph/edge/base.py index 2df20cbde..2c60b0288 100644 --- a/src/backend/langflow/graph/edge/base.py +++ b/src/backend/langflow/graph/edge/base.py @@ -17,6 +17,17 @@ class Edge: self.validate_edge() + def __setstate__(self, state): + self.source = state["source"] + self.target = state["target"] + self.target_param = state["target_param"] + self.source_handle = state["source_handle"] + self.target_handle = state["target_handle"] + + def reset(self) -> None: + self.source._build_params() + self.target._build_params() + def validate_edge(self) -> None: # Validate that the outputs of the source node are valid inputs # for the target node diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index a8b9ee592..227b04bf9 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -26,6 +26,12 @@ class Graph: self._edges = edges self._build_graph() + def __setstate__(self, state): + self.__dict__.update(state) + for edge in self.edges: + edge.reset() + edge.validate_edge() + @classmethod def from_payload(cls, payload: Dict) -> "Graph": """ @@ -48,6 +54,11 @@ class Graph: f"Invalid payload. Expected keys 'nodes' and 'edges'. Found {list(payload.keys())}" ) from exc + def __eq__(self, other: object) -> bool: + if not isinstance(other, Graph): + return False + return self.__repr__() == other.__repr__() + def _build_graph(self) -> None: """Builds the graph from the nodes and edges.""" self.nodes = self._build_vertices() diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index fecf75728..a64324285 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -4,17 +4,31 @@ from typing import Any, Dict, List, Optional, Union from langflow.graph.vertex.base import Vertex from langflow.graph.utils import flatten_list from langflow.interface.utils import extract_input_variables_from_prompt +from zmq import has class AgentVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="agents") + def __init__(self, data: Dict, params: Optional[Dict] = None): + super().__init__(data, base_type="agents", params=params) self.tools: List[Union[ToolkitVertex, ToolVertex]] = [] self.chains: List[ChainVertex] = [] + def __getstate__(self): + state = super().__getstate__() + state["tools"] = self.tools + state["chains"] = self.chains + return state + + def __setstate__(self, state): + self.tools = state["tools"] + self.chains = state["chains"] + super().__setstate__(state) + def _set_tools_and_chains(self) -> None: for edge in self.edges: + if not hasattr(edge, "source"): + continue source_node = edge.source if isinstance(source_node, (ToolVertex, ToolkitVertex)): self.tools.append(source_node) @@ -38,16 +52,16 @@ class AgentVertex(Vertex): class ToolVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="tools") + def __init__(self, data: Dict, params: Optional[Dict] = None): + super().__init__(data, base_type="tools", params=params) class LLMVertex(Vertex): built_node_type = None class_built_object = None - def __init__(self, data: Dict): - super().__init__(data, base_type="llms") + def __init__(self, data: Dict, params: Optional[Dict] = None): + super().__init__(data, base_type="llms", params=params) def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any: # LLM is different because some models might take up too much memory @@ -64,13 +78,13 @@ class LLMVertex(Vertex): class ToolkitVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="toolkits") + def __init__(self, data: Dict, params=None): + super().__init__(data, base_type="toolkits", params=params) class FileToolVertex(ToolVertex): - def __init__(self, data: Dict): - super().__init__(data) + def __init__(self, data: Dict, params=None): + super().__init__(data, params=params) class WrapperVertex(Vertex): @@ -86,17 +100,19 @@ class WrapperVertex(Vertex): class DocumentLoaderVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="documentloaders") + def __init__(self, data: Dict, params: Optional[Dict] = None): + super().__init__(data, base_type="documentloaders", params=params) def _built_object_repr(self): # This built_object is a list of documents. Maybe we should # show how many documents are in the list? if self._built_object: - avg_length = sum(len(doc.page_content) for doc in self._built_object) / len( - self._built_object - ) + avg_length = sum( + len(doc.page_content) + for doc in self._built_object + if hasattr(doc, "page_content") + ) / len(self._built_object) return f"""{self.vertex_type}({len(self._built_object)} documents) \nAvg. Document Length (characters): {int(avg_length)} Documents: {self._built_object[:3]}...""" @@ -104,13 +120,50 @@ class DocumentLoaderVertex(Vertex): class EmbeddingVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="embeddings") + def __init__(self, data: Dict, params: Optional[Dict] = None): + super().__init__(data, base_type="embeddings", params=params) class VectorStoreVertex(Vertex): - def __init__(self, data: Dict): + def __init__(self, data: Dict, params=None): super().__init__(data, base_type="vectorstores") + if params: + self.params = params + + # VectorStores may contain databse connections + # so we need to define the __reduce__ method and the __setstate__ method + # to avoid pickling errors + def clean_edges_for_pickling(self): + # for each edge that has self as source + # we need to clear the _built_object of the target + # so that we don't try to pickle a database connection + for edge in self.edges: + if edge.source == self: + edge.target._built_object = None + edge.target._built = False + edge.target.params[edge.target_param] = self + + def remove_docs_and_texts_from_params(self): + # remove documents and texts from params + # so that we don't try to pickle a database connection + self.params.pop("documents", None) + self.params.pop("texts", None) + + def __getstate__(self): + # We want to save the params attribute + # and if "documents" or "texts" are in the params + # we want to remove them because they have already + # been processed. + params = self.params.copy() + params.pop("documents", None) + params.pop("texts", None) + self.clean_edges_for_pickling() + + return super().__getstate__() + + def __setstate__(self, state): + super().__setstate__(state) + self.remove_docs_and_texts_from_params() class MemoryVertex(Vertex): @@ -124,8 +177,8 @@ class RetrieverVertex(Vertex): class TextSplitterVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="textsplitters") + def __init__(self, data: Dict, params: Optional[Dict] = None): + super().__init__(data, base_type="textsplitters", params=params) def _built_object_repr(self): # This built_object is a list of documents. Maybe we should