🐛 fix(types.py): pass graph parameter to Vertex constructors to fix missing graph reference
✨ feat(types.py): add support for passing graph parameter to Vertex constructors to ensure proper graph reference
This commit is contained in:
parent
3d20c8dc38
commit
1facfefb19
1 changed files with 56 additions and 57 deletions
|
|
@ -1,14 +1,14 @@
|
|||
import ast
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langflow.graph.utils import flatten_list
|
||||
from langflow.graph.utils import UnbuiltObject, flatten_list
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
|
||||
|
||||
class AgentVertex(Vertex):
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="agents", params=params)
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="agents", params=params)
|
||||
|
||||
self.tools: List[Union[ToolkitVertex, ToolVertex]] = []
|
||||
self.chains: List[ChainVertex] = []
|
||||
|
|
@ -28,7 +28,7 @@ class AgentVertex(Vertex):
|
|||
for edge in self.edges:
|
||||
if not hasattr(edge, "source"):
|
||||
continue
|
||||
source_node = edge.source
|
||||
source_node = self.graph.get_vertex(edge.source_id)
|
||||
if isinstance(source_node, (ToolVertex, ToolkitVertex)):
|
||||
self.tools.append(source_node)
|
||||
elif isinstance(source_node, ChainVertex):
|
||||
|
|
@ -51,16 +51,21 @@ class AgentVertex(Vertex):
|
|||
|
||||
|
||||
class ToolVertex(Vertex):
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="tools", params=params)
|
||||
def __init__(
|
||||
self,
|
||||
data: Dict,
|
||||
graph,
|
||||
params: Optional[Dict] = None,
|
||||
):
|
||||
super().__init__(data, graph=graph, base_type="tools", params=params)
|
||||
|
||||
|
||||
class LLMVertex(Vertex):
|
||||
built_node_type = None
|
||||
class_built_object = None
|
||||
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="llms", params=params)
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="llms", params=params)
|
||||
|
||||
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
# LLM is different because some models might take up too much memory
|
||||
|
|
@ -77,18 +82,18 @@ class LLMVertex(Vertex):
|
|||
|
||||
|
||||
class ToolkitVertex(Vertex):
|
||||
def __init__(self, data: Dict, params=None):
|
||||
super().__init__(data, base_type="toolkits", params=params)
|
||||
def __init__(self, data: Dict, graph, params=None):
|
||||
super().__init__(data, graph=graph, base_type="toolkits", params=params)
|
||||
|
||||
|
||||
class FileToolVertex(ToolVertex):
|
||||
def __init__(self, data: Dict, params=None):
|
||||
super().__init__(data, params=params)
|
||||
def __init__(self, data: Dict, graph, params=None):
|
||||
super().__init__(data, graph=graph, params=params)
|
||||
|
||||
|
||||
class WrapperVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="wrappers")
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="wrappers")
|
||||
|
||||
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
if not self._built or force:
|
||||
|
|
@ -99,14 +104,14 @@ class WrapperVertex(Vertex):
|
|||
|
||||
|
||||
class DocumentLoaderVertex(Vertex):
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="documentloaders", params=params)
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="documentloaders", params=params)
|
||||
|
||||
def _built_object_repr(self):
|
||||
# This built_object is a list of documents. Maybe we should
|
||||
# show how many documents are in the list?
|
||||
|
||||
if self._built_object:
|
||||
if self._built_object and not isinstance(self._built_object, UnbuiltObject):
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len(
|
||||
self._built_object
|
||||
)
|
||||
|
|
@ -117,28 +122,19 @@ class DocumentLoaderVertex(Vertex):
|
|||
|
||||
|
||||
class EmbeddingVertex(Vertex):
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="embeddings", params=params)
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="embeddings", params=params)
|
||||
|
||||
|
||||
class VectorStoreVertex(Vertex):
|
||||
def __init__(self, data: Dict, params=None):
|
||||
super().__init__(data, base_type="vectorstores")
|
||||
def __init__(self, data: Dict, graph, params=None):
|
||||
super().__init__(data, graph=graph, base_type="vectorstores")
|
||||
|
||||
self.params = params or {}
|
||||
|
||||
# VectorStores may contain databse connections
|
||||
# so we need to define the __reduce__ method and the __setstate__ method
|
||||
# to avoid pickling errors
|
||||
def clean_edges_for_pickling(self):
|
||||
# for each edge that has self as source
|
||||
# we need to clear the _built_object of the target
|
||||
# so that we don't try to pickle a database connection
|
||||
for edge in self.edges:
|
||||
if edge.source == self:
|
||||
edge.target._built_object = None
|
||||
edge.target._built = False
|
||||
edge.target.params[edge.target_param] = self
|
||||
|
||||
def remove_docs_and_texts_from_params(self):
|
||||
# remove documents and texts from params
|
||||
|
|
@ -146,17 +142,16 @@ class VectorStoreVertex(Vertex):
|
|||
self.params.pop("documents", None)
|
||||
self.params.pop("texts", None)
|
||||
|
||||
def __getstate__(self):
|
||||
# We want to save the params attribute
|
||||
# and if "documents" or "texts" are in the params
|
||||
# we want to remove them because they have already
|
||||
# been processed.
|
||||
params = self.params.copy()
|
||||
params.pop("documents", None)
|
||||
params.pop("texts", None)
|
||||
self.clean_edges_for_pickling()
|
||||
# def __getstate__(self):
|
||||
# # We want to save the params attribute
|
||||
# # and if "documents" or "texts" are in the params
|
||||
# # we want to remove them because they have already
|
||||
# # been processed.
|
||||
# params = self.params.copy()
|
||||
# params.pop("documents", None)
|
||||
# params.pop("texts", None)
|
||||
|
||||
return super().__getstate__()
|
||||
# return super().__getstate__()
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
|
|
@ -164,24 +159,24 @@ class VectorStoreVertex(Vertex):
|
|||
|
||||
|
||||
class MemoryVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="memory")
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="memory")
|
||||
|
||||
|
||||
class RetrieverVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="retrievers")
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="retrievers")
|
||||
|
||||
|
||||
class TextSplitterVertex(Vertex):
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="textsplitters", params=params)
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="textsplitters", params=params)
|
||||
|
||||
def _built_object_repr(self):
|
||||
# This built_object is a list of documents. Maybe we should
|
||||
# show how many documents are in the list?
|
||||
|
||||
if self._built_object:
|
||||
if self._built_object and not isinstance(self._built_object, UnbuiltObject):
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object)
|
||||
return f"""{self.vertex_type}({len(self._built_object)} documents)
|
||||
\nAvg. Document Length (characters): {int(avg_length)}
|
||||
|
|
@ -190,8 +185,8 @@ class TextSplitterVertex(Vertex):
|
|||
|
||||
|
||||
class ChainVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="chains")
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="chains")
|
||||
|
||||
async def build(
|
||||
self,
|
||||
|
|
@ -220,8 +215,8 @@ class ChainVertex(Vertex):
|
|||
|
||||
|
||||
class PromptVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="prompts")
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="prompts")
|
||||
|
||||
async def build(
|
||||
self,
|
||||
|
|
@ -271,9 +266,13 @@ class PromptVertex(Vertex):
|
|||
# so the prompt format doesn't break
|
||||
artifacts.pop("handle_keys", None)
|
||||
try:
|
||||
if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"):
|
||||
if (
|
||||
not hasattr(self._built_object, "template")
|
||||
and hasattr(self._built_object, "prompt")
|
||||
and not isinstance(self._built_object, UnbuiltObject)
|
||||
):
|
||||
template = self._built_object.prompt.template
|
||||
else:
|
||||
elif not isinstance(self._built_object, UnbuiltObject) and hasattr(self._built_object, "template"):
|
||||
template = self._built_object.template
|
||||
for key, value in artifacts.items():
|
||||
if value:
|
||||
|
|
@ -285,13 +284,13 @@ class PromptVertex(Vertex):
|
|||
|
||||
|
||||
class OutputParserVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="output_parsers")
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="output_parsers")
|
||||
|
||||
|
||||
class CustomComponentVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="custom_components", is_task=True)
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="custom_components", is_task=True)
|
||||
|
||||
def _built_object_repr(self):
|
||||
if self.task_id and self.is_task:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue