From 1facfefb193fedfca75a4a4836206d0fced2dd6d Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 27 Nov 2023 21:50:43 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(types.py):=20pass=20graph=20?= =?UTF-8?q?parameter=20to=20Vertex=20constructors=20to=20fix=20missing=20g?= =?UTF-8?q?raph=20reference=20=E2=9C=A8=20feat(types.py):=20add=20support?= =?UTF-8?q?=20for=20passing=20graph=20parameter=20to=20Vertex=20constructo?= =?UTF-8?q?rs=20to=20ensure=20proper=20graph=20reference?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/graph/vertex/types.py | 113 ++++++++++----------- 1 file changed, 56 insertions(+), 57 deletions(-) diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index c288a4b0a..92c920d35 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -1,14 +1,14 @@ import ast from typing import Any, Dict, List, Optional, Union -from langflow.graph.utils import flatten_list +from langflow.graph.utils import UnbuiltObject, flatten_list from langflow.graph.vertex.base import Vertex from langflow.interface.utils import extract_input_variables_from_prompt class AgentVertex(Vertex): - def __init__(self, data: Dict, params: Optional[Dict] = None): - super().__init__(data, base_type="agents", params=params) + def __init__(self, data: Dict, graph, params: Optional[Dict] = None): + super().__init__(data, graph=graph, base_type="agents", params=params) self.tools: List[Union[ToolkitVertex, ToolVertex]] = [] self.chains: List[ChainVertex] = [] @@ -28,7 +28,7 @@ class AgentVertex(Vertex): for edge in self.edges: if not hasattr(edge, "source"): continue - source_node = edge.source + source_node = self.graph.get_vertex(edge.source_id) if isinstance(source_node, (ToolVertex, ToolkitVertex)): self.tools.append(source_node) elif isinstance(source_node, ChainVertex): @@ -51,16 +51,21 @@ class AgentVertex(Vertex): class ToolVertex(Vertex): - def __init__(self, data: Dict, params: Optional[Dict] = None): - super().__init__(data, base_type="tools", params=params) + def __init__( + self, + data: Dict, + graph, + params: Optional[Dict] = None, + ): + super().__init__(data, graph=graph, base_type="tools", params=params) class LLMVertex(Vertex): built_node_type = None class_built_object = None - def __init__(self, data: Dict, params: Optional[Dict] = None): - super().__init__(data, base_type="llms", params=params) + def __init__(self, data: Dict, graph, params: Optional[Dict] = None): + super().__init__(data, graph=graph, base_type="llms", params=params) async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any: # LLM is different because some models might take up too much memory @@ -77,18 +82,18 @@ class LLMVertex(Vertex): class ToolkitVertex(Vertex): - def __init__(self, data: Dict, params=None): - super().__init__(data, base_type="toolkits", params=params) + def __init__(self, data: Dict, graph, params=None): + super().__init__(data, graph=graph, base_type="toolkits", params=params) class FileToolVertex(ToolVertex): - def __init__(self, data: Dict, params=None): - super().__init__(data, params=params) + def __init__(self, data: Dict, graph, params=None): + super().__init__(data, graph=graph, params=params) class WrapperVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="wrappers") + def __init__(self, data: Dict, graph): + super().__init__(data, graph=graph, base_type="wrappers") async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any: if not self._built or force: @@ -99,14 +104,14 @@ class WrapperVertex(Vertex): class DocumentLoaderVertex(Vertex): - def __init__(self, data: Dict, params: Optional[Dict] = None): - super().__init__(data, base_type="documentloaders", params=params) + def __init__(self, data: Dict, graph, params: Optional[Dict] = None): + super().__init__(data, graph=graph, base_type="documentloaders", params=params) def _built_object_repr(self): # This built_object is a list of documents. Maybe we should # show how many documents are in the list? - if self._built_object: + if self._built_object and not isinstance(self._built_object, UnbuiltObject): avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len( self._built_object ) @@ -117,28 +122,19 @@ class DocumentLoaderVertex(Vertex): class EmbeddingVertex(Vertex): - def __init__(self, data: Dict, params: Optional[Dict] = None): - super().__init__(data, base_type="embeddings", params=params) + def __init__(self, data: Dict, graph, params: Optional[Dict] = None): + super().__init__(data, graph=graph, base_type="embeddings", params=params) class VectorStoreVertex(Vertex): - def __init__(self, data: Dict, params=None): - super().__init__(data, base_type="vectorstores") + def __init__(self, data: Dict, graph, params=None): + super().__init__(data, graph=graph, base_type="vectorstores") self.params = params or {} # VectorStores may contain databse connections # so we need to define the __reduce__ method and the __setstate__ method # to avoid pickling errors - def clean_edges_for_pickling(self): - # for each edge that has self as source - # we need to clear the _built_object of the target - # so that we don't try to pickle a database connection - for edge in self.edges: - if edge.source == self: - edge.target._built_object = None - edge.target._built = False - edge.target.params[edge.target_param] = self def remove_docs_and_texts_from_params(self): # remove documents and texts from params @@ -146,17 +142,16 @@ class VectorStoreVertex(Vertex): self.params.pop("documents", None) self.params.pop("texts", None) - def __getstate__(self): - # We want to save the params attribute - # and if "documents" or "texts" are in the params - # we want to remove them because they have already - # been processed. - params = self.params.copy() - params.pop("documents", None) - params.pop("texts", None) - self.clean_edges_for_pickling() + # def __getstate__(self): + # # We want to save the params attribute + # # and if "documents" or "texts" are in the params + # # we want to remove them because they have already + # # been processed. + # params = self.params.copy() + # params.pop("documents", None) + # params.pop("texts", None) - return super().__getstate__() + # return super().__getstate__() def __setstate__(self, state): super().__setstate__(state) @@ -164,24 +159,24 @@ class VectorStoreVertex(Vertex): class MemoryVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="memory") + def __init__(self, data: Dict, graph): + super().__init__(data, graph=graph, base_type="memory") class RetrieverVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="retrievers") + def __init__(self, data: Dict, graph): + super().__init__(data, graph=graph, base_type="retrievers") class TextSplitterVertex(Vertex): - def __init__(self, data: Dict, params: Optional[Dict] = None): - super().__init__(data, base_type="textsplitters", params=params) + def __init__(self, data: Dict, graph, params: Optional[Dict] = None): + super().__init__(data, graph=graph, base_type="textsplitters", params=params) def _built_object_repr(self): # This built_object is a list of documents. Maybe we should # show how many documents are in the list? - if self._built_object: + if self._built_object and not isinstance(self._built_object, UnbuiltObject): avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object) return f"""{self.vertex_type}({len(self._built_object)} documents) \nAvg. Document Length (characters): {int(avg_length)} @@ -190,8 +185,8 @@ class TextSplitterVertex(Vertex): class ChainVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="chains") + def __init__(self, data: Dict, graph): + super().__init__(data, graph=graph, base_type="chains") async def build( self, @@ -220,8 +215,8 @@ class ChainVertex(Vertex): class PromptVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="prompts") + def __init__(self, data: Dict, graph): + super().__init__(data, graph=graph, base_type="prompts") async def build( self, @@ -271,9 +266,13 @@ class PromptVertex(Vertex): # so the prompt format doesn't break artifacts.pop("handle_keys", None) try: - if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"): + if ( + not hasattr(self._built_object, "template") + and hasattr(self._built_object, "prompt") + and not isinstance(self._built_object, UnbuiltObject) + ): template = self._built_object.prompt.template - else: + elif not isinstance(self._built_object, UnbuiltObject) and hasattr(self._built_object, "template"): template = self._built_object.template for key, value in artifacts.items(): if value: @@ -285,13 +284,13 @@ class PromptVertex(Vertex): class OutputParserVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="output_parsers") + def __init__(self, data: Dict, graph): + super().__init__(data, graph=graph, base_type="output_parsers") class CustomComponentVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="custom_components", is_task=True) + def __init__(self, data: Dict, graph): + super().__init__(data, graph=graph, base_type="custom_components", is_task=True) def _built_object_repr(self): if self.task_id and self.is_task: