From 95f1c563ef8e629975c961770d6b16430480cf5b Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Wed, 19 Jun 2024 01:03:39 -0300 Subject: [PATCH] refactor: Update AstraVectorStoreComponent to inherit from LCVectorStoreComponent This commit updates the AstraVectorStoreComponent class in the AstraDB.py file to inherit from the LCVectorStoreComponent class. By doing so, it ensures that the AstraVectorStoreComponent has access to the base functionality provided by the LCVectorStoreComponent. This change improves code organization and promotes code reuse. --- .../components/vectorstores/AstraDB.py | 85 +++++++------------ 1 file changed, 29 insertions(+), 56 deletions(-) diff --git a/src/backend/base/langflow/components/vectorstores/AstraDB.py b/src/backend/base/langflow/components/vectorstores/AstraDB.py index 71b58444a..13532d266 100644 --- a/src/backend/base/langflow/components/vectorstores/AstraDB.py +++ b/src/backend/base/langflow/components/vectorstores/AstraDB.py @@ -1,19 +1,11 @@ -from langflow.custom import Component -from langflow.inputs import ( - StrInput, - IntInput, - BoolInput, - DropdownInput, - MultilineInput, - HandleInput, -) -from langflow.schema import Data -from langflow.template import Output - from loguru import logger +from langflow.base.vectorstores.model import LCVectorStoreComponent +from langflow.inputs import BoolInput, DropdownInput, HandleInput, IntInput, MultilineInput, SecretStrInput, StrInput +from langflow.schema import Data -class AstraVectorStoreComponent(Component): + +class AstraVectorStoreComponent(LCVectorStoreComponent): display_name: str = "Astra DB Vector Store" description: str = "Implementation of Vector Store using Astra DB with search capabilities" documentation: str = "https://python.langchain.com/docs/integrations/vectorstores/astradb" @@ -25,21 +17,17 @@ class AstraVectorStoreComponent(Component): display_name="Collection Name", info="The name of the collection within Astra DB where the vectors will be stored.", ), - StrInput( + SecretStrInput( name="token", display_name="Astra DB Application Token", info="Authentication token for accessing Astra DB.", - password=True, + value="ASTRA_DB_APPLICATION_TOKEN", ), - StrInput( + SecretStrInput( name="api_endpoint", display_name="API Endpoint", info="API endpoint URL for the Astra DB service.", - ), - StrInput( - name="code", - display_name="Code", - advanced=True, + value="ASTRA_DB_API_ENDPOINT", ), HandleInput( name="vector_store_inputs", @@ -95,6 +83,7 @@ class AstraVectorStoreComponent(Component): info="Configuration mode for setting up the vector store, with options like 'Sync', 'Async', or 'Off'.", options=["Sync", "Async", "Off"], advanced=True, + value="Sync", ), BoolInput( name="pre_delete_collection", @@ -144,24 +133,6 @@ class AstraVectorStoreComponent(Component): ), ] - outputs = [ - Output( - display_name="Vector Store", - name="vector_store", - method="build_vector_store", - ), - Output( - display_name="Base Retriever", - name="base_retriever", - method="build_base_retriever", - ), - Output( - display_name="Search Results", - name="search_results", - method="search_documents", - ), - ] - def build_vector_store(self): try: from langchain_astradb import AstraDBVectorStore @@ -173,6 +144,9 @@ class AstraVectorStoreComponent(Component): ) try: + if not self.setup_mode: + self.setup_mode = self._inputs["setup_mode"].options[0] + setup_mode_value = SetupMode[self.setup_mode.upper()] except KeyError: raise ValueError(f"Invalid setup mode: {self.setup_mode}") @@ -182,14 +156,14 @@ class AstraVectorStoreComponent(Component): "collection_name": self.collection_name, "token": self.token, "api_endpoint": self.api_endpoint, - "namespace": self.namespace, - "metric": self.metric, - "batch_size": self.batch_size, - "bulk_insert_batch_concurrency": self.bulk_insert_batch_concurrency, - "bulk_insert_overwrite_concurrency": self.bulk_insert_overwrite_concurrency, - "bulk_delete_concurrency": self.bulk_delete_concurrency, + "namespace": self.namespace or None, + "metric": self.metric or None, + "batch_size": self.batch_size or None, + "bulk_insert_batch_concurrency": self.bulk_insert_batch_concurrency or None, + "bulk_insert_overwrite_concurrency": self.bulk_insert_overwrite_concurrency or None, + "bulk_delete_concurrency": self.bulk_delete_concurrency or None, "setup_mode": setup_mode_value, - "pre_delete_collection": self.pre_delete_collection, + "pre_delete_collection": self.pre_delete_collection or False, } if self.metadata_indexing_include: @@ -207,11 +181,12 @@ class AstraVectorStoreComponent(Component): if self.add_to_vector_store: self._add_documents_to_vector_store(vector_store) - self.status = self._astradb_collection_to_data(vector_store.collection) return vector_store def build_base_retriever(self): - return self.build_vector_store() + vector_store = self.build_vector_store() + self.status = self._astradb_collection_to_data(vector_store.collection) + return vector_store def _add_documents_to_vector_store(self, vector_store): documents = [] @@ -256,7 +231,7 @@ class AstraVectorStoreComponent(Component): logger.debug(f"Retrieved documents: {len(docs)}") - data = self._docs_to_data(docs) + data = [Data.from_document(doc) for doc in docs] logger.debug(f"Converted documents to data: {len(data)}") self.status = data return data @@ -266,12 +241,10 @@ class AstraVectorStoreComponent(Component): def _astradb_collection_to_data(self, collection): data = [] - for item in collection["data"]: + data_dict = collection.find() + if data_dict and "data" in data_dict: + data_dict = data_dict["data"].get("documents", []) + + for item in data_dict: data.append(Data(content=item["content"])) return data - - def _docs_to_data(self, docs): - data = [] - for doc in docs: - data.append(Data(content=doc.page_content)) - return data