diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index 84fd12fcd..36668c737 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -83,6 +83,8 @@ embeddings: vectorstores: - Chroma + - DocArrayHnswSearch + - DocArrayInMemorySearch documentloaders: - AirbyteJSONLoader diff --git a/src/backend/langflow/graph/base.py b/src/backend/langflow/graph/base.py index 976a9c1cf..187d2983e 100644 --- a/src/backend/langflow/graph/base.py +++ b/src/backend/langflow/graph/base.py @@ -180,7 +180,13 @@ class Node: elif isinstance(value, list) and all( isinstance(node, Node) for node in value ): - self.params[key] = [node.build() for node in value] # type: ignore + self.params[key] = [] + for node in value: + built = node.build() + if isinstance(built, list): + self.params[key].extend(built) + else: + self.params[key].append(built) # Get the class from LANGCHAIN_TYPES_DICT # and instantiate it with the params diff --git a/src/backend/langflow/interface/vector_store/base.py b/src/backend/langflow/interface/vector_store/base.py index 7fca2ba0c..425fa9559 100644 --- a/src/backend/langflow/interface/vector_store/base.py +++ b/src/backend/langflow/interface/vector_store/base.py @@ -1,15 +1,20 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Type from langflow.interface.base import LangChainTypeCreator from langflow.interface.custom_lists import vectorstores_type_to_cls_dict from langflow.settings import settings +from langflow.template.nodes import VectorStoreFrontendNode from langflow.utils.logger import logger -from langflow.utils.util import build_template_from_class +from langflow.utils.util import build_template_from_class, build_template_from_method class VectorstoreCreator(LangChainTypeCreator): type_name: str = "vectorstores" + @property + def frontend_node_class(self) -> Type[VectorStoreFrontendNode]: + return VectorStoreFrontendNode + @property def type_to_loader_dict(self) -> Dict: return vectorstores_type_to_cls_dict @@ -17,25 +22,29 @@ class VectorstoreCreator(LangChainTypeCreator): def get_signature(self, name: str) -> Optional[Dict]: """Get the signature of an embedding.""" try: - signature = build_template_from_class(name, vectorstores_type_to_cls_dict) + signature = build_template_from_method( + name, + type_to_cls_dict=vectorstores_type_to_cls_dict, + method_name="from_texts", + ) # TODO: Use FrontendendNode class to build the signature - signature["template"] = { - "documents": { - "type": "TextSplitter", - "required": True, - "show": True, - "name": "documents", - "display_name": "Text Splitter", - }, - "embedding": { - "type": "Embeddings", - "required": True, - "show": True, - "name": "embedding", - "display_name": "Embedding", - }, - } + # signature["template"] = { + # "documents": { + # "type": "TextSplitter", + # "required": True, + # "show": True, + # "name": "documents", + # "display_name": "Text Splitter", + # }, + # "embedding": { + # "type": "Embeddings", + # "required": True, + # "show": True, + # "name": "embedding", + # "display_name": "Embedding", + # }, + # } return signature except ValueError as exc: diff --git a/src/backend/langflow/template/nodes.py b/src/backend/langflow/template/nodes.py index 9fad19508..f01a99144 100644 --- a/src/backend/langflow/template/nodes.py +++ b/src/backend/langflow/template/nodes.py @@ -628,3 +628,32 @@ class EmbeddingFrontendNode(FrontendNode): FrontendNode.format_field(field, name) if field.name == "headers": field.show = False + + +class VectorStoreFrontendNode(FrontendNode): + @staticmethod + def format_field(field: TemplateField, name: Optional[str] = None) -> None: + FrontendNode.format_field(field, name) + if field.name == "texts": + field.name = "documents" + field.field_type = "TextSplitter" + field.display_name = "Text Splitter" + field.required = True + field.show = True + field.advanced = False + + if "embedding" in field.name: + # for backwards compatibility + field.name = "embedding" + field.required = True + field.show = True + field.advanced = False + field.display_name = "Embedding" + field.field_type = "Embeddings" + + elif field.name == "n_dim": + field.show = True + field.advanced = True + elif field.name == "work_dir": + field.show = True + field.advanced = False diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index e959b0103..9f2e53bc5 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -160,6 +160,70 @@ def build_template_from_class( } +def build_template_from_method( + class_name: str, + method_name: str, + type_to_cls_dict: Dict, + add_function: bool = False, +): + classes = [item.__name__ for item in type_to_cls_dict.values()] + + # Raise error if class_name is not in classes + if class_name not in classes: + raise ValueError(f"{class_name} not found.") + + for _type, v in type_to_cls_dict.items(): + if v.__name__ == class_name: + _class = v + + # Check if the method exists in this class + if not hasattr(_class, method_name): + raise ValueError( + f"Method {method_name} not found in class {class_name}" + ) + + # Get the method + method = getattr(_class, method_name) + + # Get the docstring + docs = parse(method.__doc__) + + # Get the signature of the method + sig = inspect.signature(method) + + # Get the parameters of the method + params = sig.parameters + + # Initialize the variables dictionary with method parameters + variables = { + "_type": _type, + **{ + name: { + "default": param.default + if param.default != param.empty + else None, + "type": param.annotation + if param.annotation != param.empty + else None, + "required": param.default == param.empty, + } + for name, param in params.items() + }, + } + + base_classes = get_base_classes(_class) + + # Adding function to base classes to allow the output to be a function + if add_function: + base_classes.append("function") + + return { + "template": format_dict(variables, class_name), + "description": docs.short_description or "", + "base_classes": base_classes, + } + + def get_base_classes(cls): """Get the base classes of a class. These are used to determine the output of the nodes.