From e597a6e20a59f7a6ff8690e634ca202df7d89271 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Tue, 23 May 2023 16:51:29 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=80=20chore(config):=20add=20DocArrayH?= =?UTF-8?q?nswSearch=20and=20DocArrayInMemorySearch=20to=20vectorstores=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(base.py):=20correctly=20handle=20nested=20li?= =?UTF-8?q?sts=20in=20Node.build()=20method=20=E2=9C=A8=20feat(vector=5Fst?= =?UTF-8?q?ore):=20add=20VectorStoreFrontendNode=20to=20handle=20vector=20?= =?UTF-8?q?store=20templates=20=F0=9F=90=9B=20fix(util.py):=20add=20build?= =?UTF-8?q?=5Ftemplate=5Ffrom=5Fmethod=20to=20correctly=20build=20template?= =?UTF-8?q?s=20from=20class=20methods=20The=20configuration=20file=20now?= =?UTF-8?q?=20includes=20two=20new=20vector=20stores,=20DocArrayHnswSearch?= =?UTF-8?q?=20and=20DocArrayInMemorySearch.=20The=20Node.build()=20method?= =?UTF-8?q?=20now=20correctly=20handles=20nested=20lists.=20A=20new=20Vect?= =?UTF-8?q?orStoreFrontendNode=20has=20been=20added=20to=20handle=20vector?= =?UTF-8?q?=20store=20templates.=20The=20build=5Ftemplate=5Ffrom=5Fmethod?= =?UTF-8?q?=20function=20has=20been=20added=20to=20correctly=20build=20tem?= =?UTF-8?q?plates=20from=20class=20methods.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue #335 --- src/backend/langflow/config.yaml | 2 + src/backend/langflow/graph/base.py | 8 ++- .../langflow/interface/vector_store/base.py | 47 ++++++++------ src/backend/langflow/template/nodes.py | 29 +++++++++ src/backend/langflow/utils/util.py | 64 +++++++++++++++++++ 5 files changed, 130 insertions(+), 20 deletions(-) diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index 84fd12fcd..36668c737 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -83,6 +83,8 @@ embeddings: vectorstores: - Chroma + - DocArrayHnswSearch + - DocArrayInMemorySearch documentloaders: - AirbyteJSONLoader diff --git a/src/backend/langflow/graph/base.py b/src/backend/langflow/graph/base.py index 976a9c1cf..187d2983e 100644 --- a/src/backend/langflow/graph/base.py +++ b/src/backend/langflow/graph/base.py @@ -180,7 +180,13 @@ class Node: elif isinstance(value, list) and all( isinstance(node, Node) for node in value ): - self.params[key] = [node.build() for node in value] # type: ignore + self.params[key] = [] + for node in value: + built = node.build() + if isinstance(built, list): + self.params[key].extend(built) + else: + self.params[key].append(built) # Get the class from LANGCHAIN_TYPES_DICT # and instantiate it with the params diff --git a/src/backend/langflow/interface/vector_store/base.py b/src/backend/langflow/interface/vector_store/base.py index 7fca2ba0c..425fa9559 100644 --- a/src/backend/langflow/interface/vector_store/base.py +++ b/src/backend/langflow/interface/vector_store/base.py @@ -1,15 +1,20 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Type from langflow.interface.base import LangChainTypeCreator from langflow.interface.custom_lists import vectorstores_type_to_cls_dict from langflow.settings import settings +from langflow.template.nodes import VectorStoreFrontendNode from langflow.utils.logger import logger -from langflow.utils.util import build_template_from_class +from langflow.utils.util import build_template_from_class, build_template_from_method class VectorstoreCreator(LangChainTypeCreator): type_name: str = "vectorstores" + @property + def frontend_node_class(self) -> Type[VectorStoreFrontendNode]: + return VectorStoreFrontendNode + @property def type_to_loader_dict(self) -> Dict: return vectorstores_type_to_cls_dict @@ -17,25 +22,29 @@ class VectorstoreCreator(LangChainTypeCreator): def get_signature(self, name: str) -> Optional[Dict]: """Get the signature of an embedding.""" try: - signature = build_template_from_class(name, vectorstores_type_to_cls_dict) + signature = build_template_from_method( + name, + type_to_cls_dict=vectorstores_type_to_cls_dict, + method_name="from_texts", + ) # TODO: Use FrontendendNode class to build the signature - signature["template"] = { - "documents": { - "type": "TextSplitter", - "required": True, - "show": True, - "name": "documents", - "display_name": "Text Splitter", - }, - "embedding": { - "type": "Embeddings", - "required": True, - "show": True, - "name": "embedding", - "display_name": "Embedding", - }, - } + # signature["template"] = { + # "documents": { + # "type": "TextSplitter", + # "required": True, + # "show": True, + # "name": "documents", + # "display_name": "Text Splitter", + # }, + # "embedding": { + # "type": "Embeddings", + # "required": True, + # "show": True, + # "name": "embedding", + # "display_name": "Embedding", + # }, + # } return signature except ValueError as exc: diff --git a/src/backend/langflow/template/nodes.py b/src/backend/langflow/template/nodes.py index 9fad19508..f01a99144 100644 --- a/src/backend/langflow/template/nodes.py +++ b/src/backend/langflow/template/nodes.py @@ -628,3 +628,32 @@ class EmbeddingFrontendNode(FrontendNode): FrontendNode.format_field(field, name) if field.name == "headers": field.show = False + + +class VectorStoreFrontendNode(FrontendNode): + @staticmethod + def format_field(field: TemplateField, name: Optional[str] = None) -> None: + FrontendNode.format_field(field, name) + if field.name == "texts": + field.name = "documents" + field.field_type = "TextSplitter" + field.display_name = "Text Splitter" + field.required = True + field.show = True + field.advanced = False + + if "embedding" in field.name: + # for backwards compatibility + field.name = "embedding" + field.required = True + field.show = True + field.advanced = False + field.display_name = "Embedding" + field.field_type = "Embeddings" + + elif field.name == "n_dim": + field.show = True + field.advanced = True + elif field.name == "work_dir": + field.show = True + field.advanced = False diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index e959b0103..9f2e53bc5 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -160,6 +160,70 @@ def build_template_from_class( } +def build_template_from_method( + class_name: str, + method_name: str, + type_to_cls_dict: Dict, + add_function: bool = False, +): + classes = [item.__name__ for item in type_to_cls_dict.values()] + + # Raise error if class_name is not in classes + if class_name not in classes: + raise ValueError(f"{class_name} not found.") + + for _type, v in type_to_cls_dict.items(): + if v.__name__ == class_name: + _class = v + + # Check if the method exists in this class + if not hasattr(_class, method_name): + raise ValueError( + f"Method {method_name} not found in class {class_name}" + ) + + # Get the method + method = getattr(_class, method_name) + + # Get the docstring + docs = parse(method.__doc__) + + # Get the signature of the method + sig = inspect.signature(method) + + # Get the parameters of the method + params = sig.parameters + + # Initialize the variables dictionary with method parameters + variables = { + "_type": _type, + **{ + name: { + "default": param.default + if param.default != param.empty + else None, + "type": param.annotation + if param.annotation != param.empty + else None, + "required": param.default == param.empty, + } + for name, param in params.items() + }, + } + + base_classes = get_base_classes(_class) + + # Adding function to base classes to allow the output to be a function + if add_function: + base_classes.append("function") + + return { + "template": format_dict(variables, class_name), + "description": docs.short_description or "", + "base_classes": base_classes, + } + + def get_base_classes(cls): """Get the base classes of a class. These are used to determine the output of the nodes.