feat: make NvidiaRerankComponent work with search_results (#5740)

* Renaming data input to 'search results'

Credit: @brian-ogrady

* fix: Update document handling in NvidiaRerankComponent to use new Data model

- Removed the import of Document from langchain.schema.
- Updated the rerank_documents method to utilize the to_lc_document method for converting passages to the new Data model, ensuring type safety and consistency in document processing.

This change enhances the integration with the updated data structures in the project.

* refactor: Remove unused base_retriever output from NvidiaRerankComponent

This change simplifies the output structure of the NvidiaRerankComponent by removing the base_retriever output, which was not being utilized. This refactor enhances code clarity and maintainability.

* refactor: Rename output in NvidiaRerankComponent from 'search_results' to 'reranked_documents'

This change updates the output display name and method in the NvidiaRerankComponent to better reflect its functionality, enhancing clarity in the component's purpose and usage.

* refactor: Remove legacy flag from NvidiaRerankComponent

This change simplifies the NvidiaRerankComponent by removing the unused legacy boolean flag, enhancing code clarity and maintainability.

* refactor: Rename build_model method to build_reranker in NvidiaRerankComponent

This change improves code clarity by renaming the method to better reflect its purpose, aligning with the component's functionality in the reranking process.

* feat: Enable tool mode for search query input in NvidiaRerankComponent

This change adds a new 'tool_mode' flag to the 'search_query' input in the NvidiaRerankComponent, enhancing its functionality and allowing for improved interaction with the component. This update aligns with recent refactors aimed at clarifying the component's purpose and usage.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2025-01-17 19:54:05 -03:00 committed by GitHub
commit a85546f5b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,19 +1,9 @@
from typing import Any, cast
from typing import Any
from langchain.retrievers import ContextualCompressionRetriever
from langflow.base.vectorstores.model import (
LCVectorStoreComponent,
check_cached_vector_store,
)
from langflow.field_typing import Retriever, VectorStore
from langflow.io import (
DropdownInput,
HandleInput,
MultilineInput,
SecretStrInput,
StrInput,
)
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.field_typing import VectorStore
from langflow.inputs.inputs import DataInput
from langflow.io import DropdownInput, MultilineInput, SecretStrInput, StrInput
from langflow.schema import Data
from langflow.schema.dotdict import dotdict
from langflow.template.field.base import Output
@ -23,12 +13,12 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
display_name = "NVIDIA Rerank"
description = "Rerank documents using the NVIDIA API and a retriever."
icon = "NVIDIA"
legacy: bool = True
inputs = [
MultilineInput(
name="search_query",
display_name="Search Query",
tool_mode=True,
),
StrInput(
name="base_url",
@ -44,19 +34,19 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
value="nv-rerank-qa-mistral-4b:1",
),
SecretStrInput(name="api_key", display_name="API Key"),
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]),
DataInput(
name="search_results",
display_name="Search Results",
info="Search Results from a Vector Store.",
is_list=True,
),
]
outputs = [
Output(
display_name="Retriever",
name="base_retriever",
method="build_base_retriever",
),
Output(
display_name="Search Results",
name="search_results",
method="search_documents",
display_name="Reranked Documents",
name="reranked_documents",
method="rerank_documents",
),
]
@ -72,7 +62,7 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
raise ValueError(msg) from e
return build_config
def build_model(self):
def build_reranker(self):
try:
from langchain_nvidia_ai_endpoints import NVIDIARerank
except ImportError as e:
@ -80,14 +70,12 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
raise ImportError(msg) from e
return NVIDIARerank(api_key=self.api_key, model=self.model, base_url=self.base_url)
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
nvidia_reranker = self.build_model()
retriever = ContextualCompressionRetriever(base_compressor=nvidia_reranker, base_retriever=self.retriever)
return cast("Retriever", retriever)
async def search_documents(self) -> list[Data]: # type: ignore[override]
retriever = self.build_base_retriever()
documents = await retriever.ainvoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()})
async def rerank_documents(self) -> list[Data]: # type: ignore[override]
reranker = self.build_reranker()
documents = reranker.compress_documents(
query=self.search_query,
documents=[passage.to_lc_document() for passage in self.search_results if isinstance(passage, Data)],
)
data = self.to_data(documents)
self.status = data
return data