feat: Add support for dynamic providers in Astra DB Comp (#4627)
* feat: Add support for dynamic providers in Astra DB Comp * [autofix.ci] apply automated fixes * Make sure we return a default dict * Rename params in starter template * Update test_vector_store_rag.py --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
cd2517f7e2
commit
31885175e5
5 changed files with 129 additions and 82 deletions
|
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
import orjson
|
||||
from astrapy import DataAPIClient
|
||||
from astrapy.admin import parse_api_endpoint
|
||||
from langchain_astradb import AstraDBVectorStore
|
||||
|
||||
|
|
@ -29,39 +31,45 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
|
||||
_cached_vector_store: AstraDBVectorStore | None = None
|
||||
|
||||
VECTORIZE_PROVIDERS_MAPPING = {
|
||||
"Azure OpenAI": ["azureOpenAI", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
|
||||
"Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]],
|
||||
"Hugging Face - Serverless": [
|
||||
"huggingface",
|
||||
[
|
||||
"sentence-transformers/all-MiniLM-L6-v2",
|
||||
"intfloat/multilingual-e5-large",
|
||||
"intfloat/multilingual-e5-large-instruct",
|
||||
"BAAI/bge-small-en-v1.5",
|
||||
"BAAI/bge-base-en-v1.5",
|
||||
"BAAI/bge-large-en-v1.5",
|
||||
VECTORIZE_PROVIDERS_MAPPING = defaultdict(
|
||||
list,
|
||||
{
|
||||
"Azure OpenAI": [
|
||||
"azureOpenAI",
|
||||
["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
|
||||
],
|
||||
],
|
||||
"Jina AI": [
|
||||
"jinaAI",
|
||||
[
|
||||
"jina-embeddings-v2-base-en",
|
||||
"jina-embeddings-v2-base-de",
|
||||
"jina-embeddings-v2-base-es",
|
||||
"jina-embeddings-v2-base-code",
|
||||
"jina-embeddings-v2-base-zh",
|
||||
"Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]],
|
||||
"Hugging Face - Serverless": [
|
||||
"huggingface",
|
||||
[
|
||||
"sentence-transformers/all-MiniLM-L6-v2",
|
||||
"intfloat/multilingual-e5-large",
|
||||
"intfloat/multilingual-e5-large-instruct",
|
||||
"BAAI/bge-small-en-v1.5",
|
||||
"BAAI/bge-base-en-v1.5",
|
||||
"BAAI/bge-large-en-v1.5",
|
||||
],
|
||||
],
|
||||
],
|
||||
"Mistral AI": ["mistral", ["mistral-embed"]],
|
||||
"NVIDIA": ["nvidia", ["NV-Embed-QA"]],
|
||||
"OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
|
||||
"Upstage": ["upstageAI", ["solar-embedding-1-large"]],
|
||||
"Voyage AI": [
|
||||
"voyageAI",
|
||||
["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"],
|
||||
],
|
||||
}
|
||||
"Jina AI": [
|
||||
"jinaAI",
|
||||
[
|
||||
"jina-embeddings-v2-base-en",
|
||||
"jina-embeddings-v2-base-de",
|
||||
"jina-embeddings-v2-base-es",
|
||||
"jina-embeddings-v2-base-code",
|
||||
"jina-embeddings-v2-base-zh",
|
||||
],
|
||||
],
|
||||
"Mistral AI": ["mistral", ["mistral-embed"]],
|
||||
"NVIDIA": ["nvidia", ["NV-Embed-QA"]],
|
||||
"OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
|
||||
"Upstage": ["upstageAI", ["solar-embedding-1-large"]],
|
||||
"Voyage AI": [
|
||||
"voyageAI",
|
||||
["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"],
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
inputs = [
|
||||
SecretStrInput(
|
||||
|
|
@ -109,7 +117,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
value="Embedding Model",
|
||||
),
|
||||
HandleInput(
|
||||
name="embedding",
|
||||
name="embedding_model",
|
||||
display_name="Embedding Model",
|
||||
input_types=["Embeddings"],
|
||||
info="Allows an embedding model configuration.",
|
||||
|
|
@ -247,15 +255,52 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
|
||||
return build_config
|
||||
|
||||
def update_providers_mapping(self):
|
||||
# If we don't have token or api_endpoint, we can't fetch the list of providers
|
||||
if not self.token or not self.api_endpoint:
|
||||
self.log("Astra DB token and API endpoint are required to fetch the list of Vectorize providers.")
|
||||
|
||||
return self.VECTORIZE_PROVIDERS_MAPPING
|
||||
|
||||
try:
|
||||
self.log("Dynamically updating list of Vectorize providers.")
|
||||
|
||||
# Get the admin object
|
||||
client = DataAPIClient(token=self.token)
|
||||
admin = client.get_admin()
|
||||
|
||||
# Get the embedding providers
|
||||
db_admin = admin.get_database_admin(self.api_endpoint)
|
||||
embedding_providers = db_admin.find_embedding_providers().as_dict()
|
||||
|
||||
vectorize_providers_mapping = {}
|
||||
|
||||
# Map the provider display name to the provider key and models
|
||||
for provider_key, provider_data in embedding_providers["embeddingProviders"].items():
|
||||
display_name = provider_data["displayName"]
|
||||
models = [model["name"] for model in provider_data["models"]]
|
||||
|
||||
vectorize_providers_mapping[display_name] = [provider_key, models]
|
||||
|
||||
# Sort the resulting dictionary
|
||||
return defaultdict(list, dict(sorted(vectorize_providers_mapping.items())))
|
||||
except Exception as e: # noqa: BLE001
|
||||
self.log(f"Error fetching Vectorize providers: {e}")
|
||||
|
||||
return self.VECTORIZE_PROVIDERS_MAPPING
|
||||
|
||||
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
|
||||
if field_name == "embedding_choice":
|
||||
if field_value == "Astra Vectorize":
|
||||
self.del_fields(build_config, ["embedding"])
|
||||
self.del_fields(build_config, ["embedding_model"])
|
||||
|
||||
# Update the providers mapping
|
||||
vectorize_providers = self.update_providers_mapping()
|
||||
|
||||
new_parameter = DropdownInput(
|
||||
name="embedding_provider",
|
||||
display_name="Embedding Provider",
|
||||
options=self.VECTORIZE_PROVIDERS_MAPPING.keys(),
|
||||
options=vectorize_providers.keys(),
|
||||
value="",
|
||||
required=True,
|
||||
real_time_refresh=True,
|
||||
|
|
@ -276,13 +321,13 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
)
|
||||
|
||||
new_parameter = HandleInput(
|
||||
name="embedding",
|
||||
name="embedding_model",
|
||||
display_name="Embedding Model",
|
||||
input_types=["Embeddings"],
|
||||
info="Allows an embedding model configuration.",
|
||||
).to_dict()
|
||||
|
||||
self.insert_in_dict(build_config, "embedding_choice", {"embedding": new_parameter})
|
||||
self.insert_in_dict(build_config, "embedding_choice", {"embedding_model": new_parameter})
|
||||
|
||||
elif field_name == "embedding_provider":
|
||||
self.del_fields(
|
||||
|
|
@ -290,7 +335,9 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
["model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"],
|
||||
)
|
||||
|
||||
model_options = self.VECTORIZE_PROVIDERS_MAPPING[field_value][1]
|
||||
# Update the providers mapping
|
||||
vectorize_providers = self.update_providers_mapping()
|
||||
model_options = vectorize_providers[field_value][1]
|
||||
|
||||
new_parameter = DropdownInput(
|
||||
name="model",
|
||||
|
|
@ -420,7 +467,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
raise ValueError(msg) from e
|
||||
|
||||
if self.embedding_choice == "Embedding Model":
|
||||
embedding_dict = {"embedding": self.embedding}
|
||||
embedding_dict = {"embedding": self.embedding_model}
|
||||
else:
|
||||
from astrapy.info import CollectionVectorServiceOptions
|
||||
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@
|
|||
"output_types": ["Embeddings"]
|
||||
},
|
||||
"targetHandle": {
|
||||
"fieldName": "embedding",
|
||||
"fieldName": "embedding_model",
|
||||
"id": "AstraDB-3buPx",
|
||||
"inputTypes": ["Embeddings"],
|
||||
"type": "other"
|
||||
|
|
@ -196,7 +196,7 @@
|
|||
"output_types": ["Embeddings"]
|
||||
},
|
||||
"targetHandle": {
|
||||
"fieldName": "embedding",
|
||||
"fieldName": "embedding_model",
|
||||
"id": "AstraDB-laybz",
|
||||
"inputTypes": ["Embeddings"],
|
||||
"type": "other"
|
||||
|
|
@ -1601,7 +1601,7 @@
|
|||
"ingest_data",
|
||||
"namespace",
|
||||
"embedding_service",
|
||||
"embedding",
|
||||
"embedding_model",
|
||||
"metric",
|
||||
"batch_size",
|
||||
"bulk_insert_batch_concurrency",
|
||||
|
|
@ -1781,23 +1781,6 @@
|
|||
"type": "str",
|
||||
"value": ""
|
||||
},
|
||||
"embedding": {
|
||||
"_input_type": "HandleInput",
|
||||
"advanced": false,
|
||||
"display_name": "Embedding Model",
|
||||
"dynamic": false,
|
||||
"info": "Allows an embedding model configuration.",
|
||||
"input_types": ["Embeddings"],
|
||||
"list": false,
|
||||
"name": "embedding",
|
||||
"placeholder": "",
|
||||
"required": false,
|
||||
"show": true,
|
||||
"title_case": false,
|
||||
"trace_as_metadata": true,
|
||||
"type": "other",
|
||||
"value": ""
|
||||
},
|
||||
"embedding_choice": {
|
||||
"_input_type": "DropdownInput",
|
||||
"advanced": false,
|
||||
|
|
@ -1817,6 +1800,23 @@
|
|||
"type": "str",
|
||||
"value": "Embedding Model"
|
||||
},
|
||||
"embedding_model": {
|
||||
"_input_type": "HandleInput",
|
||||
"advanced": false,
|
||||
"display_name": "Embedding Model",
|
||||
"dynamic": false,
|
||||
"info": "Allows an embedding model configuration.",
|
||||
"input_types": ["Embeddings"],
|
||||
"list": false,
|
||||
"name": "embedding_model",
|
||||
"placeholder": "",
|
||||
"required": false,
|
||||
"show": true,
|
||||
"title_case": false,
|
||||
"trace_as_metadata": true,
|
||||
"type": "other",
|
||||
"value": ""
|
||||
},
|
||||
"ingest_data": {
|
||||
"_input_type": "DataInput",
|
||||
"advanced": false,
|
||||
|
|
@ -2556,7 +2556,7 @@
|
|||
"ingest_data",
|
||||
"namespace",
|
||||
"embedding_service",
|
||||
"embedding",
|
||||
"embedding_model",
|
||||
"metric",
|
||||
"batch_size",
|
||||
"bulk_insert_batch_concurrency",
|
||||
|
|
@ -2736,23 +2736,6 @@
|
|||
"type": "str",
|
||||
"value": "test"
|
||||
},
|
||||
"embedding": {
|
||||
"_input_type": "HandleInput",
|
||||
"advanced": false,
|
||||
"display_name": "Embedding Model",
|
||||
"dynamic": false,
|
||||
"info": "Allows an embedding model configuration.",
|
||||
"input_types": ["Embeddings"],
|
||||
"list": false,
|
||||
"name": "embedding",
|
||||
"placeholder": "",
|
||||
"required": false,
|
||||
"show": true,
|
||||
"title_case": false,
|
||||
"trace_as_metadata": true,
|
||||
"type": "other",
|
||||
"value": ""
|
||||
},
|
||||
"embedding_choice": {
|
||||
"_input_type": "DropdownInput",
|
||||
"advanced": false,
|
||||
|
|
@ -2772,6 +2755,23 @@
|
|||
"type": "str",
|
||||
"value": "Embedding Model"
|
||||
},
|
||||
"embedding_model": {
|
||||
"_input_type": "HandleInput",
|
||||
"advanced": false,
|
||||
"display_name": "Embedding Model",
|
||||
"dynamic": false,
|
||||
"info": "Allows an embedding model configuration.",
|
||||
"input_types": ["Embeddings"],
|
||||
"list": false,
|
||||
"name": "embedding_model",
|
||||
"placeholder": "",
|
||||
"required": false,
|
||||
"show": true,
|
||||
"title_case": false,
|
||||
"trace_as_metadata": true,
|
||||
"type": "other",
|
||||
"value": ""
|
||||
},
|
||||
"ingest_data": {
|
||||
"_input_type": "DataInput",
|
||||
"advanced": false,
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ def ingestion_graph():
|
|||
openai_embeddings = OpenAIEmbeddingsComponent()
|
||||
vector_store = AstraVectorStoreComponent()
|
||||
vector_store.set(
|
||||
embedding=openai_embeddings.build_embeddings,
|
||||
embedding_model=openai_embeddings.build_embeddings,
|
||||
ingest_data=text_splitter.split_text,
|
||||
)
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ def rag_graph():
|
|||
rag_vector_store = AstraVectorStoreComponent()
|
||||
rag_vector_store.set(
|
||||
search_input=chat_input.message_response,
|
||||
embedding=openai_embeddings.build_embeddings,
|
||||
embedding_model=openai_embeddings.build_embeddings,
|
||||
)
|
||||
|
||||
parse_data = ParseDataComponent()
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ async def test_base(astradb_client: AstraDB):
|
|||
"token": application_token,
|
||||
"api_endpoint": api_endpoint,
|
||||
"collection_name": BASIC_COLLECTION,
|
||||
"embedding": ComponentInputHandle(
|
||||
"embedding_model": ComponentInputHandle(
|
||||
clazz=OpenAIEmbeddingsComponent,
|
||||
inputs={"openai_api_key": get_openai_api_key()},
|
||||
output_name="embeddings",
|
||||
|
|
@ -79,7 +79,7 @@ async def test_astra_embeds_and_search():
|
|||
"ingest_data": ComponentInputHandle(
|
||||
clazz=TextToData, inputs={"text_data": ["test1", "test2"]}, output_name="from_text"
|
||||
),
|
||||
"embedding": ComponentInputHandle(
|
||||
"embedding_model": ComponentInputHandle(
|
||||
clazz=OpenAIEmbeddingsComponent,
|
||||
inputs={"openai_api_key": get_openai_api_key()},
|
||||
output_name="embeddings",
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def ingestion_graph():
|
|||
)
|
||||
vector_store = AstraVectorStoreComponent(_id="vector-store-123")
|
||||
vector_store.set(
|
||||
embedding=openai_embeddings.build_embeddings,
|
||||
embedding_model=openai_embeddings.build_embeddings,
|
||||
ingest_data=text_splitter.split_text,
|
||||
api_endpoint="https://astra.example.com",
|
||||
token="token", # noqa: S106
|
||||
|
|
@ -53,7 +53,7 @@ def rag_graph():
|
|||
search_input=chat_input.message_response,
|
||||
api_endpoint="https://astra.example.com",
|
||||
token="token", # noqa: S106
|
||||
embedding=openai_embeddings.build_embeddings,
|
||||
embedding_model=openai_embeddings.build_embeddings,
|
||||
)
|
||||
# Mock search_documents
|
||||
rag_vector_store.set_on_output(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue