From a85546f5b25afb9032dba8d2c47321321d2565e9 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 17 Jan 2025 19:54:05 -0300 Subject: [PATCH] feat: make NvidiaRerankComponent work with search_results (#5740) * Renaming data input to 'search results' Credit: @brian-ogrady * fix: Update document handling in NvidiaRerankComponent to use new Data model - Removed the import of Document from langchain.schema. - Updated the rerank_documents method to utilize the to_lc_document method for converting passages to the new Data model, ensuring type safety and consistency in document processing. This change enhances the integration with the updated data structures in the project. * refactor: Remove unused base_retriever output from NvidiaRerankComponent This change simplifies the output structure of the NvidiaRerankComponent by removing the base_retriever output, which was not being utilized. This refactor enhances code clarity and maintainability. * refactor: Rename output in NvidiaRerankComponent from 'search_results' to 'reranked_documents' This change updates the output display name and method in the NvidiaRerankComponent to better reflect its functionality, enhancing clarity in the component's purpose and usage. * refactor: Remove legacy flag from NvidiaRerankComponent This change simplifies the NvidiaRerankComponent by removing the unused legacy boolean flag, enhancing code clarity and maintainability. * refactor: Rename build_model method to build_reranker in NvidiaRerankComponent This change improves code clarity by renaming the method to better reflect its purpose, aligning with the component's functionality in the reranking process. * feat: Enable tool mode for search query input in NvidiaRerankComponent This change adds a new 'tool_mode' flag to the 'search_query' input in the NvidiaRerankComponent, enhancing its functionality and allowing for improved interaction with the component. This update aligns with recent refactors aimed at clarifying the component's purpose and usage. --- .../components/nvidia/nvidia_rerank.py | 56 ++++++++----------- 1 file changed, 22 insertions(+), 34 deletions(-) diff --git a/src/backend/base/langflow/components/nvidia/nvidia_rerank.py b/src/backend/base/langflow/components/nvidia/nvidia_rerank.py index 4228fa3a8..746c1f1b7 100644 --- a/src/backend/base/langflow/components/nvidia/nvidia_rerank.py +++ b/src/backend/base/langflow/components/nvidia/nvidia_rerank.py @@ -1,19 +1,9 @@ -from typing import Any, cast +from typing import Any -from langchain.retrievers import ContextualCompressionRetriever - -from langflow.base.vectorstores.model import ( - LCVectorStoreComponent, - check_cached_vector_store, -) -from langflow.field_typing import Retriever, VectorStore -from langflow.io import ( - DropdownInput, - HandleInput, - MultilineInput, - SecretStrInput, - StrInput, -) +from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store +from langflow.field_typing import VectorStore +from langflow.inputs.inputs import DataInput +from langflow.io import DropdownInput, MultilineInput, SecretStrInput, StrInput from langflow.schema import Data from langflow.schema.dotdict import dotdict from langflow.template.field.base import Output @@ -23,12 +13,12 @@ class NvidiaRerankComponent(LCVectorStoreComponent): display_name = "NVIDIA Rerank" description = "Rerank documents using the NVIDIA API and a retriever." icon = "NVIDIA" - legacy: bool = True inputs = [ MultilineInput( name="search_query", display_name="Search Query", + tool_mode=True, ), StrInput( name="base_url", @@ -44,19 +34,19 @@ class NvidiaRerankComponent(LCVectorStoreComponent): value="nv-rerank-qa-mistral-4b:1", ), SecretStrInput(name="api_key", display_name="API Key"), - HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]), + DataInput( + name="search_results", + display_name="Search Results", + info="Search Results from a Vector Store.", + is_list=True, + ), ] outputs = [ Output( - display_name="Retriever", - name="base_retriever", - method="build_base_retriever", - ), - Output( - display_name="Search Results", - name="search_results", - method="search_documents", + display_name="Reranked Documents", + name="reranked_documents", + method="rerank_documents", ), ] @@ -72,7 +62,7 @@ class NvidiaRerankComponent(LCVectorStoreComponent): raise ValueError(msg) from e return build_config - def build_model(self): + def build_reranker(self): try: from langchain_nvidia_ai_endpoints import NVIDIARerank except ImportError as e: @@ -80,14 +70,12 @@ class NvidiaRerankComponent(LCVectorStoreComponent): raise ImportError(msg) from e return NVIDIARerank(api_key=self.api_key, model=self.model, base_url=self.base_url) - def build_base_retriever(self) -> Retriever: # type: ignore[type-var] - nvidia_reranker = self.build_model() - retriever = ContextualCompressionRetriever(base_compressor=nvidia_reranker, base_retriever=self.retriever) - return cast("Retriever", retriever) - - async def search_documents(self) -> list[Data]: # type: ignore[override] - retriever = self.build_base_retriever() - documents = await retriever.ainvoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()}) + async def rerank_documents(self) -> list[Data]: # type: ignore[override] + reranker = self.build_reranker() + documents = reranker.compress_documents( + query=self.search_query, + documents=[passage.to_lc_document() for passage in self.search_results if isinstance(passage, Data)], + ) data = self.to_data(documents) self.status = data return data