Add number_of_results parameter to AstraDBSearchComponent

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-04-02 01:41:16 -03:00
commit 58851c09b9
2 changed files with 10 additions and 2 deletions

View file

@ -92,6 +92,11 @@ class AstraDBSearchComponent(LCVectorStoreComponent):
"info": "Optional dictionary defining the indexing policy for the collection.",
"advanced": True,
},
"number_of_results": {
"display_name": "Number of Results",
"info": "Number of results to return.",
"advanced": True,
},
}
def build(
@ -102,6 +107,7 @@ class AstraDBSearchComponent(LCVectorStoreComponent):
token: str,
api_endpoint: str,
search_type: str = "Similarity",
number_of_results: int = 4,
namespace: Optional[str] = None,
metric: Optional[str] = None,
batch_size: Optional[int] = None,
@ -131,4 +137,4 @@ class AstraDBSearchComponent(LCVectorStoreComponent):
metadata_indexing_exclude=metadata_indexing_exclude,
collection_indexing_policy=collection_indexing_policy,
)
return self.search_with_vector_store(input_value, search_type, vector_store)
return self.search_with_vector_store(input_value, search_type, vector_store, k=number_of_results)

View file

@ -19,6 +19,8 @@ class LCVectorStoreComponent(CustomComponent):
input_value: Text,
search_type: str,
vector_store: Union[VectorStore, BaseRetriever],
k=10,
**kwargs,
) -> List[Record]:
"""
Search for records in the vector store based on the input value and search type.
@ -37,7 +39,7 @@ class LCVectorStoreComponent(CustomComponent):
docs: List[Document] = []
if input_value and isinstance(input_value, str) and hasattr(vector_store, "search"):
docs = vector_store.search(query=input_value, search_type=search_type.lower())
docs = vector_store.search(query=input_value, search_type=search_type.lower(), k=k, **kwargs)
else:
raise ValueError("Invalid inputs provided.")
records = docs_to_records(docs)