feat(cassandra/astradb): hybrid search support (#2396)

* cassandra/astradb: hybrid search support

* fix

* fix
This commit is contained in:
Nicolò Boschi 2024-07-02 16:09:11 +02:00 committed by GitHub
commit 30c369f064
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 187 additions and 55 deletions

View file

@ -79,7 +79,7 @@ class LCVectorStoreComponent(Component):
"""
vector_store = self.build_vector_store()
if hasattr(vector_store, "as_retriever"):
retriever = vector_store.as_retriever()
retriever = vector_store.as_retriever(**self.get_retriever_kwargs())
if self.status is None:
self.status = "Retriever built successfully."
return retriever
@ -106,3 +106,9 @@ class LCVectorStoreComponent(Component):
)
self.status = search_results
return search_results
def get_retriever_kwargs(self):
"""
Get the retriever kwargs. Implementations can override this method to provide custom retriever kwargs.
"""
return {}

View file

@ -1,6 +1,9 @@
from loguru import logger
from langchain_core.vectorstores import VectorStore
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.helpers import docs_to_data
from langflow.inputs import FloatInput, DictInput
from langflow.io import (
BoolInput,
DataInput,
@ -20,6 +23,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
documentation: str = "https://python.langchain.com/docs/integrations/vectorstores/astradb"
icon: str = "AstraDB"
_cached_vectorstore: VectorStore = None
inputs = [
StrInput(
name="collection_name",
@ -124,13 +129,6 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
info="Optional dictionary defining the indexing policy for the collection.",
advanced=True,
),
DropdownInput(
name="search_type",
display_name="Search Type",
options=["Similarity", "MMR"],
value="Similarity",
advanced=True,
),
IntInput(
name="number_of_results",
display_name="Number of Results",
@ -138,9 +136,33 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
advanced=True,
value=4,
),
DropdownInput(
name="search_type",
display_name="Search Type",
info="Search type to use",
options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"],
value="Similarity",
advanced=True,
),
FloatInput(
name="search_score_threshold",
display_name="Search Score Threshold",
info="Minimum similarity score threshold for search results. (when using 'Similarity with score threshold')",
value=0,
advanced=True,
),
DictInput(
name="search_filter",
display_name="Search Metadata Filter",
info="Optional dictionary of filters to apply to the search query.",
advanced=True,
is_list=True,
),
]
def _build_vector_store_no_ingest(self):
if self._cached_vectorstore:
return self._cached_vectorstore
try:
from langchain_astradb import AstraDBVectorStore
from langchain_astradb.utils.astradb import SetupMode
@ -199,13 +221,13 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
except Exception as e:
raise ValueError(f"Error initializing AstraDBVectorStore: {str(e)}") from e
self._cached_vectorstore = vector_store
return vector_store
def build_vector_store(self):
vector_store = self._build_vector_store_no_ingest()
if hasattr(self, "ingest_data") and self.ingest_data:
logger.debug("Ingesting data into the Vector Store.")
self._add_documents_to_vector_store(vector_store)
self._add_documents_to_vector_store(vector_store)
return vector_store
def _add_documents_to_vector_store(self, vector_store):
@ -216,7 +238,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
else:
raise ValueError("Vector Store Inputs must be Data objects.")
if documents and self.embedding is not None:
if documents:
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
try:
vector_store.add_documents(documents)
@ -225,8 +247,17 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
else:
logger.debug("No documents to add to the Vector Store.")
def _map_search_type(self):
if self.search_type == "Similarity with score threshold":
return "similarity_score_threshold"
elif self.search_type == "MMR (Max Marginal Relevance)":
return "mmr"
else:
return "similarity"
def search_documents(self) -> list[Data]:
vector_store = self._build_vector_store_no_ingest()
self._add_documents_to_vector_store(vector_store)
logger.debug(f"Search input: {self.search_input}")
logger.debug(f"Search type: {self.search_type}")
@ -234,27 +265,38 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
if self.search_input and isinstance(self.search_input, str) and self.search_input.strip():
try:
if self.search_type == "Similarity":
docs = vector_store.similarity_search(
query=self.search_input,
k=self.number_of_results,
)
elif self.search_type == "MMR":
docs = vector_store.max_marginal_relevance_search(
query=self.search_input,
k=self.number_of_results,
)
else:
raise ValueError(f"Invalid search type: {self.search_type}")
search_type = self._map_search_type()
search_args = self._build_search_args()
docs = vector_store.search(query=self.search_input, search_type=search_type, **search_args)
except Exception as e:
raise ValueError(f"Error performing search in AstraDBVectorStore: {str(e)}") from e
logger.debug(f"Retrieved documents: {len(docs)}")
data = [Data.from_document(doc) for doc in docs]
data = docs_to_data(docs)
logger.debug(f"Converted documents to data: {len(data)}")
self.status = data
return data
else:
logger.debug("No search input provided. Skipping search.")
return []
def _build_search_args(self):
args = {
"k": self.number_of_results,
"score_threshold": self.search_score_threshold,
}
if self.search_filter:
clean_filter = {k: v for k, v in self.search_filter.items() if k and v}
if len(clean_filter) > 0:
args["filter"] = clean_filter
return args
def get_retriever_kwargs(self):
search_args = self._build_search_args()
return {
"search_type": self._map_search_type(),
"search_kwargs": search_args,
}

View file

@ -1,10 +1,10 @@
from typing import List, Optional
from typing import List
from langchain_community.vectorstores import Cassandra
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.helpers.data import docs_to_data
from langflow.inputs import DictInput
from langflow.inputs import DictInput, FloatInput, BoolInput
from langflow.io import (
DataInput,
DropdownInput,
@ -15,6 +15,7 @@ from langflow.io import (
SecretStrInput,
)
from langflow.schema import Data
from loguru import logger
class CassandraVectorStoreComponent(LCVectorStoreComponent):
@ -23,6 +24,8 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
documentation = "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/cassandra"
icon = "Cassandra"
_cached_vectorstore: Cassandra = None
inputs = [
MessageTextInput(
name="database_ref",
@ -64,12 +67,6 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
value=16,
advanced=True,
),
MessageTextInput(
name="body_index_options",
display_name="Body Index Options",
info="Optional options used to create the body index.",
advanced=True,
),
DropdownInput(
name="setup_mode",
display_name="Setup Mode",
@ -99,14 +96,52 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
value=4,
advanced=True,
),
DropdownInput(
name="search_type",
display_name="Search Type",
info="Search type to use",
options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"],
value="Similarity",
advanced=True,
),
FloatInput(
name="search_score_threshold",
display_name="Search Score Threshold",
info="Minimum similarity score threshold for search results. (when using 'Similarity with score threshold')",
value=0,
advanced=True,
),
DictInput(
name="search_filter",
display_name="Search Metadata Filter",
info="Optional dictionary of filters to apply to the search query.",
advanced=True,
is_list=True,
),
MessageTextInput(
name="body_search",
display_name="Search Body",
info="Document textual search terms to apply to the search query.",
advanced=True,
),
BoolInput(
name="enable_body_search",
display_name="Enable Body Search",
info="Flag to enable body search. This must be enabled BEFORE the table is created.",
value=False,
advanced=True,
),
]
def build_vector_store(self) -> Cassandra:
return self._build_cassandra(ingest=True)
return self._build_cassandra()
def _build_cassandra(self, ingest: bool) -> Cassandra:
def _build_cassandra(self) -> Cassandra:
if self._cached_vectorstore:
return self._cached_vectorstore
try:
import cassio
from langchain_community.utilities.cassandra import SetupMode
except ImportError:
raise ImportError(
"Could not import cassio integration package. " "Please install it with `pip install cassio`."
@ -138,49 +173,73 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
password=self.token,
cluster_kwargs=self.cluster_kwargs,
)
ttl_seconds: Optional[int] = self.ttl_seconds
documents = []
if ingest:
for _input in self.ingest_data or []:
if isinstance(_input, Data):
documents.append(_input.to_lc_document())
else:
documents.append(_input)
for _input in self.ingest_data or []:
if isinstance(_input, Data):
documents.append(_input.to_lc_document())
else:
documents.append(_input)
if self.enable_body_search:
body_index_options = [("index_analyzer", "STANDARD")]
else:
body_index_options = None
if self.setup_mode == "Off":
setup_mode = SetupMode.OFF
elif self.setup_mode == "Sync":
setup_mode = SetupMode.SYNC
else:
setup_mode = SetupMode.ASYNC
if documents:
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
table = Cassandra.from_documents(
documents=documents,
embedding=self.embedding,
table_name=self.table_name,
keyspace=self.keyspace,
ttl_seconds=ttl_seconds,
ttl_seconds=self.ttl_seconds or None,
batch_size=self.batch_size,
body_index_options=self.body_index_options,
body_index_options=body_index_options,
)
else:
logger.debug("No documents to add to the Vector Store.")
table = Cassandra(
embedding=self.embedding,
table_name=self.table_name,
keyspace=self.keyspace,
ttl_seconds=ttl_seconds,
body_index_options=self.body_index_options,
setup_mode=self.setup_mode,
ttl_seconds=self.ttl_seconds or None,
body_index_options=body_index_options,
setup_mode=setup_mode,
)
self._cached_vectorstore = table
return table
def _map_search_type(self):
if self.search_type == "Similarity with score threshold":
return "similarity_score_threshold"
elif self.search_type == "MMR (Max Marginal Relevance)":
return "mmr"
else:
return "similarity"
def search_documents(self) -> List[Data]:
vector_store = self._build_cassandra(ingest=False)
vector_store = self._build_cassandra()
logger.debug(f"Search input: {self.search_query}")
logger.debug(f"Search type: {self.search_type}")
logger.debug(f"Number of results: {self.number_of_results}")
if self.search_query and isinstance(self.search_query, str) and self.search_query.strip():
try:
docs = vector_store.similarity_search(
query=self.search_query,
k=self.number_of_results,
)
search_type = self._map_search_type()
search_args = self._build_search_args()
logger.debug(f"Search args: {str(search_args)}")
docs = vector_store.search(query=self.search_query, search_type=search_type, **search_args)
except KeyError as e:
if "content" in str(e):
raise ValueError(
@ -189,8 +248,33 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
else:
raise e
logger.debug(f"Retrieved documents: {len(docs)}")
data = docs_to_data(docs)
self.status = data
return data
else:
return []
def _build_search_args(self):
args = {
"k": self.number_of_results,
"score_threshold": self.search_score_threshold,
}
if self.search_filter:
clean_filter = {k: v for k, v in self.search_filter.items() if k and v}
if len(clean_filter) > 0:
args["filter"] = clean_filter
if self.body_search:
if not self.enable_body_search:
raise ValueError("You should enable body search when creating the table to search the body field.")
args["body_search"] = self.body_search
return args
def get_retriever_kwargs(self):
search_args = self._build_search_args()
return {
"search_type": self._map_search_type(),
"search_kwargs": search_args,
}