🔀 chore(config): add DocArrayHnswSearch and DocArrayInMemorySearch to vectorstores

🐛 fix(base.py): correctly handle nested lists in Node.build() method
 feat(vector_store): add VectorStoreFrontendNode to handle vector store templates
🐛 fix(util.py): add build_template_from_method to correctly build templates from class methods
The configuration file now includes two new vector stores, DocArrayHnswSearch and DocArrayInMemorySearch. The Node.build() method now correctly handles nested lists. A new VectorStoreFrontendNode has been added to handle vector store templates. The build_template_from_method function has been added to correctly build templates from class methods.

Issue #335
This commit is contained in:
Gabriel Almeida 2023-05-23 16:51:29 -03:00
commit e597a6e20a
5 changed files with 130 additions and 20 deletions

View file

@ -83,6 +83,8 @@ embeddings:
vectorstores:
- Chroma
- DocArrayHnswSearch
- DocArrayInMemorySearch
documentloaders:
- AirbyteJSONLoader

View file

@ -180,7 +180,13 @@ class Node:
elif isinstance(value, list) and all(
isinstance(node, Node) for node in value
):
self.params[key] = [node.build() for node in value] # type: ignore
self.params[key] = []
for node in value:
built = node.build()
if isinstance(built, list):
self.params[key].extend(built)
else:
self.params[key].append(built)
# Get the class from LANGCHAIN_TYPES_DICT
# and instantiate it with the params

View file

@ -1,15 +1,20 @@
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Type
from langflow.interface.base import LangChainTypeCreator
from langflow.interface.custom_lists import vectorstores_type_to_cls_dict
from langflow.settings import settings
from langflow.template.nodes import VectorStoreFrontendNode
from langflow.utils.logger import logger
from langflow.utils.util import build_template_from_class
from langflow.utils.util import build_template_from_class, build_template_from_method
class VectorstoreCreator(LangChainTypeCreator):
type_name: str = "vectorstores"
@property
def frontend_node_class(self) -> Type[VectorStoreFrontendNode]:
return VectorStoreFrontendNode
@property
def type_to_loader_dict(self) -> Dict:
return vectorstores_type_to_cls_dict
@ -17,25 +22,29 @@ class VectorstoreCreator(LangChainTypeCreator):
def get_signature(self, name: str) -> Optional[Dict]:
"""Get the signature of an embedding."""
try:
signature = build_template_from_class(name, vectorstores_type_to_cls_dict)
signature = build_template_from_method(
name,
type_to_cls_dict=vectorstores_type_to_cls_dict,
method_name="from_texts",
)
# TODO: Use FrontendendNode class to build the signature
signature["template"] = {
"documents": {
"type": "TextSplitter",
"required": True,
"show": True,
"name": "documents",
"display_name": "Text Splitter",
},
"embedding": {
"type": "Embeddings",
"required": True,
"show": True,
"name": "embedding",
"display_name": "Embedding",
},
}
# signature["template"] = {
# "documents": {
# "type": "TextSplitter",
# "required": True,
# "show": True,
# "name": "documents",
# "display_name": "Text Splitter",
# },
# "embedding": {
# "type": "Embeddings",
# "required": True,
# "show": True,
# "name": "embedding",
# "display_name": "Embedding",
# },
# }
return signature
except ValueError as exc:

View file

@ -628,3 +628,32 @@ class EmbeddingFrontendNode(FrontendNode):
FrontendNode.format_field(field, name)
if field.name == "headers":
field.show = False
class VectorStoreFrontendNode(FrontendNode):
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
if field.name == "texts":
field.name = "documents"
field.field_type = "TextSplitter"
field.display_name = "Text Splitter"
field.required = True
field.show = True
field.advanced = False
if "embedding" in field.name:
# for backwards compatibility
field.name = "embedding"
field.required = True
field.show = True
field.advanced = False
field.display_name = "Embedding"
field.field_type = "Embeddings"
elif field.name == "n_dim":
field.show = True
field.advanced = True
elif field.name == "work_dir":
field.show = True
field.advanced = False

View file

@ -160,6 +160,70 @@ def build_template_from_class(
}
def build_template_from_method(
class_name: str,
method_name: str,
type_to_cls_dict: Dict,
add_function: bool = False,
):
classes = [item.__name__ for item in type_to_cls_dict.values()]
# Raise error if class_name is not in classes
if class_name not in classes:
raise ValueError(f"{class_name} not found.")
for _type, v in type_to_cls_dict.items():
if v.__name__ == class_name:
_class = v
# Check if the method exists in this class
if not hasattr(_class, method_name):
raise ValueError(
f"Method {method_name} not found in class {class_name}"
)
# Get the method
method = getattr(_class, method_name)
# Get the docstring
docs = parse(method.__doc__)
# Get the signature of the method
sig = inspect.signature(method)
# Get the parameters of the method
params = sig.parameters
# Initialize the variables dictionary with method parameters
variables = {
"_type": _type,
**{
name: {
"default": param.default
if param.default != param.empty
else None,
"type": param.annotation
if param.annotation != param.empty
else None,
"required": param.default == param.empty,
}
for name, param in params.items()
},
}
base_classes = get_base_classes(_class)
# Adding function to base classes to allow the output to be a function
if add_function:
base_classes.append("function")
return {
"template": format_dict(variables, class_name),
"description": docs.short_description or "",
"base_classes": base_classes,
}
def get_base_classes(cls):
"""Get the base classes of a class.
These are used to determine the output of the nodes.