feat(vectorstores.py): add support for index_name parameter in MongoDBAtlasVectorSearch template

The hardcoded values for db_name, collection_name, and index_name have been removed from the initialize_mongodb function and are now required parameters. This makes the function more flexible and allows it to be used with different databases and collections. The support for the index_name parameter has been added to the MongoDBAtlasVectorSearch template in vectorstores.py, which allows the user to specify the name of the index to be used in the search.
🐛 fix(vector_store.py): remove hardcoded values for db_name, collection_name, and index_name and make them required parameters
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-23 13:19:30 -03:00
commit ea0231025f
2 changed files with 26 additions and 9 deletions

View file

@ -23,16 +23,23 @@ def docs_in_params(params: dict) -> bool:
def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dict):
"""Initialize mongodb and return the class object"""
MONGODB_ATLAS_CLUSTER_URI = params.get("mongodb_atlas_cluster_uri")
MONGODB_ATLAS_CLUSTER_URI = params.pop("mongodb_atlas_cluster_uri")
if not MONGODB_ATLAS_CLUSTER_URI:
raise ValueError("Mongodb atlas cluster uri must be provided in the params")
from pymongo import MongoClient
import certifi
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where())
db_name = params.pop("db_name", None)
collection_name = params.pop("collection_name", None)
if not db_name or not collection_name:
raise ValueError("db_name and collection_name must be provided in the params")
index_name = params.pop("index_name", None)
if not index_name:
raise ValueError("index_name must be provided in the params")
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
db_name = "lanchain_db"
collection_name = "langchain_col"
collection = client[db_name][collection_name]
index_name = "langchain_demo"
if not docs_in_params(params):
# __init__ requires collection, embedding and index_name
init_args = {

View file

@ -147,9 +147,8 @@ class VectorStoreFrontendNode(FrontendNode):
extra_fields.extend((extra_field, extra_field2, extra_field3, extra_field4))
elif self.template.type_name == "MongoDBAtlasVectorSearch":
# add "mongodb_atlas_cluster_uri",
# "collection_name",
# "db_name",
self.display_name = "MongoDB Atlas"
extra_field = TemplateField(
name="mongodb_atlas_cluster_uri",
field_type="str",
@ -183,7 +182,18 @@ class VectorStoreFrontendNode(FrontendNode):
display_name="Database Name",
value="",
)
extra_fields.extend((extra_field, extra_field2, extra_field3))
extra_field4 = TemplateField(
name="index_name",
field_type="str",
required=False,
placeholder="",
show=True,
advanced=True,
multiline=False,
display_name="Index Name",
value="",
)
extra_fields.extend((extra_field, extra_field2, extra_field3, extra_field4))
if extra_fields:
for field in extra_fields: