diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 41e4e9488..2d547c6d8 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -6,15 +6,8 @@ from langchain.agents import agent as agent_module from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.agents.tools import BaseTool -from langflow.interface.initialize.vector_store import ( - initialize_chroma, - initialize_faiss, - initialize_mongodb, - initialize_pinecone, - initialize_qdrant, - initialize_supabase, - initialize_weaviate, -) +from langflow.interface.initialize.vector_store import vecstore_initializer + from pydantic import ValidationError from langflow.interface.custom_lists import CUSTOM_NODES @@ -151,29 +144,11 @@ def instantiate_embedding(class_object, params): def instantiate_vectorstore(class_object, params): search_kwargs = params.pop("search_kwargs", {}) - # could be documents or texts - if class_object.__name__ == "Pinecone": - vecstore = initialize_pinecone(class_object, params) - # Chroma requires all metadata values to not be None - elif class_object.__name__ == "Chroma": - vecstore = initialize_chroma(class_object, params) - - elif class_object.__name__ == "Qdrant": - vecstore = initialize_qdrant(class_object, params) - - elif class_object.__name__ == "Weaviate": - vecstore = initialize_weaviate(class_object, params) - elif class_object.__name__ == "FAISS": - vecstore = initialize_faiss(class_object, params) - elif class_object.__name__ == "SupabaseVectorStore": - vecstore = initialize_supabase(class_object, params) - elif class_object.__name__ == "MongoDBAtlasVectorSearch": - vecstore = initialize_mongodb(class_object, params) - + if initializer := vecstore_initializer.get(class_object.__name__): + vecstore = initializer(class_object, params) else: if "texts" in params: params["documents"] = params.pop("texts") - vecstore = class_object.from_documents(**params) # ! This might not work. Need to test diff --git a/src/backend/langflow/interface/initialize/vector_store.py b/src/backend/langflow/interface/initialize/vector_store.py index 2f7e9dfb1..d4bdb0155 100644 --- a/src/backend/langflow/interface/initialize/vector_store.py +++ b/src/backend/langflow/interface/initialize/vector_store.py @@ -1,5 +1,5 @@ import json -from typing import Type +from typing import Any, Callable, Dict, Type from langchain.vectorstores import ( Pinecone, Qdrant, @@ -9,6 +9,7 @@ from langchain.vectorstores import ( SupabaseVectorStore, MongoDBAtlasVectorSearch, ) + import os @@ -29,7 +30,9 @@ def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dic from pymongo import MongoClient import certifi - client = MongoClient(MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where()) + client: MongoClient = MongoClient( + MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where() + ) db_name = params.pop("db_name", None) collection_name = params.pop("collection_name", None) if not db_name or not collection_name: @@ -207,3 +210,14 @@ def initialize_qdrant(class_object: Type[Qdrant], params: dict): return class_object(client=client, **lc_params) return class_object.from_documents(**params) + + +vecstore_initializer: Dict[str, Callable[[Type[Any], dict], Any]] = { + "Pinecone": initialize_pinecone, + "Chroma": initialize_chroma, + "Qdrant": initialize_qdrant, + "Weaviate": initialize_weaviate, + "FAISS": initialize_faiss, + "SupabaseVectorStore": initialize_supabase, + "MongoDBAtlasVectorSearch": initialize_mongodb, +}