feat: Add support for dynamic providers in Astra DB Comp (#4627)

* feat: Add support for dynamic providers in Astra DB Comp

* [autofix.ci] apply automated fixes

* Make sure we return a default dict

* Rename params in starter template

* Update test_vector_store_rag.py

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Eric Hare 2024-11-18 14:42:15 -08:00 committed by GitHub
commit 31885175e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 129 additions and 82 deletions

View file

@ -1,6 +1,8 @@
import os
from collections import defaultdict
import orjson
from astrapy import DataAPIClient
from astrapy.admin import parse_api_endpoint
from langchain_astradb import AstraDBVectorStore
@ -29,39 +31,45 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
_cached_vector_store: AstraDBVectorStore | None = None
VECTORIZE_PROVIDERS_MAPPING = {
"Azure OpenAI": ["azureOpenAI", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
"Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]],
"Hugging Face - Serverless": [
"huggingface",
[
"sentence-transformers/all-MiniLM-L6-v2",
"intfloat/multilingual-e5-large",
"intfloat/multilingual-e5-large-instruct",
"BAAI/bge-small-en-v1.5",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-large-en-v1.5",
VECTORIZE_PROVIDERS_MAPPING = defaultdict(
list,
{
"Azure OpenAI": [
"azureOpenAI",
["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
],
],
"Jina AI": [
"jinaAI",
[
"jina-embeddings-v2-base-en",
"jina-embeddings-v2-base-de",
"jina-embeddings-v2-base-es",
"jina-embeddings-v2-base-code",
"jina-embeddings-v2-base-zh",
"Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]],
"Hugging Face - Serverless": [
"huggingface",
[
"sentence-transformers/all-MiniLM-L6-v2",
"intfloat/multilingual-e5-large",
"intfloat/multilingual-e5-large-instruct",
"BAAI/bge-small-en-v1.5",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-large-en-v1.5",
],
],
],
"Mistral AI": ["mistral", ["mistral-embed"]],
"NVIDIA": ["nvidia", ["NV-Embed-QA"]],
"OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
"Upstage": ["upstageAI", ["solar-embedding-1-large"]],
"Voyage AI": [
"voyageAI",
["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"],
],
}
"Jina AI": [
"jinaAI",
[
"jina-embeddings-v2-base-en",
"jina-embeddings-v2-base-de",
"jina-embeddings-v2-base-es",
"jina-embeddings-v2-base-code",
"jina-embeddings-v2-base-zh",
],
],
"Mistral AI": ["mistral", ["mistral-embed"]],
"NVIDIA": ["nvidia", ["NV-Embed-QA"]],
"OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
"Upstage": ["upstageAI", ["solar-embedding-1-large"]],
"Voyage AI": [
"voyageAI",
["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"],
],
},
)
inputs = [
SecretStrInput(
@ -109,7 +117,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
value="Embedding Model",
),
HandleInput(
name="embedding",
name="embedding_model",
display_name="Embedding Model",
input_types=["Embeddings"],
info="Allows an embedding model configuration.",
@ -247,15 +255,52 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
return build_config
def update_providers_mapping(self):
# If we don't have token or api_endpoint, we can't fetch the list of providers
if not self.token or not self.api_endpoint:
self.log("Astra DB token and API endpoint are required to fetch the list of Vectorize providers.")
return self.VECTORIZE_PROVIDERS_MAPPING
try:
self.log("Dynamically updating list of Vectorize providers.")
# Get the admin object
client = DataAPIClient(token=self.token)
admin = client.get_admin()
# Get the embedding providers
db_admin = admin.get_database_admin(self.api_endpoint)
embedding_providers = db_admin.find_embedding_providers().as_dict()
vectorize_providers_mapping = {}
# Map the provider display name to the provider key and models
for provider_key, provider_data in embedding_providers["embeddingProviders"].items():
display_name = provider_data["displayName"]
models = [model["name"] for model in provider_data["models"]]
vectorize_providers_mapping[display_name] = [provider_key, models]
# Sort the resulting dictionary
return defaultdict(list, dict(sorted(vectorize_providers_mapping.items())))
except Exception as e: # noqa: BLE001
self.log(f"Error fetching Vectorize providers: {e}")
return self.VECTORIZE_PROVIDERS_MAPPING
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
if field_name == "embedding_choice":
if field_value == "Astra Vectorize":
self.del_fields(build_config, ["embedding"])
self.del_fields(build_config, ["embedding_model"])
# Update the providers mapping
vectorize_providers = self.update_providers_mapping()
new_parameter = DropdownInput(
name="embedding_provider",
display_name="Embedding Provider",
options=self.VECTORIZE_PROVIDERS_MAPPING.keys(),
options=vectorize_providers.keys(),
value="",
required=True,
real_time_refresh=True,
@ -276,13 +321,13 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
)
new_parameter = HandleInput(
name="embedding",
name="embedding_model",
display_name="Embedding Model",
input_types=["Embeddings"],
info="Allows an embedding model configuration.",
).to_dict()
self.insert_in_dict(build_config, "embedding_choice", {"embedding": new_parameter})
self.insert_in_dict(build_config, "embedding_choice", {"embedding_model": new_parameter})
elif field_name == "embedding_provider":
self.del_fields(
@ -290,7 +335,9 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
["model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"],
)
model_options = self.VECTORIZE_PROVIDERS_MAPPING[field_value][1]
# Update the providers mapping
vectorize_providers = self.update_providers_mapping()
model_options = vectorize_providers[field_value][1]
new_parameter = DropdownInput(
name="model",
@ -420,7 +467,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
raise ValueError(msg) from e
if self.embedding_choice == "Embedding Model":
embedding_dict = {"embedding": self.embedding}
embedding_dict = {"embedding": self.embedding_model}
else:
from astrapy.info import CollectionVectorServiceOptions

View file

@ -104,7 +104,7 @@
"output_types": ["Embeddings"]
},
"targetHandle": {
"fieldName": "embedding",
"fieldName": "embedding_model",
"id": "AstraDB-3buPx",
"inputTypes": ["Embeddings"],
"type": "other"
@ -196,7 +196,7 @@
"output_types": ["Embeddings"]
},
"targetHandle": {
"fieldName": "embedding",
"fieldName": "embedding_model",
"id": "AstraDB-laybz",
"inputTypes": ["Embeddings"],
"type": "other"
@ -1601,7 +1601,7 @@
"ingest_data",
"namespace",
"embedding_service",
"embedding",
"embedding_model",
"metric",
"batch_size",
"bulk_insert_batch_concurrency",
@ -1781,23 +1781,6 @@
"type": "str",
"value": ""
},
"embedding": {
"_input_type": "HandleInput",
"advanced": false,
"display_name": "Embedding Model",
"dynamic": false,
"info": "Allows an embedding model configuration.",
"input_types": ["Embeddings"],
"list": false,
"name": "embedding",
"placeholder": "",
"required": false,
"show": true,
"title_case": false,
"trace_as_metadata": true,
"type": "other",
"value": ""
},
"embedding_choice": {
"_input_type": "DropdownInput",
"advanced": false,
@ -1817,6 +1800,23 @@
"type": "str",
"value": "Embedding Model"
},
"embedding_model": {
"_input_type": "HandleInput",
"advanced": false,
"display_name": "Embedding Model",
"dynamic": false,
"info": "Allows an embedding model configuration.",
"input_types": ["Embeddings"],
"list": false,
"name": "embedding_model",
"placeholder": "",
"required": false,
"show": true,
"title_case": false,
"trace_as_metadata": true,
"type": "other",
"value": ""
},
"ingest_data": {
"_input_type": "DataInput",
"advanced": false,
@ -2556,7 +2556,7 @@
"ingest_data",
"namespace",
"embedding_service",
"embedding",
"embedding_model",
"metric",
"batch_size",
"bulk_insert_batch_concurrency",
@ -2736,23 +2736,6 @@
"type": "str",
"value": "test"
},
"embedding": {
"_input_type": "HandleInput",
"advanced": false,
"display_name": "Embedding Model",
"dynamic": false,
"info": "Allows an embedding model configuration.",
"input_types": ["Embeddings"],
"list": false,
"name": "embedding",
"placeholder": "",
"required": false,
"show": true,
"title_case": false,
"trace_as_metadata": true,
"type": "other",
"value": ""
},
"embedding_choice": {
"_input_type": "DropdownInput",
"advanced": false,
@ -2772,6 +2755,23 @@
"type": "str",
"value": "Embedding Model"
},
"embedding_model": {
"_input_type": "HandleInput",
"advanced": false,
"display_name": "Embedding Model",
"dynamic": false,
"info": "Allows an embedding model configuration.",
"input_types": ["Embeddings"],
"list": false,
"name": "embedding_model",
"placeholder": "",
"required": false,
"show": true,
"title_case": false,
"trace_as_metadata": true,
"type": "other",
"value": ""
},
"ingest_data": {
"_input_type": "DataInput",
"advanced": false,

View file

@ -20,7 +20,7 @@ def ingestion_graph():
openai_embeddings = OpenAIEmbeddingsComponent()
vector_store = AstraVectorStoreComponent()
vector_store.set(
embedding=openai_embeddings.build_embeddings,
embedding_model=openai_embeddings.build_embeddings,
ingest_data=text_splitter.split_text,
)
@ -34,7 +34,7 @@ def rag_graph():
rag_vector_store = AstraVectorStoreComponent()
rag_vector_store.set(
search_input=chat_input.message_response,
embedding=openai_embeddings.build_embeddings,
embedding_model=openai_embeddings.build_embeddings,
)
parse_data = ParseDataComponent()

View file

@ -48,7 +48,7 @@ async def test_base(astradb_client: AstraDB):
"token": application_token,
"api_endpoint": api_endpoint,
"collection_name": BASIC_COLLECTION,
"embedding": ComponentInputHandle(
"embedding_model": ComponentInputHandle(
clazz=OpenAIEmbeddingsComponent,
inputs={"openai_api_key": get_openai_api_key()},
output_name="embeddings",
@ -79,7 +79,7 @@ async def test_astra_embeds_and_search():
"ingest_data": ComponentInputHandle(
clazz=TextToData, inputs={"text_data": ["test1", "test2"]}, output_name="from_text"
),
"embedding": ComponentInputHandle(
"embedding_model": ComponentInputHandle(
clazz=OpenAIEmbeddingsComponent,
inputs={"openai_api_key": get_openai_api_key()},
output_name="embeddings",

View file

@ -31,7 +31,7 @@ def ingestion_graph():
)
vector_store = AstraVectorStoreComponent(_id="vector-store-123")
vector_store.set(
embedding=openai_embeddings.build_embeddings,
embedding_model=openai_embeddings.build_embeddings,
ingest_data=text_splitter.split_text,
api_endpoint="https://astra.example.com",
token="token", # noqa: S106
@ -53,7 +53,7 @@ def rag_graph():
search_input=chat_input.message_response,
api_endpoint="https://astra.example.com",
token="token", # noqa: S106
embedding=openai_embeddings.build_embeddings,
embedding_model=openai_embeddings.build_embeddings,
)
# Mock search_documents
rag_vector_store.set_on_output(