From bea1328a3e3ccf87a3cb3cf86afd9973549a69a7 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 22 Sep 2023 10:58:27 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(base.py):=20add=20=5F=5Fsets?= =?UTF-8?q?tate=5F=5F=20method=20to=20Edge=20class=20to=20properly=20set?= =?UTF-8?q?=20state=20when=20unpickling=20=F0=9F=90=9B=20fix(base.py):=20a?= =?UTF-8?q?dd=20reset=20method=20to=20Edge=20class=20to=20reset=20source?= =?UTF-8?q?=20and=20target=20params=20when=20needed=20=F0=9F=90=9B=20fix(b?= =?UTF-8?q?ase.py):=20add=20=5F=5Fsetstate=5F=5F=20method=20to=20Graph=20c?= =?UTF-8?q?lass=20to=20properly=20set=20state=20when=20unpickling=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(base.py):=20add=20=5F=5Feq=5F=5F=20method=20?= =?UTF-8?q?to=20Graph=20class=20to=20compare=20graphs=20based=20on=20their?= =?UTF-8?q?=20string=20representation=20=F0=9F=90=9B=20fix(types.py):=20ad?= =?UTF-8?q?d=20=5F=5Fgetstate=5F=5F=20and=20=5F=5Fsetstate=5F=5F=20methods?= =?UTF-8?q?=20to=20AgentVertex=20class=20to=20properly=20set=20and=20get?= =?UTF-8?q?=20state=20when=20pickling=20and=20unpickling=20=F0=9F=90=9B=20?= =?UTF-8?q?fix(types.py):=20add=20=5F=5Fgetstate=5F=5F=20and=20=5F=5Fsetst?= =?UTF-8?q?ate=5F=5F=20methods=20to=20ToolVertex=20class=20to=20properly?= =?UTF-8?q?=20set=20and=20get=20state=20when=20pickling=20and=20unpickling?= =?UTF-8?q?=20=F0=9F=90=9B=20fix(types.py):=20add=20=5F=5Fgetstate=5F=5F?= =?UTF-8?q?=20and=20=5F=5Fsetstate=5F=5F=20methods=20to=20LLMVertex=20clas?= =?UTF-8?q?s=20to=20properly=20set=20and=20get=20state=20when=20pickling?= =?UTF-8?q?=20and=20unpickling=20=F0=9F=90=9B=20fix(types.py):=20add=20=5F?= =?UTF-8?q?=5Fgetstate=5F=5F=20and=20=5F=5Fsetstate=5F=5F=20methods=20to?= =?UTF-8?q?=20ToolkitVertex=20class=20to=20properly=20set=20and=20get=20st?= =?UTF-8?q?ate=20when=20pickling=20and=20unpickling=20=F0=9F=90=9B=20fix(t?= =?UTF-8?q?ypes.py):=20add=20=5F=5Fgetstate=5F=5F=20and=20=5F=5Fsetstate?= =?UTF-8?q?=5F=5F=20methods=20to=20FileToolVertex=20class=20to=20properly?= =?UTF-8?q?=20set=20and=20get=20state=20when=20pickling=20and=20unpickling?= =?UTF-8?q?=20=F0=9F=90=9B=20fix(types.py):=20add=20=5F=5Fgetstate=5F=5F?= =?UTF-8?q?=20and=20=5F=5Fsetstate=5F=5F=20methods=20to=20DocumentLoaderVe?= =?UTF-8?q?rtex=20class=20to=20properly=20set=20and=20get=20state=20when?= =?UTF-8?q?=20pickling=20and=20unpickling=20=F0=9F=90=9B=20fix(types.py):?= =?UTF-8?q?=20add=20=5F=5Fgetstate=5F=5F=20and=20=5F=5Fsetstate=5F=5F=20me?= =?UTF-8?q?thods=20to=20EmbeddingVertex=20class=20to=20properly=20set=20an?= =?UTF-8?q?d=20get=20state=20when=20pickling=20and=20unpickling=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(types.py):=20add=20=5F=5Fgetstate=5F=5F=20an?= =?UTF-8?q?d=20=5F=5Fsetstate=5F=5F=20methods=20to=20VectorStoreVertex=20c?= =?UTF-8?q?lass=20to=20properly=20set=20and=20get=20state=20when=20picklin?= =?UTF-8?q?g=20and=20unpickling=20=F0=9F=90=9B=20fix(types.py):=20add=20?= =?UTF-8?q?=5F=5Fgetstate=5F=5F=20and=20=5F=5Fsetstate=5F=5F=20methods=20t?= =?UTF-8?q?o=20TextSplitterVertex=20class=20to=20properly=20set=20and=20ge?= =?UTF-8?q?t=20state=20when=20pickling=20and=20unpickling=20=E2=9C=A8=20fe?= =?UTF-8?q?at(types.py):=20add=20reset=20method=20to=20AgentVertex=20class?= =?UTF-8?q?=20to=20reset=20source=20and=20target=20params=20when=20needed?= =?UTF-8?q?=20=E2=9C=A8=20feat(types.py):=20add=20reset=20method=20to=20To?= =?UTF-8?q?olVertex=20class=20to=20reset=20source=20and=20target=20params?= =?UTF-8?q?=20when=20needed=20=E2=9C=A8=20feat(types.py):=20add=20reset=20?= =?UTF-8?q?method=20to=20LLMVertex=20class=20to=20reset=20source=20and=20t?= =?UTF-8?q?arget=20params=20when=20needed=20=E2=9C=A8=20feat(types.py):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/graph/edge/base.py | 11 +++ src/backend/langflow/graph/graph/base.py | 11 +++ src/backend/langflow/graph/vertex/types.py | 93 +++++++++++++++++----- 3 files changed, 95 insertions(+), 20 deletions(-) diff --git a/src/backend/langflow/graph/edge/base.py b/src/backend/langflow/graph/edge/base.py index 2df20cbde..2c60b0288 100644 --- a/src/backend/langflow/graph/edge/base.py +++ b/src/backend/langflow/graph/edge/base.py @@ -17,6 +17,17 @@ class Edge: self.validate_edge() + def __setstate__(self, state): + self.source = state["source"] + self.target = state["target"] + self.target_param = state["target_param"] + self.source_handle = state["source_handle"] + self.target_handle = state["target_handle"] + + def reset(self) -> None: + self.source._build_params() + self.target._build_params() + def validate_edge(self) -> None: # Validate that the outputs of the source node are valid inputs # for the target node diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index a8b9ee592..227b04bf9 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -26,6 +26,12 @@ class Graph: self._edges = edges self._build_graph() + def __setstate__(self, state): + self.__dict__.update(state) + for edge in self.edges: + edge.reset() + edge.validate_edge() + @classmethod def from_payload(cls, payload: Dict) -> "Graph": """ @@ -48,6 +54,11 @@ class Graph: f"Invalid payload. Expected keys 'nodes' and 'edges'. Found {list(payload.keys())}" ) from exc + def __eq__(self, other: object) -> bool: + if not isinstance(other, Graph): + return False + return self.__repr__() == other.__repr__() + def _build_graph(self) -> None: """Builds the graph from the nodes and edges.""" self.nodes = self._build_vertices() diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index fecf75728..a64324285 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -4,17 +4,31 @@ from typing import Any, Dict, List, Optional, Union from langflow.graph.vertex.base import Vertex from langflow.graph.utils import flatten_list from langflow.interface.utils import extract_input_variables_from_prompt +from zmq import has class AgentVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="agents") + def __init__(self, data: Dict, params: Optional[Dict] = None): + super().__init__(data, base_type="agents", params=params) self.tools: List[Union[ToolkitVertex, ToolVertex]] = [] self.chains: List[ChainVertex] = [] + def __getstate__(self): + state = super().__getstate__() + state["tools"] = self.tools + state["chains"] = self.chains + return state + + def __setstate__(self, state): + self.tools = state["tools"] + self.chains = state["chains"] + super().__setstate__(state) + def _set_tools_and_chains(self) -> None: for edge in self.edges: + if not hasattr(edge, "source"): + continue source_node = edge.source if isinstance(source_node, (ToolVertex, ToolkitVertex)): self.tools.append(source_node) @@ -38,16 +52,16 @@ class AgentVertex(Vertex): class ToolVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="tools") + def __init__(self, data: Dict, params: Optional[Dict] = None): + super().__init__(data, base_type="tools", params=params) class LLMVertex(Vertex): built_node_type = None class_built_object = None - def __init__(self, data: Dict): - super().__init__(data, base_type="llms") + def __init__(self, data: Dict, params: Optional[Dict] = None): + super().__init__(data, base_type="llms", params=params) def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any: # LLM is different because some models might take up too much memory @@ -64,13 +78,13 @@ class LLMVertex(Vertex): class ToolkitVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="toolkits") + def __init__(self, data: Dict, params=None): + super().__init__(data, base_type="toolkits", params=params) class FileToolVertex(ToolVertex): - def __init__(self, data: Dict): - super().__init__(data) + def __init__(self, data: Dict, params=None): + super().__init__(data, params=params) class WrapperVertex(Vertex): @@ -86,17 +100,19 @@ class WrapperVertex(Vertex): class DocumentLoaderVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="documentloaders") + def __init__(self, data: Dict, params: Optional[Dict] = None): + super().__init__(data, base_type="documentloaders", params=params) def _built_object_repr(self): # This built_object is a list of documents. Maybe we should # show how many documents are in the list? if self._built_object: - avg_length = sum(len(doc.page_content) for doc in self._built_object) / len( - self._built_object - ) + avg_length = sum( + len(doc.page_content) + for doc in self._built_object + if hasattr(doc, "page_content") + ) / len(self._built_object) return f"""{self.vertex_type}({len(self._built_object)} documents) \nAvg. Document Length (characters): {int(avg_length)} Documents: {self._built_object[:3]}...""" @@ -104,13 +120,50 @@ class DocumentLoaderVertex(Vertex): class EmbeddingVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="embeddings") + def __init__(self, data: Dict, params: Optional[Dict] = None): + super().__init__(data, base_type="embeddings", params=params) class VectorStoreVertex(Vertex): - def __init__(self, data: Dict): + def __init__(self, data: Dict, params=None): super().__init__(data, base_type="vectorstores") + if params: + self.params = params + + # VectorStores may contain databse connections + # so we need to define the __reduce__ method and the __setstate__ method + # to avoid pickling errors + def clean_edges_for_pickling(self): + # for each edge that has self as source + # we need to clear the _built_object of the target + # so that we don't try to pickle a database connection + for edge in self.edges: + if edge.source == self: + edge.target._built_object = None + edge.target._built = False + edge.target.params[edge.target_param] = self + + def remove_docs_and_texts_from_params(self): + # remove documents and texts from params + # so that we don't try to pickle a database connection + self.params.pop("documents", None) + self.params.pop("texts", None) + + def __getstate__(self): + # We want to save the params attribute + # and if "documents" or "texts" are in the params + # we want to remove them because they have already + # been processed. + params = self.params.copy() + params.pop("documents", None) + params.pop("texts", None) + self.clean_edges_for_pickling() + + return super().__getstate__() + + def __setstate__(self, state): + super().__setstate__(state) + self.remove_docs_and_texts_from_params() class MemoryVertex(Vertex): @@ -124,8 +177,8 @@ class RetrieverVertex(Vertex): class TextSplitterVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="textsplitters") + def __init__(self, data: Dict, params: Optional[Dict] = None): + super().__init__(data, base_type="textsplitters", params=params) def _built_object_repr(self): # This built_object is a list of documents. Maybe we should