feat: Move vectorize to Astra DB Component (#3766)

* Move vectorize to Astra DB Component

* [autofix.ci] apply automated fixes

* Ruff check fixes

* Update compatibility tests and add new tests

* [autofix.ci] apply automated fixes

* Fixes from review feedback

* Restore old vectorize component, add deprecation label

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Eric Hare 2024-09-19 06:11:31 -07:00 committed by GitHub
commit f6d93fc472
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 266 additions and 53 deletions

View file

@ -6,8 +6,8 @@ from langflow.template.field.base import Output
class AstraVectorizeComponent(Component):
display_name: str = "Astra Vectorize"
description: str = "Configuration options for Astra Vectorize server-side embeddings."
display_name: str = "Astra Vectorize [DEPRECATED]"
description: str = "Configuration options for Astra Vectorize server-side embeddings. This component is deprecated. Please use the Astra DB Component directly."
documentation: str = "https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html"
icon = "AstraDB"
name = "AstraVectorize"

View file

@ -2,7 +2,7 @@ from loguru import logger
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers import docs_to_data
from langflow.inputs import DictInput, FloatInput
from langflow.inputs import DictInput, FloatInput, MessageTextInput
from langflow.io import (
BoolInput,
DataInput,
@ -23,6 +23,40 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
name = "AstraDB"
icon: str = "AstraDB"
VECTORIZE_PROVIDERS_MAPPING = {
"Azure OpenAI": ["azureOpenAI", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
"Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]],
"Hugging Face - Serverless": [
"huggingface",
[
"sentence-transformers/all-MiniLM-L6-v2",
"intfloat/multilingual-e5-large",
"intfloat/multilingual-e5-large-instruct",
"BAAI/bge-small-en-v1.5",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-large-en-v1.5",
],
],
"Jina AI": [
"jinaAI",
[
"jina-embeddings-v2-base-en",
"jina-embeddings-v2-base-de",
"jina-embeddings-v2-base-es",
"jina-embeddings-v2-base-code",
"jina-embeddings-v2-base-zh",
],
],
"Mistral AI": ["mistral", ["mistral-embed"]],
"NVIDIA": ["nvidia", ["NV-Embed-QA"]],
"OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
"Upstage": ["upstageAI", ["solar-embedding-1-large"]],
"Voyage AI": [
"voyageAI",
["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"],
],
}
inputs = [
StrInput(
name="collection_name",
@ -59,6 +93,20 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
info="Optional namespace within Astra DB to use for the collection.",
advanced=True,
),
DropdownInput(
name="embedding_service",
display_name="Embedding Model or Astra Vectorize",
info="Determines whether to use Astra Vectorize for the collection.",
options=["Embedding Model", "Astra Vectorize"],
real_time_refresh=True,
value="Embedding Model",
),
HandleInput(
name="embedding",
display_name="Embedding Model",
input_types=["Embeddings"],
info="Allows an embedding model configuration.",
),
DropdownInput(
name="metric",
display_name="Metric",
@ -110,12 +158,6 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
info="Optional list of metadata fields to include in the indexing.",
advanced=True,
),
HandleInput(
name="embedding",
display_name="Embedding or Astra Vectorize",
input_types=["Embeddings", "dict"],
info="Allows either an embedding model or an Astra Vectorize configuration.", # TODO: This should be optional, but need to refactor langchain-astradb first.
),
StrInput(
name="metadata_indexing_exclude",
display_name="Metadata Indexing Exclude",
@ -160,7 +202,159 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
]
@check_cached_vector_store
def build_vector_store(self):
def insert_in_dict(self, build_config, field_name, new_parameters):
# Insert the new key-value pair after the found key
for new_field_name, new_parameter in new_parameters.items():
# Get all the items as a list of tuples (key, value)
items = list(build_config.items())
# Find the index of the key to insert after
for i, (key, value) in enumerate(items):
if key == field_name:
break
items.insert(i + 1, (new_field_name, new_parameter))
# Clear the original dictionary and update with the modified items
build_config.clear()
build_config.update(items)
return build_config
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
if field_name == "embedding_service":
if field_value == "Astra Vectorize":
for field in ["embedding"]:
if field in build_config:
del build_config[field]
new_parameter = DropdownInput(
name="provider",
display_name="Vectorize Provider",
options=self.VECTORIZE_PROVIDERS_MAPPING.keys(),
value="",
required=True,
real_time_refresh=True,
).to_dict()
self.insert_in_dict(build_config, "embedding_service", {"provider": new_parameter})
else:
for field in [
"provider",
"z_00_model_name",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
"z_04_authentication",
]:
if field in build_config:
del build_config[field]
new_parameter = HandleInput(
name="embedding",
display_name="Embedding Model",
input_types=["Embeddings"],
info="Allows an embedding model configuration.",
).to_dict()
self.insert_in_dict(build_config, "embedding_service", {"embedding": new_parameter})
elif field_name == "provider":
for field in [
"z_00_model_name",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
"z_04_authentication",
]:
if field in build_config:
del build_config[field]
model_options = self.VECTORIZE_PROVIDERS_MAPPING[field_value][1]
new_parameter_0 = DropdownInput(
name="z_00_model_name",
display_name="Model Name",
info=f"The embedding model to use for the selected provider. Each provider has a different set of models "
f"available (full list at https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html):\n\n{', '.join(model_options)}",
options=model_options,
required=True,
).to_dict()
new_parameter_1 = DictInput(
name="z_01_model_parameters",
display_name="Model Parameters",
is_list=True,
).to_dict()
new_parameter_2 = MessageTextInput(
name="z_02_api_key_name",
display_name="API Key name",
info="The name of the embeddings provider API key stored on Astra. If set, it will override the 'ProviderKey' in the authentication parameters.",
).to_dict()
new_parameter_3 = SecretStrInput(
name="z_03_provider_api_key",
display_name="Provider API Key",
info="An alternative to the Astra Authentication that passes an API key for the provider with each request to Astra DB. This may be used when Vectorize is configured for the collection, but no corresponding provider secret is stored within Astra's key management system.",
).to_dict()
new_parameter_4 = DictInput(
name="z_04_authentication",
display_name="Authentication parameters",
is_list=True,
).to_dict()
self.insert_in_dict(
build_config,
"provider",
{
"z_00_model_name": new_parameter_0,
"z_01_model_parameters": new_parameter_1,
"z_02_api_key_name": new_parameter_2,
"z_03_provider_api_key": new_parameter_3,
"z_04_authentication": new_parameter_4,
},
)
return build_config
def build_vectorize_options(self, **kwargs):
for attribute in [
"provider",
"z_00_api_key_name",
"z_01_model_name",
"z_02_authentication",
"z_03_provider_api_key",
"z_04_model_parameters",
]:
if not hasattr(self, attribute):
setattr(self, attribute, None)
# Fetch values from kwargs if any self.* attributes are None
provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.provider, [None])[0] or kwargs.get("provider")
authentication = {**(self.z_02_authentication or kwargs.get("z_02_authentication", {}))}
api_key_name = self.z_00_api_key_name or kwargs.get("z_00_api_key_name")
provider_key_name = self.z_03_provider_api_key or kwargs.get("z_03_provider_api_key")
if provider_key_name:
authentication["providerKey"] = provider_key_name
if api_key_name:
authentication["providerKey"] = api_key_name
return {
# must match astrapy.info.CollectionVectorServiceOptions
"collection_vector_service_options": {
"provider": provider_value,
"modelName": self.z_01_model_name or kwargs.get("z_01_model_name"),
"authentication": authentication,
"parameters": self.z_04_model_parameters or kwargs.get("z_04_model_parameters", {}),
},
"collection_embedding_api_key": self.z_03_provider_api_key or kwargs.get("z_03_provider_api_key"),
}
@check_cached_vector_store
def build_vector_store(self, vectorize_options=None):
try:
from langchain_astradb import AstraDBVectorStore
from langchain_astradb.utils.astradb import SetupMode
@ -178,22 +372,22 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
except KeyError:
raise ValueError(f"Invalid setup mode: {self.setup_mode}")
if not isinstance(self.embedding, dict):
if self.embedding:
embedding_dict = {"embedding": self.embedding}
else:
from astrapy.info import CollectionVectorServiceOptions
dict_options = self.embedding.get("collection_vector_service_options", {})
dict_options = vectorize_options or self.build_vectorize_options()
dict_options["authentication"] = {
k: v for k, v in dict_options.get("authentication", {}).items() if k and v
}
dict_options["parameters"] = {k: v for k, v in dict_options.get("parameters", {}).items() if k and v}
embedding_dict = {
"collection_vector_service_options": CollectionVectorServiceOptions.from_dict(dict_options)
"collection_vector_service_options": CollectionVectorServiceOptions.from_dict(
dict_options.get("collection_vector_service_options", {})
),
}
collection_embedding_api_key = self.embedding.get("collection_embedding_api_key")
if collection_embedding_api_key:
embedding_dict["collection_embedding_api_key"] = collection_embedding_api_key
vector_store_kwargs = {
**embedding_dict,
@ -223,6 +417,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
raise ValueError(f"Error initializing AstraDBVectorStore: {str(e)}") from e
self._add_documents_to_vector_store(vector_store)
return vector_store
def _add_documents_to_vector_store(self, vector_store):
@ -262,8 +457,9 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
args["filter"] = clean_filter
return args
def search_documents(self) -> list[Data]:
vector_store = self.build_vector_store()
def search_documents(self, vector_store=None) -> list[Data]:
if not vector_store:
vector_store = self.build_vector_store()
logger.debug(f"Search input: {self.search_input}")
logger.debug(f"Search type: {self.search_type}")

View file

@ -0,0 +1,3 @@
from .AstraDB import AstraVectorStoreComponent
__all__ = ["AstraVectorStoreComponent"]

View file

@ -4,13 +4,13 @@ from astrapy.db import AstraDB
import pytest
from langflow.components.embeddings import OpenAIEmbeddingsComponent
from langflow.components.vectorstores import AstraVectorStoreComponent
from tests.api_keys import get_astradb_application_token, get_astradb_api_endpoint, get_openai_api_key
from tests.integration.components.mock_components import TextToData
from tests.integration.utils import ComponentInputHandle
from langchain_core.documents import Document
from langflow.components.vectorstores.AstraDB import AstraVectorStoreComponent
from langflow.schema.data import Data
from tests.integration.utils import run_single_component
@ -98,14 +98,14 @@ async def test_astra_embeds_and_search():
def test_astra_vectorize():
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
from langflow.components.embeddings.AstraVectorize import AstraVectorizeComponent
application_token = get_astradb_application_token()
api_endpoint = get_astradb_api_endpoint()
store = None
try:
options = {"provider": "nvidia", "modelName": "NV-Embed-QA"}
options_comp = {"provider": "nvidia", "z_01_model_name": "NV-Embed-QA"}
store = AstraDBVectorStore(
collection_name=VECTORIZE_COLLECTION,
api_endpoint=api_endpoint,
@ -116,22 +116,20 @@ def test_astra_vectorize():
documents = [Document(page_content="test1"), Document(page_content="test2")]
records = [Data.from_document(d) for d in documents]
vectorize = AstraVectorizeComponent()
vectorize.build(provider="NVIDIA", model_name="NV-Embed-QA")
vectorize_options = vectorize.build_options()
component = AstraVectorStoreComponent()
vectorize_options = component.build_vectorize_options(**options_comp)
component.build(
token=application_token,
api_endpoint=api_endpoint,
collection_name=VECTORIZE_COLLECTION,
ingest_data=records,
embedding=vectorize_options,
search_input="test",
number_of_results=2,
pre_delete_collection=True,
)
component.build_vector_store()
records = component.search_documents()
vector_store = component.build_vector_store(vectorize_options)
records = component.search_documents(vector_store=vector_store)
assert len(records) == 2
finally:
@ -144,14 +142,26 @@ def test_astra_vectorize_with_provider_api_key():
"""tests vectorize using an openai api key"""
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
from langflow.components.embeddings.AstraVectorize import AstraVectorizeComponent
application_token = get_astradb_application_token()
api_endpoint = get_astradb_api_endpoint()
store = None
try:
options = {"provider": "openai", "modelName": "text-embedding-3-small", "parameters": {}, "authentication": {}}
options = {
"provider": "openai",
"modelName": "text-embedding-3-small",
"parameters": {},
"authentication": {"providerKey": "openai"},
}
options_comp = {
"provider": "openai",
"z_01_model_name": "text-embedding-3-small",
"z_04_model_parameters": {},
"z_02_authentication": {},
"z_03_provider_api_key": "openai",
}
store = AstraDBVectorStore(
collection_name=VECTORIZE_COLLECTION_OPENAI,
api_endpoint=api_endpoint,
@ -162,24 +172,22 @@ def test_astra_vectorize_with_provider_api_key():
documents = [Document(page_content="test1"), Document(page_content="test2")]
records = [Data.from_document(d) for d in documents]
vectorize = AstraVectorizeComponent()
vectorize.build(
provider="OpenAI", model_name="text-embedding-3-small", provider_api_key=os.getenv("OPENAI_API_KEY")
)
vectorize_options = vectorize.build_options()
component = AstraVectorStoreComponent()
vectorize_options = component.build_vectorize_options(**options_comp)
component.build(
token=application_token,
api_endpoint=api_endpoint,
collection_name=VECTORIZE_COLLECTION_OPENAI,
ingest_data=records,
embedding=vectorize_options,
search_input="test",
number_of_results=4,
number_of_results=2,
pre_delete_collection=True,
)
component.build_vector_store()
records = component.search_documents()
vector_store = component.build_vector_store(vectorize_options)
records = component.search_documents(vector_store=vector_store)
assert len(records) == 2
finally:
if store is not None:
@ -191,44 +199,50 @@ def test_astra_vectorize_passes_authentication():
"""tests vectorize using the authentication parameter"""
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
from langflow.components.embeddings.AstraVectorize import AstraVectorizeComponent
store = None
try:
application_token = get_astradb_application_token()
api_endpoint = get_astradb_api_endpoint()
options = {
"provider": "openai",
"modelName": "text-embedding-3-small",
"parameters": {},
"authentication": {"providerKey": "apikey"},
"authentication": {"providerKey": "openai"},
}
options_comp = {
"provider": "openai",
"z_01_model_name": "text-embedding-3-small",
"z_04_model_parameters": {},
"z_02_authentication": {"providerKey": "openai"},
}
store = AstraDBVectorStore(
collection_name=VECTORIZE_COLLECTION_OPENAI_WITH_AUTH,
api_endpoint=api_endpoint,
token=application_token,
collection_vector_service_options=CollectionVectorServiceOptions.from_dict(options),
)
documents = [Document(page_content="test1"), Document(page_content="test2")]
records = [Data.from_document(d) for d in documents]
vectorize = AstraVectorizeComponent()
vectorize.build(
provider="OpenAI", model_name="text-embedding-3-small", authentication={"providerKey": "apikey"}
)
vectorize_options = vectorize.build_options()
component = AstraVectorStoreComponent()
vectorize_options = component.build_vectorize_options(**options_comp)
component.build(
token=application_token,
api_endpoint=api_endpoint,
collection_name=VECTORIZE_COLLECTION_OPENAI_WITH_AUTH,
ingest_data=records,
embedding=vectorize_options,
search_input="test",
number_of_results=2,
pre_delete_collection=True,
)
component.build_vector_store()
records = component.search_documents()
vector_store = component.build_vector_store(vectorize_options)
records = component.search_documents(vector_store=vector_store)
assert len(records) == 2
finally:
if store is not None: