feat: Add more icons and QOL improvements to Astra DB component (#6918)
* Test commit * QoL updates for Astra DB component * [autofix.ci] apply automated fixes * Update astradb.py * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Update Vector Store RAG.json * [autofix.ci] apply automated fixes * Keep ordering proper in Vector RAG template * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
389325b67b
commit
e1ee081d32
2 changed files with 420 additions and 372 deletions
|
|
@ -37,25 +37,25 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
"data": {
|
||||
"node": {
|
||||
"name": "create_database",
|
||||
"description": "",
|
||||
"description": "Please allow several minutes for creation to complete.",
|
||||
"display_name": "Create new database",
|
||||
"field_order": ["new_database_name", "cloud_provider", "region"],
|
||||
"field_order": ["01_new_database_name", "02_cloud_provider", "03_region"],
|
||||
"template": {
|
||||
"new_database_name": StrInput(
|
||||
"01_new_database_name": StrInput(
|
||||
name="new_database_name",
|
||||
display_name="Name",
|
||||
info="Name of the new database to create in Astra DB.",
|
||||
required=True,
|
||||
),
|
||||
"cloud_provider": DropdownInput(
|
||||
"02_cloud_provider": DropdownInput(
|
||||
name="cloud_provider",
|
||||
display_name="Cloud provider",
|
||||
info="Cloud provider for the new database.",
|
||||
options=["Amazon Web Services", "Google Cloud Platform", "Microsoft Azure"],
|
||||
options=[],
|
||||
required=True,
|
||||
real_time_refresh=True,
|
||||
),
|
||||
"region": DropdownInput(
|
||||
"03_region": DropdownInput(
|
||||
name="region",
|
||||
display_name="Region",
|
||||
info="Region for the new database.",
|
||||
|
|
@ -76,37 +76,37 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
"data": {
|
||||
"node": {
|
||||
"name": "create_collection",
|
||||
"description": "",
|
||||
"description": "Please allow several seconds for creation to complete.",
|
||||
"display_name": "Create new collection",
|
||||
"field_order": [
|
||||
"new_collection_name",
|
||||
"embedding_generation_provider",
|
||||
"embedding_generation_model",
|
||||
"dimension",
|
||||
"01_new_collection_name",
|
||||
"02_embedding_generation_provider",
|
||||
"03_embedding_generation_model",
|
||||
"04_dimension",
|
||||
],
|
||||
"template": {
|
||||
"new_collection_name": StrInput(
|
||||
"01_new_collection_name": StrInput(
|
||||
name="new_collection_name",
|
||||
display_name="Name",
|
||||
info="Name of the new collection to create in Astra DB.",
|
||||
required=True,
|
||||
),
|
||||
"embedding_generation_provider": DropdownInput(
|
||||
"02_embedding_generation_provider": DropdownInput(
|
||||
name="embedding_generation_provider",
|
||||
display_name="Embedding generation method",
|
||||
info="Provider to use for generating embeddings.",
|
||||
real_time_refresh=True,
|
||||
required=True,
|
||||
options=["Bring your own", "Nvidia"],
|
||||
options=[],
|
||||
),
|
||||
"embedding_generation_model": DropdownInput(
|
||||
"03_embedding_generation_model": DropdownInput(
|
||||
name="embedding_generation_model",
|
||||
display_name="Embedding model",
|
||||
info="Model to use for generating embeddings.",
|
||||
required=True,
|
||||
options=[],
|
||||
),
|
||||
"dimension": IntInput(
|
||||
"04_dimension": IntInput(
|
||||
name="dimension",
|
||||
display_name="Dimensions (Required only for `Bring your own`)",
|
||||
info="Dimensions of the embeddings to generate.",
|
||||
|
|
@ -254,17 +254,32 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
def map_cloud_providers(cls):
|
||||
# TODO: Programmatically fetch the regions for each cloud provider
|
||||
return {
|
||||
"Amazon Web Services": {
|
||||
"id": "aws",
|
||||
"regions": ["us-east-2", "ap-south-1", "eu-west-1"],
|
||||
"dev": {
|
||||
"Google Cloud Platform": {
|
||||
"id": "gcp",
|
||||
"regions": ["us-central1"],
|
||||
},
|
||||
},
|
||||
"Google Cloud Platform": {
|
||||
"id": "gcp",
|
||||
"regions": ["us-east1"],
|
||||
# TODO: Check test regions
|
||||
"test": {
|
||||
"Google Cloud Platform": {
|
||||
"id": "gcp",
|
||||
"regions": ["us-central1"],
|
||||
},
|
||||
},
|
||||
"Microsoft Azure": {
|
||||
"id": "azure",
|
||||
"regions": ["westus3"],
|
||||
"prod": {
|
||||
"Amazon Web Services": {
|
||||
"id": "aws",
|
||||
"regions": ["us-east-2", "ap-south-1", "eu-west-1"],
|
||||
},
|
||||
"Google Cloud Platform": {
|
||||
"id": "gcp",
|
||||
"regions": ["us-east1"],
|
||||
},
|
||||
"Microsoft Azure": {
|
||||
"id": "azure",
|
||||
"regions": ["westus3"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -309,10 +324,13 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
# Get the admin object
|
||||
admin_client = client.get_admin(token=token)
|
||||
|
||||
# Get the environment, set to prod if null like
|
||||
my_env = environment or "prod"
|
||||
|
||||
# Call the create database function
|
||||
return await admin_client.async_create_database(
|
||||
name=new_database_name,
|
||||
cloud_provider=cls.map_cloud_providers()[cloud_provider]["id"],
|
||||
cloud_provider=cls.map_cloud_providers()[my_env][cloud_provider]["id"],
|
||||
region=region,
|
||||
keyspace=keyspace,
|
||||
wait_until_active=False,
|
||||
|
|
@ -506,18 +524,24 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
)
|
||||
|
||||
# If there is no provider, use the vector store icon
|
||||
if not provider_name or provider_name == "bring your own":
|
||||
if not provider_name or provider_name == "Bring your own":
|
||||
return "vectorstores"
|
||||
|
||||
# Special case for certain models
|
||||
# TODO: Add more icons
|
||||
if provider_name == "nvidia":
|
||||
return "NVIDIA"
|
||||
if provider_name == "openai":
|
||||
return "OpenAI"
|
||||
# Map provider casings
|
||||
case_map = {
|
||||
"nvidia": "NVIDIA",
|
||||
"openai": "OpenAI",
|
||||
"amazon bedrock": "AmazonBedrockEmbeddings",
|
||||
"azure openai": "AzureOpenAiEmbeddings",
|
||||
"cohere": "Cohere",
|
||||
"jina ai": "JinaAI",
|
||||
"mistral ai": "MistralAI",
|
||||
"upstage": "Upstage",
|
||||
"voyage ai": "VoyageAI",
|
||||
}
|
||||
|
||||
# Title case on the provider for the icon if no special case
|
||||
return provider_name.title()
|
||||
# Adjust the casing on some like nvidia
|
||||
return case_map[provider_name.lower()] if provider_name.lower() in case_map else provider_name.title()
|
||||
|
||||
def _initialize_collection_options(self, api_endpoint: str | None = None):
|
||||
# Nothing to generate if we don't have an API endpoint yet
|
||||
|
|
@ -560,36 +584,41 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
|
||||
# If the collection is set, allow user to see embedding options
|
||||
build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][
|
||||
"embedding_generation_provider"
|
||||
]["options"] = ["Bring your own", "Nvidia", *[key for key in vectorize_providers if key != "Nvidia"]]
|
||||
"02_embedding_generation_provider"
|
||||
]["options"] = [
|
||||
"Bring your own",
|
||||
"Nvidia",
|
||||
*[key for key in vectorize_providers if key not in ["Bring your own", "Nvidia"]],
|
||||
]
|
||||
|
||||
# For all not Bring your own or Nvidia providers, add metadata saying configure in Astra DB Portal
|
||||
provider_options = build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][
|
||||
"embedding_generation_provider"
|
||||
"02_embedding_generation_provider"
|
||||
]["options"]
|
||||
|
||||
# Go over each possible provider and add metadata to configure in Astra DB Portal
|
||||
for provider in provider_options:
|
||||
# Skip Bring your own and Nvidia, automatically configured
|
||||
if provider in {"Bring your own", "Nvidia"}:
|
||||
build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][
|
||||
"embedding_generation_provider"
|
||||
]["options_metadata"].append({"icon": self.get_provider_icon(provider_name=provider.lower())})
|
||||
continue
|
||||
# Add the icon for the provider
|
||||
my_metadata = {"icon": self.get_provider_icon(provider_name=provider)}
|
||||
|
||||
# Add metadata to configure in Astra DB Portal
|
||||
# Skip Bring your own and Nvidia, automatically configured
|
||||
if provider not in {"Bring your own", "Nvidia"}:
|
||||
# Add metadata to configure in Astra DB Portal
|
||||
my_metadata[" "] = "Configure in Astra DB Portal"
|
||||
|
||||
# Add the metadata to the options metadata
|
||||
build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][
|
||||
"embedding_generation_provider"
|
||||
]["options_metadata"].append({" ": "Configure in Astra DB Portal"})
|
||||
"02_embedding_generation_provider"
|
||||
]["options_metadata"].append(my_metadata)
|
||||
|
||||
# And allow the user to see the models based on a selected provider
|
||||
embedding_provider = build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][
|
||||
"embedding_generation_provider"
|
||||
"02_embedding_generation_provider"
|
||||
]["value"]
|
||||
|
||||
# Set the options for the embedding model based on the provider
|
||||
build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][
|
||||
"embedding_generation_model"
|
||||
"03_embedding_generation_model"
|
||||
]["options"] = vectorize_providers.get(embedding_provider, [[], []])[1]
|
||||
|
||||
return build_config
|
||||
|
|
@ -617,6 +646,12 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
# Get the list of options we have based on the token provided
|
||||
database_options = self._initialize_database_options()
|
||||
|
||||
# Update the list of cloud providers
|
||||
my_env = self.environment or "prod"
|
||||
build_config["database_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"]["02_cloud_provider"][
|
||||
"options"
|
||||
] = list(self.map_cloud_providers()[my_env].keys())
|
||||
|
||||
# If we retrieved options based on the token, show the dropdown
|
||||
build_config["database_name"]["options"] = [db["name"] for db in database_options]
|
||||
build_config["database_name"]["options_metadata"] = [
|
||||
|
|
@ -652,68 +687,82 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
|
||||
async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
|
||||
# Callback for database creation
|
||||
if field_name == "database_name" and isinstance(field_value, dict) and "new_database_name" in field_value:
|
||||
if field_name == "database_name" and isinstance(field_value, dict) and "01_new_database_name" in field_value:
|
||||
try:
|
||||
await self.create_database_api(
|
||||
new_database_name=field_value["new_database_name"],
|
||||
new_database_name=field_value["01_new_database_name"],
|
||||
token=self.token,
|
||||
keyspace=self.get_keyspace(),
|
||||
environment=self.environment,
|
||||
cloud_provider=field_value["cloud_provider"],
|
||||
region=field_value["region"],
|
||||
cloud_provider=field_value["02_cloud_provider"],
|
||||
region=field_value["03_region"],
|
||||
)
|
||||
except Exception as e:
|
||||
msg = f"Error creating database: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
# Add the new database to the list of options
|
||||
build_config["database_name"]["options"] += [field_value["new_database_name"]]
|
||||
build_config["database_name"]["options"] += [field_value["01_new_database_name"]]
|
||||
build_config["database_name"]["options_metadata"] += [{"status": "PENDING"}]
|
||||
|
||||
return self.reset_collection_list(build_config)
|
||||
|
||||
# This is the callback required to update the list of regions for a cloud provider
|
||||
if field_name == "database_name" and isinstance(field_value, dict) and "new_database_name" not in field_value:
|
||||
cloud_provider = field_value["cloud_provider"]
|
||||
build_config["database_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"]["region"][
|
||||
if (
|
||||
field_name == "database_name"
|
||||
and isinstance(field_value, dict)
|
||||
and "01_new_database_name" not in field_value
|
||||
):
|
||||
# Get the cloud provider and environment
|
||||
cloud_provider = field_value["02_cloud_provider"]
|
||||
my_env = self.environment or "prod"
|
||||
|
||||
# Update the list of regions for the cloud provider
|
||||
build_config["database_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"]["03_region"][
|
||||
"options"
|
||||
] = self.map_cloud_providers()[cloud_provider]["regions"]
|
||||
] = self.map_cloud_providers()[my_env][cloud_provider]["regions"]
|
||||
|
||||
return build_config
|
||||
|
||||
# Callback for the creation of collections
|
||||
if field_name == "collection_name" and isinstance(field_value, dict) and "new_collection_name" in field_value:
|
||||
if (
|
||||
field_name == "collection_name"
|
||||
and isinstance(field_value, dict)
|
||||
and "01_new_collection_name" in field_value
|
||||
):
|
||||
try:
|
||||
# Get the dimension if its a BYO provider
|
||||
dimension = (
|
||||
field_value["dimension"]
|
||||
if field_value["embedding_generation_provider"] == "Bring your own"
|
||||
field_value["04_dimension"]
|
||||
if field_value["02_embedding_generation_provider"] == "Bring your own"
|
||||
else None
|
||||
)
|
||||
|
||||
# Create the collection
|
||||
await self.create_collection_api(
|
||||
new_collection_name=field_value["new_collection_name"],
|
||||
new_collection_name=field_value["01_new_collection_name"],
|
||||
token=self.token,
|
||||
api_endpoint=build_config["api_endpoint"]["value"],
|
||||
environment=self.environment,
|
||||
keyspace=self.get_keyspace(),
|
||||
dimension=dimension,
|
||||
embedding_generation_provider=field_value["embedding_generation_provider"],
|
||||
embedding_generation_model=field_value["embedding_generation_model"],
|
||||
embedding_generation_provider=field_value["02_embedding_generation_provider"],
|
||||
embedding_generation_model=field_value["03_embedding_generation_model"],
|
||||
)
|
||||
except Exception as e:
|
||||
msg = f"Error creating collection: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
# Add the new collection to the list of options
|
||||
build_config["collection_name"]["value"] = field_value["new_collection_name"]
|
||||
build_config["collection_name"]["options"].append(field_value["new_collection_name"])
|
||||
build_config["collection_name"]["value"] = field_value["01_new_collection_name"]
|
||||
build_config["collection_name"]["options"].append(field_value["01_new_collection_name"])
|
||||
|
||||
# Get the provider and model for the new collection
|
||||
generation_provider = field_value["embedding_generation_provider"]
|
||||
provider = generation_provider if generation_provider != "Bring your own" else None
|
||||
generation_model = field_value["embedding_generation_model"]
|
||||
generation_provider = field_value["02_embedding_generation_provider"]
|
||||
provider = (
|
||||
generation_provider.lower() if generation_provider and generation_provider != "Bring your own" else None
|
||||
)
|
||||
generation_model = field_value["03_embedding_generation_model"]
|
||||
model = generation_model if generation_model and generation_model != "Bring your own" else None
|
||||
|
||||
# Set the embedding choice
|
||||
|
|
@ -721,7 +770,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
build_config["embedding_model"]["advanced"] = bool(provider)
|
||||
|
||||
# Add the new collection to the list of options
|
||||
icon = "NVIDIA" if provider == "Nvidia" else "vectorstores"
|
||||
icon = self.get_provider_icon(provider_name=generation_provider)
|
||||
build_config["collection_name"]["options_metadata"] += [
|
||||
{"records": 0, "provider": provider, "icon": icon, "model": model}
|
||||
]
|
||||
|
|
@ -732,7 +781,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
if (
|
||||
field_name == "collection_name"
|
||||
and isinstance(field_value, dict)
|
||||
and "new_collection_name" not in field_value
|
||||
and "01_new_collection_name" not in field_value
|
||||
):
|
||||
return self.reset_provider_options(build_config)
|
||||
|
||||
|
|
@ -789,7 +838,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
{
|
||||
"records": 0,
|
||||
"provider": None,
|
||||
"icon": "",
|
||||
"icon": "vectorstores",
|
||||
"model": None,
|
||||
}
|
||||
)
|
||||
|
|
@ -806,12 +855,8 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
value_of_provider = build_config["collection_name"]["options_metadata"][index_of_name]["provider"]
|
||||
|
||||
# If we were able to determine the Vectorize provider, set it accordingly
|
||||
if value_of_provider:
|
||||
build_config["embedding_model"]["advanced"] = True
|
||||
build_config["embedding_choice"]["value"] = "Astra Vectorize"
|
||||
else:
|
||||
build_config["embedding_model"]["advanced"] = False
|
||||
build_config["embedding_choice"]["value"] = "Embedding Model"
|
||||
build_config["embedding_model"]["advanced"] = bool(value_of_provider)
|
||||
build_config["embedding_choice"]["value"] = "Astra Vectorize" if value_of_provider else "Embedding Model"
|
||||
|
||||
return build_config
|
||||
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue