🐛 fix(types.py): pass graph parameter to Vertex constructors to fix missing graph reference

 feat(types.py): add support for passing graph parameter to Vertex constructors to ensure proper graph reference
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-11-27 21:50:43 -03:00
commit 1facfefb19

View file

@ -1,14 +1,14 @@
import ast
from typing import Any, Dict, List, Optional, Union
from langflow.graph.utils import flatten_list
from langflow.graph.utils import UnbuiltObject, flatten_list
from langflow.graph.vertex.base import Vertex
from langflow.interface.utils import extract_input_variables_from_prompt
class AgentVertex(Vertex):
def __init__(self, data: Dict, params: Optional[Dict] = None):
super().__init__(data, base_type="agents", params=params)
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="agents", params=params)
self.tools: List[Union[ToolkitVertex, ToolVertex]] = []
self.chains: List[ChainVertex] = []
@ -28,7 +28,7 @@ class AgentVertex(Vertex):
for edge in self.edges:
if not hasattr(edge, "source"):
continue
source_node = edge.source
source_node = self.graph.get_vertex(edge.source_id)
if isinstance(source_node, (ToolVertex, ToolkitVertex)):
self.tools.append(source_node)
elif isinstance(source_node, ChainVertex):
@ -51,16 +51,21 @@ class AgentVertex(Vertex):
class ToolVertex(Vertex):
def __init__(self, data: Dict, params: Optional[Dict] = None):
super().__init__(data, base_type="tools", params=params)
def __init__(
self,
data: Dict,
graph,
params: Optional[Dict] = None,
):
super().__init__(data, graph=graph, base_type="tools", params=params)
class LLMVertex(Vertex):
built_node_type = None
class_built_object = None
def __init__(self, data: Dict, params: Optional[Dict] = None):
super().__init__(data, base_type="llms", params=params)
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="llms", params=params)
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
# LLM is different because some models might take up too much memory
@ -77,18 +82,18 @@ class LLMVertex(Vertex):
class ToolkitVertex(Vertex):
def __init__(self, data: Dict, params=None):
super().__init__(data, base_type="toolkits", params=params)
def __init__(self, data: Dict, graph, params=None):
super().__init__(data, graph=graph, base_type="toolkits", params=params)
class FileToolVertex(ToolVertex):
def __init__(self, data: Dict, params=None):
super().__init__(data, params=params)
def __init__(self, data: Dict, graph, params=None):
super().__init__(data, graph=graph, params=params)
class WrapperVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="wrappers")
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="wrappers")
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
if not self._built or force:
@ -99,14 +104,14 @@ class WrapperVertex(Vertex):
class DocumentLoaderVertex(Vertex):
def __init__(self, data: Dict, params: Optional[Dict] = None):
super().__init__(data, base_type="documentloaders", params=params)
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="documentloaders", params=params)
def _built_object_repr(self):
# This built_object is a list of documents. Maybe we should
# show how many documents are in the list?
if self._built_object:
if self._built_object and not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len(
self._built_object
)
@ -117,28 +122,19 @@ class DocumentLoaderVertex(Vertex):
class EmbeddingVertex(Vertex):
def __init__(self, data: Dict, params: Optional[Dict] = None):
super().__init__(data, base_type="embeddings", params=params)
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="embeddings", params=params)
class VectorStoreVertex(Vertex):
def __init__(self, data: Dict, params=None):
super().__init__(data, base_type="vectorstores")
def __init__(self, data: Dict, graph, params=None):
super().__init__(data, graph=graph, base_type="vectorstores")
self.params = params or {}
# VectorStores may contain databse connections
# so we need to define the __reduce__ method and the __setstate__ method
# to avoid pickling errors
def clean_edges_for_pickling(self):
# for each edge that has self as source
# we need to clear the _built_object of the target
# so that we don't try to pickle a database connection
for edge in self.edges:
if edge.source == self:
edge.target._built_object = None
edge.target._built = False
edge.target.params[edge.target_param] = self
def remove_docs_and_texts_from_params(self):
# remove documents and texts from params
@ -146,17 +142,16 @@ class VectorStoreVertex(Vertex):
self.params.pop("documents", None)
self.params.pop("texts", None)
def __getstate__(self):
# We want to save the params attribute
# and if "documents" or "texts" are in the params
# we want to remove them because they have already
# been processed.
params = self.params.copy()
params.pop("documents", None)
params.pop("texts", None)
self.clean_edges_for_pickling()
# def __getstate__(self):
# # We want to save the params attribute
# # and if "documents" or "texts" are in the params
# # we want to remove them because they have already
# # been processed.
# params = self.params.copy()
# params.pop("documents", None)
# params.pop("texts", None)
return super().__getstate__()
# return super().__getstate__()
def __setstate__(self, state):
super().__setstate__(state)
@ -164,24 +159,24 @@ class VectorStoreVertex(Vertex):
class MemoryVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="memory")
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="memory")
class RetrieverVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="retrievers")
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="retrievers")
class TextSplitterVertex(Vertex):
def __init__(self, data: Dict, params: Optional[Dict] = None):
super().__init__(data, base_type="textsplitters", params=params)
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="textsplitters", params=params)
def _built_object_repr(self):
# This built_object is a list of documents. Maybe we should
# show how many documents are in the list?
if self._built_object:
if self._built_object and not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object)
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
@ -190,8 +185,8 @@ class TextSplitterVertex(Vertex):
class ChainVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="chains")
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="chains")
async def build(
self,
@ -220,8 +215,8 @@ class ChainVertex(Vertex):
class PromptVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="prompts")
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="prompts")
async def build(
self,
@ -271,9 +266,13 @@ class PromptVertex(Vertex):
# so the prompt format doesn't break
artifacts.pop("handle_keys", None)
try:
if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"):
if (
not hasattr(self._built_object, "template")
and hasattr(self._built_object, "prompt")
and not isinstance(self._built_object, UnbuiltObject)
):
template = self._built_object.prompt.template
else:
elif not isinstance(self._built_object, UnbuiltObject) and hasattr(self._built_object, "template"):
template = self._built_object.template
for key, value in artifacts.items():
if value:
@ -285,13 +284,13 @@ class PromptVertex(Vertex):
class OutputParserVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="output_parsers")
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="output_parsers")
class CustomComponentVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="custom_components", is_task=True)
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="custom_components", is_task=True)
def _built_object_repr(self):
if self.task_id and self.is_task: