🔀 chore(config): add DocArrayHnswSearch and DocArrayInMemorySearch to vectorstores
🐛 fix(base.py): correctly handle nested lists in Node.build() method ✨ feat(vector_store): add VectorStoreFrontendNode to handle vector store templates 🐛 fix(util.py): add build_template_from_method to correctly build templates from class methods The configuration file now includes two new vector stores, DocArrayHnswSearch and DocArrayInMemorySearch. The Node.build() method now correctly handles nested lists. A new VectorStoreFrontendNode has been added to handle vector store templates. The build_template_from_method function has been added to correctly build templates from class methods. Issue #335
This commit is contained in:
parent
d24404eba9
commit
e597a6e20a
5 changed files with 130 additions and 20 deletions
|
|
@ -83,6 +83,8 @@ embeddings:
|
|||
|
||||
vectorstores:
|
||||
- Chroma
|
||||
- DocArrayHnswSearch
|
||||
- DocArrayInMemorySearch
|
||||
|
||||
documentloaders:
|
||||
- AirbyteJSONLoader
|
||||
|
|
|
|||
|
|
@ -180,7 +180,13 @@ class Node:
|
|||
elif isinstance(value, list) and all(
|
||||
isinstance(node, Node) for node in value
|
||||
):
|
||||
self.params[key] = [node.build() for node in value] # type: ignore
|
||||
self.params[key] = []
|
||||
for node in value:
|
||||
built = node.build()
|
||||
if isinstance(built, list):
|
||||
self.params[key].extend(built)
|
||||
else:
|
||||
self.params[key].append(built)
|
||||
|
||||
# Get the class from LANGCHAIN_TYPES_DICT
|
||||
# and instantiate it with the params
|
||||
|
|
|
|||
|
|
@ -1,15 +1,20 @@
|
|||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom_lists import vectorstores_type_to_cls_dict
|
||||
from langflow.settings import settings
|
||||
from langflow.template.nodes import VectorStoreFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
|
||||
|
||||
class VectorstoreCreator(LangChainTypeCreator):
|
||||
type_name: str = "vectorstores"
|
||||
|
||||
@property
|
||||
def frontend_node_class(self) -> Type[VectorStoreFrontendNode]:
|
||||
return VectorStoreFrontendNode
|
||||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
return vectorstores_type_to_cls_dict
|
||||
|
|
@ -17,25 +22,29 @@ class VectorstoreCreator(LangChainTypeCreator):
|
|||
def get_signature(self, name: str) -> Optional[Dict]:
|
||||
"""Get the signature of an embedding."""
|
||||
try:
|
||||
signature = build_template_from_class(name, vectorstores_type_to_cls_dict)
|
||||
signature = build_template_from_method(
|
||||
name,
|
||||
type_to_cls_dict=vectorstores_type_to_cls_dict,
|
||||
method_name="from_texts",
|
||||
)
|
||||
|
||||
# TODO: Use FrontendendNode class to build the signature
|
||||
signature["template"] = {
|
||||
"documents": {
|
||||
"type": "TextSplitter",
|
||||
"required": True,
|
||||
"show": True,
|
||||
"name": "documents",
|
||||
"display_name": "Text Splitter",
|
||||
},
|
||||
"embedding": {
|
||||
"type": "Embeddings",
|
||||
"required": True,
|
||||
"show": True,
|
||||
"name": "embedding",
|
||||
"display_name": "Embedding",
|
||||
},
|
||||
}
|
||||
# signature["template"] = {
|
||||
# "documents": {
|
||||
# "type": "TextSplitter",
|
||||
# "required": True,
|
||||
# "show": True,
|
||||
# "name": "documents",
|
||||
# "display_name": "Text Splitter",
|
||||
# },
|
||||
# "embedding": {
|
||||
# "type": "Embeddings",
|
||||
# "required": True,
|
||||
# "show": True,
|
||||
# "name": "embedding",
|
||||
# "display_name": "Embedding",
|
||||
# },
|
||||
# }
|
||||
return signature
|
||||
|
||||
except ValueError as exc:
|
||||
|
|
|
|||
|
|
@ -628,3 +628,32 @@ class EmbeddingFrontendNode(FrontendNode):
|
|||
FrontendNode.format_field(field, name)
|
||||
if field.name == "headers":
|
||||
field.show = False
|
||||
|
||||
|
||||
class VectorStoreFrontendNode(FrontendNode):
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
if field.name == "texts":
|
||||
field.name = "documents"
|
||||
field.field_type = "TextSplitter"
|
||||
field.display_name = "Text Splitter"
|
||||
field.required = True
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
|
||||
if "embedding" in field.name:
|
||||
# for backwards compatibility
|
||||
field.name = "embedding"
|
||||
field.required = True
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
field.display_name = "Embedding"
|
||||
field.field_type = "Embeddings"
|
||||
|
||||
elif field.name == "n_dim":
|
||||
field.show = True
|
||||
field.advanced = True
|
||||
elif field.name == "work_dir":
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
|
|
|
|||
|
|
@ -160,6 +160,70 @@ def build_template_from_class(
|
|||
}
|
||||
|
||||
|
||||
def build_template_from_method(
|
||||
class_name: str,
|
||||
method_name: str,
|
||||
type_to_cls_dict: Dict,
|
||||
add_function: bool = False,
|
||||
):
|
||||
classes = [item.__name__ for item in type_to_cls_dict.values()]
|
||||
|
||||
# Raise error if class_name is not in classes
|
||||
if class_name not in classes:
|
||||
raise ValueError(f"{class_name} not found.")
|
||||
|
||||
for _type, v in type_to_cls_dict.items():
|
||||
if v.__name__ == class_name:
|
||||
_class = v
|
||||
|
||||
# Check if the method exists in this class
|
||||
if not hasattr(_class, method_name):
|
||||
raise ValueError(
|
||||
f"Method {method_name} not found in class {class_name}"
|
||||
)
|
||||
|
||||
# Get the method
|
||||
method = getattr(_class, method_name)
|
||||
|
||||
# Get the docstring
|
||||
docs = parse(method.__doc__)
|
||||
|
||||
# Get the signature of the method
|
||||
sig = inspect.signature(method)
|
||||
|
||||
# Get the parameters of the method
|
||||
params = sig.parameters
|
||||
|
||||
# Initialize the variables dictionary with method parameters
|
||||
variables = {
|
||||
"_type": _type,
|
||||
**{
|
||||
name: {
|
||||
"default": param.default
|
||||
if param.default != param.empty
|
||||
else None,
|
||||
"type": param.annotation
|
||||
if param.annotation != param.empty
|
||||
else None,
|
||||
"required": param.default == param.empty,
|
||||
}
|
||||
for name, param in params.items()
|
||||
},
|
||||
}
|
||||
|
||||
base_classes = get_base_classes(_class)
|
||||
|
||||
# Adding function to base classes to allow the output to be a function
|
||||
if add_function:
|
||||
base_classes.append("function")
|
||||
|
||||
return {
|
||||
"template": format_dict(variables, class_name),
|
||||
"description": docs.short_description or "",
|
||||
"base_classes": base_classes,
|
||||
}
|
||||
|
||||
|
||||
def get_base_classes(cls):
|
||||
"""Get the base classes of a class.
|
||||
These are used to determine the output of the nodes.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue