diff --git a/src/backend/base/langflow/components/vectorstores/AstraDB.py b/src/backend/base/langflow/components/vectorstores/AstraDB.py index 71b58444a..13532d266 100644 --- a/src/backend/base/langflow/components/vectorstores/AstraDB.py +++ b/src/backend/base/langflow/components/vectorstores/AstraDB.py @@ -1,19 +1,11 @@ -from langflow.custom import Component -from langflow.inputs import ( - StrInput, - IntInput, - BoolInput, - DropdownInput, - MultilineInput, - HandleInput, -) -from langflow.schema import Data -from langflow.template import Output - from loguru import logger +from langflow.base.vectorstores.model import LCVectorStoreComponent +from langflow.inputs import BoolInput, DropdownInput, HandleInput, IntInput, MultilineInput, SecretStrInput, StrInput +from langflow.schema import Data -class AstraVectorStoreComponent(Component): + +class AstraVectorStoreComponent(LCVectorStoreComponent): display_name: str = "Astra DB Vector Store" description: str = "Implementation of Vector Store using Astra DB with search capabilities" documentation: str = "https://python.langchain.com/docs/integrations/vectorstores/astradb" @@ -25,21 +17,17 @@ class AstraVectorStoreComponent(Component): display_name="Collection Name", info="The name of the collection within Astra DB where the vectors will be stored.", ), - StrInput( + SecretStrInput( name="token", display_name="Astra DB Application Token", info="Authentication token for accessing Astra DB.", - password=True, + value="ASTRA_DB_APPLICATION_TOKEN", ), - StrInput( + SecretStrInput( name="api_endpoint", display_name="API Endpoint", info="API endpoint URL for the Astra DB service.", - ), - StrInput( - name="code", - display_name="Code", - advanced=True, + value="ASTRA_DB_API_ENDPOINT", ), HandleInput( name="vector_store_inputs", @@ -95,6 +83,7 @@ class AstraVectorStoreComponent(Component): info="Configuration mode for setting up the vector store, with options like 'Sync', 'Async', or 'Off'.", options=["Sync", "Async", "Off"], advanced=True, + value="Sync", ), BoolInput( name="pre_delete_collection", @@ -144,24 +133,6 @@ class AstraVectorStoreComponent(Component): ), ] - outputs = [ - Output( - display_name="Vector Store", - name="vector_store", - method="build_vector_store", - ), - Output( - display_name="Base Retriever", - name="base_retriever", - method="build_base_retriever", - ), - Output( - display_name="Search Results", - name="search_results", - method="search_documents", - ), - ] - def build_vector_store(self): try: from langchain_astradb import AstraDBVectorStore @@ -173,6 +144,9 @@ class AstraVectorStoreComponent(Component): ) try: + if not self.setup_mode: + self.setup_mode = self._inputs["setup_mode"].options[0] + setup_mode_value = SetupMode[self.setup_mode.upper()] except KeyError: raise ValueError(f"Invalid setup mode: {self.setup_mode}") @@ -182,14 +156,14 @@ class AstraVectorStoreComponent(Component): "collection_name": self.collection_name, "token": self.token, "api_endpoint": self.api_endpoint, - "namespace": self.namespace, - "metric": self.metric, - "batch_size": self.batch_size, - "bulk_insert_batch_concurrency": self.bulk_insert_batch_concurrency, - "bulk_insert_overwrite_concurrency": self.bulk_insert_overwrite_concurrency, - "bulk_delete_concurrency": self.bulk_delete_concurrency, + "namespace": self.namespace or None, + "metric": self.metric or None, + "batch_size": self.batch_size or None, + "bulk_insert_batch_concurrency": self.bulk_insert_batch_concurrency or None, + "bulk_insert_overwrite_concurrency": self.bulk_insert_overwrite_concurrency or None, + "bulk_delete_concurrency": self.bulk_delete_concurrency or None, "setup_mode": setup_mode_value, - "pre_delete_collection": self.pre_delete_collection, + "pre_delete_collection": self.pre_delete_collection or False, } if self.metadata_indexing_include: @@ -207,11 +181,12 @@ class AstraVectorStoreComponent(Component): if self.add_to_vector_store: self._add_documents_to_vector_store(vector_store) - self.status = self._astradb_collection_to_data(vector_store.collection) return vector_store def build_base_retriever(self): - return self.build_vector_store() + vector_store = self.build_vector_store() + self.status = self._astradb_collection_to_data(vector_store.collection) + return vector_store def _add_documents_to_vector_store(self, vector_store): documents = [] @@ -256,7 +231,7 @@ class AstraVectorStoreComponent(Component): logger.debug(f"Retrieved documents: {len(docs)}") - data = self._docs_to_data(docs) + data = [Data.from_document(doc) for doc in docs] logger.debug(f"Converted documents to data: {len(data)}") self.status = data return data @@ -266,12 +241,10 @@ class AstraVectorStoreComponent(Component): def _astradb_collection_to_data(self, collection): data = [] - for item in collection["data"]: + data_dict = collection.find() + if data_dict and "data" in data_dict: + data_dict = data_dict["data"].get("documents", []) + + for item in data_dict: data.append(Data(content=item["content"])) return data - - def _docs_to_data(self, docs): - data = [] - for doc in docs: - data.append(Data(content=doc.page_content)) - return data