fixes and refactory

This commit is contained in:
Nicolò Boschi 2024-06-24 10:21:49 +02:00 committed by Gabriel Luiz Freitas Almeida
commit 554aed4e35
2 changed files with 15 additions and 7 deletions

View file

@ -30,11 +30,10 @@ class AstraVectorize(Component):
SecretStrInput(
name="provider_api_key",
display_name="Provider API Key",
info='An alternative to the Astra Authentication that let you use directly the API key of the provider.',
advanced=True
info='An alternative to the Astra Authentication that let you use directly the API key of the provider.'
),
DictInput(
name="parameters",
name="model_parameters",
display_name="Model parameters",
info='Additional model parameters.',
advanced=True,
@ -47,11 +46,12 @@ class AstraVectorize(Component):
def build_options(self) -> dict[str, Any]:
return {
# must match exactly astra CollectionVectorServiceOptions
"collection_vector_service_options": {
"provider": self.provider,
"model_name": self.model_name,
"modelName": self.model_name,
"authentication": self.authentication,
"parameters": self.parameters
"parameters": self.model_parameters
},
"collection_embedding_api_key": self.provider_api_key
}

View file

@ -155,10 +155,18 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
except KeyError:
raise ValueError(f"Invalid setup mode: {self.setup_mode}")
if isinstance(self.embedding, dict):
if not isinstance(self.embedding, dict):
embedding_dict = {"embedding": self.embedding}
else:
embedding_dict = self.embedding.to_dict()
from astrapy.info import CollectionVectorServiceOptions
dict_options = self.embedding.get("collection_vector_service_options", {})
dict_options["authentication"] = {k: v for k, v in dict_options.get("authentication", {}).items() if k and v}
dict_options["parameters"] = {k: v for k, v in dict_options.get("parameters", {}).items() if
k and v}
embedding_dict = {
"collection_vector_service_options": CollectionVectorServiceOptions.from_dict(dict_options),
"collection_embedding_api_key": self.embedding.get("collection_embedding_api_key"),
}
vector_store_kwargs = {
**embedding_dict,
"collection_name": self.collection_name,