FIX: proper parameters in Astra DB Vectorize options (#3901)

* FIX: proper parameters in vectorize options

* Update test_astra_component.py
This commit is contained in:
Eric Hare 2024-09-24 12:43:23 -07:00 committed by GitHub
commit f403c17d10
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 15 additions and 15 deletions

View file

@ -322,20 +322,20 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
def build_vectorize_options(self, **kwargs):
for attribute in [
"provider",
"z_00_api_key_name",
"z_01_model_name",
"z_02_authentication",
"z_00_model_name",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
"z_04_model_parameters",
"z_04_authentication",
]:
if not hasattr(self, attribute):
setattr(self, attribute, None)
# Fetch values from kwargs if any self.* attributes are None
provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.provider, [None])[0] or kwargs.get("provider")
authentication = {**(self.z_02_authentication or kwargs.get("z_02_authentication", {}))}
authentication = {**(self.z_04_authentication or kwargs.get("z_04_authentication", {}))}
api_key_name = self.z_00_api_key_name or kwargs.get("z_00_api_key_name")
api_key_name = self.z_02_api_key_name or kwargs.get("z_02_api_key_name")
provider_key_name = self.z_03_provider_api_key or kwargs.get("z_03_provider_api_key")
if provider_key_name:
authentication["providerKey"] = provider_key_name
@ -346,9 +346,9 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
# must match astrapy.info.CollectionVectorServiceOptions
"collection_vector_service_options": {
"provider": provider_value,
"modelName": self.z_01_model_name or kwargs.get("z_01_model_name"),
"modelName": self.z_00_model_name or kwargs.get("z_00_model_name"),
"authentication": authentication,
"parameters": self.z_04_model_parameters or kwargs.get("z_04_model_parameters", {}),
"parameters": self.z_01_model_parameters or kwargs.get("z_01_model_parameters", {}),
},
"collection_embedding_api_key": self.z_03_provider_api_key or kwargs.get("z_03_provider_api_key"),
}

View file

@ -104,7 +104,7 @@ def test_astra_vectorize():
store = None
try:
options = {"provider": "nvidia", "modelName": "NV-Embed-QA"}
options_comp = {"provider": "nvidia", "z_01_model_name": "NV-Embed-QA"}
options_comp = {"provider": "nvidia", "z_00_model_name": "NV-Embed-QA"}
store = AstraDBVectorStore(
collection_name=VECTORIZE_COLLECTION,
@ -156,10 +156,10 @@ def test_astra_vectorize_with_provider_api_key():
options_comp = {
"provider": "openai",
"z_01_model_name": "text-embedding-3-small",
"z_04_model_parameters": {},
"z_02_authentication": {},
"z_00_model_name": "text-embedding-3-small",
"z_01_model_parameters": {},
"z_03_provider_api_key": "openai",
"z_04_authentication": {},
}
store = AstraDBVectorStore(
@ -212,9 +212,9 @@ def test_astra_vectorize_passes_authentication():
}
options_comp = {
"provider": "openai",
"z_01_model_name": "text-embedding-3-small",
"z_04_model_parameters": {},
"z_02_authentication": {"providerKey": "openai"},
"z_00_model_name": "text-embedding-3-small",
"z_01_model_parameters": {},
"z_04_authentication": {"providerKey": "openai"},
}
store = AstraDBVectorStore(