diff --git a/src/backend/base/langflow/components/vectorstores/astradb_graph.py b/src/backend/base/langflow/components/vectorstores/astradb_graph.py index d8471dd4f..3e5b398d2 100644 --- a/src/backend/base/langflow/components/vectorstores/astradb_graph.py +++ b/src/backend/base/langflow/components/vectorstores/astradb_graph.py @@ -6,11 +6,12 @@ from loguru import logger from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store from langflow.helpers import docs_to_data -from langflow.inputs import DictInput, FloatInput -from langflow.io import ( +from langflow.inputs import ( BoolInput, DataInput, + DictInput, DropdownInput, + FloatInput, HandleInput, IntInput, MultilineInput, @@ -71,11 +72,10 @@ class AstraDBGraphVectorStoreComponent(LCVectorStoreComponent): advanced=True, ), HandleInput( - name="embedding", + name="embedding_model", display_name="Embedding Model", input_types=["Embeddings"], - info="Embedding model.", - required=True, + info="Allows an embedding model configuration.", ), DropdownInput( name="metric", @@ -156,8 +156,14 @@ class AstraDBGraphVectorStoreComponent(LCVectorStoreComponent): name="search_type", display_name="Search Type", info="Search type to use", - options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"], - value="Similarity", + options=[ + "Similarity", + "Similarity with score threshold", + "MMR (Max Marginal Relevance)", + "Graph Traversal", + "MMR (Max Marginal Relevance) Graph Traversal", + ], + value="MMR (Max Marginal Relevance) Graph Traversal", advanced=True, ), FloatInput( @@ -199,8 +205,10 @@ class AstraDBGraphVectorStoreComponent(LCVectorStoreComponent): raise ValueError(msg) from e try: + logger.debug(f"Initializing Graph Vector Store {self.collection_name}") + vector_store = AstraDBGraphVectorStore( - embedding=self.embedding, + embedding=self.embedding_model, collection_name=self.collection_name, metadata_incoming_links_key=self.metadata_incoming_links_key or "incoming_links", token=self.token, @@ -216,7 +224,7 @@ class AstraDBGraphVectorStoreComponent(LCVectorStoreComponent): pre_delete_collection=self.pre_delete_collection, metadata_indexing_include=[s for s in self.metadata_indexing_include if s] or None, metadata_indexing_exclude=[s for s in self.metadata_indexing_exclude if s] or None, - collection_indexing_policy=orjson.dumps(self.collection_indexing_policy) + collection_indexing_policy=orjson.loads(self.collection_indexing_policy.encode("utf-8")) if self.collection_indexing_policy else None, ) @@ -224,6 +232,7 @@ class AstraDBGraphVectorStoreComponent(LCVectorStoreComponent): msg = f"Error initializing AstraDBGraphVectorStore: {e}" raise ValueError(msg) from e + logger.debug(f"Vector Store initialized: {vector_store.astra_env.collection_name}") self._add_documents_to_vector_store(vector_store) return vector_store @@ -248,11 +257,19 @@ class AstraDBGraphVectorStoreComponent(LCVectorStoreComponent): logger.debug("No documents to add to the Vector Store.") def _map_search_type(self) -> str: - if self.search_type == "Similarity with score threshold": - return "similarity_score_threshold" - if self.search_type == "MMR (Max Marginal Relevance)": - return "mmr" - return "similarity" + match self.search_type: + case "Similarity": + return "similarity" + case "Similarity with score threshold": + return "similarity_score_threshold" + case "MMR (Max Marginal Relevance)": + return "mmr" + case "Graph Traversal": + return "traversal" + case "MMR (Max Marginal Relevance) Graph Traversal": + return "mmr_traversal" + case _: + return "similarity" def _build_search_args(self): args = { @@ -270,6 +287,7 @@ class AstraDBGraphVectorStoreComponent(LCVectorStoreComponent): if not vector_store: vector_store = self.build_vector_store() + logger.debug("Searching for documents in AstraDBGraphVectorStore.") logger.debug(f"Search input: {self.search_input}") logger.debug(f"Search type: {self.search_type}") logger.debug(f"Number of results: {self.number_of_results}") @@ -280,6 +298,14 @@ class AstraDBGraphVectorStoreComponent(LCVectorStoreComponent): search_args = self._build_search_args() docs = vector_store.search(query=self.search_input, search_type=search_type, **search_args) + + # Drop links from the metadata. At this point the links don't add any value for building the + # context and haven't been restored to json which causes the conversion to fail. + logger.debug("Removing links from metadata.") + for doc in docs: + if "links" in doc.metadata: + doc.metadata.pop("links") + except Exception as e: msg = f"Error performing search in AstraDBGraphVectorStore: {e}" raise ValueError(msg) from e @@ -287,7 +313,9 @@ class AstraDBGraphVectorStoreComponent(LCVectorStoreComponent): logger.debug(f"Retrieved documents: {len(docs)}") data = docs_to_data(docs) + logger.debug(f"Converted documents to data: {len(data)}") + self.status = data return data logger.debug("No search input provided. Skipping search.")