diff --git a/src/backend/base/langflow/base/vectorstores/model.py b/src/backend/base/langflow/base/vectorstores/model.py index 2cc93cc60..e05be744e 100644 --- a/src/backend/base/langflow/base/vectorstores/model.py +++ b/src/backend/base/langflow/base/vectorstores/model.py @@ -73,7 +73,7 @@ class LCVectorStoreComponent(Component): """ raise NotImplementedError("build_vector_store method must be implemented.") - def build_base_retriever(self) -> Retriever: + def build_base_retriever(self) -> Retriever: # type: ignore[type-var] """ Builds the BaseRetriever object. """ diff --git a/src/backend/base/langflow/components/retrievers/AmazonKendra.py b/src/backend/base/langflow/components/retrievers/AmazonKendra.py index ff830f2ed..90c70a7bc 100644 --- a/src/backend/base/langflow/components/retrievers/AmazonKendra.py +++ b/src/backend/base/langflow/components/retrievers/AmazonKendra.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, cast from langchain_community.retrievers import AmazonKendraRetriever @@ -36,7 +36,7 @@ class AmazonKendraRetrieverComponent(CustomComponent): credentials_profile_name: Optional[str] = None, attribute_filter: Optional[dict] = None, user_context: Optional[dict] = None, - ) -> Retriever: + ) -> Retriever: # type: ignore[type-var] try: output = AmazonKendraRetriever( index_id=index_id, @@ -48,4 +48,4 @@ class AmazonKendraRetrieverComponent(CustomComponent): ) # type: ignore except Exception as e: raise ValueError("Could not connect to AmazonKendra API.") from e - return output + return cast(Retriever, output) diff --git a/src/backend/base/langflow/components/retrievers/CohereRerank.py b/src/backend/base/langflow/components/retrievers/CohereRerank.py index dd560d588..939ee3989 100644 --- a/src/backend/base/langflow/components/retrievers/CohereRerank.py +++ b/src/backend/base/langflow/components/retrievers/CohereRerank.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, cast from langchain.retrievers import ContextualCompressionRetriever from langchain_cohere import CohereRerank @@ -36,15 +36,15 @@ class CohereRerankComponent(LCVectorStoreComponent): HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]), ] - def build_base_retriever(self) -> Retriever: + def build_base_retriever(self) -> Retriever: # type: ignore[type-var] cohere_reranker = CohereRerank( cohere_api_key=self.api_key, model=self.model, top_n=self.top_n, user_agent=self.user_agent ) retriever = ContextualCompressionRetriever(base_compressor=cohere_reranker, base_retriever=self.retriever) - return retriever + return cast(Retriever, retriever) - async def search_documents(self) -> List[Data]: - retriever: ContextualCompressionRetriever = self.build_base_retriever() + async def search_documents(self) -> List[Data]: # type: ignore + retriever = self.build_base_retriever() documents = await retriever.ainvoke(self.search_query) data = self.to_data(documents) self.status = data diff --git a/src/backend/base/langflow/components/retrievers/MetalRetriever.py b/src/backend/base/langflow/components/retrievers/MetalRetriever.py index 104adcbde..f3af0ebb0 100644 --- a/src/backend/base/langflow/components/retrievers/MetalRetriever.py +++ b/src/backend/base/langflow/components/retrievers/MetalRetriever.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, cast from langchain_community.retrievers import MetalRetriever from metal_sdk.metal import Metal # type: ignore @@ -20,9 +20,9 @@ class MetalRetrieverComponent(CustomComponent): "code": {"show": False}, } - def build(self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None) -> Retriever: + def build(self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None) -> Retriever: # type: ignore[type-var] try: metal = Metal(api_key=api_key, client_id=client_id, index_id=index_id) except Exception as e: raise ValueError("Could not connect to Metal API.") from e - return MetalRetriever(client=metal, params=params or {}) + return cast(Retriever, MetalRetriever(client=metal, params=params or {})) diff --git a/src/backend/base/langflow/components/retrievers/VectaraSelfQueryRetriver.py b/src/backend/base/langflow/components/retrievers/VectaraSelfQueryRetriver.py index 42ffd92d6..a328b2b20 100644 --- a/src/backend/base/langflow/components/retrievers/VectaraSelfQueryRetriver.py +++ b/src/backend/base/langflow/components/retrievers/VectaraSelfQueryRetriver.py @@ -1,13 +1,13 @@ import json -from typing import List +from typing import List, cast from langchain.chains.query_constructor.base import AttributeInfo from langchain.retrievers.self_query.base import SelfQueryRetriever from langchain_core.vectorstores import VectorStore from langflow.custom import CustomComponent -from langflow.field_typing.constants import LanguageModel from langflow.field_typing import Retriever +from langflow.field_typing.constants import LanguageModel class VectaraSelfQueryRetriverComponent(CustomComponent): @@ -40,7 +40,7 @@ class VectaraSelfQueryRetriverComponent(CustomComponent): document_content_description: str, llm: LanguageModel, metadata_field_info: List[str], - ) -> Retriever: + ) -> Retriever: # type: ignore metadata_field_obj = [] for meta in metadata_field_info: @@ -54,6 +54,9 @@ class VectaraSelfQueryRetriverComponent(CustomComponent): ) metadata_field_obj.append(attribute_info) - return SelfQueryRetriever.from_llm( - llm, vectorstore, document_content_description, metadata_field_obj, verbose=True + return cast( + Retriever, + SelfQueryRetriever.from_llm( + llm, vectorstore, document_content_description, metadata_field_obj, verbose=True + ), ) diff --git a/src/backend/base/langflow/field_typing/constants.py b/src/backend/base/langflow/field_typing/constants.py index 89e3785cd..0156849cc 100644 --- a/src/backend/base/langflow/field_typing/constants.py +++ b/src/backend/base/langflow/field_typing/constants.py @@ -13,15 +13,16 @@ from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate, PromptTemplate from langchain_core.retrievers import BaseRetriever from langchain_core.tools import Tool -from langchain_core.vectorstores import VectorStore +from langchain_core.vectorstores import VectorStore, VectorStoreRetriever from langchain_text_splitters import TextSplitter NestedDict: TypeAlias = Dict[str, Union[str, Dict]] LanguageModel = TypeVar("LanguageModel", BaseLanguageModel, BaseLLM, BaseChatModel) - - -class Retriever(BaseRetriever): - pass +Retriever = TypeVar( + "Retriever", + BaseRetriever, + VectorStoreRetriever, +) class Object: