🚀 feat(langflow): add support for MongoDB Atlas Vector Search in vectorstores

 feat(langflow): add support for search_kwargs field in VectorStoreFrontendNode
The changes add support for MongoDB Atlas Vector Search in the vectorstores. The `MongoDBAtlasVectorSearch` class is now imported and initialized in `vector_store.py`. The `initialize_mongodb` function is added to initialize the MongoDB Atlas Vector Search class. The `VectorStoreFrontendNode` class is updated to add the `mongodb_atlas_cluster_uri`, `collection_name`, and `db_name` fields. The `search_kwargs` field is also added to the `VectorStoreFrontendNode` class to allow users to pass additional search parameters to the vector store.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-23 12:38:02 -03:00
commit 4cc2fae52b
3 changed files with 100 additions and 0 deletions

View file

@ -133,6 +133,7 @@ vectorstores:
- FAISS
- Pinecone
- SupabaseVectorStore
- MongoDBAtlasVectorSearch
wrappers:
- RequestsWrapper
# - ChatPromptTemplate

View file

@ -7,7 +7,9 @@ from langchain.vectorstores import (
FAISS,
Weaviate,
SupabaseVectorStore,
MongoDBAtlasVectorSearch,
)
import os
def docs_in_params(params: dict) -> bool:
@ -18,6 +20,38 @@ def docs_in_params(params: dict) -> bool:
)
def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dict):
"""Initialize mongodb and return the class object"""
MONGODB_ATLAS_CLUSTER_URI = params.get("mongodb_atlas_cluster_uri")
if not MONGODB_ATLAS_CLUSTER_URI:
raise ValueError("Mongodb atlas cluster uri must be provided in the params")
from pymongo import MongoClient
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
db_name = "lanchain_db"
collection_name = "langchain_col"
collection = client[db_name][collection_name]
index_name = "langchain_demo"
if not docs_in_params(params):
# __init__ requires collection, embedding and index_name
init_args = {
"collection": collection,
"index_name": index_name,
"embedding": params.get("embedding"),
}
return class_object(**init_args)
if "texts" in params:
params["documents"] = params.pop("texts")
params["collection"] = collection
params["index_name"] = index_name
return class_object.from_documents(**params)
def initialize_supabase(class_object: Type[SupabaseVectorStore], params: dict):
"""Initialize supabase and return the class object"""
from supabase.client import Client, create_client
@ -89,6 +123,12 @@ def initialize_pinecone(class_object: Type[Pinecone], params: dict):
pinecone_api_key = params.get("pinecone_api_key")
pinecone_env = params.get("pinecone_env")
if pinecone_api_key is None or pinecone_env is None:
if os.getenv("PINECONE_API_KEY") is not None:
pinecone_api_key = os.getenv("PINECONE_API_KEY")
if os.getenv("PINECONE_ENV") is not None:
pinecone_env = os.getenv("PINECONE_ENV")
if pinecone_api_key is None or pinecone_env is None:
raise ValueError(
"Pinecone API key and environment must be provided in the params"

View file

@ -7,6 +7,18 @@ from langflow.template.frontend_node.base import FrontendNode
class VectorStoreFrontendNode(FrontendNode):
def add_extra_fields(self) -> None:
extra_fields: List[TemplateField] = []
# Add search_kwargs field
extra_field = TemplateField(
name="search_kwargs",
field_type="code",
required=False,
placeholder="",
show=True,
advanced=True,
multiline=False,
value="{}",
)
extra_fields.append(extra_field)
if self.template.type_name == "Weaviate":
extra_field = TemplateField(
name="weaviate_url",
@ -134,6 +146,45 @@ class VectorStoreFrontendNode(FrontendNode):
)
extra_fields.extend((extra_field, extra_field2, extra_field3, extra_field4))
elif self.template.type_name == "MongoDBAtlasVectorSearch":
# add "mongodb_atlas_cluster_uri",
# "collection_name",
# "db_name",
extra_field = TemplateField(
name="mongodb_atlas_cluster_uri",
field_type="str",
required=False,
placeholder="",
show=True,
advanced=True,
multiline=False,
display_name="MongoDB Atlas Cluster URI",
value="",
)
extra_field2 = TemplateField(
name="collection_name",
field_type="str",
required=False,
placeholder="",
show=True,
advanced=True,
multiline=False,
display_name="Collection Name",
value="",
)
extra_field3 = TemplateField(
name="db_name",
field_type="str",
required=False,
placeholder="",
show=True,
advanced=True,
multiline=False,
display_name="Database Name",
value="",
)
extra_fields.extend((extra_field, extra_field2, extra_field3))
if extra_fields:
for field in extra_fields:
self.template.add_field(field)
@ -160,6 +211,9 @@ class VectorStoreFrontendNode(FrontendNode):
"query_name",
"supabase_url",
"supabase_service_key",
"mongodb_atlas_cluster_uri",
"collection_name",
"db_name",
]
advanced_fields = [
"n_dim",
@ -179,10 +233,15 @@ class VectorStoreFrontendNode(FrontendNode):
"pinecone_api_key",
"pinecone_env",
"client_kwargs",
"search_kwargs",
]
# Check and set field attributes
if field.name == "texts":
# if field.name is "texts" it has to be replaced
# when instantiating the vectorstores
field.name = "documents"
field.field_type = "TextSplitter"
field.display_name = "Documents"
field.required = False