feat: add NestedDictInput filter and non-vector search for AstraVectorStoreComponent (#4564)

* NestedDictInput filter and non-vector search for AstraVectorStoreComponent

* [autofix.ci] apply automated fixes

* addressing Ruff linting

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Hare <ericrhare@gmail.com>
This commit is contained in:
Phil Miesle 2024-11-14 20:55:25 +00:00 committed by GitHub
commit 44b0531f6b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2,10 +2,11 @@ import os
import orjson
from astrapy.admin import parse_api_endpoint
from langchain_astradb import AstraDBVectorStore
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers import docs_to_data
from langflow.inputs import DictInput, FloatInput, MessageTextInput
from langflow.inputs import DictInput, FloatInput, MessageTextInput, NestedDictInput
from langflow.io import (
BoolInput,
DataInput,
@ -26,6 +27,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
name = "AstraDB"
icon: str = "AstraDB"
_cached_vector_store: AstraDBVectorStore | None = None
VECTORIZE_PROVIDERS_MAPPING = {
"Azure OpenAI": ["azureOpenAI", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
"Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]],
@ -201,11 +204,17 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
value=0,
advanced=True,
),
DictInput(
name="search_filter",
NestedDictInput(
name="advanced_search_filter",
display_name="Search Metadata Filter",
info="Optional dictionary of filters to apply to the search query.",
advanced=True,
),
DictInput(
name="search_filter",
display_name="[DEPRECATED] Search Metadata Filter",
info="Deprecated: use advanced_search_filter. Optional dictionary of filters to apply to the search query.",
advanced=True,
is_list=True,
),
]
@ -482,43 +491,69 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
return "similarity"
def _build_search_args(self):
args = {
"k": self.number_of_results,
"score_threshold": self.search_score_threshold,
}
query = self.search_input if isinstance(self.search_input, str) and self.search_input.strip() else None
search_filter = (
{k: v for k, v in self.search_filter.items() if k and v and k.strip()} if self.search_filter else None
)
if query:
args = {
"query": query,
"search_type": self._map_search_type(),
"k": self.number_of_results,
"score_threshold": self.search_score_threshold,
}
elif self.advanced_search_filter or search_filter:
args = {
"n": self.number_of_results,
}
else:
return {}
filter_arg = self.advanced_search_filter or {}
if search_filter:
self.log(self.log(f"`search_filter` is deprecated. Use `advanced_search_filter`. Cleaned: {search_filter}"))
filter_arg.update(search_filter)
if filter_arg:
args["filter"] = filter_arg
if self.search_filter:
clean_filter = {k: v for k, v in self.search_filter.items() if k and v}
if len(clean_filter) > 0:
args["filter"] = clean_filter
return args
def search_documents(self, vector_store=None) -> list[Data]:
if not vector_store:
vector_store = self.build_vector_store()
vector_store = vector_store or self.build_vector_store()
self.log(f"Search input: {self.search_input}")
self.log(f"Search type: {self.search_type}")
self.log(f"Number of results: {self.number_of_results}")
if self.search_input and isinstance(self.search_input, str) and self.search_input.strip():
try:
search_type = self._map_search_type()
search_args = self._build_search_args()
try:
search_args = self._build_search_args()
except Exception as e:
msg = f"Error in AstraDBVectorStore._build_search_args: {e}"
raise ValueError(msg) from e
docs = vector_store.search(query=self.search_input, search_type=search_type, **search_args)
except Exception as e:
msg = f"Error performing search in AstraDBVectorStore: {e}"
raise ValueError(msg) from e
if not search_args:
self.log("No search input or filters provided. Skipping search.")
return []
self.log(f"Retrieved documents: {len(docs)}")
docs = []
search_method = "search" if "query" in search_args else "metadata_search"
data = docs_to_data(docs)
self.log(f"Converted documents to data: {len(data)}")
self.status = data
return data
self.log("No search input provided. Skipping search.")
return []
try:
self.log(f"Calling vector_store.{search_method} with args: {search_args}")
docs = getattr(vector_store, search_method)(**search_args)
except Exception as e:
msg = f"Error performing {search_method} in AstraDBVectorStore: {e}"
raise ValueError(msg) from e
self.log(f"Retrieved documents: {len(docs)}")
data = docs_to_data(docs)
self.log(f"Converted documents to data: {len(data)}")
self.status = data
return data
def get_retriever_kwargs(self):
search_args = self._build_search_args()