diff --git a/src/backend/base/langflow/components/vectorstores/astradb.py b/src/backend/base/langflow/components/vectorstores/astradb.py index 52fb9b8ba..d0c338b88 100644 --- a/src/backend/base/langflow/components/vectorstores/astradb.py +++ b/src/backend/base/langflow/components/vectorstores/astradb.py @@ -210,6 +210,13 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): ), ] + def del_fields(self, build_config, field_list): + for field in field_list: + if field in build_config: + del build_config[field] + + return build_config + def insert_in_dict(self, build_config, field_name, new_parameters): # Insert the new key-value pair after the found key for new_field_name, new_parameter in new_parameters.items(): @@ -234,31 +241,30 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None): if field_name == "embedding_service": if field_value == "Astra Vectorize": - for field in ["embedding"]: - if field in build_config: - del build_config[field] + self.del_fields(build_config, ["embedding"]) new_parameter = DropdownInput( - name="provider", - display_name="Vectorize Provider", + name="embedding_provider", + display_name="Embedding Provider", options=self.VECTORIZE_PROVIDERS_MAPPING.keys(), value="", required=True, real_time_refresh=True, ).to_dict() - self.insert_in_dict(build_config, "embedding_service", {"provider": new_parameter}) + self.insert_in_dict(build_config, "embedding_service", {"embedding_provider": new_parameter}) else: - for field in [ - "provider", - "z_00_model_name", - "z_01_model_parameters", - "z_02_api_key_name", - "z_03_provider_api_key", - "z_04_authentication", - ]: - if field in build_config: - del build_config[field] + self.del_fields( + build_config, + [ + "embedding_provider", + "model", + "z_01_model_parameters", + "z_02_api_key_name", + "z_03_provider_api_key", + "z_04_authentication", + ], + ) new_parameter = HandleInput( name="embedding", @@ -269,32 +275,35 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): self.insert_in_dict(build_config, "embedding_service", {"embedding": new_parameter}) - elif field_name == "provider": - for field in [ - "z_00_model_name", - "z_01_model_parameters", - "z_02_api_key_name", - "z_03_provider_api_key", - "z_04_authentication", - ]: - if field in build_config: - del build_config[field] + elif field_name == "embedding_provider": + self.del_fields( + build_config, + ["model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"], + ) model_options = self.VECTORIZE_PROVIDERS_MAPPING[field_value][1] - new_parameter_0 = DropdownInput( - name="z_00_model_name", - display_name="Model Name", + new_parameter = DropdownInput( + name="model", + display_name="Model", info="The embedding model to use for the selected provider. Each provider has a different set of " "models available (full list at " "https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html):\n\n" f"{', '.join(model_options)}", options=model_options, - placeholder="Select a model", - value=model_options[0], + value=None, required=True, + real_time_refresh=True, ).to_dict() + self.insert_in_dict(build_config, "embedding_provider", {"model": new_parameter}) + + elif field_name == "model": + self.del_fields( + build_config, + ["z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"], + ) + new_parameter_1 = DictInput( name="z_01_model_parameters", display_name="Model Parameters", @@ -303,12 +312,13 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): new_parameter_2 = MessageTextInput( name="z_02_api_key_name", - display_name="API Key name", + display_name="API Key Name", info="The name of the embeddings provider API key stored on Astra. " "If set, it will override the 'ProviderKey' in the authentication parameters.", ).to_dict() new_parameter_3 = SecretStrInput( + load_from_db=False, name="z_03_provider_api_key", display_name="Provider API Key", info="An alternative to the Astra Authentication that passes an API key for the provider " @@ -319,15 +329,14 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): new_parameter_4 = DictInput( name="z_04_authentication", - display_name="Authentication parameters", + display_name="Authentication Parameters", is_list=True, ).to_dict() self.insert_in_dict( build_config, - "provider", + "model", { - "z_00_model_name": new_parameter_0, "z_01_model_parameters": new_parameter_1, "z_02_api_key_name": new_parameter_2, "z_03_provider_api_key": new_parameter_3, @@ -339,8 +348,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): def build_vectorize_options(self, **kwargs): for attribute in [ - "provider", - "z_00_model_name", + "embedding_provider", + "model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", @@ -350,8 +359,10 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): setattr(self, attribute, None) # Fetch values from kwargs if any self.* attributes are None - provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.provider, [None])[0] or kwargs.get("provider") - model_name = self.z_00_model_name or kwargs.get("z_00_model_name") + provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.embedding_provider, [None])[0] or kwargs.get( + "embedding_provider" + ) + model_name = self.model or kwargs.get("model") authentication = {**(self.z_04_authentication or kwargs.get("z_04_authentication", {}))} parameters = self.z_01_model_parameters or kwargs.get("z_01_model_parameters", {}) @@ -414,6 +425,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): ), "collection_embedding_api_key": dict_options.get("collection_embedding_api_key"), } + try: vector_store = AstraDBVectorStore( collection_name=self.collection_name, diff --git a/src/backend/tests/integration/components/astra/test_astra_component.py b/src/backend/tests/integration/components/astra/test_astra_component.py index 3d598a3c1..866898ed6 100644 --- a/src/backend/tests/integration/components/astra/test_astra_component.py +++ b/src/backend/tests/integration/components/astra/test_astra_component.py @@ -99,7 +99,7 @@ def test_astra_vectorize(): store = None try: options = {"provider": "nvidia", "modelName": "NV-Embed-QA"} - options_comp = {"provider": "nvidia", "z_00_model_name": "NV-Embed-QA"} + options_comp = {"embedding_provider": "nvidia", "model": "NV-Embed-QA"} store = AstraDBVectorStore( collection_name=VECTORIZE_COLLECTION, @@ -150,8 +150,8 @@ def test_astra_vectorize_with_provider_api_key(): } options_comp = { - "provider": "openai", - "z_00_model_name": "text-embedding-3-small", + "embedding_provider": "openai", + "model": "text-embedding-3-small", "z_01_model_parameters": {}, "z_03_provider_api_key": "openai", "z_04_authentication": {}, @@ -206,8 +206,8 @@ def test_astra_vectorize_passes_authentication(): "authentication": {"providerKey": "openai"}, } options_comp = { - "provider": "openai", - "z_00_model_name": "text-embedding-3-small", + "embedding_provider": "openai", + "model": "text-embedding-3-small", "z_01_model_parameters": {}, "z_04_authentication": {"providerKey": "openai"}, }