FIX: Properly pass missing vectorize params in Astra DB (#4511)

This commit is contained in:
Eric Hare 2024-11-11 15:09:50 -08:00 committed by GitHub
commit a107650cc6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -351,20 +351,29 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
# Fetch values from kwargs if any self.* attributes are None
provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.provider, [None])[0] or kwargs.get("provider")
model_name = self.z_00_model_name or kwargs.get("z_00_model_name")
authentication = {**(self.z_04_authentication or kwargs.get("z_04_authentication", {}))}
parameters = self.z_01_model_parameters or kwargs.get("z_01_model_parameters", {})
# Set the API key name if provided
api_key_name = self.z_02_api_key_name or kwargs.get("z_02_api_key_name")
provider_key = self.z_03_provider_api_key or kwargs.get("z_03_provider_api_key")
if api_key_name:
authentication["providerKey"] = api_key_name
# Set authentication and parameters to None if no values are provided
if not authentication:
authentication = None
if not parameters:
parameters = None
return {
# must match astrapy.info.CollectionVectorServiceOptions
"collection_vector_service_options": {
"provider": provider_value,
"modelName": self.z_00_model_name or kwargs.get("z_00_model_name"),
"modelName": model_name,
"authentication": authentication,
"parameters": self.z_01_model_parameters or kwargs.get("z_01_model_parameters", {}),
"parameters": parameters,
},
"collection_embedding_api_key": provider_key,
}
@ -395,15 +404,13 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
else:
from astrapy.info import CollectionVectorServiceOptions
# Fetch values from kwargs if any self.* attributes are None
dict_options = vectorize_options or self.build_vectorize_options()
dict_options["authentication"] = {
k: v for k, v in dict_options.get("authentication", {}).items() if k and v
}
dict_options["parameters"] = {k: v for k, v in dict_options.get("parameters", {}).items() if k and v}
# Set the embedding dictionary
embedding_dict = {
"collection_vector_service_options": CollectionVectorServiceOptions.from_dict(
dict_options.get("collection_vector_service_options", {})
dict_options.get("collection_vector_service_options")
),
"collection_embedding_api_key": dict_options.get("collection_embedding_api_key"),
}