diff --git a/src/backend/base/langflow/components/vectorsearch/AstraDBSearch.py b/src/backend/base/langflow/components/vectorsearch/AstraDBSearch.py index 1337bffb9..f11c04748 100644 --- a/src/backend/base/langflow/components/vectorsearch/AstraDBSearch.py +++ b/src/backend/base/langflow/components/vectorsearch/AstraDBSearch.py @@ -92,6 +92,11 @@ class AstraDBSearchComponent(LCVectorStoreComponent): "info": "Optional dictionary defining the indexing policy for the collection.", "advanced": True, }, + "number_of_results": { + "display_name": "Number of Results", + "info": "Number of results to return.", + "advanced": True, + }, } def build( @@ -102,6 +107,7 @@ class AstraDBSearchComponent(LCVectorStoreComponent): token: str, api_endpoint: str, search_type: str = "Similarity", + number_of_results: int = 4, namespace: Optional[str] = None, metric: Optional[str] = None, batch_size: Optional[int] = None, @@ -131,4 +137,4 @@ class AstraDBSearchComponent(LCVectorStoreComponent): metadata_indexing_exclude=metadata_indexing_exclude, collection_indexing_policy=collection_indexing_policy, ) - return self.search_with_vector_store(input_value, search_type, vector_store) + return self.search_with_vector_store(input_value, search_type, vector_store, k=number_of_results) diff --git a/src/backend/base/langflow/components/vectorstores/base/model.py b/src/backend/base/langflow/components/vectorstores/base/model.py index eef99ef21..668c5eff2 100644 --- a/src/backend/base/langflow/components/vectorstores/base/model.py +++ b/src/backend/base/langflow/components/vectorstores/base/model.py @@ -19,6 +19,8 @@ class LCVectorStoreComponent(CustomComponent): input_value: Text, search_type: str, vector_store: Union[VectorStore, BaseRetriever], + k=10, + **kwargs, ) -> List[Record]: """ Search for records in the vector store based on the input value and search type. @@ -37,7 +39,7 @@ class LCVectorStoreComponent(CustomComponent): docs: List[Document] = [] if input_value and isinstance(input_value, str) and hasattr(vector_store, "search"): - docs = vector_store.search(query=input_value, search_type=search_type.lower()) + docs = vector_store.search(query=input_value, search_type=search_type.lower(), k=k, **kwargs) else: raise ValueError("Invalid inputs provided.") records = docs_to_records(docs)