fix: moves caching of vector store to LCModelComponent level (#3435)

* refactor LCModelComponent to use a cached vector store to prevent multiple embeddings
This commit is contained in:
Jordan Frazier 2024-08-21 14:38:06 -07:00 committed by GitHub
commit 96ca71dab8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 163 additions and 91 deletions

View file

@ -1,3 +1,5 @@
from abc import ABC, ABCMeta, abstractmethod
from functools import wraps
from typing import List, cast
from langchain_core.documents import Document
@ -10,7 +12,48 @@ from langflow.io import Output
from langflow.schema import Data
class LCVectorStoreComponent(Component):
def check_cached_vector_store(f):
"""
Decorator to check for cached vector stores, and returns them if they exist.
"""
@wraps(f)
def check_cached(self, *args, **kwargs):
if self._cached_vector_store is not None:
return self._cached_vector_store
result = f(self, *args, **kwargs)
self._cached_vector_store = result
return result
check_cached._is_cached_vector_store_checked = True
return check_cached
class EnforceCacheDecoratorMeta(ABCMeta):
"""
Enforces that abstract methods marked with @check_cached_vector_store are implemented with the decorator.
"""
def __init__(cls, name, bases, dct):
for name, value in dct.items():
if hasattr(value, "__isabstractmethod__"):
cls._check_method_decorator(name, cls)
super().__init__(name, bases, dct)
@staticmethod
def _check_method_decorator(name, cls):
method = getattr(cls, name)
# Check if the method has been marked as decorated by `check_cached_vector_store`
if not getattr(method, "_is_cached_vector_store_checked", False):
raise TypeError(f"Concrete implementation of '{name}' must use '@check_cached_vector_store' decorator.")
class LCVectorStoreComponent(Component, ABC, metaclass=EnforceCacheDecoratorMeta):
# Used to ensure a single vector store is built for each run of the flow
_cached_vector_store: VectorStore | None = None
trace_type = "retriever"
outputs = [
Output(
@ -32,7 +75,11 @@ class LCVectorStoreComponent(Component):
def _validate_outputs(self):
# At least these three outputs must be defined
required_output_methods = ["build_base_retriever", "search_documents", "build_vector_store"]
required_output_methods = [
"build_base_retriever",
"search_documents",
"build_vector_store",
]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:
if method_name not in output_names:
@ -75,17 +122,16 @@ class LCVectorStoreComponent(Component):
def cast_vector_store(self) -> VectorStore:
return cast(VectorStore, self.build_vector_store())
def build_vector_store(self) -> VectorStore:
"""
Builds the Vector Store object.c
"""
raise NotImplementedError("build_vector_store method must be implemented.")
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
"""
Builds the BaseRetriever object.
"""
vector_store = self.build_vector_store()
if self._cached_vector_store is not None:
vector_store = self._cached_vector_store
else:
vector_store = self.build_vector_store()
self._cached_vector_store = vector_store
if hasattr(vector_store, "as_retriever"):
retriever = vector_store.as_retriever(**self.get_retriever_kwargs())
if self.status is None:
@ -103,7 +149,11 @@ class LCVectorStoreComponent(Component):
self.status = ""
return []
vector_store = self.build_vector_store()
if self._cached_vector_store is not None:
vector_store = self._cached_vector_store
else:
vector_store = self.build_vector_store()
self._cached_vector_store = vector_store
logger.debug(f"Search input: {search_query}")
logger.debug(f"Search type: {self.search_type}")
@ -120,3 +170,11 @@ class LCVectorStoreComponent(Component):
Get the retriever kwargs. Implementations can override this method to provide custom retriever kwargs.
"""
return {}
@abstractmethod
@check_cached_vector_store
def build_vector_store(self) -> VectorStore:
"""
Builds the Vector Store object.
"""
raise NotImplementedError("build_vector_store method must be implemented.")

View file

@ -4,9 +4,17 @@ from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.field_typing import Retriever
from langflow.io import DropdownInput, HandleInput, IntInput, MessageTextInput, MultilineInput, SecretStrInput
from langflow.field_typing import Retriever, VectorStore
from langflow.io import (
DropdownInput,
HandleInput,
IntInput,
MessageTextInput,
MultilineInput,
SecretStrInput,
)
from langflow.schema import Data
from langflow.template.field.base import Output
class CohereRerankComponent(LCVectorStoreComponent):
@ -33,13 +41,34 @@ class CohereRerankComponent(LCVectorStoreComponent):
),
SecretStrInput(name="api_key", display_name="API Key"),
IntInput(name="top_n", display_name="Top N", value=3),
MessageTextInput(name="user_agent", display_name="User Agent", value="langflow", advanced=True),
MessageTextInput(
name="user_agent",
display_name="User Agent",
value="langflow",
advanced=True,
),
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]),
]
outputs = [
Output(
display_name="Retriever",
name="base_retriever",
method="build_base_retriever",
),
Output(
display_name="Search Results",
name="search_results",
method="search_documents",
),
]
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
cohere_reranker = CohereRerank(
cohere_api_key=self.api_key, model=self.model, top_n=self.top_n, user_agent=self.user_agent
cohere_api_key=self.api_key,
model=self.model,
top_n=self.top_n,
user_agent=self.user_agent,
)
retriever = ContextualCompressionRetriever(base_compressor=cohere_reranker, base_retriever=self.retriever)
return cast(Retriever, retriever)
@ -50,3 +79,6 @@ class CohereRerankComponent(LCVectorStoreComponent):
data = self.to_data(documents)
self.status = data
return data
def build_vector_store(self) -> VectorStore:
raise NotImplementedError("Cohere Rerank does not support vector stores.")

View file

@ -3,10 +3,11 @@ from typing import Any, List, cast
from langchain.retrievers import ContextualCompressionRetriever
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.field_typing import Retriever
from langflow.field_typing import Retriever, VectorStore
from langflow.io import DropdownInput, HandleInput, MultilineInput, SecretStrInput, StrInput
from langflow.schema import Data
from langflow.schema.dotdict import dotdict
from langflow.template.field.base import Output
class NvidiaRerankComponent(LCVectorStoreComponent):
@ -33,6 +34,19 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]),
]
outputs = [
Output(
display_name="Retriever",
name="base_retriever",
method="build_base_retriever",
),
Output(
display_name="Search Results",
name="search_results",
method="search_documents",
),
]
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "base_url" and field_value:
try:
@ -62,3 +76,6 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
data = self.to_data(documents)
self.status = data
return data
def build_vector_store(self) -> VectorStore:
raise NotImplementedError("NVIDIA Rerank does not support vector stores.")

View file

@ -1,7 +1,6 @@
from langchain_core.vectorstores import VectorStore
from loguru import logger
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers import docs_to_data
from langflow.inputs import DictInput, FloatInput
from langflow.io import (
@ -24,8 +23,6 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
name = "AstraDB"
icon: str = "AstraDB"
_cached_vectorstore: VectorStore | None = None
inputs = [
StrInput(
name="collection_name",
@ -162,11 +159,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
),
]
def _build_vector_store(self):
# cache the vector store to avoid re-initializing and ingest data again
if self._cached_vectorstore:
return self._cached_vectorstore
@check_cached_vector_store
def build_vector_store(self):
try:
from langchain_astradb import AstraDBVectorStore
from langchain_astradb.utils.astradb import SetupMode
@ -229,9 +223,6 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
raise ValueError(f"Error initializing AstraDBVectorStore: {str(e)}") from e
self._add_documents_to_vector_store(vector_store)
self._cached_vectorstore = vector_store
return vector_store
def _add_documents_to_vector_store(self, vector_store):
@ -272,7 +263,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
return args
def search_documents(self) -> list[Data]:
vector_store = self._build_vector_store()
vector_store = self.build_vector_store()
logger.debug(f"Search input: {self.search_input}")
logger.debug(f"Search type: {self.search_type}")
@ -303,7 +294,3 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
"search_type": self._map_search_type(),
"search_kwargs": search_args,
}
def build_vector_store(self):
vector_store = self._build_vector_store()
return vector_store

View file

@ -3,7 +3,7 @@ from typing import List
from langchain_community.vectorstores import Cassandra
from loguru import logger
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.inputs import BoolInput, DictInput, FloatInput
from langflow.io import (
@ -25,8 +25,6 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
name = "Cassandra"
icon = "Cassandra"
_cached_vectorstore: Cassandra | None = None
inputs = [
MessageTextInput(
name="database_ref",
@ -134,12 +132,8 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
),
]
@check_cached_vector_store
def build_vector_store(self) -> Cassandra:
return self._build_cassandra()
def _build_cassandra(self) -> Cassandra:
if self._cached_vectorstore:
return self._cached_vectorstore
try:
import cassio
from langchain_community.utilities.cassandra import SetupMode
@ -215,7 +209,6 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
body_index_options=body_index_options,
setup_mode=setup_mode,
)
self._cached_vectorstore = table
return table
def _map_search_type(self):
@ -227,7 +220,7 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
return "similarity"
def search_documents(self) -> List[Data]:
vector_store = self._build_cassandra()
vector_store = self.build_vector_store()
logger.debug(f"Search input: {self.search_query}")
logger.debug(f"Search type: {self.search_type}")

View file

@ -5,7 +5,7 @@ from chromadb.config import Settings
from langchain_chroma.vectorstores import Chroma
from loguru import logger
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.base.vectorstores.utils import chroma_collection_to_data
from langflow.io import BoolInput, DataInput, DropdownInput, HandleInput, IntInput, StrInput, MultilineInput
from langflow.schema import Data
@ -98,6 +98,7 @@ class ChromaVectorStoreComponent(LCVectorStoreComponent):
),
]
@check_cached_vector_store
def build_vector_store(self) -> Chroma:
"""
Builds the Chroma object.

View file

@ -3,7 +3,7 @@ from typing import List
from langchain_community.vectorstores import CouchbaseVectorStore
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import HandleInput, IntInput, StrInput, SecretStrInput, DataInput, MultilineInput
from langflow.schema import Data
@ -42,10 +42,8 @@ class CouchbaseVectorStoreComponent(LCVectorStoreComponent):
),
]
@check_cached_vector_store
def build_vector_store(self) -> CouchbaseVectorStore:
return self._build_couchbase()
def _build_couchbase(self) -> CouchbaseVectorStore:
try:
from couchbase.auth import PasswordAuthenticator # type: ignore
from couchbase.cluster import Cluster # type: ignore
@ -95,7 +93,7 @@ class CouchbaseVectorStoreComponent(LCVectorStoreComponent):
return couchbase_vs
def search_documents(self) -> List[Data]:
vector_store = self._build_couchbase()
vector_store = self.build_vector_store()
if self.search_query and isinstance(self.search_query, str) and self.search_query.strip():
docs = vector_store.similarity_search(

View file

@ -3,7 +3,7 @@ from typing import List
from langchain_community.vectorstores import FAISS
from loguru import logger
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import BoolInput, DataInput, HandleInput, IntInput, MultilineInput, StrInput
from langflow.schema import Data
@ -57,6 +57,7 @@ class FaissVectorStoreComponent(LCVectorStoreComponent):
),
]
@check_cached_vector_store
def build_vector_store(self) -> FAISS:
"""
Builds the FAISS object.

View file

@ -1,6 +1,6 @@
from typing import List
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import (
DataInput,
@ -71,6 +71,7 @@ class MilvusVectorStoreComponent(LCVectorStoreComponent):
),
]
@check_cached_vector_store
def build_vector_store(self):
try:
from langchain_milvus.vectorstores import Milvus as LangchainMilvus

View file

@ -2,7 +2,7 @@ from typing import List
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import HandleInput, IntInput, StrInput, SecretStrInput, DataInput, MultilineInput
from langflow.schema import Data
@ -36,10 +36,8 @@ class MongoVectorStoreComponent(LCVectorStoreComponent):
),
]
@check_cached_vector_store
def build_vector_store(self) -> MongoDBAtlasVectorSearch:
return self._build_mongodb_atlas()
def _build_mongodb_atlas(self) -> MongoDBAtlasVectorSearch:
try:
from pymongo import MongoClient
except ImportError:
@ -80,7 +78,7 @@ class MongoVectorStoreComponent(LCVectorStoreComponent):
def search_documents(self) -> List[Data]:
from bson import ObjectId
vector_store = self._build_mongodb_atlas()
vector_store = self.build_vector_store()
if self.search_query and isinstance(self.search_query, str):
docs = vector_store.similarity_search(

View file

@ -2,7 +2,7 @@ from typing import List
from langchain_pinecone import Pinecone
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import (
DropdownInput,
@ -22,7 +22,6 @@ class PineconeVectorStoreComponent(LCVectorStoreComponent):
documentation = "https://python.langchain.com/v0.2/docs/integrations/vectorstores/pinecone/"
name = "Pinecone"
icon = "Pinecone"
pinecone_instance = None
inputs = [
StrInput(name="index_name", display_name="Index Name", required=True),
@ -58,12 +57,8 @@ class PineconeVectorStoreComponent(LCVectorStoreComponent):
),
]
@check_cached_vector_store
def build_vector_store(self) -> Pinecone:
return self._build_pinecone()
def _build_pinecone(self) -> Pinecone:
if self.pinecone_instance is not None:
return self.pinecone_instance
from langchain_pinecone._utilities import DistanceStrategy
from langchain_pinecone.vectorstores import Pinecone
@ -88,11 +83,10 @@ class PineconeVectorStoreComponent(LCVectorStoreComponent):
if documents:
pinecone.add_documents(documents)
self.pinecone_instance = pinecone
return pinecone
def search_documents(self) -> List[Data]:
vector_store = self._build_pinecone()
vector_store = self.build_vector_store()
if self.search_query and isinstance(self.search_query, str) and self.search_query.strip():
docs = vector_store.similarity_search(

View file

@ -1,7 +1,7 @@
from typing import List
from langchain_community.vectorstores import Qdrant
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import (
DropdownInput,
@ -57,10 +57,8 @@ class QdrantVectorStoreComponent(LCVectorStoreComponent):
),
]
@check_cached_vector_store
def build_vector_store(self) -> Qdrant:
return self._build_qdrant()
def _build_qdrant(self) -> Qdrant:
qdrant_kwargs = {
"collection_name": self.collection_name,
"content_payload_key": self.content_payload_key,
@ -101,7 +99,7 @@ class QdrantVectorStoreComponent(LCVectorStoreComponent):
return qdrant
def search_documents(self) -> List[Data]:
vector_store = self._build_qdrant()
vector_store = self.build_vector_store()
if self.search_query and isinstance(self.search_query, str) and self.search_query.strip():
docs = vector_store.similarity_search(

View file

@ -2,7 +2,7 @@ from typing import List
from langchain_community.vectorstores.redis import Redis
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import HandleInput, IntInput, StrInput, SecretStrInput, DataInput, MultilineInput
from langflow.schema import Data
@ -46,6 +46,7 @@ class RedisVectorStoreComponent(LCVectorStoreComponent):
HandleInput(name="embedding", display_name="Embedding", input_types=["Embeddings"]),
]
@check_cached_vector_store
def build_vector_store(self) -> Redis:
documents = []

View file

@ -3,7 +3,7 @@ from typing import List
from langchain_community.vectorstores import SupabaseVectorStore
from supabase.client import Client, create_client
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import HandleInput, IntInput, StrInput, SecretStrInput, DataInput, MultilineInput
from langflow.schema import Data
@ -37,10 +37,8 @@ class SupabaseVectorStoreComponent(LCVectorStoreComponent):
),
]
@check_cached_vector_store
def build_vector_store(self) -> SupabaseVectorStore:
return self._build_supabase()
def _build_supabase(self) -> SupabaseVectorStore:
supabase: Client = create_client(self.supabase_url, supabase_key=self.supabase_service_key)
documents = []
@ -69,7 +67,7 @@ class SupabaseVectorStoreComponent(LCVectorStoreComponent):
return supabase_vs
def search_documents(self) -> List[Data]:
vector_store = self._build_supabase()
vector_store = self.build_vector_store()
if self.search_query and isinstance(self.search_query, str) and self.search_query.strip():
docs = vector_store.similarity_search(

View file

@ -2,7 +2,7 @@ from typing import List
from langchain_community.vectorstores import UpstashVectorStore
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import (
HandleInput,
@ -73,10 +73,8 @@ class UpstashVectorStoreComponent(LCVectorStoreComponent):
),
]
@check_cached_vector_store
def build_vector_store(self) -> UpstashVectorStore:
return self._build_upstash()
def _build_upstash(self) -> UpstashVectorStore:
use_upstash_embedding = self.embedding is None
documents = []
@ -117,7 +115,7 @@ class UpstashVectorStoreComponent(LCVectorStoreComponent):
return upstash_vs
def search_documents(self) -> List[Data]:
vector_store = self._build_upstash()
vector_store = self.build_vector_store()
if self.search_query and isinstance(self.search_query, str) and self.search_query.strip():
docs = vector_store.similarity_search(

View file

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, List
from langchain_community.vectorstores import Vectara
from loguru import logger
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import HandleInput, IntInput, MessageTextInput, SecretStrInput, StrInput
from langflow.schema import Data
@ -51,6 +51,7 @@ class VectaraVectorStoreComponent(LCVectorStoreComponent):
),
]
@check_cached_vector_store
def build_vector_store(self) -> "Vectara":
"""
Builds the Vectara object.

View file

@ -3,7 +3,7 @@ from typing import List
import weaviate # type: ignore
from langchain_community.vectorstores import Weaviate
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import BoolInput, HandleInput, IntInput, StrInput, SecretStrInput, DataInput, MultilineInput
from langflow.schema import Data
@ -38,10 +38,8 @@ class WeaviateVectorStoreComponent(LCVectorStoreComponent):
BoolInput(name="search_by_text", display_name="Search By Text", advanced=True),
]
@check_cached_vector_store
def build_vector_store(self) -> Weaviate:
return self._build_weaviate()
def _build_weaviate(self) -> Weaviate:
if self.api_key:
auth_config = weaviate.AuthApiKey(api_key=self.api_key)
client = weaviate.Client(url=self.url, auth_client_secret=auth_config)
@ -73,7 +71,7 @@ class WeaviateVectorStoreComponent(LCVectorStoreComponent):
)
def search_documents(self) -> List[Data]:
vector_store = self._build_weaviate()
vector_store = self.build_vector_store()
if self.search_query and isinstance(self.search_query, str) and self.search_query.strip():
docs = vector_store.similarity_search(

View file

@ -2,7 +2,7 @@ from typing import List
from langchain_community.vectorstores import PGVector
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import HandleInput, IntInput, StrInput, SecretStrInput, DataInput, MultilineInput
from langflow.schema import Data
@ -36,10 +36,8 @@ class PGVectorStoreComponent(LCVectorStoreComponent):
HandleInput(name="embedding", display_name="Embedding", input_types=["Embeddings"]),
]
@check_cached_vector_store
def build_vector_store(self) -> PGVector:
return self._build_pgvector()
def _build_pgvector(self) -> PGVector:
documents = []
for _input in self.ingest_data or []:
if isinstance(_input, Data):
@ -66,7 +64,7 @@ class PGVectorStoreComponent(LCVectorStoreComponent):
return pgvector
def search_documents(self) -> List[Data]:
vector_store = self._build_pgvector()
vector_store = self.build_vector_store()
if self.search_query and isinstance(self.search_query, str) and self.search_query.strip():
docs = vector_store.similarity_search(