508 lines
20 KiB
Python
508 lines
20 KiB
Python
import os
|
|
|
|
import orjson
|
|
from astrapy.admin import parse_api_endpoint
|
|
from loguru import logger
|
|
|
|
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
|
|
from langflow.helpers import docs_to_data
|
|
from langflow.inputs import DictInput, FloatInput, MessageTextInput
|
|
from langflow.io import (
|
|
BoolInput,
|
|
DataInput,
|
|
DropdownInput,
|
|
HandleInput,
|
|
IntInput,
|
|
MultilineInput,
|
|
SecretStrInput,
|
|
StrInput,
|
|
)
|
|
from langflow.schema import Data
|
|
|
|
|
|
class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|
display_name: str = "Astra DB"
|
|
description: str = "Implementation of Vector Store using Astra DB with search capabilities"
|
|
documentation: str = "https://docs.langflow.org/starter-projects-vector-store-rag"
|
|
name = "AstraDB"
|
|
icon: str = "AstraDB"
|
|
|
|
VECTORIZE_PROVIDERS_MAPPING = {
|
|
"Azure OpenAI": ["azureOpenAI", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
|
|
"Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]],
|
|
"Hugging Face - Serverless": [
|
|
"huggingface",
|
|
[
|
|
"sentence-transformers/all-MiniLM-L6-v2",
|
|
"intfloat/multilingual-e5-large",
|
|
"intfloat/multilingual-e5-large-instruct",
|
|
"BAAI/bge-small-en-v1.5",
|
|
"BAAI/bge-base-en-v1.5",
|
|
"BAAI/bge-large-en-v1.5",
|
|
],
|
|
],
|
|
"Jina AI": [
|
|
"jinaAI",
|
|
[
|
|
"jina-embeddings-v2-base-en",
|
|
"jina-embeddings-v2-base-de",
|
|
"jina-embeddings-v2-base-es",
|
|
"jina-embeddings-v2-base-code",
|
|
"jina-embeddings-v2-base-zh",
|
|
],
|
|
],
|
|
"Mistral AI": ["mistral", ["mistral-embed"]],
|
|
"NVIDIA": ["nvidia", ["NV-Embed-QA"]],
|
|
"OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
|
|
"Upstage": ["upstageAI", ["solar-embedding-1-large"]],
|
|
"Voyage AI": [
|
|
"voyageAI",
|
|
["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"],
|
|
],
|
|
}
|
|
|
|
inputs = [
|
|
SecretStrInput(
|
|
name="token",
|
|
display_name="Astra DB Application Token",
|
|
info="Authentication token for accessing Astra DB.",
|
|
value="ASTRA_DB_APPLICATION_TOKEN",
|
|
required=True,
|
|
advanced=os.getenv("ASTRA_ENHANCED", "false").lower() == "true",
|
|
),
|
|
SecretStrInput(
|
|
name="api_endpoint",
|
|
display_name="Database" if os.getenv("ASTRA_ENHANCED", "false").lower() == "true" else "API Endpoint",
|
|
info="API endpoint URL for the Astra DB service.",
|
|
value="ASTRA_DB_API_ENDPOINT",
|
|
required=True,
|
|
),
|
|
StrInput(
|
|
name="collection_name",
|
|
display_name="Collection Name",
|
|
info="The name of the collection within Astra DB where the vectors will be stored.",
|
|
required=True,
|
|
),
|
|
MultilineInput(
|
|
name="search_input",
|
|
display_name="Search Input",
|
|
),
|
|
DataInput(
|
|
name="ingest_data",
|
|
display_name="Ingest Data",
|
|
is_list=True,
|
|
),
|
|
StrInput(
|
|
name="namespace",
|
|
display_name="Namespace",
|
|
info="Optional namespace within Astra DB to use for the collection.",
|
|
advanced=True,
|
|
),
|
|
DropdownInput(
|
|
name="embedding_service",
|
|
display_name="Embedding Model or Astra Vectorize",
|
|
info="Determines whether to use Astra Vectorize for the collection.",
|
|
options=["Embedding Model", "Astra Vectorize"],
|
|
real_time_refresh=True,
|
|
value="Embedding Model",
|
|
),
|
|
HandleInput(
|
|
name="embedding",
|
|
display_name="Embedding Model",
|
|
input_types=["Embeddings"],
|
|
info="Allows an embedding model configuration.",
|
|
),
|
|
DropdownInput(
|
|
name="metric",
|
|
display_name="Metric",
|
|
info="Optional distance metric for vector comparisons in the vector store.",
|
|
options=["cosine", "dot_product", "euclidean"],
|
|
value="cosine",
|
|
advanced=True,
|
|
),
|
|
IntInput(
|
|
name="batch_size",
|
|
display_name="Batch Size",
|
|
info="Optional number of data to process in a single batch.",
|
|
advanced=True,
|
|
),
|
|
IntInput(
|
|
name="bulk_insert_batch_concurrency",
|
|
display_name="Bulk Insert Batch Concurrency",
|
|
info="Optional concurrency level for bulk insert operations.",
|
|
advanced=True,
|
|
),
|
|
IntInput(
|
|
name="bulk_insert_overwrite_concurrency",
|
|
display_name="Bulk Insert Overwrite Concurrency",
|
|
info="Optional concurrency level for bulk insert operations that overwrite existing data.",
|
|
advanced=True,
|
|
),
|
|
IntInput(
|
|
name="bulk_delete_concurrency",
|
|
display_name="Bulk Delete Concurrency",
|
|
info="Optional concurrency level for bulk delete operations.",
|
|
advanced=True,
|
|
),
|
|
DropdownInput(
|
|
name="setup_mode",
|
|
display_name="Setup Mode",
|
|
info="Configuration mode for setting up the vector store, with options like 'Sync' or 'Off'.",
|
|
options=["Sync", "Off"],
|
|
advanced=True,
|
|
value="Sync",
|
|
),
|
|
BoolInput(
|
|
name="pre_delete_collection",
|
|
display_name="Pre Delete Collection",
|
|
info="Boolean flag to determine whether to delete the collection before creating a new one.",
|
|
advanced=True,
|
|
),
|
|
StrInput(
|
|
name="metadata_indexing_include",
|
|
display_name="Metadata Indexing Include",
|
|
info="Optional list of metadata fields to include in the indexing.",
|
|
is_list=True,
|
|
advanced=True,
|
|
),
|
|
StrInput(
|
|
name="metadata_indexing_exclude",
|
|
display_name="Metadata Indexing Exclude",
|
|
info="Optional list of metadata fields to exclude from the indexing.",
|
|
is_list=True,
|
|
advanced=True,
|
|
),
|
|
StrInput(
|
|
name="collection_indexing_policy",
|
|
display_name="Collection Indexing Policy",
|
|
info='Optional JSON string for the "indexing" field of the collection. '
|
|
"See https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option",
|
|
advanced=True,
|
|
),
|
|
IntInput(
|
|
name="number_of_results",
|
|
display_name="Number of Results",
|
|
info="Number of results to return.",
|
|
advanced=True,
|
|
value=4,
|
|
),
|
|
DropdownInput(
|
|
name="search_type",
|
|
display_name="Search Type",
|
|
info="Search type to use",
|
|
options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"],
|
|
value="Similarity",
|
|
advanced=True,
|
|
),
|
|
FloatInput(
|
|
name="search_score_threshold",
|
|
display_name="Search Score Threshold",
|
|
info="Minimum similarity score threshold for search results. "
|
|
"(when using 'Similarity with score threshold')",
|
|
value=0,
|
|
advanced=True,
|
|
),
|
|
DictInput(
|
|
name="search_filter",
|
|
display_name="Search Metadata Filter",
|
|
info="Optional dictionary of filters to apply to the search query.",
|
|
advanced=True,
|
|
is_list=True,
|
|
),
|
|
]
|
|
|
|
def insert_in_dict(self, build_config, field_name, new_parameters):
|
|
# Insert the new key-value pair after the found key
|
|
for new_field_name, new_parameter in new_parameters.items():
|
|
# Get all the items as a list of tuples (key, value)
|
|
items = list(build_config.items())
|
|
|
|
# Find the index of the key to insert after
|
|
idx = len(items)
|
|
for i, (key, _value) in enumerate(items):
|
|
if key == field_name:
|
|
idx = i + 1
|
|
break
|
|
|
|
items.insert(idx, (new_field_name, new_parameter))
|
|
|
|
# Clear the original dictionary and update with the modified items
|
|
build_config.clear()
|
|
build_config.update(items)
|
|
|
|
return build_config
|
|
|
|
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
|
|
if field_name == "embedding_service":
|
|
if field_value == "Astra Vectorize":
|
|
for field in ["embedding"]:
|
|
if field in build_config:
|
|
del build_config[field]
|
|
|
|
new_parameter = DropdownInput(
|
|
name="provider",
|
|
display_name="Vectorize Provider",
|
|
options=self.VECTORIZE_PROVIDERS_MAPPING.keys(),
|
|
value="",
|
|
required=True,
|
|
real_time_refresh=True,
|
|
).to_dict()
|
|
|
|
self.insert_in_dict(build_config, "embedding_service", {"provider": new_parameter})
|
|
else:
|
|
for field in [
|
|
"provider",
|
|
"z_00_model_name",
|
|
"z_01_model_parameters",
|
|
"z_02_api_key_name",
|
|
"z_03_provider_api_key",
|
|
"z_04_authentication",
|
|
]:
|
|
if field in build_config:
|
|
del build_config[field]
|
|
|
|
new_parameter = HandleInput(
|
|
name="embedding",
|
|
display_name="Embedding Model",
|
|
input_types=["Embeddings"],
|
|
info="Allows an embedding model configuration.",
|
|
).to_dict()
|
|
|
|
self.insert_in_dict(build_config, "embedding_service", {"embedding": new_parameter})
|
|
|
|
elif field_name == "provider":
|
|
for field in [
|
|
"z_00_model_name",
|
|
"z_01_model_parameters",
|
|
"z_02_api_key_name",
|
|
"z_03_provider_api_key",
|
|
"z_04_authentication",
|
|
]:
|
|
if field in build_config:
|
|
del build_config[field]
|
|
|
|
model_options = self.VECTORIZE_PROVIDERS_MAPPING[field_value][1]
|
|
|
|
new_parameter_0 = DropdownInput(
|
|
name="z_00_model_name",
|
|
display_name="Model Name",
|
|
info="The embedding model to use for the selected provider. Each provider has a different set of "
|
|
"models available (full list at "
|
|
"https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html):\n\n"
|
|
f"{', '.join(model_options)}",
|
|
options=model_options,
|
|
required=True,
|
|
).to_dict()
|
|
|
|
new_parameter_1 = DictInput(
|
|
name="z_01_model_parameters",
|
|
display_name="Model Parameters",
|
|
is_list=True,
|
|
).to_dict()
|
|
|
|
new_parameter_2 = MessageTextInput(
|
|
name="z_02_api_key_name",
|
|
display_name="API Key name",
|
|
info="The name of the embeddings provider API key stored on Astra. "
|
|
"If set, it will override the 'ProviderKey' in the authentication parameters.",
|
|
).to_dict()
|
|
|
|
new_parameter_3 = SecretStrInput(
|
|
name="z_03_provider_api_key",
|
|
display_name="Provider API Key",
|
|
info="An alternative to the Astra Authentication that passes an API key for the provider "
|
|
"with each request to Astra DB. "
|
|
"This may be used when Vectorize is configured for the collection, "
|
|
"but no corresponding provider secret is stored within Astra's key management system.",
|
|
).to_dict()
|
|
|
|
new_parameter_4 = DictInput(
|
|
name="z_04_authentication",
|
|
display_name="Authentication parameters",
|
|
is_list=True,
|
|
).to_dict()
|
|
|
|
self.insert_in_dict(
|
|
build_config,
|
|
"provider",
|
|
{
|
|
"z_00_model_name": new_parameter_0,
|
|
"z_01_model_parameters": new_parameter_1,
|
|
"z_02_api_key_name": new_parameter_2,
|
|
"z_03_provider_api_key": new_parameter_3,
|
|
"z_04_authentication": new_parameter_4,
|
|
},
|
|
)
|
|
|
|
return build_config
|
|
|
|
def build_vectorize_options(self, **kwargs):
|
|
for attribute in [
|
|
"provider",
|
|
"z_00_model_name",
|
|
"z_01_model_parameters",
|
|
"z_02_api_key_name",
|
|
"z_03_provider_api_key",
|
|
"z_04_authentication",
|
|
]:
|
|
if not hasattr(self, attribute):
|
|
setattr(self, attribute, None)
|
|
|
|
# Fetch values from kwargs if any self.* attributes are None
|
|
provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.provider, [None])[0] or kwargs.get("provider")
|
|
authentication = {**(self.z_04_authentication or kwargs.get("z_04_authentication", {}))}
|
|
|
|
api_key_name = self.z_02_api_key_name or kwargs.get("z_02_api_key_name")
|
|
provider_key = self.z_03_provider_api_key or kwargs.get("z_03_provider_api_key")
|
|
if api_key_name:
|
|
authentication["providerKey"] = api_key_name
|
|
|
|
return {
|
|
# must match astrapy.info.CollectionVectorServiceOptions
|
|
"collection_vector_service_options": {
|
|
"provider": provider_value,
|
|
"modelName": self.z_00_model_name or kwargs.get("z_00_model_name"),
|
|
"authentication": authentication,
|
|
"parameters": self.z_01_model_parameters or kwargs.get("z_01_model_parameters", {}),
|
|
},
|
|
"collection_embedding_api_key": provider_key,
|
|
}
|
|
|
|
@check_cached_vector_store
|
|
def build_vector_store(self, vectorize_options=None):
|
|
try:
|
|
from langchain_astradb import AstraDBVectorStore
|
|
from langchain_astradb.utils.astradb import SetupMode
|
|
except ImportError as e:
|
|
msg = (
|
|
"Could not import langchain Astra DB integration package. "
|
|
"Please install it with `pip install langchain-astradb`."
|
|
)
|
|
raise ImportError(msg) from e
|
|
|
|
try:
|
|
if not self.setup_mode:
|
|
self.setup_mode = self._inputs["setup_mode"].options[0]
|
|
|
|
setup_mode_value = SetupMode[self.setup_mode.upper()]
|
|
except KeyError as e:
|
|
msg = f"Invalid setup mode: {self.setup_mode}"
|
|
raise ValueError(msg) from e
|
|
|
|
if self.embedding:
|
|
embedding_dict = {"embedding": self.embedding}
|
|
else:
|
|
from astrapy.info import CollectionVectorServiceOptions
|
|
|
|
dict_options = vectorize_options or self.build_vectorize_options()
|
|
dict_options["authentication"] = {
|
|
k: v for k, v in dict_options.get("authentication", {}).items() if k and v
|
|
}
|
|
dict_options["parameters"] = {k: v for k, v in dict_options.get("parameters", {}).items() if k and v}
|
|
|
|
embedding_dict = {
|
|
"collection_vector_service_options": CollectionVectorServiceOptions.from_dict(
|
|
dict_options.get("collection_vector_service_options", {})
|
|
),
|
|
}
|
|
|
|
try:
|
|
vector_store = AstraDBVectorStore(
|
|
collection_name=self.collection_name,
|
|
token=self.token,
|
|
api_endpoint=self.api_endpoint,
|
|
namespace=self.namespace or None,
|
|
environment=parse_api_endpoint(self.api_endpoint).environment,
|
|
metric=self.metric,
|
|
batch_size=self.batch_size or None,
|
|
bulk_insert_batch_concurrency=self.bulk_insert_batch_concurrency or None,
|
|
bulk_insert_overwrite_concurrency=self.bulk_insert_overwrite_concurrency or None,
|
|
bulk_delete_concurrency=self.bulk_delete_concurrency or None,
|
|
setup_mode=setup_mode_value,
|
|
pre_delete_collection=self.pre_delete_collection,
|
|
metadata_indexing_include=[s for s in self.metadata_indexing_include if s] or None,
|
|
metadata_indexing_exclude=[s for s in self.metadata_indexing_exclude if s] or None,
|
|
collection_indexing_policy=orjson.dumps(self.collection_indexing_policy)
|
|
if self.collection_indexing_policy
|
|
else None,
|
|
**embedding_dict,
|
|
)
|
|
except Exception as e:
|
|
msg = f"Error initializing AstraDBVectorStore: {e}"
|
|
raise ValueError(msg) from e
|
|
|
|
self._add_documents_to_vector_store(vector_store)
|
|
|
|
return vector_store
|
|
|
|
def _add_documents_to_vector_store(self, vector_store) -> None:
|
|
documents = []
|
|
for _input in self.ingest_data or []:
|
|
if isinstance(_input, Data):
|
|
documents.append(_input.to_lc_document())
|
|
else:
|
|
msg = "Vector Store Inputs must be Data objects."
|
|
raise TypeError(msg)
|
|
|
|
if documents:
|
|
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
|
|
try:
|
|
vector_store.add_documents(documents)
|
|
except Exception as e:
|
|
msg = f"Error adding documents to AstraDBVectorStore: {e}"
|
|
raise ValueError(msg) from e
|
|
else:
|
|
logger.debug("No documents to add to the Vector Store.")
|
|
|
|
def _map_search_type(self) -> str:
|
|
if self.search_type == "Similarity with score threshold":
|
|
return "similarity_score_threshold"
|
|
if self.search_type == "MMR (Max Marginal Relevance)":
|
|
return "mmr"
|
|
return "similarity"
|
|
|
|
def _build_search_args(self):
|
|
args = {
|
|
"k": self.number_of_results,
|
|
"score_threshold": self.search_score_threshold,
|
|
}
|
|
|
|
if self.search_filter:
|
|
clean_filter = {k: v for k, v in self.search_filter.items() if k and v}
|
|
if len(clean_filter) > 0:
|
|
args["filter"] = clean_filter
|
|
return args
|
|
|
|
def search_documents(self, vector_store=None) -> list[Data]:
|
|
if not vector_store:
|
|
vector_store = self.build_vector_store()
|
|
|
|
logger.debug(f"Search input: {self.search_input}")
|
|
logger.debug(f"Search type: {self.search_type}")
|
|
logger.debug(f"Number of results: {self.number_of_results}")
|
|
|
|
if self.search_input and isinstance(self.search_input, str) and self.search_input.strip():
|
|
try:
|
|
search_type = self._map_search_type()
|
|
search_args = self._build_search_args()
|
|
|
|
docs = vector_store.search(query=self.search_input, search_type=search_type, **search_args)
|
|
except Exception as e:
|
|
msg = f"Error performing search in AstraDBVectorStore: {e}"
|
|
raise ValueError(msg) from e
|
|
|
|
logger.debug(f"Retrieved documents: {len(docs)}")
|
|
|
|
data = docs_to_data(docs)
|
|
logger.debug(f"Converted documents to data: {len(data)}")
|
|
self.status = data
|
|
return data
|
|
logger.debug("No search input provided. Skipping search.")
|
|
return []
|
|
|
|
def get_retriever_kwargs(self):
|
|
search_args = self._build_search_args()
|
|
return {
|
|
"search_type": self._map_search_type(),
|
|
"search_kwargs": search_args,
|
|
}
|