🐛 fix(base.py): add __setstate__ method to Edge class to properly set state when unpickling

🐛 fix(base.py): add reset method to Edge class to reset source and target params when needed
🐛 fix(base.py): add __setstate__ method to Graph class to properly set state when unpickling
🐛 fix(base.py): add __eq__ method to Graph class to compare graphs based on their string representation
🐛 fix(types.py): add __getstate__ and __setstate__ methods to AgentVertex class to properly set and get state when pickling and unpickling
🐛 fix(types.py): add __getstate__ and __setstate__ methods to ToolVertex class to properly set and get state when pickling and unpickling
🐛 fix(types.py): add __getstate__ and __setstate__ methods to LLMVertex class to properly set and get state when pickling and unpickling
🐛 fix(types.py): add __getstate__ and __setstate__ methods to ToolkitVertex class to properly set and get state when pickling and unpickling
🐛 fix(types.py): add __getstate__ and __setstate__ methods to FileToolVertex class to properly set and get state when pickling and unpickling
🐛 fix(types.py): add __getstate__ and __setstate__ methods to DocumentLoaderVertex class to properly set and get state when pickling and unpickling
🐛 fix(types.py): add __getstate__ and __setstate__ methods to EmbeddingVertex class to properly set and get state when pickling and unpickling
🐛 fix(types.py): add __getstate__ and __setstate__ methods to VectorStoreVertex class to properly set and get state when pickling and unpickling
🐛 fix(types.py): add __getstate__ and __setstate__ methods to TextSplitterVertex class to properly set and get state when pickling and unpickling
 feat(types.py): add reset method to AgentVertex class to reset source and target params when needed
 feat(types.py): add reset method to ToolVertex class to reset source and target params when needed
 feat(types.py): add reset method to LLMVertex class to reset source and target params when needed
 feat(types.py):
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-09-22 10:58:27 -03:00
commit bea1328a3e
3 changed files with 95 additions and 20 deletions

View file

@ -17,6 +17,17 @@ class Edge:
self.validate_edge()
def __setstate__(self, state):
self.source = state["source"]
self.target = state["target"]
self.target_param = state["target_param"]
self.source_handle = state["source_handle"]
self.target_handle = state["target_handle"]
def reset(self) -> None:
self.source._build_params()
self.target._build_params()
def validate_edge(self) -> None:
# Validate that the outputs of the source node are valid inputs
# for the target node

View file

@ -26,6 +26,12 @@ class Graph:
self._edges = edges
self._build_graph()
def __setstate__(self, state):
self.__dict__.update(state)
for edge in self.edges:
edge.reset()
edge.validate_edge()
@classmethod
def from_payload(cls, payload: Dict) -> "Graph":
"""
@ -48,6 +54,11 @@ class Graph:
f"Invalid payload. Expected keys 'nodes' and 'edges'. Found {list(payload.keys())}"
) from exc
def __eq__(self, other: object) -> bool:
if not isinstance(other, Graph):
return False
return self.__repr__() == other.__repr__()
def _build_graph(self) -> None:
"""Builds the graph from the nodes and edges."""
self.nodes = self._build_vertices()

View file

@ -4,17 +4,31 @@ from typing import Any, Dict, List, Optional, Union
from langflow.graph.vertex.base import Vertex
from langflow.graph.utils import flatten_list
from langflow.interface.utils import extract_input_variables_from_prompt
from zmq import has
class AgentVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="agents")
def __init__(self, data: Dict, params: Optional[Dict] = None):
super().__init__(data, base_type="agents", params=params)
self.tools: List[Union[ToolkitVertex, ToolVertex]] = []
self.chains: List[ChainVertex] = []
def __getstate__(self):
state = super().__getstate__()
state["tools"] = self.tools
state["chains"] = self.chains
return state
def __setstate__(self, state):
self.tools = state["tools"]
self.chains = state["chains"]
super().__setstate__(state)
def _set_tools_and_chains(self) -> None:
for edge in self.edges:
if not hasattr(edge, "source"):
continue
source_node = edge.source
if isinstance(source_node, (ToolVertex, ToolkitVertex)):
self.tools.append(source_node)
@ -38,16 +52,16 @@ class AgentVertex(Vertex):
class ToolVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="tools")
def __init__(self, data: Dict, params: Optional[Dict] = None):
super().__init__(data, base_type="tools", params=params)
class LLMVertex(Vertex):
built_node_type = None
class_built_object = None
def __init__(self, data: Dict):
super().__init__(data, base_type="llms")
def __init__(self, data: Dict, params: Optional[Dict] = None):
super().__init__(data, base_type="llms", params=params)
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
# LLM is different because some models might take up too much memory
@ -64,13 +78,13 @@ class LLMVertex(Vertex):
class ToolkitVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="toolkits")
def __init__(self, data: Dict, params=None):
super().__init__(data, base_type="toolkits", params=params)
class FileToolVertex(ToolVertex):
def __init__(self, data: Dict):
super().__init__(data)
def __init__(self, data: Dict, params=None):
super().__init__(data, params=params)
class WrapperVertex(Vertex):
@ -86,17 +100,19 @@ class WrapperVertex(Vertex):
class DocumentLoaderVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="documentloaders")
def __init__(self, data: Dict, params: Optional[Dict] = None):
super().__init__(data, base_type="documentloaders", params=params)
def _built_object_repr(self):
# This built_object is a list of documents. Maybe we should
# show how many documents are in the list?
if self._built_object:
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
self._built_object
)
avg_length = sum(
len(doc.page_content)
for doc in self._built_object
if hasattr(doc, "page_content")
) / len(self._built_object)
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
Documents: {self._built_object[:3]}..."""
@ -104,13 +120,50 @@ class DocumentLoaderVertex(Vertex):
class EmbeddingVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="embeddings")
def __init__(self, data: Dict, params: Optional[Dict] = None):
super().__init__(data, base_type="embeddings", params=params)
class VectorStoreVertex(Vertex):
def __init__(self, data: Dict):
def __init__(self, data: Dict, params=None):
super().__init__(data, base_type="vectorstores")
if params:
self.params = params
# VectorStores may contain databse connections
# so we need to define the __reduce__ method and the __setstate__ method
# to avoid pickling errors
def clean_edges_for_pickling(self):
# for each edge that has self as source
# we need to clear the _built_object of the target
# so that we don't try to pickle a database connection
for edge in self.edges:
if edge.source == self:
edge.target._built_object = None
edge.target._built = False
edge.target.params[edge.target_param] = self
def remove_docs_and_texts_from_params(self):
# remove documents and texts from params
# so that we don't try to pickle a database connection
self.params.pop("documents", None)
self.params.pop("texts", None)
def __getstate__(self):
# We want to save the params attribute
# and if "documents" or "texts" are in the params
# we want to remove them because they have already
# been processed.
params = self.params.copy()
params.pop("documents", None)
params.pop("texts", None)
self.clean_edges_for_pickling()
return super().__getstate__()
def __setstate__(self, state):
super().__setstate__(state)
self.remove_docs_and_texts_from_params()
class MemoryVertex(Vertex):
@ -124,8 +177,8 @@ class RetrieverVertex(Vertex):
class TextSplitterVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="textsplitters")
def __init__(self, data: Dict, params: Optional[Dict] = None):
super().__init__(data, base_type="textsplitters", params=params)
def _built_object_repr(self):
# This built_object is a list of documents. Maybe we should