From 31885175e5504cb7869b832d2372152a1cceeaea Mon Sep 17 00:00:00 2001 From: Eric Hare Date: Mon, 18 Nov 2024 14:42:15 -0800 Subject: [PATCH] feat: Add support for dynamic providers in Astra DB Comp (#4627) * feat: Add support for dynamic providers in Astra DB Comp * [autofix.ci] apply automated fixes * Make sure we return a default dict * Rename params in starter template * Update test_vector_store_rag.py --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../components/vectorstores/astradb.py | 123 ++++++++++++------ .../starter_projects/Vector Store RAG.json | 76 +++++------ .../starter_projects/vector_store_rag.py | 4 +- .../components/astra/test_astra_component.py | 4 +- .../starter_projects/test_vector_store_rag.py | 4 +- 5 files changed, 129 insertions(+), 82 deletions(-) diff --git a/src/backend/base/langflow/components/vectorstores/astradb.py b/src/backend/base/langflow/components/vectorstores/astradb.py index 54fcd5c48..3f8f2dd43 100644 --- a/src/backend/base/langflow/components/vectorstores/astradb.py +++ b/src/backend/base/langflow/components/vectorstores/astradb.py @@ -1,6 +1,8 @@ import os +from collections import defaultdict import orjson +from astrapy import DataAPIClient from astrapy.admin import parse_api_endpoint from langchain_astradb import AstraDBVectorStore @@ -29,39 +31,45 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): _cached_vector_store: AstraDBVectorStore | None = None - VECTORIZE_PROVIDERS_MAPPING = { - "Azure OpenAI": ["azureOpenAI", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]], - "Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]], - "Hugging Face - Serverless": [ - "huggingface", - [ - "sentence-transformers/all-MiniLM-L6-v2", - "intfloat/multilingual-e5-large", - "intfloat/multilingual-e5-large-instruct", - "BAAI/bge-small-en-v1.5", - "BAAI/bge-base-en-v1.5", - "BAAI/bge-large-en-v1.5", + VECTORIZE_PROVIDERS_MAPPING = defaultdict( + list, + { + "Azure OpenAI": [ + "azureOpenAI", + ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"], ], - ], - "Jina AI": [ - "jinaAI", - [ - "jina-embeddings-v2-base-en", - "jina-embeddings-v2-base-de", - "jina-embeddings-v2-base-es", - "jina-embeddings-v2-base-code", - "jina-embeddings-v2-base-zh", + "Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]], + "Hugging Face - Serverless": [ + "huggingface", + [ + "sentence-transformers/all-MiniLM-L6-v2", + "intfloat/multilingual-e5-large", + "intfloat/multilingual-e5-large-instruct", + "BAAI/bge-small-en-v1.5", + "BAAI/bge-base-en-v1.5", + "BAAI/bge-large-en-v1.5", + ], ], - ], - "Mistral AI": ["mistral", ["mistral-embed"]], - "NVIDIA": ["nvidia", ["NV-Embed-QA"]], - "OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]], - "Upstage": ["upstageAI", ["solar-embedding-1-large"]], - "Voyage AI": [ - "voyageAI", - ["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"], - ], - } + "Jina AI": [ + "jinaAI", + [ + "jina-embeddings-v2-base-en", + "jina-embeddings-v2-base-de", + "jina-embeddings-v2-base-es", + "jina-embeddings-v2-base-code", + "jina-embeddings-v2-base-zh", + ], + ], + "Mistral AI": ["mistral", ["mistral-embed"]], + "NVIDIA": ["nvidia", ["NV-Embed-QA"]], + "OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]], + "Upstage": ["upstageAI", ["solar-embedding-1-large"]], + "Voyage AI": [ + "voyageAI", + ["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"], + ], + }, + ) inputs = [ SecretStrInput( @@ -109,7 +117,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): value="Embedding Model", ), HandleInput( - name="embedding", + name="embedding_model", display_name="Embedding Model", input_types=["Embeddings"], info="Allows an embedding model configuration.", @@ -247,15 +255,52 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): return build_config + def update_providers_mapping(self): + # If we don't have token or api_endpoint, we can't fetch the list of providers + if not self.token or not self.api_endpoint: + self.log("Astra DB token and API endpoint are required to fetch the list of Vectorize providers.") + + return self.VECTORIZE_PROVIDERS_MAPPING + + try: + self.log("Dynamically updating list of Vectorize providers.") + + # Get the admin object + client = DataAPIClient(token=self.token) + admin = client.get_admin() + + # Get the embedding providers + db_admin = admin.get_database_admin(self.api_endpoint) + embedding_providers = db_admin.find_embedding_providers().as_dict() + + vectorize_providers_mapping = {} + + # Map the provider display name to the provider key and models + for provider_key, provider_data in embedding_providers["embeddingProviders"].items(): + display_name = provider_data["displayName"] + models = [model["name"] for model in provider_data["models"]] + + vectorize_providers_mapping[display_name] = [provider_key, models] + + # Sort the resulting dictionary + return defaultdict(list, dict(sorted(vectorize_providers_mapping.items()))) + except Exception as e: # noqa: BLE001 + self.log(f"Error fetching Vectorize providers: {e}") + + return self.VECTORIZE_PROVIDERS_MAPPING + def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None): if field_name == "embedding_choice": if field_value == "Astra Vectorize": - self.del_fields(build_config, ["embedding"]) + self.del_fields(build_config, ["embedding_model"]) + + # Update the providers mapping + vectorize_providers = self.update_providers_mapping() new_parameter = DropdownInput( name="embedding_provider", display_name="Embedding Provider", - options=self.VECTORIZE_PROVIDERS_MAPPING.keys(), + options=vectorize_providers.keys(), value="", required=True, real_time_refresh=True, @@ -276,13 +321,13 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): ) new_parameter = HandleInput( - name="embedding", + name="embedding_model", display_name="Embedding Model", input_types=["Embeddings"], info="Allows an embedding model configuration.", ).to_dict() - self.insert_in_dict(build_config, "embedding_choice", {"embedding": new_parameter}) + self.insert_in_dict(build_config, "embedding_choice", {"embedding_model": new_parameter}) elif field_name == "embedding_provider": self.del_fields( @@ -290,7 +335,9 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): ["model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"], ) - model_options = self.VECTORIZE_PROVIDERS_MAPPING[field_value][1] + # Update the providers mapping + vectorize_providers = self.update_providers_mapping() + model_options = vectorize_providers[field_value][1] new_parameter = DropdownInput( name="model", @@ -420,7 +467,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): raise ValueError(msg) from e if self.embedding_choice == "Embedding Model": - embedding_dict = {"embedding": self.embedding} + embedding_dict = {"embedding": self.embedding_model} else: from astrapy.info import CollectionVectorServiceOptions diff --git a/src/backend/base/langflow/initial_setup/starter_projects/Vector Store RAG.json b/src/backend/base/langflow/initial_setup/starter_projects/Vector Store RAG.json index 95efad77b..8595d8cb6 100644 --- a/src/backend/base/langflow/initial_setup/starter_projects/Vector Store RAG.json +++ b/src/backend/base/langflow/initial_setup/starter_projects/Vector Store RAG.json @@ -104,7 +104,7 @@ "output_types": ["Embeddings"] }, "targetHandle": { - "fieldName": "embedding", + "fieldName": "embedding_model", "id": "AstraDB-3buPx", "inputTypes": ["Embeddings"], "type": "other" @@ -196,7 +196,7 @@ "output_types": ["Embeddings"] }, "targetHandle": { - "fieldName": "embedding", + "fieldName": "embedding_model", "id": "AstraDB-laybz", "inputTypes": ["Embeddings"], "type": "other" @@ -1601,7 +1601,7 @@ "ingest_data", "namespace", "embedding_service", - "embedding", + "embedding_model", "metric", "batch_size", "bulk_insert_batch_concurrency", @@ -1781,23 +1781,6 @@ "type": "str", "value": "" }, - "embedding": { - "_input_type": "HandleInput", - "advanced": false, - "display_name": "Embedding Model", - "dynamic": false, - "info": "Allows an embedding model configuration.", - "input_types": ["Embeddings"], - "list": false, - "name": "embedding", - "placeholder": "", - "required": false, - "show": true, - "title_case": false, - "trace_as_metadata": true, - "type": "other", - "value": "" - }, "embedding_choice": { "_input_type": "DropdownInput", "advanced": false, @@ -1817,6 +1800,23 @@ "type": "str", "value": "Embedding Model" }, + "embedding_model": { + "_input_type": "HandleInput", + "advanced": false, + "display_name": "Embedding Model", + "dynamic": false, + "info": "Allows an embedding model configuration.", + "input_types": ["Embeddings"], + "list": false, + "name": "embedding_model", + "placeholder": "", + "required": false, + "show": true, + "title_case": false, + "trace_as_metadata": true, + "type": "other", + "value": "" + }, "ingest_data": { "_input_type": "DataInput", "advanced": false, @@ -2556,7 +2556,7 @@ "ingest_data", "namespace", "embedding_service", - "embedding", + "embedding_model", "metric", "batch_size", "bulk_insert_batch_concurrency", @@ -2736,23 +2736,6 @@ "type": "str", "value": "test" }, - "embedding": { - "_input_type": "HandleInput", - "advanced": false, - "display_name": "Embedding Model", - "dynamic": false, - "info": "Allows an embedding model configuration.", - "input_types": ["Embeddings"], - "list": false, - "name": "embedding", - "placeholder": "", - "required": false, - "show": true, - "title_case": false, - "trace_as_metadata": true, - "type": "other", - "value": "" - }, "embedding_choice": { "_input_type": "DropdownInput", "advanced": false, @@ -2772,6 +2755,23 @@ "type": "str", "value": "Embedding Model" }, + "embedding_model": { + "_input_type": "HandleInput", + "advanced": false, + "display_name": "Embedding Model", + "dynamic": false, + "info": "Allows an embedding model configuration.", + "input_types": ["Embeddings"], + "list": false, + "name": "embedding_model", + "placeholder": "", + "required": false, + "show": true, + "title_case": false, + "trace_as_metadata": true, + "type": "other", + "value": "" + }, "ingest_data": { "_input_type": "DataInput", "advanced": false, diff --git a/src/backend/base/langflow/initial_setup/starter_projects/vector_store_rag.py b/src/backend/base/langflow/initial_setup/starter_projects/vector_store_rag.py index c7fe0ef06..f7ffdb0eb 100644 --- a/src/backend/base/langflow/initial_setup/starter_projects/vector_store_rag.py +++ b/src/backend/base/langflow/initial_setup/starter_projects/vector_store_rag.py @@ -20,7 +20,7 @@ def ingestion_graph(): openai_embeddings = OpenAIEmbeddingsComponent() vector_store = AstraVectorStoreComponent() vector_store.set( - embedding=openai_embeddings.build_embeddings, + embedding_model=openai_embeddings.build_embeddings, ingest_data=text_splitter.split_text, ) @@ -34,7 +34,7 @@ def rag_graph(): rag_vector_store = AstraVectorStoreComponent() rag_vector_store.set( search_input=chat_input.message_response, - embedding=openai_embeddings.build_embeddings, + embedding_model=openai_embeddings.build_embeddings, ) parse_data = ParseDataComponent() diff --git a/src/backend/tests/integration/components/astra/test_astra_component.py b/src/backend/tests/integration/components/astra/test_astra_component.py index 866898ed6..fee54315c 100644 --- a/src/backend/tests/integration/components/astra/test_astra_component.py +++ b/src/backend/tests/integration/components/astra/test_astra_component.py @@ -48,7 +48,7 @@ async def test_base(astradb_client: AstraDB): "token": application_token, "api_endpoint": api_endpoint, "collection_name": BASIC_COLLECTION, - "embedding": ComponentInputHandle( + "embedding_model": ComponentInputHandle( clazz=OpenAIEmbeddingsComponent, inputs={"openai_api_key": get_openai_api_key()}, output_name="embeddings", @@ -79,7 +79,7 @@ async def test_astra_embeds_and_search(): "ingest_data": ComponentInputHandle( clazz=TextToData, inputs={"text_data": ["test1", "test2"]}, output_name="from_text" ), - "embedding": ComponentInputHandle( + "embedding_model": ComponentInputHandle( clazz=OpenAIEmbeddingsComponent, inputs={"openai_api_key": get_openai_api_key()}, output_name="embeddings", diff --git a/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py b/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py index 967e5eb92..d11014881 100644 --- a/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py +++ b/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py @@ -31,7 +31,7 @@ def ingestion_graph(): ) vector_store = AstraVectorStoreComponent(_id="vector-store-123") vector_store.set( - embedding=openai_embeddings.build_embeddings, + embedding_model=openai_embeddings.build_embeddings, ingest_data=text_splitter.split_text, api_endpoint="https://astra.example.com", token="token", # noqa: S106 @@ -53,7 +53,7 @@ def rag_graph(): search_input=chat_input.message_response, api_endpoint="https://astra.example.com", token="token", # noqa: S106 - embedding=openai_embeddings.build_embeddings, + embedding_model=openai_embeddings.build_embeddings, ) # Mock search_documents rag_vector_store.set_on_output(