🐛 fix(base.py): add __setstate__ method to Edge class to properly set state when unpickling
🐛 fix(base.py): add reset method to Edge class to reset source and target params when needed 🐛 fix(base.py): add __setstate__ method to Graph class to properly set state when unpickling 🐛 fix(base.py): add __eq__ method to Graph class to compare graphs based on their string representation 🐛 fix(types.py): add __getstate__ and __setstate__ methods to AgentVertex class to properly set and get state when pickling and unpickling 🐛 fix(types.py): add __getstate__ and __setstate__ methods to ToolVertex class to properly set and get state when pickling and unpickling 🐛 fix(types.py): add __getstate__ and __setstate__ methods to LLMVertex class to properly set and get state when pickling and unpickling 🐛 fix(types.py): add __getstate__ and __setstate__ methods to ToolkitVertex class to properly set and get state when pickling and unpickling 🐛 fix(types.py): add __getstate__ and __setstate__ methods to FileToolVertex class to properly set and get state when pickling and unpickling 🐛 fix(types.py): add __getstate__ and __setstate__ methods to DocumentLoaderVertex class to properly set and get state when pickling and unpickling 🐛 fix(types.py): add __getstate__ and __setstate__ methods to EmbeddingVertex class to properly set and get state when pickling and unpickling 🐛 fix(types.py): add __getstate__ and __setstate__ methods to VectorStoreVertex class to properly set and get state when pickling and unpickling 🐛 fix(types.py): add __getstate__ and __setstate__ methods to TextSplitterVertex class to properly set and get state when pickling and unpickling ✨ feat(types.py): add reset method to AgentVertex class to reset source and target params when needed ✨ feat(types.py): add reset method to ToolVertex class to reset source and target params when needed ✨ feat(types.py): add reset method to LLMVertex class to reset source and target params when needed ✨ feat(types.py):
This commit is contained in:
parent
2b10cfe96d
commit
bea1328a3e
3 changed files with 95 additions and 20 deletions
|
|
@ -17,6 +17,17 @@ class Edge:
|
|||
|
||||
self.validate_edge()
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.source = state["source"]
|
||||
self.target = state["target"]
|
||||
self.target_param = state["target_param"]
|
||||
self.source_handle = state["source_handle"]
|
||||
self.target_handle = state["target_handle"]
|
||||
|
||||
def reset(self) -> None:
|
||||
self.source._build_params()
|
||||
self.target._build_params()
|
||||
|
||||
def validate_edge(self) -> None:
|
||||
# Validate that the outputs of the source node are valid inputs
|
||||
# for the target node
|
||||
|
|
|
|||
|
|
@ -26,6 +26,12 @@ class Graph:
|
|||
self._edges = edges
|
||||
self._build_graph()
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
for edge in self.edges:
|
||||
edge.reset()
|
||||
edge.validate_edge()
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: Dict) -> "Graph":
|
||||
"""
|
||||
|
|
@ -48,6 +54,11 @@ class Graph:
|
|||
f"Invalid payload. Expected keys 'nodes' and 'edges'. Found {list(payload.keys())}"
|
||||
) from exc
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Graph):
|
||||
return False
|
||||
return self.__repr__() == other.__repr__()
|
||||
|
||||
def _build_graph(self) -> None:
|
||||
"""Builds the graph from the nodes and edges."""
|
||||
self.nodes = self._build_vertices()
|
||||
|
|
|
|||
|
|
@ -4,17 +4,31 @@ from typing import Any, Dict, List, Optional, Union
|
|||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.utils import flatten_list
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
from zmq import has
|
||||
|
||||
|
||||
class AgentVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="agents")
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="agents", params=params)
|
||||
|
||||
self.tools: List[Union[ToolkitVertex, ToolVertex]] = []
|
||||
self.chains: List[ChainVertex] = []
|
||||
|
||||
def __getstate__(self):
|
||||
state = super().__getstate__()
|
||||
state["tools"] = self.tools
|
||||
state["chains"] = self.chains
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.tools = state["tools"]
|
||||
self.chains = state["chains"]
|
||||
super().__setstate__(state)
|
||||
|
||||
def _set_tools_and_chains(self) -> None:
|
||||
for edge in self.edges:
|
||||
if not hasattr(edge, "source"):
|
||||
continue
|
||||
source_node = edge.source
|
||||
if isinstance(source_node, (ToolVertex, ToolkitVertex)):
|
||||
self.tools.append(source_node)
|
||||
|
|
@ -38,16 +52,16 @@ class AgentVertex(Vertex):
|
|||
|
||||
|
||||
class ToolVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="tools")
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="tools", params=params)
|
||||
|
||||
|
||||
class LLMVertex(Vertex):
|
||||
built_node_type = None
|
||||
class_built_object = None
|
||||
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="llms")
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="llms", params=params)
|
||||
|
||||
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
# LLM is different because some models might take up too much memory
|
||||
|
|
@ -64,13 +78,13 @@ class LLMVertex(Vertex):
|
|||
|
||||
|
||||
class ToolkitVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="toolkits")
|
||||
def __init__(self, data: Dict, params=None):
|
||||
super().__init__(data, base_type="toolkits", params=params)
|
||||
|
||||
|
||||
class FileToolVertex(ToolVertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data)
|
||||
def __init__(self, data: Dict, params=None):
|
||||
super().__init__(data, params=params)
|
||||
|
||||
|
||||
class WrapperVertex(Vertex):
|
||||
|
|
@ -86,17 +100,19 @@ class WrapperVertex(Vertex):
|
|||
|
||||
|
||||
class DocumentLoaderVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="documentloaders")
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="documentloaders", params=params)
|
||||
|
||||
def _built_object_repr(self):
|
||||
# This built_object is a list of documents. Maybe we should
|
||||
# show how many documents are in the list?
|
||||
|
||||
if self._built_object:
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
|
||||
self._built_object
|
||||
)
|
||||
avg_length = sum(
|
||||
len(doc.page_content)
|
||||
for doc in self._built_object
|
||||
if hasattr(doc, "page_content")
|
||||
) / len(self._built_object)
|
||||
return f"""{self.vertex_type}({len(self._built_object)} documents)
|
||||
\nAvg. Document Length (characters): {int(avg_length)}
|
||||
Documents: {self._built_object[:3]}..."""
|
||||
|
|
@ -104,13 +120,50 @@ class DocumentLoaderVertex(Vertex):
|
|||
|
||||
|
||||
class EmbeddingVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="embeddings")
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="embeddings", params=params)
|
||||
|
||||
|
||||
class VectorStoreVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
def __init__(self, data: Dict, params=None):
|
||||
super().__init__(data, base_type="vectorstores")
|
||||
if params:
|
||||
self.params = params
|
||||
|
||||
# VectorStores may contain databse connections
|
||||
# so we need to define the __reduce__ method and the __setstate__ method
|
||||
# to avoid pickling errors
|
||||
def clean_edges_for_pickling(self):
|
||||
# for each edge that has self as source
|
||||
# we need to clear the _built_object of the target
|
||||
# so that we don't try to pickle a database connection
|
||||
for edge in self.edges:
|
||||
if edge.source == self:
|
||||
edge.target._built_object = None
|
||||
edge.target._built = False
|
||||
edge.target.params[edge.target_param] = self
|
||||
|
||||
def remove_docs_and_texts_from_params(self):
|
||||
# remove documents and texts from params
|
||||
# so that we don't try to pickle a database connection
|
||||
self.params.pop("documents", None)
|
||||
self.params.pop("texts", None)
|
||||
|
||||
def __getstate__(self):
|
||||
# We want to save the params attribute
|
||||
# and if "documents" or "texts" are in the params
|
||||
# we want to remove them because they have already
|
||||
# been processed.
|
||||
params = self.params.copy()
|
||||
params.pop("documents", None)
|
||||
params.pop("texts", None)
|
||||
self.clean_edges_for_pickling()
|
||||
|
||||
return super().__getstate__()
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
self.remove_docs_and_texts_from_params()
|
||||
|
||||
|
||||
class MemoryVertex(Vertex):
|
||||
|
|
@ -124,8 +177,8 @@ class RetrieverVertex(Vertex):
|
|||
|
||||
|
||||
class TextSplitterVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="textsplitters")
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="textsplitters", params=params)
|
||||
|
||||
def _built_object_repr(self):
|
||||
# This built_object is a list of documents. Maybe we should
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue