feat(cassandra/astradb): hybrid search support (#2396)
* cassandra/astradb: hybrid search support * fix * fix
This commit is contained in:
parent
805df8298a
commit
30c369f064
3 changed files with 187 additions and 55 deletions
|
|
@ -79,7 +79,7 @@ class LCVectorStoreComponent(Component):
|
|||
"""
|
||||
vector_store = self.build_vector_store()
|
||||
if hasattr(vector_store, "as_retriever"):
|
||||
retriever = vector_store.as_retriever()
|
||||
retriever = vector_store.as_retriever(**self.get_retriever_kwargs())
|
||||
if self.status is None:
|
||||
self.status = "Retriever built successfully."
|
||||
return retriever
|
||||
|
|
@ -106,3 +106,9 @@ class LCVectorStoreComponent(Component):
|
|||
)
|
||||
self.status = search_results
|
||||
return search_results
|
||||
|
||||
def get_retriever_kwargs(self):
|
||||
"""
|
||||
Get the retriever kwargs. Implementations can override this method to provide custom retriever kwargs.
|
||||
"""
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
from loguru import logger
|
||||
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from langflow.base.vectorstores.model import LCVectorStoreComponent
|
||||
from langflow.helpers import docs_to_data
|
||||
from langflow.inputs import FloatInput, DictInput
|
||||
from langflow.io import (
|
||||
BoolInput,
|
||||
DataInput,
|
||||
|
|
@ -20,6 +23,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
documentation: str = "https://python.langchain.com/docs/integrations/vectorstores/astradb"
|
||||
icon: str = "AstraDB"
|
||||
|
||||
_cached_vectorstore: VectorStore = None
|
||||
|
||||
inputs = [
|
||||
StrInput(
|
||||
name="collection_name",
|
||||
|
|
@ -124,13 +129,6 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
info="Optional dictionary defining the indexing policy for the collection.",
|
||||
advanced=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="search_type",
|
||||
display_name="Search Type",
|
||||
options=["Similarity", "MMR"],
|
||||
value="Similarity",
|
||||
advanced=True,
|
||||
),
|
||||
IntInput(
|
||||
name="number_of_results",
|
||||
display_name="Number of Results",
|
||||
|
|
@ -138,9 +136,33 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
advanced=True,
|
||||
value=4,
|
||||
),
|
||||
DropdownInput(
|
||||
name="search_type",
|
||||
display_name="Search Type",
|
||||
info="Search type to use",
|
||||
options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"],
|
||||
value="Similarity",
|
||||
advanced=True,
|
||||
),
|
||||
FloatInput(
|
||||
name="search_score_threshold",
|
||||
display_name="Search Score Threshold",
|
||||
info="Minimum similarity score threshold for search results. (when using 'Similarity with score threshold')",
|
||||
value=0,
|
||||
advanced=True,
|
||||
),
|
||||
DictInput(
|
||||
name="search_filter",
|
||||
display_name="Search Metadata Filter",
|
||||
info="Optional dictionary of filters to apply to the search query.",
|
||||
advanced=True,
|
||||
is_list=True,
|
||||
),
|
||||
]
|
||||
|
||||
def _build_vector_store_no_ingest(self):
|
||||
if self._cached_vectorstore:
|
||||
return self._cached_vectorstore
|
||||
try:
|
||||
from langchain_astradb import AstraDBVectorStore
|
||||
from langchain_astradb.utils.astradb import SetupMode
|
||||
|
|
@ -199,13 +221,13 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
except Exception as e:
|
||||
raise ValueError(f"Error initializing AstraDBVectorStore: {str(e)}") from e
|
||||
|
||||
self._cached_vectorstore = vector_store
|
||||
|
||||
return vector_store
|
||||
|
||||
def build_vector_store(self):
|
||||
vector_store = self._build_vector_store_no_ingest()
|
||||
if hasattr(self, "ingest_data") and self.ingest_data:
|
||||
logger.debug("Ingesting data into the Vector Store.")
|
||||
self._add_documents_to_vector_store(vector_store)
|
||||
self._add_documents_to_vector_store(vector_store)
|
||||
return vector_store
|
||||
|
||||
def _add_documents_to_vector_store(self, vector_store):
|
||||
|
|
@ -216,7 +238,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
else:
|
||||
raise ValueError("Vector Store Inputs must be Data objects.")
|
||||
|
||||
if documents and self.embedding is not None:
|
||||
if documents:
|
||||
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
|
||||
try:
|
||||
vector_store.add_documents(documents)
|
||||
|
|
@ -225,8 +247,17 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
else:
|
||||
logger.debug("No documents to add to the Vector Store.")
|
||||
|
||||
def _map_search_type(self):
|
||||
if self.search_type == "Similarity with score threshold":
|
||||
return "similarity_score_threshold"
|
||||
elif self.search_type == "MMR (Max Marginal Relevance)":
|
||||
return "mmr"
|
||||
else:
|
||||
return "similarity"
|
||||
|
||||
def search_documents(self) -> list[Data]:
|
||||
vector_store = self._build_vector_store_no_ingest()
|
||||
self._add_documents_to_vector_store(vector_store)
|
||||
|
||||
logger.debug(f"Search input: {self.search_input}")
|
||||
logger.debug(f"Search type: {self.search_type}")
|
||||
|
|
@ -234,27 +265,38 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
|
||||
if self.search_input and isinstance(self.search_input, str) and self.search_input.strip():
|
||||
try:
|
||||
if self.search_type == "Similarity":
|
||||
docs = vector_store.similarity_search(
|
||||
query=self.search_input,
|
||||
k=self.number_of_results,
|
||||
)
|
||||
elif self.search_type == "MMR":
|
||||
docs = vector_store.max_marginal_relevance_search(
|
||||
query=self.search_input,
|
||||
k=self.number_of_results,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid search type: {self.search_type}")
|
||||
search_type = self._map_search_type()
|
||||
search_args = self._build_search_args()
|
||||
|
||||
docs = vector_store.search(query=self.search_input, search_type=search_type, **search_args)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error performing search in AstraDBVectorStore: {str(e)}") from e
|
||||
|
||||
logger.debug(f"Retrieved documents: {len(docs)}")
|
||||
|
||||
data = [Data.from_document(doc) for doc in docs]
|
||||
data = docs_to_data(docs)
|
||||
logger.debug(f"Converted documents to data: {len(data)}")
|
||||
self.status = data
|
||||
return data
|
||||
else:
|
||||
logger.debug("No search input provided. Skipping search.")
|
||||
return []
|
||||
|
||||
def _build_search_args(self):
|
||||
args = {
|
||||
"k": self.number_of_results,
|
||||
"score_threshold": self.search_score_threshold,
|
||||
}
|
||||
|
||||
if self.search_filter:
|
||||
clean_filter = {k: v for k, v in self.search_filter.items() if k and v}
|
||||
if len(clean_filter) > 0:
|
||||
args["filter"] = clean_filter
|
||||
return args
|
||||
|
||||
def get_retriever_kwargs(self):
|
||||
search_args = self._build_search_args()
|
||||
return {
|
||||
"search_type": self._map_search_type(),
|
||||
"search_kwargs": search_args,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from langchain_community.vectorstores import Cassandra
|
||||
|
||||
from langflow.base.vectorstores.model import LCVectorStoreComponent
|
||||
from langflow.helpers.data import docs_to_data
|
||||
from langflow.inputs import DictInput
|
||||
from langflow.inputs import DictInput, FloatInput, BoolInput
|
||||
from langflow.io import (
|
||||
DataInput,
|
||||
DropdownInput,
|
||||
|
|
@ -15,6 +15,7 @@ from langflow.io import (
|
|||
SecretStrInput,
|
||||
)
|
||||
from langflow.schema import Data
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CassandraVectorStoreComponent(LCVectorStoreComponent):
|
||||
|
|
@ -23,6 +24,8 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
|
|||
documentation = "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/cassandra"
|
||||
icon = "Cassandra"
|
||||
|
||||
_cached_vectorstore: Cassandra = None
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
name="database_ref",
|
||||
|
|
@ -64,12 +67,6 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
|
|||
value=16,
|
||||
advanced=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="body_index_options",
|
||||
display_name="Body Index Options",
|
||||
info="Optional options used to create the body index.",
|
||||
advanced=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="setup_mode",
|
||||
display_name="Setup Mode",
|
||||
|
|
@ -99,14 +96,52 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
|
|||
value=4,
|
||||
advanced=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="search_type",
|
||||
display_name="Search Type",
|
||||
info="Search type to use",
|
||||
options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"],
|
||||
value="Similarity",
|
||||
advanced=True,
|
||||
),
|
||||
FloatInput(
|
||||
name="search_score_threshold",
|
||||
display_name="Search Score Threshold",
|
||||
info="Minimum similarity score threshold for search results. (when using 'Similarity with score threshold')",
|
||||
value=0,
|
||||
advanced=True,
|
||||
),
|
||||
DictInput(
|
||||
name="search_filter",
|
||||
display_name="Search Metadata Filter",
|
||||
info="Optional dictionary of filters to apply to the search query.",
|
||||
advanced=True,
|
||||
is_list=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="body_search",
|
||||
display_name="Search Body",
|
||||
info="Document textual search terms to apply to the search query.",
|
||||
advanced=True,
|
||||
),
|
||||
BoolInput(
|
||||
name="enable_body_search",
|
||||
display_name="Enable Body Search",
|
||||
info="Flag to enable body search. This must be enabled BEFORE the table is created.",
|
||||
value=False,
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
def build_vector_store(self) -> Cassandra:
|
||||
return self._build_cassandra(ingest=True)
|
||||
return self._build_cassandra()
|
||||
|
||||
def _build_cassandra(self, ingest: bool) -> Cassandra:
|
||||
def _build_cassandra(self) -> Cassandra:
|
||||
if self._cached_vectorstore:
|
||||
return self._cached_vectorstore
|
||||
try:
|
||||
import cassio
|
||||
from langchain_community.utilities.cassandra import SetupMode
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import cassio integration package. " "Please install it with `pip install cassio`."
|
||||
|
|
@ -138,49 +173,73 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
|
|||
password=self.token,
|
||||
cluster_kwargs=self.cluster_kwargs,
|
||||
)
|
||||
ttl_seconds: Optional[int] = self.ttl_seconds
|
||||
|
||||
documents = []
|
||||
|
||||
if ingest:
|
||||
for _input in self.ingest_data or []:
|
||||
if isinstance(_input, Data):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
documents.append(_input)
|
||||
for _input in self.ingest_data or []:
|
||||
if isinstance(_input, Data):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
documents.append(_input)
|
||||
|
||||
if self.enable_body_search:
|
||||
body_index_options = [("index_analyzer", "STANDARD")]
|
||||
else:
|
||||
body_index_options = None
|
||||
|
||||
if self.setup_mode == "Off":
|
||||
setup_mode = SetupMode.OFF
|
||||
elif self.setup_mode == "Sync":
|
||||
setup_mode = SetupMode.SYNC
|
||||
else:
|
||||
setup_mode = SetupMode.ASYNC
|
||||
|
||||
if documents:
|
||||
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
|
||||
table = Cassandra.from_documents(
|
||||
documents=documents,
|
||||
embedding=self.embedding,
|
||||
table_name=self.table_name,
|
||||
keyspace=self.keyspace,
|
||||
ttl_seconds=ttl_seconds,
|
||||
ttl_seconds=self.ttl_seconds or None,
|
||||
batch_size=self.batch_size,
|
||||
body_index_options=self.body_index_options,
|
||||
body_index_options=body_index_options,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.debug("No documents to add to the Vector Store.")
|
||||
table = Cassandra(
|
||||
embedding=self.embedding,
|
||||
table_name=self.table_name,
|
||||
keyspace=self.keyspace,
|
||||
ttl_seconds=ttl_seconds,
|
||||
body_index_options=self.body_index_options,
|
||||
setup_mode=self.setup_mode,
|
||||
ttl_seconds=self.ttl_seconds or None,
|
||||
body_index_options=body_index_options,
|
||||
setup_mode=setup_mode,
|
||||
)
|
||||
|
||||
self._cached_vectorstore = table
|
||||
return table
|
||||
|
||||
def _map_search_type(self):
|
||||
if self.search_type == "Similarity with score threshold":
|
||||
return "similarity_score_threshold"
|
||||
elif self.search_type == "MMR (Max Marginal Relevance)":
|
||||
return "mmr"
|
||||
else:
|
||||
return "similarity"
|
||||
|
||||
def search_documents(self) -> List[Data]:
|
||||
vector_store = self._build_cassandra(ingest=False)
|
||||
vector_store = self._build_cassandra()
|
||||
|
||||
logger.debug(f"Search input: {self.search_query}")
|
||||
logger.debug(f"Search type: {self.search_type}")
|
||||
logger.debug(f"Number of results: {self.number_of_results}")
|
||||
|
||||
if self.search_query and isinstance(self.search_query, str) and self.search_query.strip():
|
||||
try:
|
||||
docs = vector_store.similarity_search(
|
||||
query=self.search_query,
|
||||
k=self.number_of_results,
|
||||
)
|
||||
search_type = self._map_search_type()
|
||||
search_args = self._build_search_args()
|
||||
|
||||
logger.debug(f"Search args: {str(search_args)}")
|
||||
|
||||
docs = vector_store.search(query=self.search_query, search_type=search_type, **search_args)
|
||||
except KeyError as e:
|
||||
if "content" in str(e):
|
||||
raise ValueError(
|
||||
|
|
@ -189,8 +248,33 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
|
|||
else:
|
||||
raise e
|
||||
|
||||
logger.debug(f"Retrieved documents: {len(docs)}")
|
||||
|
||||
data = docs_to_data(docs)
|
||||
self.status = data
|
||||
return data
|
||||
else:
|
||||
return []
|
||||
|
||||
def _build_search_args(self):
|
||||
args = {
|
||||
"k": self.number_of_results,
|
||||
"score_threshold": self.search_score_threshold,
|
||||
}
|
||||
|
||||
if self.search_filter:
|
||||
clean_filter = {k: v for k, v in self.search_filter.items() if k and v}
|
||||
if len(clean_filter) > 0:
|
||||
args["filter"] = clean_filter
|
||||
if self.body_search:
|
||||
if not self.enable_body_search:
|
||||
raise ValueError("You should enable body search when creating the table to search the body field.")
|
||||
args["body_search"] = self.body_search
|
||||
return args
|
||||
|
||||
def get_retriever_kwargs(self):
|
||||
search_args = self._build_search_args()
|
||||
return {
|
||||
"search_type": self._map_search_type(),
|
||||
"search_kwargs": search_args,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue