feat: Move vectorize to Astra DB Component (#3766)
* Move vectorize to Astra DB Component * [autofix.ci] apply automated fixes * Ruff check fixes * Update compatibility tests and add new tests * [autofix.ci] apply automated fixes * Fixes from review feedback * Restore old vectorize component, add deprecation label --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
98c1f0e8aa
commit
f6d93fc472
4 changed files with 266 additions and 53 deletions
|
|
@ -6,8 +6,8 @@ from langflow.template.field.base import Output
|
|||
|
||||
|
||||
class AstraVectorizeComponent(Component):
|
||||
display_name: str = "Astra Vectorize"
|
||||
description: str = "Configuration options for Astra Vectorize server-side embeddings."
|
||||
display_name: str = "Astra Vectorize [DEPRECATED]"
|
||||
description: str = "Configuration options for Astra Vectorize server-side embeddings. This component is deprecated. Please use the Astra DB Component directly."
|
||||
documentation: str = "https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html"
|
||||
icon = "AstraDB"
|
||||
name = "AstraVectorize"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from loguru import logger
|
|||
|
||||
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
|
||||
from langflow.helpers import docs_to_data
|
||||
from langflow.inputs import DictInput, FloatInput
|
||||
from langflow.inputs import DictInput, FloatInput, MessageTextInput
|
||||
from langflow.io import (
|
||||
BoolInput,
|
||||
DataInput,
|
||||
|
|
@ -23,6 +23,40 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
name = "AstraDB"
|
||||
icon: str = "AstraDB"
|
||||
|
||||
VECTORIZE_PROVIDERS_MAPPING = {
|
||||
"Azure OpenAI": ["azureOpenAI", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
|
||||
"Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]],
|
||||
"Hugging Face - Serverless": [
|
||||
"huggingface",
|
||||
[
|
||||
"sentence-transformers/all-MiniLM-L6-v2",
|
||||
"intfloat/multilingual-e5-large",
|
||||
"intfloat/multilingual-e5-large-instruct",
|
||||
"BAAI/bge-small-en-v1.5",
|
||||
"BAAI/bge-base-en-v1.5",
|
||||
"BAAI/bge-large-en-v1.5",
|
||||
],
|
||||
],
|
||||
"Jina AI": [
|
||||
"jinaAI",
|
||||
[
|
||||
"jina-embeddings-v2-base-en",
|
||||
"jina-embeddings-v2-base-de",
|
||||
"jina-embeddings-v2-base-es",
|
||||
"jina-embeddings-v2-base-code",
|
||||
"jina-embeddings-v2-base-zh",
|
||||
],
|
||||
],
|
||||
"Mistral AI": ["mistral", ["mistral-embed"]],
|
||||
"NVIDIA": ["nvidia", ["NV-Embed-QA"]],
|
||||
"OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
|
||||
"Upstage": ["upstageAI", ["solar-embedding-1-large"]],
|
||||
"Voyage AI": [
|
||||
"voyageAI",
|
||||
["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"],
|
||||
],
|
||||
}
|
||||
|
||||
inputs = [
|
||||
StrInput(
|
||||
name="collection_name",
|
||||
|
|
@ -59,6 +93,20 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
info="Optional namespace within Astra DB to use for the collection.",
|
||||
advanced=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="embedding_service",
|
||||
display_name="Embedding Model or Astra Vectorize",
|
||||
info="Determines whether to use Astra Vectorize for the collection.",
|
||||
options=["Embedding Model", "Astra Vectorize"],
|
||||
real_time_refresh=True,
|
||||
value="Embedding Model",
|
||||
),
|
||||
HandleInput(
|
||||
name="embedding",
|
||||
display_name="Embedding Model",
|
||||
input_types=["Embeddings"],
|
||||
info="Allows an embedding model configuration.",
|
||||
),
|
||||
DropdownInput(
|
||||
name="metric",
|
||||
display_name="Metric",
|
||||
|
|
@ -110,12 +158,6 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
info="Optional list of metadata fields to include in the indexing.",
|
||||
advanced=True,
|
||||
),
|
||||
HandleInput(
|
||||
name="embedding",
|
||||
display_name="Embedding or Astra Vectorize",
|
||||
input_types=["Embeddings", "dict"],
|
||||
info="Allows either an embedding model or an Astra Vectorize configuration.", # TODO: This should be optional, but need to refactor langchain-astradb first.
|
||||
),
|
||||
StrInput(
|
||||
name="metadata_indexing_exclude",
|
||||
display_name="Metadata Indexing Exclude",
|
||||
|
|
@ -160,7 +202,159 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
]
|
||||
|
||||
@check_cached_vector_store
|
||||
def build_vector_store(self):
|
||||
def insert_in_dict(self, build_config, field_name, new_parameters):
|
||||
# Insert the new key-value pair after the found key
|
||||
for new_field_name, new_parameter in new_parameters.items():
|
||||
# Get all the items as a list of tuples (key, value)
|
||||
items = list(build_config.items())
|
||||
|
||||
# Find the index of the key to insert after
|
||||
for i, (key, value) in enumerate(items):
|
||||
if key == field_name:
|
||||
break
|
||||
|
||||
items.insert(i + 1, (new_field_name, new_parameter))
|
||||
|
||||
# Clear the original dictionary and update with the modified items
|
||||
build_config.clear()
|
||||
build_config.update(items)
|
||||
|
||||
return build_config
|
||||
|
||||
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
|
||||
if field_name == "embedding_service":
|
||||
if field_value == "Astra Vectorize":
|
||||
for field in ["embedding"]:
|
||||
if field in build_config:
|
||||
del build_config[field]
|
||||
|
||||
new_parameter = DropdownInput(
|
||||
name="provider",
|
||||
display_name="Vectorize Provider",
|
||||
options=self.VECTORIZE_PROVIDERS_MAPPING.keys(),
|
||||
value="",
|
||||
required=True,
|
||||
real_time_refresh=True,
|
||||
).to_dict()
|
||||
|
||||
self.insert_in_dict(build_config, "embedding_service", {"provider": new_parameter})
|
||||
else:
|
||||
for field in [
|
||||
"provider",
|
||||
"z_00_model_name",
|
||||
"z_01_model_parameters",
|
||||
"z_02_api_key_name",
|
||||
"z_03_provider_api_key",
|
||||
"z_04_authentication",
|
||||
]:
|
||||
if field in build_config:
|
||||
del build_config[field]
|
||||
|
||||
new_parameter = HandleInput(
|
||||
name="embedding",
|
||||
display_name="Embedding Model",
|
||||
input_types=["Embeddings"],
|
||||
info="Allows an embedding model configuration.",
|
||||
).to_dict()
|
||||
|
||||
self.insert_in_dict(build_config, "embedding_service", {"embedding": new_parameter})
|
||||
|
||||
elif field_name == "provider":
|
||||
for field in [
|
||||
"z_00_model_name",
|
||||
"z_01_model_parameters",
|
||||
"z_02_api_key_name",
|
||||
"z_03_provider_api_key",
|
||||
"z_04_authentication",
|
||||
]:
|
||||
if field in build_config:
|
||||
del build_config[field]
|
||||
|
||||
model_options = self.VECTORIZE_PROVIDERS_MAPPING[field_value][1]
|
||||
|
||||
new_parameter_0 = DropdownInput(
|
||||
name="z_00_model_name",
|
||||
display_name="Model Name",
|
||||
info=f"The embedding model to use for the selected provider. Each provider has a different set of models "
|
||||
f"available (full list at https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html):\n\n{', '.join(model_options)}",
|
||||
options=model_options,
|
||||
required=True,
|
||||
).to_dict()
|
||||
|
||||
new_parameter_1 = DictInput(
|
||||
name="z_01_model_parameters",
|
||||
display_name="Model Parameters",
|
||||
is_list=True,
|
||||
).to_dict()
|
||||
|
||||
new_parameter_2 = MessageTextInput(
|
||||
name="z_02_api_key_name",
|
||||
display_name="API Key name",
|
||||
info="The name of the embeddings provider API key stored on Astra. If set, it will override the 'ProviderKey' in the authentication parameters.",
|
||||
).to_dict()
|
||||
|
||||
new_parameter_3 = SecretStrInput(
|
||||
name="z_03_provider_api_key",
|
||||
display_name="Provider API Key",
|
||||
info="An alternative to the Astra Authentication that passes an API key for the provider with each request to Astra DB. This may be used when Vectorize is configured for the collection, but no corresponding provider secret is stored within Astra's key management system.",
|
||||
).to_dict()
|
||||
|
||||
new_parameter_4 = DictInput(
|
||||
name="z_04_authentication",
|
||||
display_name="Authentication parameters",
|
||||
is_list=True,
|
||||
).to_dict()
|
||||
|
||||
self.insert_in_dict(
|
||||
build_config,
|
||||
"provider",
|
||||
{
|
||||
"z_00_model_name": new_parameter_0,
|
||||
"z_01_model_parameters": new_parameter_1,
|
||||
"z_02_api_key_name": new_parameter_2,
|
||||
"z_03_provider_api_key": new_parameter_3,
|
||||
"z_04_authentication": new_parameter_4,
|
||||
},
|
||||
)
|
||||
|
||||
return build_config
|
||||
|
||||
def build_vectorize_options(self, **kwargs):
|
||||
for attribute in [
|
||||
"provider",
|
||||
"z_00_api_key_name",
|
||||
"z_01_model_name",
|
||||
"z_02_authentication",
|
||||
"z_03_provider_api_key",
|
||||
"z_04_model_parameters",
|
||||
]:
|
||||
if not hasattr(self, attribute):
|
||||
setattr(self, attribute, None)
|
||||
|
||||
# Fetch values from kwargs if any self.* attributes are None
|
||||
provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.provider, [None])[0] or kwargs.get("provider")
|
||||
authentication = {**(self.z_02_authentication or kwargs.get("z_02_authentication", {}))}
|
||||
|
||||
api_key_name = self.z_00_api_key_name or kwargs.get("z_00_api_key_name")
|
||||
provider_key_name = self.z_03_provider_api_key or kwargs.get("z_03_provider_api_key")
|
||||
if provider_key_name:
|
||||
authentication["providerKey"] = provider_key_name
|
||||
if api_key_name:
|
||||
authentication["providerKey"] = api_key_name
|
||||
|
||||
return {
|
||||
# must match astrapy.info.CollectionVectorServiceOptions
|
||||
"collection_vector_service_options": {
|
||||
"provider": provider_value,
|
||||
"modelName": self.z_01_model_name or kwargs.get("z_01_model_name"),
|
||||
"authentication": authentication,
|
||||
"parameters": self.z_04_model_parameters or kwargs.get("z_04_model_parameters", {}),
|
||||
},
|
||||
"collection_embedding_api_key": self.z_03_provider_api_key or kwargs.get("z_03_provider_api_key"),
|
||||
}
|
||||
|
||||
@check_cached_vector_store
|
||||
def build_vector_store(self, vectorize_options=None):
|
||||
try:
|
||||
from langchain_astradb import AstraDBVectorStore
|
||||
from langchain_astradb.utils.astradb import SetupMode
|
||||
|
|
@ -178,22 +372,22 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
except KeyError:
|
||||
raise ValueError(f"Invalid setup mode: {self.setup_mode}")
|
||||
|
||||
if not isinstance(self.embedding, dict):
|
||||
if self.embedding:
|
||||
embedding_dict = {"embedding": self.embedding}
|
||||
else:
|
||||
from astrapy.info import CollectionVectorServiceOptions
|
||||
|
||||
dict_options = self.embedding.get("collection_vector_service_options", {})
|
||||
dict_options = vectorize_options or self.build_vectorize_options()
|
||||
dict_options["authentication"] = {
|
||||
k: v for k, v in dict_options.get("authentication", {}).items() if k and v
|
||||
}
|
||||
dict_options["parameters"] = {k: v for k, v in dict_options.get("parameters", {}).items() if k and v}
|
||||
|
||||
embedding_dict = {
|
||||
"collection_vector_service_options": CollectionVectorServiceOptions.from_dict(dict_options)
|
||||
"collection_vector_service_options": CollectionVectorServiceOptions.from_dict(
|
||||
dict_options.get("collection_vector_service_options", {})
|
||||
),
|
||||
}
|
||||
collection_embedding_api_key = self.embedding.get("collection_embedding_api_key")
|
||||
if collection_embedding_api_key:
|
||||
embedding_dict["collection_embedding_api_key"] = collection_embedding_api_key
|
||||
|
||||
vector_store_kwargs = {
|
||||
**embedding_dict,
|
||||
|
|
@ -223,6 +417,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
raise ValueError(f"Error initializing AstraDBVectorStore: {str(e)}") from e
|
||||
|
||||
self._add_documents_to_vector_store(vector_store)
|
||||
|
||||
return vector_store
|
||||
|
||||
def _add_documents_to_vector_store(self, vector_store):
|
||||
|
|
@ -262,8 +457,9 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
args["filter"] = clean_filter
|
||||
return args
|
||||
|
||||
def search_documents(self) -> list[Data]:
|
||||
vector_store = self.build_vector_store()
|
||||
def search_documents(self, vector_store=None) -> list[Data]:
|
||||
if not vector_store:
|
||||
vector_store = self.build_vector_store()
|
||||
|
||||
logger.debug(f"Search input: {self.search_input}")
|
||||
logger.debug(f"Search type: {self.search_type}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from .AstraDB import AstraVectorStoreComponent
|
||||
|
||||
__all__ = ["AstraVectorStoreComponent"]
|
||||
|
|
@ -4,13 +4,13 @@ from astrapy.db import AstraDB
|
|||
import pytest
|
||||
|
||||
from langflow.components.embeddings import OpenAIEmbeddingsComponent
|
||||
from langflow.components.vectorstores import AstraVectorStoreComponent
|
||||
from tests.api_keys import get_astradb_application_token, get_astradb_api_endpoint, get_openai_api_key
|
||||
from tests.integration.components.mock_components import TextToData
|
||||
from tests.integration.utils import ComponentInputHandle
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
from langflow.components.vectorstores.AstraDB import AstraVectorStoreComponent
|
||||
from langflow.schema.data import Data
|
||||
from tests.integration.utils import run_single_component
|
||||
|
||||
|
|
@ -98,14 +98,14 @@ async def test_astra_embeds_and_search():
|
|||
def test_astra_vectorize():
|
||||
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
|
||||
|
||||
from langflow.components.embeddings.AstraVectorize import AstraVectorizeComponent
|
||||
|
||||
application_token = get_astradb_application_token()
|
||||
api_endpoint = get_astradb_api_endpoint()
|
||||
|
||||
store = None
|
||||
try:
|
||||
options = {"provider": "nvidia", "modelName": "NV-Embed-QA"}
|
||||
options_comp = {"provider": "nvidia", "z_01_model_name": "NV-Embed-QA"}
|
||||
|
||||
store = AstraDBVectorStore(
|
||||
collection_name=VECTORIZE_COLLECTION,
|
||||
api_endpoint=api_endpoint,
|
||||
|
|
@ -116,22 +116,20 @@ def test_astra_vectorize():
|
|||
documents = [Document(page_content="test1"), Document(page_content="test2")]
|
||||
records = [Data.from_document(d) for d in documents]
|
||||
|
||||
vectorize = AstraVectorizeComponent()
|
||||
vectorize.build(provider="NVIDIA", model_name="NV-Embed-QA")
|
||||
vectorize_options = vectorize.build_options()
|
||||
|
||||
component = AstraVectorStoreComponent()
|
||||
vectorize_options = component.build_vectorize_options(**options_comp)
|
||||
|
||||
component.build(
|
||||
token=application_token,
|
||||
api_endpoint=api_endpoint,
|
||||
collection_name=VECTORIZE_COLLECTION,
|
||||
ingest_data=records,
|
||||
embedding=vectorize_options,
|
||||
search_input="test",
|
||||
number_of_results=2,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
component.build_vector_store()
|
||||
records = component.search_documents()
|
||||
vector_store = component.build_vector_store(vectorize_options)
|
||||
records = component.search_documents(vector_store=vector_store)
|
||||
|
||||
assert len(records) == 2
|
||||
finally:
|
||||
|
|
@ -144,14 +142,26 @@ def test_astra_vectorize_with_provider_api_key():
|
|||
"""tests vectorize using an openai api key"""
|
||||
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
|
||||
|
||||
from langflow.components.embeddings.AstraVectorize import AstraVectorizeComponent
|
||||
|
||||
application_token = get_astradb_application_token()
|
||||
api_endpoint = get_astradb_api_endpoint()
|
||||
|
||||
store = None
|
||||
try:
|
||||
options = {"provider": "openai", "modelName": "text-embedding-3-small", "parameters": {}, "authentication": {}}
|
||||
options = {
|
||||
"provider": "openai",
|
||||
"modelName": "text-embedding-3-small",
|
||||
"parameters": {},
|
||||
"authentication": {"providerKey": "openai"},
|
||||
}
|
||||
|
||||
options_comp = {
|
||||
"provider": "openai",
|
||||
"z_01_model_name": "text-embedding-3-small",
|
||||
"z_04_model_parameters": {},
|
||||
"z_02_authentication": {},
|
||||
"z_03_provider_api_key": "openai",
|
||||
}
|
||||
|
||||
store = AstraDBVectorStore(
|
||||
collection_name=VECTORIZE_COLLECTION_OPENAI,
|
||||
api_endpoint=api_endpoint,
|
||||
|
|
@ -162,24 +172,22 @@ def test_astra_vectorize_with_provider_api_key():
|
|||
documents = [Document(page_content="test1"), Document(page_content="test2")]
|
||||
records = [Data.from_document(d) for d in documents]
|
||||
|
||||
vectorize = AstraVectorizeComponent()
|
||||
vectorize.build(
|
||||
provider="OpenAI", model_name="text-embedding-3-small", provider_api_key=os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
vectorize_options = vectorize.build_options()
|
||||
|
||||
component = AstraVectorStoreComponent()
|
||||
vectorize_options = component.build_vectorize_options(**options_comp)
|
||||
|
||||
component.build(
|
||||
token=application_token,
|
||||
api_endpoint=api_endpoint,
|
||||
collection_name=VECTORIZE_COLLECTION_OPENAI,
|
||||
ingest_data=records,
|
||||
embedding=vectorize_options,
|
||||
search_input="test",
|
||||
number_of_results=4,
|
||||
number_of_results=2,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
component.build_vector_store()
|
||||
records = component.search_documents()
|
||||
|
||||
vector_store = component.build_vector_store(vectorize_options)
|
||||
records = component.search_documents(vector_store=vector_store)
|
||||
|
||||
assert len(records) == 2
|
||||
finally:
|
||||
if store is not None:
|
||||
|
|
@ -191,44 +199,50 @@ def test_astra_vectorize_passes_authentication():
|
|||
"""tests vectorize using the authentication parameter"""
|
||||
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
|
||||
|
||||
from langflow.components.embeddings.AstraVectorize import AstraVectorizeComponent
|
||||
|
||||
store = None
|
||||
try:
|
||||
application_token = get_astradb_application_token()
|
||||
api_endpoint = get_astradb_api_endpoint()
|
||||
|
||||
options = {
|
||||
"provider": "openai",
|
||||
"modelName": "text-embedding-3-small",
|
||||
"parameters": {},
|
||||
"authentication": {"providerKey": "apikey"},
|
||||
"authentication": {"providerKey": "openai"},
|
||||
}
|
||||
options_comp = {
|
||||
"provider": "openai",
|
||||
"z_01_model_name": "text-embedding-3-small",
|
||||
"z_04_model_parameters": {},
|
||||
"z_02_authentication": {"providerKey": "openai"},
|
||||
}
|
||||
|
||||
store = AstraDBVectorStore(
|
||||
collection_name=VECTORIZE_COLLECTION_OPENAI_WITH_AUTH,
|
||||
api_endpoint=api_endpoint,
|
||||
token=application_token,
|
||||
collection_vector_service_options=CollectionVectorServiceOptions.from_dict(options),
|
||||
)
|
||||
|
||||
documents = [Document(page_content="test1"), Document(page_content="test2")]
|
||||
records = [Data.from_document(d) for d in documents]
|
||||
|
||||
vectorize = AstraVectorizeComponent()
|
||||
vectorize.build(
|
||||
provider="OpenAI", model_name="text-embedding-3-small", authentication={"providerKey": "apikey"}
|
||||
)
|
||||
vectorize_options = vectorize.build_options()
|
||||
|
||||
component = AstraVectorStoreComponent()
|
||||
vectorize_options = component.build_vectorize_options(**options_comp)
|
||||
|
||||
component.build(
|
||||
token=application_token,
|
||||
api_endpoint=api_endpoint,
|
||||
collection_name=VECTORIZE_COLLECTION_OPENAI_WITH_AUTH,
|
||||
ingest_data=records,
|
||||
embedding=vectorize_options,
|
||||
search_input="test",
|
||||
number_of_results=2,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
component.build_vector_store()
|
||||
records = component.search_documents()
|
||||
|
||||
vector_store = component.build_vector_store(vectorize_options)
|
||||
records = component.search_documents(vector_store=vector_store)
|
||||
|
||||
assert len(records) == 2
|
||||
finally:
|
||||
if store is not None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue