From f94b86f6b74507295d55b7b7836964bcfe32dd67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Boschi?= Date: Wed, 26 Jun 2024 10:31:48 +0200 Subject: [PATCH] components: simplify astra vectorize --- .../components/embeddings/AstraVectorize.py | 49 +++++++++++++++---- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/backend/base/langflow/components/embeddings/AstraVectorize.py b/src/backend/base/langflow/components/embeddings/AstraVectorize.py index 8c9e6d974..00b1a9a63 100644 --- a/src/backend/base/langflow/components/embeddings/AstraVectorize.py +++ b/src/backend/base/langflow/components/embeddings/AstraVectorize.py @@ -1,6 +1,6 @@ from typing import Any from langflow.custom import Component -from langflow.inputs.inputs import DictInput, SecretStrInput, MessageTextInput +from langflow.inputs.inputs import DictInput, SecretStrInput, MessageTextInput, DropdownInput from langflow.template.field.base import Output @@ -10,32 +10,58 @@ class AstraVectorize(Component): documentation: str = "https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html" icon = "AstraDB" + VECTORIZE_PROVIDERS_MAPPING = { + "Azure OpenAI": ["azureOpenAI", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]], + "Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]], + "Hugging Face - Serverless": ["huggingface", + ["sentence-transformers/all-MiniLM-L6-v2", "intfloat/multilingual-e5-large", + "intfloat/multilingual-e5-large-instruct", "BAAI/bge-small-en-v1.5", + "BAAI/bge-base-en-v1.5", "BAAI/bge-large-en-v1.5"]], + "Jina AI": ["jinaAI", ["jina-embeddings-v2-base-en", "jina-embeddings-v2-base-de", "jina-embeddings-v2-base-es", + "jina-embeddings-v2-base-code", "jina-embeddings-v2-base-zh"]], + "Mistral AI": ["mistral", ["mistral-embed"]], + "NVIDIA": ["nvidia", ["NV-Embed-QA"]], + "OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]], + "Upstage": ["upstageAI", ["solar-embedding-1-large"]], + "Voyage AI": ["voyageAI", + ["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"]] + } + VECTORIZE_MODELS_STR = "\n\n".join([provider + ": " + (', '.join(models[1])) for provider, models in VECTORIZE_PROVIDERS_MAPPING.items()]) + inputs = [ - MessageTextInput( + DropdownInput( name="provider", display_name="Provider name", - info="The embedding provider to use.", + options=VECTORIZE_PROVIDERS_MAPPING.keys(), + value="", ), MessageTextInput( name="model_name", display_name="Model name", - info="The embedding model to use.", + info=f"The embedding model to use for the selected provider. Each provider has a different set of models " + f"available (full list at https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html):\n\n{VECTORIZE_MODELS_STR}", + required=True + ), + MessageTextInput( + name="api_key_name", + display_name="API Key name", + info="The name of the embeddings provider API key stored on Astra. If set, it will override the 'ProviderKey' in the authentication parameters." ), DictInput( name="authentication", - display_name="Authentication", - info="Authentication parameters. Use the Astra Portal to add the embedding provider integration to your Astra organization.", + display_name="Authentication parameters", is_list=True, + advanced=True, ), SecretStrInput( name="provider_api_key", display_name="Provider API Key", info="An alternative to the Astra Authentication that let you use directly the API key of the provider.", + advanced=True, ), DictInput( name="model_parameters", display_name="Model parameters", - info="Additional model parameters.", advanced=True, is_list=True, ), @@ -45,12 +71,17 @@ class AstraVectorize(Component): ] def build_options(self) -> dict[str, Any]: + provider_value = self.VECTORIZE_PROVIDERS_MAPPING[self.provider][0] + authentication = {**self.authentication} + api_key_name = self.api_key_name + if api_key_name: + authentication["providerKey"] = api_key_name return { # must match exactly astra CollectionVectorServiceOptions "collection_vector_service_options": { - "provider": self.provider, + "provider": provider_value, "modelName": self.model_name, - "authentication": self.authentication, + "authentication": authentication, "parameters": self.model_parameters, }, "collection_embedding_api_key": self.provider_api_key,