fix: support additional autodetect astradb params (#5254)

* fix: support additional autodetect astradb params

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Eric Hare 2024-12-13 11:23:43 -08:00 committed by GitHub
commit b28c6dcc55
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 285 additions and 199 deletions

View file

@ -239,6 +239,18 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
"See https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option",
advanced=True,
),
StrInput(
name="content_field",
display_name="Content Field",
info="Field to use as the text content field for the vector store.",
advanced=True,
),
BoolInput(
name="ignore_invalid_documents",
display_name="Ignore Invalid Documents",
info="Boolean flag to determine whether to ignore invalid documents at runtime.",
advanced=True,
),
]
def del_fields(self, build_config, field_list):
@ -357,22 +369,30 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
# Refresh the collection name options
build_config["collection_name"]["options"] = self._initialize_collection_options()
# If the collection name is set to "+ Create new collection", show embedding choice
if field_name == "collection_name" and field_value == "+ Create new collection":
build_config["embedding_choice"]["advanced"] = False
build_config["embedding_choice"]["value"] = "Embedding Model"
build_config["embedding_model"]["advanced"] = False
# Update the choice of embedding model based on collection name
if field_name == "collection_name":
# Detect if it is a new collection
is_new_collection = field_value == "+ Create new collection"
build_config["collection_name_new"]["advanced"] = False
build_config["collection_name_new"]["required"] = True
# Set the advanced and required fields based on the collection choice
build_config["embedding_choice"].update(
{
"advanced": not is_new_collection,
"value": "Embedding Model" if is_new_collection else build_config["embedding_choice"].get("value"),
}
)
# But if it's not, hide embedding choice
elif field_name == "collection_name" and field_value != "+ Create new collection":
build_config["embedding_choice"]["advanced"] = True
# Set the advanced field for the embedding model
build_config["embedding_model"]["advanced"] = not is_new_collection
build_config["collection_name_new"]["advanced"] = True
build_config["collection_name_new"]["required"] = False
build_config["collection_name_new"]["value"] = ""
# Set the advanced and required fields for the new collection name
build_config["collection_name_new"].update(
{
"advanced": not is_new_collection,
"required": is_new_collection,
"value": "" if not is_new_collection else build_config["collection_name_new"].get("value"),
}
)
# Get the collection options for the selected collection
collection_options = self.get_collection_options()
@ -382,6 +402,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
build_config["embedding_choice"]["advanced"] = True
if collection_options.service:
# Remove unnecessary fields when a service is set
self.del_fields(
build_config,
[
@ -394,12 +415,22 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
],
)
build_config["embedding_model"]["advanced"] = True
build_config["embedding_choice"]["value"] = "Astra Vectorize"
# Update the providers mapping
updates = {
"embedding_model": {"advanced": True},
"embedding_choice": {"value": "Astra Vectorize"},
}
else:
build_config["embedding_model"]["advanced"] = False
build_config["embedding_provider"]["advanced"] = False
build_config["embedding_choice"]["value"] = "Embedding Model"
# Update the providers mapping
updates = {
"embedding_model": {"advanced": False},
"embedding_provider": {"advanced": False},
"embedding_choice": {"value": "Embedding Model"},
}
# Apply updates to the build_config
for key, value in updates.items():
build_config[key].update(value)
elif field_name == "embedding_choice":
if field_value == "Astra Vectorize":
@ -572,65 +603,50 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
msg = f"Invalid setup mode: {self.setup_mode}"
raise ValueError(msg) from e
metric_value = self.metric or None
autodetect = False
# Initialize parameters based on the collection name
is_new_collection = self.collection_name == "+ Create new collection"
if self.embedding_choice == "Embedding Model":
embedding_dict = {"embedding": self.embedding_model}
# Use autodetect if the collection name is NOT set to "+ Create new collection"
elif self.collection_name != "+ Create new collection":
autodetect = True
metric_value = None
setup_mode_value = None
embedding_dict = {}
else:
# Build the list of autodetect parameters
autodetect_params = {
"autodetect": not is_new_collection,
"metric_value": self.metric if is_new_collection else None,
"metadata_indexing_include": (
[s for s in self.metadata_indexing_include if s] or None if is_new_collection else None
),
"metadata_indexing_exclude": (
[s for s in self.metadata_indexing_exclude if s] or None if is_new_collection else None
),
"collection_indexing_policy": (
orjson.dumps(self.collection_indexing_policy)
if is_new_collection and self.collection_indexing_policy
else None
),
"setup_mode": setup_mode_value if is_new_collection else None,
}
# Unpack parameters
autodetect = autodetect_params["autodetect"]
metric_value = autodetect_params["metric_value"]
metadata_indexing_include = autodetect_params["metadata_indexing_include"]
metadata_indexing_exclude = autodetect_params["metadata_indexing_exclude"]
collection_indexing_policy = autodetect_params["collection_indexing_policy"]
setup_mode = autodetect_params["setup_mode"]
# Get the embedding model
embedding_dict = {"embedding": self.embedding_model}
# Use the embedding model if the choice is set to "Embedding Model"
if self.embedding_choice == "Astra Vectorize" and not autodetect:
from astrapy.info import CollectionVectorServiceOptions
# Grab the collection options if available
collection_options = self.get_collection_options()
# Ensure collection_options and its nested attributes are handled safely
authentication = getattr(self, "z_04_authentication", {}) or (
collection_options.service.authentication
if collection_options and collection_options.service and collection_options.service.authentication
else {}
)
# Build the vectorize options dictionary
dict_options = vectorize_options or self.build_vectorize_options(
embedding_provider=(
getattr(self, "embedding_provider", None)
or (
collection_options.service.provider
if collection_options and collection_options.service
else None
)
),
model=(
getattr(self, "model", None)
or (
collection_options.service.model_name
if collection_options and collection_options.service
else None
)
),
z_01_model_parameters=(
getattr(self, "z_01_model_parameters", None)
or (
collection_options.service.parameters
if collection_options and collection_options.service
else None
)
),
z_02_api_key_name=(
getattr(self, "z_02_api_key_name", None)
or (authentication.get("apiKey") if authentication else None)
),
z_03_provider_api_key=(
getattr(self, "z_03_provider_api_key", None)
or (authentication.get("providerKey") if authentication else None)
),
z_04_authentication=authentication,
embedding_provider=getattr(self, "embedding_provider", None) or None,
model=getattr(self, "model", None) or None,
z_01_model_parameters=getattr(self, "z_01_model_parameters", None) or None,
z_02_api_key_name=getattr(self, "z_02_api_key_name", None) or None,
z_03_provider_api_key=getattr(self, "z_03_provider_api_key", None) or None,
z_04_authentication=getattr(self, "z_04_authentication", {}) or {},
)
# Set the embedding dictionary
@ -641,12 +657,20 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
"collection_embedding_api_key": dict_options.get("collection_embedding_api_key"),
}
# Get the running environment for Langflow
environment = (
parse_api_endpoint(getattr(self, "api_endpoint", None)).environment
if getattr(self, "api_endpoint", None)
else None
)
# Get Langflow version and platform information
__version__ = get_version_info()["version"]
langflow_prefix = ""
if os.getenv("LANGFLOW_HOST") is not None:
langflow_prefix = "ds-"
# Attempt to build the Vector Store object
try:
vector_store = AstraDBVectorStore(
token=self.token,
@ -654,23 +678,19 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
namespace=self.keyspace or None,
collection_name=self.get_collection_choice(),
autodetect_collection=autodetect,
environment=(
parse_api_endpoint(getattr(self, "api_endpoint", None)).environment
if getattr(self, "api_endpoint", None)
else None
),
content_field=self.content_field or None,
ignore_invalid_documents=self.ignore_invalid_documents,
environment=environment,
metric=metric_value,
batch_size=self.batch_size or None,
bulk_insert_batch_concurrency=self.bulk_insert_batch_concurrency or None,
bulk_insert_overwrite_concurrency=self.bulk_insert_overwrite_concurrency or None,
bulk_delete_concurrency=self.bulk_delete_concurrency or None,
setup_mode=setup_mode_value,
setup_mode=setup_mode,
pre_delete_collection=self.pre_delete_collection,
metadata_indexing_include=[s for s in self.metadata_indexing_include if s] or None,
metadata_indexing_exclude=[s for s in self.metadata_indexing_exclude if s] or None,
collection_indexing_policy=orjson.dumps(self.collection_indexing_policy)
if self.collection_indexing_policy
else None,
metadata_indexing_include=metadata_indexing_include,
metadata_indexing_exclude=metadata_indexing_exclude,
collection_indexing_policy=collection_indexing_policy,
ext_callers=[(f"{langflow_prefix}langflow", __version__)],
**embedding_dict,
)

File diff suppressed because one or more lines are too long