diff --git a/src/backend/base/langflow/components/nvidia/nvidia_rerank.py b/src/backend/base/langflow/components/nvidia/nvidia_rerank.py index 4228fa3a8..746c1f1b7 100644 --- a/src/backend/base/langflow/components/nvidia/nvidia_rerank.py +++ b/src/backend/base/langflow/components/nvidia/nvidia_rerank.py @@ -1,19 +1,9 @@ -from typing import Any, cast +from typing import Any -from langchain.retrievers import ContextualCompressionRetriever - -from langflow.base.vectorstores.model import ( - LCVectorStoreComponent, - check_cached_vector_store, -) -from langflow.field_typing import Retriever, VectorStore -from langflow.io import ( - DropdownInput, - HandleInput, - MultilineInput, - SecretStrInput, - StrInput, -) +from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store +from langflow.field_typing import VectorStore +from langflow.inputs.inputs import DataInput +from langflow.io import DropdownInput, MultilineInput, SecretStrInput, StrInput from langflow.schema import Data from langflow.schema.dotdict import dotdict from langflow.template.field.base import Output @@ -23,12 +13,12 @@ class NvidiaRerankComponent(LCVectorStoreComponent): display_name = "NVIDIA Rerank" description = "Rerank documents using the NVIDIA API and a retriever." icon = "NVIDIA" - legacy: bool = True inputs = [ MultilineInput( name="search_query", display_name="Search Query", + tool_mode=True, ), StrInput( name="base_url", @@ -44,19 +34,19 @@ class NvidiaRerankComponent(LCVectorStoreComponent): value="nv-rerank-qa-mistral-4b:1", ), SecretStrInput(name="api_key", display_name="API Key"), - HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]), + DataInput( + name="search_results", + display_name="Search Results", + info="Search Results from a Vector Store.", + is_list=True, + ), ] outputs = [ Output( - display_name="Retriever", - name="base_retriever", - method="build_base_retriever", - ), - Output( - display_name="Search Results", - name="search_results", - method="search_documents", + display_name="Reranked Documents", + name="reranked_documents", + method="rerank_documents", ), ] @@ -72,7 +62,7 @@ class NvidiaRerankComponent(LCVectorStoreComponent): raise ValueError(msg) from e return build_config - def build_model(self): + def build_reranker(self): try: from langchain_nvidia_ai_endpoints import NVIDIARerank except ImportError as e: @@ -80,14 +70,12 @@ class NvidiaRerankComponent(LCVectorStoreComponent): raise ImportError(msg) from e return NVIDIARerank(api_key=self.api_key, model=self.model, base_url=self.base_url) - def build_base_retriever(self) -> Retriever: # type: ignore[type-var] - nvidia_reranker = self.build_model() - retriever = ContextualCompressionRetriever(base_compressor=nvidia_reranker, base_retriever=self.retriever) - return cast("Retriever", retriever) - - async def search_documents(self) -> list[Data]: # type: ignore[override] - retriever = self.build_base_retriever() - documents = await retriever.ainvoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()}) + async def rerank_documents(self) -> list[Data]: # type: ignore[override] + reranker = self.build_reranker() + documents = reranker.compress_documents( + query=self.search_query, + documents=[passage.to_lc_document() for passage in self.search_results if isinstance(passage, Data)], + ) data = self.to_data(documents) self.status = data return data