feat: Add more icons and QOL improvements to Astra DB component (#6918)

* Test commit

* QoL updates for Astra DB component

* [autofix.ci] apply automated fixes

* Update astradb.py

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* Update Vector Store RAG.json

* [autofix.ci] apply automated fixes

* Keep ordering proper in Vector RAG template

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Eric Hare 2025-03-05 08:58:25 -08:00 committed by GitHub
commit e1ee081d32
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 420 additions and 372 deletions

View file

@ -37,25 +37,25 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
"data": {
"node": {
"name": "create_database",
"description": "",
"description": "Please allow several minutes for creation to complete.",
"display_name": "Create new database",
"field_order": ["new_database_name", "cloud_provider", "region"],
"field_order": ["01_new_database_name", "02_cloud_provider", "03_region"],
"template": {
"new_database_name": StrInput(
"01_new_database_name": StrInput(
name="new_database_name",
display_name="Name",
info="Name of the new database to create in Astra DB.",
required=True,
),
"cloud_provider": DropdownInput(
"02_cloud_provider": DropdownInput(
name="cloud_provider",
display_name="Cloud provider",
info="Cloud provider for the new database.",
options=["Amazon Web Services", "Google Cloud Platform", "Microsoft Azure"],
options=[],
required=True,
real_time_refresh=True,
),
"region": DropdownInput(
"03_region": DropdownInput(
name="region",
display_name="Region",
info="Region for the new database.",
@ -76,37 +76,37 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
"data": {
"node": {
"name": "create_collection",
"description": "",
"description": "Please allow several seconds for creation to complete.",
"display_name": "Create new collection",
"field_order": [
"new_collection_name",
"embedding_generation_provider",
"embedding_generation_model",
"dimension",
"01_new_collection_name",
"02_embedding_generation_provider",
"03_embedding_generation_model",
"04_dimension",
],
"template": {
"new_collection_name": StrInput(
"01_new_collection_name": StrInput(
name="new_collection_name",
display_name="Name",
info="Name of the new collection to create in Astra DB.",
required=True,
),
"embedding_generation_provider": DropdownInput(
"02_embedding_generation_provider": DropdownInput(
name="embedding_generation_provider",
display_name="Embedding generation method",
info="Provider to use for generating embeddings.",
real_time_refresh=True,
required=True,
options=["Bring your own", "Nvidia"],
options=[],
),
"embedding_generation_model": DropdownInput(
"03_embedding_generation_model": DropdownInput(
name="embedding_generation_model",
display_name="Embedding model",
info="Model to use for generating embeddings.",
required=True,
options=[],
),
"dimension": IntInput(
"04_dimension": IntInput(
name="dimension",
display_name="Dimensions (Required only for `Bring your own`)",
info="Dimensions of the embeddings to generate.",
@ -254,17 +254,32 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
def map_cloud_providers(cls):
# TODO: Programmatically fetch the regions for each cloud provider
return {
"Amazon Web Services": {
"id": "aws",
"regions": ["us-east-2", "ap-south-1", "eu-west-1"],
"dev": {
"Google Cloud Platform": {
"id": "gcp",
"regions": ["us-central1"],
},
},
"Google Cloud Platform": {
"id": "gcp",
"regions": ["us-east1"],
# TODO: Check test regions
"test": {
"Google Cloud Platform": {
"id": "gcp",
"regions": ["us-central1"],
},
},
"Microsoft Azure": {
"id": "azure",
"regions": ["westus3"],
"prod": {
"Amazon Web Services": {
"id": "aws",
"regions": ["us-east-2", "ap-south-1", "eu-west-1"],
},
"Google Cloud Platform": {
"id": "gcp",
"regions": ["us-east1"],
},
"Microsoft Azure": {
"id": "azure",
"regions": ["westus3"],
},
},
}
@ -309,10 +324,13 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
# Get the admin object
admin_client = client.get_admin(token=token)
# Get the environment, set to prod if null like
my_env = environment or "prod"
# Call the create database function
return await admin_client.async_create_database(
name=new_database_name,
cloud_provider=cls.map_cloud_providers()[cloud_provider]["id"],
cloud_provider=cls.map_cloud_providers()[my_env][cloud_provider]["id"],
region=region,
keyspace=keyspace,
wait_until_active=False,
@ -506,18 +524,24 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
)
# If there is no provider, use the vector store icon
if not provider_name or provider_name == "bring your own":
if not provider_name or provider_name == "Bring your own":
return "vectorstores"
# Special case for certain models
# TODO: Add more icons
if provider_name == "nvidia":
return "NVIDIA"
if provider_name == "openai":
return "OpenAI"
# Map provider casings
case_map = {
"nvidia": "NVIDIA",
"openai": "OpenAI",
"amazon bedrock": "AmazonBedrockEmbeddings",
"azure openai": "AzureOpenAiEmbeddings",
"cohere": "Cohere",
"jina ai": "JinaAI",
"mistral ai": "MistralAI",
"upstage": "Upstage",
"voyage ai": "VoyageAI",
}
# Title case on the provider for the icon if no special case
return provider_name.title()
# Adjust the casing on some like nvidia
return case_map[provider_name.lower()] if provider_name.lower() in case_map else provider_name.title()
def _initialize_collection_options(self, api_endpoint: str | None = None):
# Nothing to generate if we don't have an API endpoint yet
@ -560,36 +584,41 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
# If the collection is set, allow user to see embedding options
build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][
"embedding_generation_provider"
]["options"] = ["Bring your own", "Nvidia", *[key for key in vectorize_providers if key != "Nvidia"]]
"02_embedding_generation_provider"
]["options"] = [
"Bring your own",
"Nvidia",
*[key for key in vectorize_providers if key not in ["Bring your own", "Nvidia"]],
]
# For all not Bring your own or Nvidia providers, add metadata saying configure in Astra DB Portal
provider_options = build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][
"embedding_generation_provider"
"02_embedding_generation_provider"
]["options"]
# Go over each possible provider and add metadata to configure in Astra DB Portal
for provider in provider_options:
# Skip Bring your own and Nvidia, automatically configured
if provider in {"Bring your own", "Nvidia"}:
build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][
"embedding_generation_provider"
]["options_metadata"].append({"icon": self.get_provider_icon(provider_name=provider.lower())})
continue
# Add the icon for the provider
my_metadata = {"icon": self.get_provider_icon(provider_name=provider)}
# Add metadata to configure in Astra DB Portal
# Skip Bring your own and Nvidia, automatically configured
if provider not in {"Bring your own", "Nvidia"}:
# Add metadata to configure in Astra DB Portal
my_metadata[" "] = "Configure in Astra DB Portal"
# Add the metadata to the options metadata
build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][
"embedding_generation_provider"
]["options_metadata"].append({" ": "Configure in Astra DB Portal"})
"02_embedding_generation_provider"
]["options_metadata"].append(my_metadata)
# And allow the user to see the models based on a selected provider
embedding_provider = build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][
"embedding_generation_provider"
"02_embedding_generation_provider"
]["value"]
# Set the options for the embedding model based on the provider
build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][
"embedding_generation_model"
"03_embedding_generation_model"
]["options"] = vectorize_providers.get(embedding_provider, [[], []])[1]
return build_config
@ -617,6 +646,12 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
# Get the list of options we have based on the token provided
database_options = self._initialize_database_options()
# Update the list of cloud providers
my_env = self.environment or "prod"
build_config["database_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"]["02_cloud_provider"][
"options"
] = list(self.map_cloud_providers()[my_env].keys())
# If we retrieved options based on the token, show the dropdown
build_config["database_name"]["options"] = [db["name"] for db in database_options]
build_config["database_name"]["options_metadata"] = [
@ -652,68 +687,82 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
# Callback for database creation
if field_name == "database_name" and isinstance(field_value, dict) and "new_database_name" in field_value:
if field_name == "database_name" and isinstance(field_value, dict) and "01_new_database_name" in field_value:
try:
await self.create_database_api(
new_database_name=field_value["new_database_name"],
new_database_name=field_value["01_new_database_name"],
token=self.token,
keyspace=self.get_keyspace(),
environment=self.environment,
cloud_provider=field_value["cloud_provider"],
region=field_value["region"],
cloud_provider=field_value["02_cloud_provider"],
region=field_value["03_region"],
)
except Exception as e:
msg = f"Error creating database: {e}"
raise ValueError(msg) from e
# Add the new database to the list of options
build_config["database_name"]["options"] += [field_value["new_database_name"]]
build_config["database_name"]["options"] += [field_value["01_new_database_name"]]
build_config["database_name"]["options_metadata"] += [{"status": "PENDING"}]
return self.reset_collection_list(build_config)
# This is the callback required to update the list of regions for a cloud provider
if field_name == "database_name" and isinstance(field_value, dict) and "new_database_name" not in field_value:
cloud_provider = field_value["cloud_provider"]
build_config["database_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"]["region"][
if (
field_name == "database_name"
and isinstance(field_value, dict)
and "01_new_database_name" not in field_value
):
# Get the cloud provider and environment
cloud_provider = field_value["02_cloud_provider"]
my_env = self.environment or "prod"
# Update the list of regions for the cloud provider
build_config["database_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"]["03_region"][
"options"
] = self.map_cloud_providers()[cloud_provider]["regions"]
] = self.map_cloud_providers()[my_env][cloud_provider]["regions"]
return build_config
# Callback for the creation of collections
if field_name == "collection_name" and isinstance(field_value, dict) and "new_collection_name" in field_value:
if (
field_name == "collection_name"
and isinstance(field_value, dict)
and "01_new_collection_name" in field_value
):
try:
# Get the dimension if its a BYO provider
dimension = (
field_value["dimension"]
if field_value["embedding_generation_provider"] == "Bring your own"
field_value["04_dimension"]
if field_value["02_embedding_generation_provider"] == "Bring your own"
else None
)
# Create the collection
await self.create_collection_api(
new_collection_name=field_value["new_collection_name"],
new_collection_name=field_value["01_new_collection_name"],
token=self.token,
api_endpoint=build_config["api_endpoint"]["value"],
environment=self.environment,
keyspace=self.get_keyspace(),
dimension=dimension,
embedding_generation_provider=field_value["embedding_generation_provider"],
embedding_generation_model=field_value["embedding_generation_model"],
embedding_generation_provider=field_value["02_embedding_generation_provider"],
embedding_generation_model=field_value["03_embedding_generation_model"],
)
except Exception as e:
msg = f"Error creating collection: {e}"
raise ValueError(msg) from e
# Add the new collection to the list of options
build_config["collection_name"]["value"] = field_value["new_collection_name"]
build_config["collection_name"]["options"].append(field_value["new_collection_name"])
build_config["collection_name"]["value"] = field_value["01_new_collection_name"]
build_config["collection_name"]["options"].append(field_value["01_new_collection_name"])
# Get the provider and model for the new collection
generation_provider = field_value["embedding_generation_provider"]
provider = generation_provider if generation_provider != "Bring your own" else None
generation_model = field_value["embedding_generation_model"]
generation_provider = field_value["02_embedding_generation_provider"]
provider = (
generation_provider.lower() if generation_provider and generation_provider != "Bring your own" else None
)
generation_model = field_value["03_embedding_generation_model"]
model = generation_model if generation_model and generation_model != "Bring your own" else None
# Set the embedding choice
@ -721,7 +770,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
build_config["embedding_model"]["advanced"] = bool(provider)
# Add the new collection to the list of options
icon = "NVIDIA" if provider == "Nvidia" else "vectorstores"
icon = self.get_provider_icon(provider_name=generation_provider)
build_config["collection_name"]["options_metadata"] += [
{"records": 0, "provider": provider, "icon": icon, "model": model}
]
@ -732,7 +781,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
if (
field_name == "collection_name"
and isinstance(field_value, dict)
and "new_collection_name" not in field_value
and "01_new_collection_name" not in field_value
):
return self.reset_provider_options(build_config)
@ -789,7 +838,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
{
"records": 0,
"provider": None,
"icon": "",
"icon": "vectorstores",
"model": None,
}
)
@ -806,12 +855,8 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
value_of_provider = build_config["collection_name"]["options_metadata"][index_of_name]["provider"]
# If we were able to determine the Vectorize provider, set it accordingly
if value_of_provider:
build_config["embedding_model"]["advanced"] = True
build_config["embedding_choice"]["value"] = "Astra Vectorize"
else:
build_config["embedding_model"]["advanced"] = False
build_config["embedding_choice"]["value"] = "Embedding Model"
build_config["embedding_model"]["advanced"] = bool(value_of_provider)
build_config["embedding_choice"]["value"] = "Astra Vectorize" if value_of_provider else "Embedding Model"
return build_config

File diff suppressed because one or more lines are too long