fix: issue with dynamic inputs when selecting model (#4538)
This commit is contained in:
parent
0dc6cce8dc
commit
1dfa160385
2 changed files with 56 additions and 44 deletions
|
|
@ -210,6 +210,13 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
),
|
||||
]
|
||||
|
||||
def del_fields(self, build_config, field_list):
|
||||
for field in field_list:
|
||||
if field in build_config:
|
||||
del build_config[field]
|
||||
|
||||
return build_config
|
||||
|
||||
def insert_in_dict(self, build_config, field_name, new_parameters):
|
||||
# Insert the new key-value pair after the found key
|
||||
for new_field_name, new_parameter in new_parameters.items():
|
||||
|
|
@ -234,31 +241,30 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
|
||||
if field_name == "embedding_service":
|
||||
if field_value == "Astra Vectorize":
|
||||
for field in ["embedding"]:
|
||||
if field in build_config:
|
||||
del build_config[field]
|
||||
self.del_fields(build_config, ["embedding"])
|
||||
|
||||
new_parameter = DropdownInput(
|
||||
name="provider",
|
||||
display_name="Vectorize Provider",
|
||||
name="embedding_provider",
|
||||
display_name="Embedding Provider",
|
||||
options=self.VECTORIZE_PROVIDERS_MAPPING.keys(),
|
||||
value="",
|
||||
required=True,
|
||||
real_time_refresh=True,
|
||||
).to_dict()
|
||||
|
||||
self.insert_in_dict(build_config, "embedding_service", {"provider": new_parameter})
|
||||
self.insert_in_dict(build_config, "embedding_service", {"embedding_provider": new_parameter})
|
||||
else:
|
||||
for field in [
|
||||
"provider",
|
||||
"z_00_model_name",
|
||||
"z_01_model_parameters",
|
||||
"z_02_api_key_name",
|
||||
"z_03_provider_api_key",
|
||||
"z_04_authentication",
|
||||
]:
|
||||
if field in build_config:
|
||||
del build_config[field]
|
||||
self.del_fields(
|
||||
build_config,
|
||||
[
|
||||
"embedding_provider",
|
||||
"model",
|
||||
"z_01_model_parameters",
|
||||
"z_02_api_key_name",
|
||||
"z_03_provider_api_key",
|
||||
"z_04_authentication",
|
||||
],
|
||||
)
|
||||
|
||||
new_parameter = HandleInput(
|
||||
name="embedding",
|
||||
|
|
@ -269,32 +275,35 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
|
||||
self.insert_in_dict(build_config, "embedding_service", {"embedding": new_parameter})
|
||||
|
||||
elif field_name == "provider":
|
||||
for field in [
|
||||
"z_00_model_name",
|
||||
"z_01_model_parameters",
|
||||
"z_02_api_key_name",
|
||||
"z_03_provider_api_key",
|
||||
"z_04_authentication",
|
||||
]:
|
||||
if field in build_config:
|
||||
del build_config[field]
|
||||
elif field_name == "embedding_provider":
|
||||
self.del_fields(
|
||||
build_config,
|
||||
["model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"],
|
||||
)
|
||||
|
||||
model_options = self.VECTORIZE_PROVIDERS_MAPPING[field_value][1]
|
||||
|
||||
new_parameter_0 = DropdownInput(
|
||||
name="z_00_model_name",
|
||||
display_name="Model Name",
|
||||
new_parameter = DropdownInput(
|
||||
name="model",
|
||||
display_name="Model",
|
||||
info="The embedding model to use for the selected provider. Each provider has a different set of "
|
||||
"models available (full list at "
|
||||
"https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html):\n\n"
|
||||
f"{', '.join(model_options)}",
|
||||
options=model_options,
|
||||
placeholder="Select a model",
|
||||
value=model_options[0],
|
||||
value=None,
|
||||
required=True,
|
||||
real_time_refresh=True,
|
||||
).to_dict()
|
||||
|
||||
self.insert_in_dict(build_config, "embedding_provider", {"model": new_parameter})
|
||||
|
||||
elif field_name == "model":
|
||||
self.del_fields(
|
||||
build_config,
|
||||
["z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"],
|
||||
)
|
||||
|
||||
new_parameter_1 = DictInput(
|
||||
name="z_01_model_parameters",
|
||||
display_name="Model Parameters",
|
||||
|
|
@ -303,12 +312,13 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
|
||||
new_parameter_2 = MessageTextInput(
|
||||
name="z_02_api_key_name",
|
||||
display_name="API Key name",
|
||||
display_name="API Key Name",
|
||||
info="The name of the embeddings provider API key stored on Astra. "
|
||||
"If set, it will override the 'ProviderKey' in the authentication parameters.",
|
||||
).to_dict()
|
||||
|
||||
new_parameter_3 = SecretStrInput(
|
||||
load_from_db=False,
|
||||
name="z_03_provider_api_key",
|
||||
display_name="Provider API Key",
|
||||
info="An alternative to the Astra Authentication that passes an API key for the provider "
|
||||
|
|
@ -319,15 +329,14 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
|
||||
new_parameter_4 = DictInput(
|
||||
name="z_04_authentication",
|
||||
display_name="Authentication parameters",
|
||||
display_name="Authentication Parameters",
|
||||
is_list=True,
|
||||
).to_dict()
|
||||
|
||||
self.insert_in_dict(
|
||||
build_config,
|
||||
"provider",
|
||||
"model",
|
||||
{
|
||||
"z_00_model_name": new_parameter_0,
|
||||
"z_01_model_parameters": new_parameter_1,
|
||||
"z_02_api_key_name": new_parameter_2,
|
||||
"z_03_provider_api_key": new_parameter_3,
|
||||
|
|
@ -339,8 +348,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
|
||||
def build_vectorize_options(self, **kwargs):
|
||||
for attribute in [
|
||||
"provider",
|
||||
"z_00_model_name",
|
||||
"embedding_provider",
|
||||
"model",
|
||||
"z_01_model_parameters",
|
||||
"z_02_api_key_name",
|
||||
"z_03_provider_api_key",
|
||||
|
|
@ -350,8 +359,10 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
setattr(self, attribute, None)
|
||||
|
||||
# Fetch values from kwargs if any self.* attributes are None
|
||||
provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.provider, [None])[0] or kwargs.get("provider")
|
||||
model_name = self.z_00_model_name or kwargs.get("z_00_model_name")
|
||||
provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.embedding_provider, [None])[0] or kwargs.get(
|
||||
"embedding_provider"
|
||||
)
|
||||
model_name = self.model or kwargs.get("model")
|
||||
authentication = {**(self.z_04_authentication or kwargs.get("z_04_authentication", {}))}
|
||||
parameters = self.z_01_model_parameters or kwargs.get("z_01_model_parameters", {})
|
||||
|
||||
|
|
@ -414,6 +425,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
),
|
||||
"collection_embedding_api_key": dict_options.get("collection_embedding_api_key"),
|
||||
}
|
||||
|
||||
try:
|
||||
vector_store = AstraDBVectorStore(
|
||||
collection_name=self.collection_name,
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ def test_astra_vectorize():
|
|||
store = None
|
||||
try:
|
||||
options = {"provider": "nvidia", "modelName": "NV-Embed-QA"}
|
||||
options_comp = {"provider": "nvidia", "z_00_model_name": "NV-Embed-QA"}
|
||||
options_comp = {"embedding_provider": "nvidia", "model": "NV-Embed-QA"}
|
||||
|
||||
store = AstraDBVectorStore(
|
||||
collection_name=VECTORIZE_COLLECTION,
|
||||
|
|
@ -150,8 +150,8 @@ def test_astra_vectorize_with_provider_api_key():
|
|||
}
|
||||
|
||||
options_comp = {
|
||||
"provider": "openai",
|
||||
"z_00_model_name": "text-embedding-3-small",
|
||||
"embedding_provider": "openai",
|
||||
"model": "text-embedding-3-small",
|
||||
"z_01_model_parameters": {},
|
||||
"z_03_provider_api_key": "openai",
|
||||
"z_04_authentication": {},
|
||||
|
|
@ -206,8 +206,8 @@ def test_astra_vectorize_passes_authentication():
|
|||
"authentication": {"providerKey": "openai"},
|
||||
}
|
||||
options_comp = {
|
||||
"provider": "openai",
|
||||
"z_00_model_name": "text-embedding-3-small",
|
||||
"embedding_provider": "openai",
|
||||
"model": "text-embedding-3-small",
|
||||
"z_01_model_parameters": {},
|
||||
"z_04_authentication": {"providerKey": "openai"},
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue