From f6d93fc472568f13ae90b7c70c8d5cc6ae727eb5 Mon Sep 17 00:00:00 2001 From: Eric Hare Date: Thu, 19 Sep 2024 06:11:31 -0700 Subject: [PATCH] feat: Move vectorize to Astra DB Component (#3766) * Move vectorize to Astra DB Component * [autofix.ci] apply automated fixes * Ruff check fixes * Update compatibility tests and add new tests * [autofix.ci] apply automated fixes * Fixes from review feedback * Restore old vectorize component, add deprecation label --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../components/embeddings/AstraVectorize.py | 4 +- .../components/vectorstores/AstraDB.py | 228 ++++++++++++++++-- .../components/vectorstores/__init__.py | 3 + .../components/astra/test_astra_component.py | 84 ++++--- 4 files changed, 266 insertions(+), 53 deletions(-) diff --git a/src/backend/base/langflow/components/embeddings/AstraVectorize.py b/src/backend/base/langflow/components/embeddings/AstraVectorize.py index 4de49eb75..5305385fd 100644 --- a/src/backend/base/langflow/components/embeddings/AstraVectorize.py +++ b/src/backend/base/langflow/components/embeddings/AstraVectorize.py @@ -6,8 +6,8 @@ from langflow.template.field.base import Output class AstraVectorizeComponent(Component): - display_name: str = "Astra Vectorize" - description: str = "Configuration options for Astra Vectorize server-side embeddings." + display_name: str = "Astra Vectorize [DEPRECATED]" + description: str = "Configuration options for Astra Vectorize server-side embeddings. This component is deprecated. Please use the Astra DB Component directly." documentation: str = "https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html" icon = "AstraDB" name = "AstraVectorize" diff --git a/src/backend/base/langflow/components/vectorstores/AstraDB.py b/src/backend/base/langflow/components/vectorstores/AstraDB.py index 6c7d3e451..060726b94 100644 --- a/src/backend/base/langflow/components/vectorstores/AstraDB.py +++ b/src/backend/base/langflow/components/vectorstores/AstraDB.py @@ -2,7 +2,7 @@ from loguru import logger from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store from langflow.helpers import docs_to_data -from langflow.inputs import DictInput, FloatInput +from langflow.inputs import DictInput, FloatInput, MessageTextInput from langflow.io import ( BoolInput, DataInput, @@ -23,6 +23,40 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): name = "AstraDB" icon: str = "AstraDB" + VECTORIZE_PROVIDERS_MAPPING = { + "Azure OpenAI": ["azureOpenAI", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]], + "Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]], + "Hugging Face - Serverless": [ + "huggingface", + [ + "sentence-transformers/all-MiniLM-L6-v2", + "intfloat/multilingual-e5-large", + "intfloat/multilingual-e5-large-instruct", + "BAAI/bge-small-en-v1.5", + "BAAI/bge-base-en-v1.5", + "BAAI/bge-large-en-v1.5", + ], + ], + "Jina AI": [ + "jinaAI", + [ + "jina-embeddings-v2-base-en", + "jina-embeddings-v2-base-de", + "jina-embeddings-v2-base-es", + "jina-embeddings-v2-base-code", + "jina-embeddings-v2-base-zh", + ], + ], + "Mistral AI": ["mistral", ["mistral-embed"]], + "NVIDIA": ["nvidia", ["NV-Embed-QA"]], + "OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]], + "Upstage": ["upstageAI", ["solar-embedding-1-large"]], + "Voyage AI": [ + "voyageAI", + ["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"], + ], + } + inputs = [ StrInput( name="collection_name", @@ -59,6 +93,20 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): info="Optional namespace within Astra DB to use for the collection.", advanced=True, ), + DropdownInput( + name="embedding_service", + display_name="Embedding Model or Astra Vectorize", + info="Determines whether to use Astra Vectorize for the collection.", + options=["Embedding Model", "Astra Vectorize"], + real_time_refresh=True, + value="Embedding Model", + ), + HandleInput( + name="embedding", + display_name="Embedding Model", + input_types=["Embeddings"], + info="Allows an embedding model configuration.", + ), DropdownInput( name="metric", display_name="Metric", @@ -110,12 +158,6 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): info="Optional list of metadata fields to include in the indexing.", advanced=True, ), - HandleInput( - name="embedding", - display_name="Embedding or Astra Vectorize", - input_types=["Embeddings", "dict"], - info="Allows either an embedding model or an Astra Vectorize configuration.", # TODO: This should be optional, but need to refactor langchain-astradb first. - ), StrInput( name="metadata_indexing_exclude", display_name="Metadata Indexing Exclude", @@ -160,7 +202,159 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): ] @check_cached_vector_store - def build_vector_store(self): + def insert_in_dict(self, build_config, field_name, new_parameters): + # Insert the new key-value pair after the found key + for new_field_name, new_parameter in new_parameters.items(): + # Get all the items as a list of tuples (key, value) + items = list(build_config.items()) + + # Find the index of the key to insert after + for i, (key, value) in enumerate(items): + if key == field_name: + break + + items.insert(i + 1, (new_field_name, new_parameter)) + + # Clear the original dictionary and update with the modified items + build_config.clear() + build_config.update(items) + + return build_config + + def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None): + if field_name == "embedding_service": + if field_value == "Astra Vectorize": + for field in ["embedding"]: + if field in build_config: + del build_config[field] + + new_parameter = DropdownInput( + name="provider", + display_name="Vectorize Provider", + options=self.VECTORIZE_PROVIDERS_MAPPING.keys(), + value="", + required=True, + real_time_refresh=True, + ).to_dict() + + self.insert_in_dict(build_config, "embedding_service", {"provider": new_parameter}) + else: + for field in [ + "provider", + "z_00_model_name", + "z_01_model_parameters", + "z_02_api_key_name", + "z_03_provider_api_key", + "z_04_authentication", + ]: + if field in build_config: + del build_config[field] + + new_parameter = HandleInput( + name="embedding", + display_name="Embedding Model", + input_types=["Embeddings"], + info="Allows an embedding model configuration.", + ).to_dict() + + self.insert_in_dict(build_config, "embedding_service", {"embedding": new_parameter}) + + elif field_name == "provider": + for field in [ + "z_00_model_name", + "z_01_model_parameters", + "z_02_api_key_name", + "z_03_provider_api_key", + "z_04_authentication", + ]: + if field in build_config: + del build_config[field] + + model_options = self.VECTORIZE_PROVIDERS_MAPPING[field_value][1] + + new_parameter_0 = DropdownInput( + name="z_00_model_name", + display_name="Model Name", + info=f"The embedding model to use for the selected provider. Each provider has a different set of models " + f"available (full list at https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html):\n\n{', '.join(model_options)}", + options=model_options, + required=True, + ).to_dict() + + new_parameter_1 = DictInput( + name="z_01_model_parameters", + display_name="Model Parameters", + is_list=True, + ).to_dict() + + new_parameter_2 = MessageTextInput( + name="z_02_api_key_name", + display_name="API Key name", + info="The name of the embeddings provider API key stored on Astra. If set, it will override the 'ProviderKey' in the authentication parameters.", + ).to_dict() + + new_parameter_3 = SecretStrInput( + name="z_03_provider_api_key", + display_name="Provider API Key", + info="An alternative to the Astra Authentication that passes an API key for the provider with each request to Astra DB. This may be used when Vectorize is configured for the collection, but no corresponding provider secret is stored within Astra's key management system.", + ).to_dict() + + new_parameter_4 = DictInput( + name="z_04_authentication", + display_name="Authentication parameters", + is_list=True, + ).to_dict() + + self.insert_in_dict( + build_config, + "provider", + { + "z_00_model_name": new_parameter_0, + "z_01_model_parameters": new_parameter_1, + "z_02_api_key_name": new_parameter_2, + "z_03_provider_api_key": new_parameter_3, + "z_04_authentication": new_parameter_4, + }, + ) + + return build_config + + def build_vectorize_options(self, **kwargs): + for attribute in [ + "provider", + "z_00_api_key_name", + "z_01_model_name", + "z_02_authentication", + "z_03_provider_api_key", + "z_04_model_parameters", + ]: + if not hasattr(self, attribute): + setattr(self, attribute, None) + + # Fetch values from kwargs if any self.* attributes are None + provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.provider, [None])[0] or kwargs.get("provider") + authentication = {**(self.z_02_authentication or kwargs.get("z_02_authentication", {}))} + + api_key_name = self.z_00_api_key_name or kwargs.get("z_00_api_key_name") + provider_key_name = self.z_03_provider_api_key or kwargs.get("z_03_provider_api_key") + if provider_key_name: + authentication["providerKey"] = provider_key_name + if api_key_name: + authentication["providerKey"] = api_key_name + + return { + # must match astrapy.info.CollectionVectorServiceOptions + "collection_vector_service_options": { + "provider": provider_value, + "modelName": self.z_01_model_name or kwargs.get("z_01_model_name"), + "authentication": authentication, + "parameters": self.z_04_model_parameters or kwargs.get("z_04_model_parameters", {}), + }, + "collection_embedding_api_key": self.z_03_provider_api_key or kwargs.get("z_03_provider_api_key"), + } + + @check_cached_vector_store + def build_vector_store(self, vectorize_options=None): try: from langchain_astradb import AstraDBVectorStore from langchain_astradb.utils.astradb import SetupMode @@ -178,22 +372,22 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): except KeyError: raise ValueError(f"Invalid setup mode: {self.setup_mode}") - if not isinstance(self.embedding, dict): + if self.embedding: embedding_dict = {"embedding": self.embedding} else: from astrapy.info import CollectionVectorServiceOptions - dict_options = self.embedding.get("collection_vector_service_options", {}) + dict_options = vectorize_options or self.build_vectorize_options() dict_options["authentication"] = { k: v for k, v in dict_options.get("authentication", {}).items() if k and v } dict_options["parameters"] = {k: v for k, v in dict_options.get("parameters", {}).items() if k and v} + embedding_dict = { - "collection_vector_service_options": CollectionVectorServiceOptions.from_dict(dict_options) + "collection_vector_service_options": CollectionVectorServiceOptions.from_dict( + dict_options.get("collection_vector_service_options", {}) + ), } - collection_embedding_api_key = self.embedding.get("collection_embedding_api_key") - if collection_embedding_api_key: - embedding_dict["collection_embedding_api_key"] = collection_embedding_api_key vector_store_kwargs = { **embedding_dict, @@ -223,6 +417,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): raise ValueError(f"Error initializing AstraDBVectorStore: {str(e)}") from e self._add_documents_to_vector_store(vector_store) + return vector_store def _add_documents_to_vector_store(self, vector_store): @@ -262,8 +457,9 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): args["filter"] = clean_filter return args - def search_documents(self) -> list[Data]: - vector_store = self.build_vector_store() + def search_documents(self, vector_store=None) -> list[Data]: + if not vector_store: + vector_store = self.build_vector_store() logger.debug(f"Search input: {self.search_input}") logger.debug(f"Search type: {self.search_type}") diff --git a/src/backend/base/langflow/components/vectorstores/__init__.py b/src/backend/base/langflow/components/vectorstores/__init__.py index e69de29bb..724c5d72f 100644 --- a/src/backend/base/langflow/components/vectorstores/__init__.py +++ b/src/backend/base/langflow/components/vectorstores/__init__.py @@ -0,0 +1,3 @@ +from .AstraDB import AstraVectorStoreComponent + +__all__ = ["AstraVectorStoreComponent"] diff --git a/src/backend/tests/integration/components/astra/test_astra_component.py b/src/backend/tests/integration/components/astra/test_astra_component.py index 5bbb25dbf..641e9d44c 100644 --- a/src/backend/tests/integration/components/astra/test_astra_component.py +++ b/src/backend/tests/integration/components/astra/test_astra_component.py @@ -4,13 +4,13 @@ from astrapy.db import AstraDB import pytest from langflow.components.embeddings import OpenAIEmbeddingsComponent +from langflow.components.vectorstores import AstraVectorStoreComponent from tests.api_keys import get_astradb_application_token, get_astradb_api_endpoint, get_openai_api_key from tests.integration.components.mock_components import TextToData from tests.integration.utils import ComponentInputHandle from langchain_core.documents import Document -from langflow.components.vectorstores.AstraDB import AstraVectorStoreComponent from langflow.schema.data import Data from tests.integration.utils import run_single_component @@ -98,14 +98,14 @@ async def test_astra_embeds_and_search(): def test_astra_vectorize(): from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions - from langflow.components.embeddings.AstraVectorize import AstraVectorizeComponent - application_token = get_astradb_application_token() api_endpoint = get_astradb_api_endpoint() store = None try: options = {"provider": "nvidia", "modelName": "NV-Embed-QA"} + options_comp = {"provider": "nvidia", "z_01_model_name": "NV-Embed-QA"} + store = AstraDBVectorStore( collection_name=VECTORIZE_COLLECTION, api_endpoint=api_endpoint, @@ -116,22 +116,20 @@ def test_astra_vectorize(): documents = [Document(page_content="test1"), Document(page_content="test2")] records = [Data.from_document(d) for d in documents] - vectorize = AstraVectorizeComponent() - vectorize.build(provider="NVIDIA", model_name="NV-Embed-QA") - vectorize_options = vectorize.build_options() - component = AstraVectorStoreComponent() + vectorize_options = component.build_vectorize_options(**options_comp) + component.build( token=application_token, api_endpoint=api_endpoint, collection_name=VECTORIZE_COLLECTION, ingest_data=records, - embedding=vectorize_options, search_input="test", number_of_results=2, + pre_delete_collection=True, ) - component.build_vector_store() - records = component.search_documents() + vector_store = component.build_vector_store(vectorize_options) + records = component.search_documents(vector_store=vector_store) assert len(records) == 2 finally: @@ -144,14 +142,26 @@ def test_astra_vectorize_with_provider_api_key(): """tests vectorize using an openai api key""" from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions - from langflow.components.embeddings.AstraVectorize import AstraVectorizeComponent - application_token = get_astradb_application_token() api_endpoint = get_astradb_api_endpoint() store = None try: - options = {"provider": "openai", "modelName": "text-embedding-3-small", "parameters": {}, "authentication": {}} + options = { + "provider": "openai", + "modelName": "text-embedding-3-small", + "parameters": {}, + "authentication": {"providerKey": "openai"}, + } + + options_comp = { + "provider": "openai", + "z_01_model_name": "text-embedding-3-small", + "z_04_model_parameters": {}, + "z_02_authentication": {}, + "z_03_provider_api_key": "openai", + } + store = AstraDBVectorStore( collection_name=VECTORIZE_COLLECTION_OPENAI, api_endpoint=api_endpoint, @@ -162,24 +172,22 @@ def test_astra_vectorize_with_provider_api_key(): documents = [Document(page_content="test1"), Document(page_content="test2")] records = [Data.from_document(d) for d in documents] - vectorize = AstraVectorizeComponent() - vectorize.build( - provider="OpenAI", model_name="text-embedding-3-small", provider_api_key=os.getenv("OPENAI_API_KEY") - ) - vectorize_options = vectorize.build_options() - component = AstraVectorStoreComponent() + vectorize_options = component.build_vectorize_options(**options_comp) + component.build( token=application_token, api_endpoint=api_endpoint, collection_name=VECTORIZE_COLLECTION_OPENAI, ingest_data=records, - embedding=vectorize_options, search_input="test", - number_of_results=4, + number_of_results=2, + pre_delete_collection=True, ) - component.build_vector_store() - records = component.search_documents() + + vector_store = component.build_vector_store(vectorize_options) + records = component.search_documents(vector_store=vector_store) + assert len(records) == 2 finally: if store is not None: @@ -191,44 +199,50 @@ def test_astra_vectorize_passes_authentication(): """tests vectorize using the authentication parameter""" from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions - from langflow.components.embeddings.AstraVectorize import AstraVectorizeComponent - store = None try: application_token = get_astradb_application_token() api_endpoint = get_astradb_api_endpoint() + options = { "provider": "openai", "modelName": "text-embedding-3-small", "parameters": {}, - "authentication": {"providerKey": "apikey"}, + "authentication": {"providerKey": "openai"}, } + options_comp = { + "provider": "openai", + "z_01_model_name": "text-embedding-3-small", + "z_04_model_parameters": {}, + "z_02_authentication": {"providerKey": "openai"}, + } + store = AstraDBVectorStore( collection_name=VECTORIZE_COLLECTION_OPENAI_WITH_AUTH, api_endpoint=api_endpoint, token=application_token, collection_vector_service_options=CollectionVectorServiceOptions.from_dict(options), ) + documents = [Document(page_content="test1"), Document(page_content="test2")] records = [Data.from_document(d) for d in documents] - vectorize = AstraVectorizeComponent() - vectorize.build( - provider="OpenAI", model_name="text-embedding-3-small", authentication={"providerKey": "apikey"} - ) - vectorize_options = vectorize.build_options() - component = AstraVectorStoreComponent() + vectorize_options = component.build_vectorize_options(**options_comp) + component.build( token=application_token, api_endpoint=api_endpoint, collection_name=VECTORIZE_COLLECTION_OPENAI_WITH_AUTH, ingest_data=records, - embedding=vectorize_options, search_input="test", + number_of_results=2, + pre_delete_collection=True, ) - component.build_vector_store() - records = component.search_documents() + + vector_store = component.build_vector_store(vectorize_options) + records = component.search_documents(vector_store=vector_store) + assert len(records) == 2 finally: if store is not None: