fix: use init_subclass instead of metaclass to enforce decorator (#3942)

* Use init_subclass instead of metaclass to enforce decorator

* [autofix.ci] apply automated fixes

* fix imports

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jordan Frazier 2024-09-27 05:06:15 -07:00 committed by GitHub
commit 1bf6781dc4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 23 additions and 24 deletions

View file

@ -1,4 +1,4 @@
from abc import ABC, ABCMeta, abstractmethod
from abc import abstractmethod
from functools import wraps
from typing import cast
@ -15,6 +15,11 @@ from langflow.schema import Data
def check_cached_vector_store(f):
"""
Decorator to check for cached vector stores, and returns them if they exist.
Note: caching only occurs during the execution of a component - they do not persist
across separate invocations of the component. This method exists so that components with
multiple output methods share the same vector store during the same invocation of the
component.
"""
@wraps(f)
@ -30,30 +35,22 @@ def check_cached_vector_store(f):
return check_cached
class EnforceCacheDecoratorMeta(ABCMeta):
"""
Enforces that abstract methods marked with @check_cached_vector_store are implemented with the decorator.
"""
def __init__(cls, name, bases, dct):
for name, value in dct.items():
if hasattr(value, "__isabstractmethod__"):
cls._check_method_decorator(name, cls)
super().__init__(name, bases, dct)
@staticmethod
def _check_method_decorator(name, cls):
method = getattr(cls, name)
# Check if the method has been marked as decorated by `check_cached_vector_store`
if not getattr(method, "_is_cached_vector_store_checked", False):
raise TypeError(f"Concrete implementation of '{name}' must use '@check_cached_vector_store' decorator.")
class LCVectorStoreComponent(Component, ABC, metaclass=EnforceCacheDecoratorMeta):
class LCVectorStoreComponent(Component):
# Used to ensure a single vector store is built for each run of the flow
_cached_vector_store: VectorStore | None = None
def __init_subclass__(cls, **kwargs):
"""
Enforces the check cached decorator on all subclasses
"""
super().__init_subclass__(**kwargs)
if hasattr(cls, "build_vector_store"):
method = cls.build_vector_store
if not hasattr(method, "_is_cached_vector_store_checked"):
raise TypeError(
f"The method 'build_vector_store' in class {cls.__name__} must be decorated with @check_cached_vector_store"
)
trace_type = "retriever"
outputs = [
Output(

View file

@ -3,7 +3,7 @@ from typing import cast
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.field_typing import Retriever, VectorStore
from langflow.io import (
DropdownInput,
@ -80,5 +80,6 @@ class CohereRerankComponent(LCVectorStoreComponent):
self.status = data
return data
@check_cached_vector_store
def build_vector_store(self) -> VectorStore:
raise NotImplementedError("Cohere Rerank does not support vector stores.")

View file

@ -2,7 +2,7 @@ from typing import Any, cast
from langchain.retrievers import ContextualCompressionRetriever
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.field_typing import Retriever, VectorStore
from langflow.io import DropdownInput, HandleInput, MultilineInput, SecretStrInput, StrInput
from langflow.schema import Data
@ -77,5 +77,6 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
self.status = data
return data
@check_cached_vector_store
def build_vector_store(self) -> VectorStore:
raise NotImplementedError("NVIDIA Rerank does not support vector stores.")

Binary file not shown.