fix: support additional autodetect astradb params (#5254)
* fix: support additional autodetect astradb params * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ee31f152ad
commit
b28c6dcc55
2 changed files with 285 additions and 199 deletions
|
|
@ -239,6 +239,18 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
"See https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option",
|
||||
advanced=True,
|
||||
),
|
||||
StrInput(
|
||||
name="content_field",
|
||||
display_name="Content Field",
|
||||
info="Field to use as the text content field for the vector store.",
|
||||
advanced=True,
|
||||
),
|
||||
BoolInput(
|
||||
name="ignore_invalid_documents",
|
||||
display_name="Ignore Invalid Documents",
|
||||
info="Boolean flag to determine whether to ignore invalid documents at runtime.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
def del_fields(self, build_config, field_list):
|
||||
|
|
@ -357,22 +369,30 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
# Refresh the collection name options
|
||||
build_config["collection_name"]["options"] = self._initialize_collection_options()
|
||||
|
||||
# If the collection name is set to "+ Create new collection", show embedding choice
|
||||
if field_name == "collection_name" and field_value == "+ Create new collection":
|
||||
build_config["embedding_choice"]["advanced"] = False
|
||||
build_config["embedding_choice"]["value"] = "Embedding Model"
|
||||
build_config["embedding_model"]["advanced"] = False
|
||||
# Update the choice of embedding model based on collection name
|
||||
if field_name == "collection_name":
|
||||
# Detect if it is a new collection
|
||||
is_new_collection = field_value == "+ Create new collection"
|
||||
|
||||
build_config["collection_name_new"]["advanced"] = False
|
||||
build_config["collection_name_new"]["required"] = True
|
||||
# Set the advanced and required fields based on the collection choice
|
||||
build_config["embedding_choice"].update(
|
||||
{
|
||||
"advanced": not is_new_collection,
|
||||
"value": "Embedding Model" if is_new_collection else build_config["embedding_choice"].get("value"),
|
||||
}
|
||||
)
|
||||
|
||||
# But if it's not, hide embedding choice
|
||||
elif field_name == "collection_name" and field_value != "+ Create new collection":
|
||||
build_config["embedding_choice"]["advanced"] = True
|
||||
# Set the advanced field for the embedding model
|
||||
build_config["embedding_model"]["advanced"] = not is_new_collection
|
||||
|
||||
build_config["collection_name_new"]["advanced"] = True
|
||||
build_config["collection_name_new"]["required"] = False
|
||||
build_config["collection_name_new"]["value"] = ""
|
||||
# Set the advanced and required fields for the new collection name
|
||||
build_config["collection_name_new"].update(
|
||||
{
|
||||
"advanced": not is_new_collection,
|
||||
"required": is_new_collection,
|
||||
"value": "" if not is_new_collection else build_config["collection_name_new"].get("value"),
|
||||
}
|
||||
)
|
||||
|
||||
# Get the collection options for the selected collection
|
||||
collection_options = self.get_collection_options()
|
||||
|
|
@ -382,6 +402,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
build_config["embedding_choice"]["advanced"] = True
|
||||
|
||||
if collection_options.service:
|
||||
# Remove unnecessary fields when a service is set
|
||||
self.del_fields(
|
||||
build_config,
|
||||
[
|
||||
|
|
@ -394,12 +415,22 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
],
|
||||
)
|
||||
|
||||
build_config["embedding_model"]["advanced"] = True
|
||||
build_config["embedding_choice"]["value"] = "Astra Vectorize"
|
||||
# Update the providers mapping
|
||||
updates = {
|
||||
"embedding_model": {"advanced": True},
|
||||
"embedding_choice": {"value": "Astra Vectorize"},
|
||||
}
|
||||
else:
|
||||
build_config["embedding_model"]["advanced"] = False
|
||||
build_config["embedding_provider"]["advanced"] = False
|
||||
build_config["embedding_choice"]["value"] = "Embedding Model"
|
||||
# Update the providers mapping
|
||||
updates = {
|
||||
"embedding_model": {"advanced": False},
|
||||
"embedding_provider": {"advanced": False},
|
||||
"embedding_choice": {"value": "Embedding Model"},
|
||||
}
|
||||
|
||||
# Apply updates to the build_config
|
||||
for key, value in updates.items():
|
||||
build_config[key].update(value)
|
||||
|
||||
elif field_name == "embedding_choice":
|
||||
if field_value == "Astra Vectorize":
|
||||
|
|
@ -572,65 +603,50 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
msg = f"Invalid setup mode: {self.setup_mode}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
metric_value = self.metric or None
|
||||
autodetect = False
|
||||
# Initialize parameters based on the collection name
|
||||
is_new_collection = self.collection_name == "+ Create new collection"
|
||||
|
||||
if self.embedding_choice == "Embedding Model":
|
||||
embedding_dict = {"embedding": self.embedding_model}
|
||||
# Use autodetect if the collection name is NOT set to "+ Create new collection"
|
||||
elif self.collection_name != "+ Create new collection":
|
||||
autodetect = True
|
||||
metric_value = None
|
||||
setup_mode_value = None
|
||||
embedding_dict = {}
|
||||
else:
|
||||
# Build the list of autodetect parameters
|
||||
autodetect_params = {
|
||||
"autodetect": not is_new_collection,
|
||||
"metric_value": self.metric if is_new_collection else None,
|
||||
"metadata_indexing_include": (
|
||||
[s for s in self.metadata_indexing_include if s] or None if is_new_collection else None
|
||||
),
|
||||
"metadata_indexing_exclude": (
|
||||
[s for s in self.metadata_indexing_exclude if s] or None if is_new_collection else None
|
||||
),
|
||||
"collection_indexing_policy": (
|
||||
orjson.dumps(self.collection_indexing_policy)
|
||||
if is_new_collection and self.collection_indexing_policy
|
||||
else None
|
||||
),
|
||||
"setup_mode": setup_mode_value if is_new_collection else None,
|
||||
}
|
||||
|
||||
# Unpack parameters
|
||||
autodetect = autodetect_params["autodetect"]
|
||||
metric_value = autodetect_params["metric_value"]
|
||||
metadata_indexing_include = autodetect_params["metadata_indexing_include"]
|
||||
metadata_indexing_exclude = autodetect_params["metadata_indexing_exclude"]
|
||||
collection_indexing_policy = autodetect_params["collection_indexing_policy"]
|
||||
setup_mode = autodetect_params["setup_mode"]
|
||||
|
||||
# Get the embedding model
|
||||
embedding_dict = {"embedding": self.embedding_model}
|
||||
|
||||
# Use the embedding model if the choice is set to "Embedding Model"
|
||||
if self.embedding_choice == "Astra Vectorize" and not autodetect:
|
||||
from astrapy.info import CollectionVectorServiceOptions
|
||||
|
||||
# Grab the collection options if available
|
||||
collection_options = self.get_collection_options()
|
||||
|
||||
# Ensure collection_options and its nested attributes are handled safely
|
||||
authentication = getattr(self, "z_04_authentication", {}) or (
|
||||
collection_options.service.authentication
|
||||
if collection_options and collection_options.service and collection_options.service.authentication
|
||||
else {}
|
||||
)
|
||||
|
||||
# Build the vectorize options dictionary
|
||||
dict_options = vectorize_options or self.build_vectorize_options(
|
||||
embedding_provider=(
|
||||
getattr(self, "embedding_provider", None)
|
||||
or (
|
||||
collection_options.service.provider
|
||||
if collection_options and collection_options.service
|
||||
else None
|
||||
)
|
||||
),
|
||||
model=(
|
||||
getattr(self, "model", None)
|
||||
or (
|
||||
collection_options.service.model_name
|
||||
if collection_options and collection_options.service
|
||||
else None
|
||||
)
|
||||
),
|
||||
z_01_model_parameters=(
|
||||
getattr(self, "z_01_model_parameters", None)
|
||||
or (
|
||||
collection_options.service.parameters
|
||||
if collection_options and collection_options.service
|
||||
else None
|
||||
)
|
||||
),
|
||||
z_02_api_key_name=(
|
||||
getattr(self, "z_02_api_key_name", None)
|
||||
or (authentication.get("apiKey") if authentication else None)
|
||||
),
|
||||
z_03_provider_api_key=(
|
||||
getattr(self, "z_03_provider_api_key", None)
|
||||
or (authentication.get("providerKey") if authentication else None)
|
||||
),
|
||||
z_04_authentication=authentication,
|
||||
embedding_provider=getattr(self, "embedding_provider", None) or None,
|
||||
model=getattr(self, "model", None) or None,
|
||||
z_01_model_parameters=getattr(self, "z_01_model_parameters", None) or None,
|
||||
z_02_api_key_name=getattr(self, "z_02_api_key_name", None) or None,
|
||||
z_03_provider_api_key=getattr(self, "z_03_provider_api_key", None) or None,
|
||||
z_04_authentication=getattr(self, "z_04_authentication", {}) or {},
|
||||
)
|
||||
|
||||
# Set the embedding dictionary
|
||||
|
|
@ -641,12 +657,20 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
"collection_embedding_api_key": dict_options.get("collection_embedding_api_key"),
|
||||
}
|
||||
|
||||
# Get the running environment for Langflow
|
||||
environment = (
|
||||
parse_api_endpoint(getattr(self, "api_endpoint", None)).environment
|
||||
if getattr(self, "api_endpoint", None)
|
||||
else None
|
||||
)
|
||||
|
||||
# Get Langflow version and platform information
|
||||
__version__ = get_version_info()["version"]
|
||||
langflow_prefix = ""
|
||||
if os.getenv("LANGFLOW_HOST") is not None:
|
||||
langflow_prefix = "ds-"
|
||||
|
||||
# Attempt to build the Vector Store object
|
||||
try:
|
||||
vector_store = AstraDBVectorStore(
|
||||
token=self.token,
|
||||
|
|
@ -654,23 +678,19 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
namespace=self.keyspace or None,
|
||||
collection_name=self.get_collection_choice(),
|
||||
autodetect_collection=autodetect,
|
||||
environment=(
|
||||
parse_api_endpoint(getattr(self, "api_endpoint", None)).environment
|
||||
if getattr(self, "api_endpoint", None)
|
||||
else None
|
||||
),
|
||||
content_field=self.content_field or None,
|
||||
ignore_invalid_documents=self.ignore_invalid_documents,
|
||||
environment=environment,
|
||||
metric=metric_value,
|
||||
batch_size=self.batch_size or None,
|
||||
bulk_insert_batch_concurrency=self.bulk_insert_batch_concurrency or None,
|
||||
bulk_insert_overwrite_concurrency=self.bulk_insert_overwrite_concurrency or None,
|
||||
bulk_delete_concurrency=self.bulk_delete_concurrency or None,
|
||||
setup_mode=setup_mode_value,
|
||||
setup_mode=setup_mode,
|
||||
pre_delete_collection=self.pre_delete_collection,
|
||||
metadata_indexing_include=[s for s in self.metadata_indexing_include if s] or None,
|
||||
metadata_indexing_exclude=[s for s in self.metadata_indexing_exclude if s] or None,
|
||||
collection_indexing_policy=orjson.dumps(self.collection_indexing_policy)
|
||||
if self.collection_indexing_policy
|
||||
else None,
|
||||
metadata_indexing_include=metadata_indexing_include,
|
||||
metadata_indexing_exclude=metadata_indexing_exclude,
|
||||
collection_indexing_policy=collection_indexing_policy,
|
||||
ext_callers=[(f"{langflow_prefix}langflow", __version__)],
|
||||
**embedding_dict,
|
||||
)
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue