From 99121d95c178cb9e67651f653ccc64d4b98e1234 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 23 Jun 2023 14:32:15 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=80=20refactor(loading.py):=20use=20di?= =?UTF-8?q?ctionary=20to=20initialize=20vector=20stores=20The=20`instantia?= =?UTF-8?q?te=5Fvectorstore`=20function=20now=20uses=20a=20dictionary=20to?= =?UTF-8?q?=20initialize=20vector=20stores=20instead=20of=20a=20series=20o?= =?UTF-8?q?f=20if-else=20statements.=20This=20improves=20the=20readability?= =?UTF-8?q?=20and=20maintainability=20of=20the=20code.=20A=20new=20diction?= =?UTF-8?q?ary=20`vecstore=5Finitializer`=20is=20added=20to=20`vector=5Fst?= =?UTF-8?q?ore.py`=20to=20map=20the=20class=20names=20of=20vector=20stores?= =?UTF-8?q?=20to=20their=20respective=20initialization=20functions.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../langflow/interface/initialize/loading.py | 33 +++---------------- .../interface/initialize/vector_store.py | 18 ++++++++-- 2 files changed, 20 insertions(+), 31 deletions(-) diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 41e4e9488..2d547c6d8 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -6,15 +6,8 @@ from langchain.agents import agent as agent_module from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.agents.tools import BaseTool -from langflow.interface.initialize.vector_store import ( - initialize_chroma, - initialize_faiss, - initialize_mongodb, - initialize_pinecone, - initialize_qdrant, - initialize_supabase, - initialize_weaviate, -) +from langflow.interface.initialize.vector_store import vecstore_initializer + from pydantic import ValidationError from langflow.interface.custom_lists import CUSTOM_NODES @@ -151,29 +144,11 @@ def instantiate_embedding(class_object, params): def instantiate_vectorstore(class_object, params): search_kwargs = params.pop("search_kwargs", {}) - # could be documents or texts - if class_object.__name__ == "Pinecone": - vecstore = initialize_pinecone(class_object, params) - # Chroma requires all metadata values to not be None - elif class_object.__name__ == "Chroma": - vecstore = initialize_chroma(class_object, params) - - elif class_object.__name__ == "Qdrant": - vecstore = initialize_qdrant(class_object, params) - - elif class_object.__name__ == "Weaviate": - vecstore = initialize_weaviate(class_object, params) - elif class_object.__name__ == "FAISS": - vecstore = initialize_faiss(class_object, params) - elif class_object.__name__ == "SupabaseVectorStore": - vecstore = initialize_supabase(class_object, params) - elif class_object.__name__ == "MongoDBAtlasVectorSearch": - vecstore = initialize_mongodb(class_object, params) - + if initializer := vecstore_initializer.get(class_object.__name__): + vecstore = initializer(class_object, params) else: if "texts" in params: params["documents"] = params.pop("texts") - vecstore = class_object.from_documents(**params) # ! This might not work. Need to test diff --git a/src/backend/langflow/interface/initialize/vector_store.py b/src/backend/langflow/interface/initialize/vector_store.py index 2f7e9dfb1..d4bdb0155 100644 --- a/src/backend/langflow/interface/initialize/vector_store.py +++ b/src/backend/langflow/interface/initialize/vector_store.py @@ -1,5 +1,5 @@ import json -from typing import Type +from typing import Any, Callable, Dict, Type from langchain.vectorstores import ( Pinecone, Qdrant, @@ -9,6 +9,7 @@ from langchain.vectorstores import ( SupabaseVectorStore, MongoDBAtlasVectorSearch, ) + import os @@ -29,7 +30,9 @@ def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dic from pymongo import MongoClient import certifi - client = MongoClient(MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where()) + client: MongoClient = MongoClient( + MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where() + ) db_name = params.pop("db_name", None) collection_name = params.pop("collection_name", None) if not db_name or not collection_name: @@ -207,3 +210,14 @@ def initialize_qdrant(class_object: Type[Qdrant], params: dict): return class_object(client=client, **lc_params) return class_object.from_documents(**params) + + +vecstore_initializer: Dict[str, Callable[[Type[Any], dict], Any]] = { + "Pinecone": initialize_pinecone, + "Chroma": initialize_chroma, + "Qdrant": initialize_qdrant, + "Weaviate": initialize_weaviate, + "FAISS": initialize_faiss, + "SupabaseVectorStore": initialize_supabase, + "MongoDBAtlasVectorSearch": initialize_mongodb, +}