feat: Restructure Rerankers so both NVIDIA and Cohere work properly (#5933)

* Removing reference to nonexistent method

* Restructuring rerankers to inherit from BaseDocumentCompressor. Adding Voyage AI reranker.

* Removing Voyage AI component and dependency.

* [autofix.ci] apply automated fixes

* feat: Add method to compress documents as DataFrame in LCCompressorComponent

* Changing description of abstract build_compressor method

* [autofix.ci] apply automated fixes

* Adding top_n as an argument to the NVIDIA reranker

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
brian-ogrady 2025-02-05 13:29:29 -05:00 committed by GitHub
commit 8bf7048485
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 117 additions and 716 deletions

View file

@ -0,0 +1,64 @@
from abc import abstractmethod
from langflow.custom import Component
from langflow.field_typing import BaseDocumentCompressor
from langflow.io import DataInput, IntInput, MultilineInput, SecretStrInput
from langflow.schema import Data
from langflow.schema.dataframe import DataFrame
from langflow.template.field.base import Output
class LCCompressorComponent(Component):
inputs = [
MultilineInput(
name="search_query",
display_name="Search Query",
tool_mode=True,
),
SecretStrInput(
name="api_key",
display_name="API Key",
),
DataInput(
name="search_results",
display_name="Search Results",
info="Search Results from a Vector Store.",
is_list=True,
),
IntInput(name="top_n", display_name="Top N", value=3, advanced=True),
]
outputs = [
Output(
display_name="Data",
name="compressed_documents",
method="Compressed Documents",
),
Output(
display_name="DataFrame",
name="compressed_documents_as_dataframe",
method="Compressed Documents as DataFrame",
),
]
@abstractmethod
def build_compressor(self) -> BaseDocumentCompressor:
"""Builds the Base Document Compressor object."""
msg = "build_compressor method must be implemented."
raise NotImplementedError(msg)
async def compress_documents(self) -> list[Data]:
"""Compresses the documents retrieved from the vector store."""
compressor = self.build_compressor()
documents = compressor.compress_documents(
query=self.search_query,
documents=[passage.to_lc_document() for passage in self.search_results if isinstance(passage, Data)],
)
data = self.to_data(documents)
self.status = data
return data
async def compress_documents_as_dataframe(self) -> DataFrame:
"""Compresses the documents retrieved from the vector store and returns a pandas DataFrame."""
data_objs = await self.compress_documents()
return DataFrame(data=data_objs)

View file

@ -1,37 +1,17 @@
from typing import cast
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
from langflow.base.vectorstores.model import (
LCVectorStoreComponent,
check_cached_vector_store,
)
from langflow.field_typing import Retriever, VectorStore
from langflow.io import (
DropdownInput,
HandleInput,
IntInput,
MessageTextInput,
MultilineInput,
SecretStrInput,
)
from langflow.schema import Data
from langflow.base.compressors.model import LCCompressorComponent
from langflow.field_typing import BaseDocumentCompressor
from langflow.io import DropdownInput
from langflow.template.field.base import Output
class CohereRerankComponent(LCVectorStoreComponent):
class CohereRerankComponent(LCCompressorComponent):
display_name = "Cohere Rerank"
description = "Rerank documents using the Cohere API and a retriever."
description = "Rerank documents using the Cohere API."
name = "CohereRerank"
icon = "Cohere"
legacy: bool = True
inputs = [
MultilineInput(
name="search_query",
display_name="Search Query",
),
*LCCompressorComponent.inputs,
DropdownInput(
name="model",
display_name="Model",
@ -43,48 +23,24 @@ class CohereRerankComponent(LCVectorStoreComponent):
],
value="rerank-english-v3.0",
),
SecretStrInput(name="api_key", display_name="API Key"),
IntInput(name="top_n", display_name="Top N", value=3),
MessageTextInput(
name="user_agent",
display_name="User Agent",
value="langflow",
advanced=True,
),
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]),
]
outputs = [
Output(
display_name="Retriever",
name="base_retriever",
method="build_base_retriever",
),
Output(
display_name="Search Results",
name="search_results",
method="search_documents",
display_name="Reranked Documents",
name="reranked_documents",
method="compress_documents",
),
]
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
cohere_reranker = CohereRerank(
def build_compressor(self) -> BaseDocumentCompressor: # type: ignore[type-var]
try:
from langchain_cohere import CohereRerank
except ImportError as e:
msg = "Please install langchain-cohere to use the Cohere model."
raise ImportError(msg) from e
return CohereRerank(
cohere_api_key=self.api_key,
model=self.model,
top_n=self.top_n,
user_agent=self.user_agent,
)
retriever = ContextualCompressionRetriever(base_compressor=cohere_reranker, base_retriever=self.retriever)
return cast("Retriever", retriever)
async def search_documents(self) -> list[Data]: # type: ignore[override]
retriever = self.build_base_retriever()
documents = await retriever.ainvoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()})
data = self.to_data(documents)
self.status = data
return data
@check_cached_vector_store
def build_vector_store(self) -> VectorStore:
msg = "Cohere Rerank does not support vector stores."
raise NotImplementedError(msg)

View file

@ -1,25 +1,19 @@
from typing import Any
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.field_typing import VectorStore
from langflow.inputs.inputs import DataInput
from langflow.io import DropdownInput, MultilineInput, SecretStrInput, StrInput
from langflow.schema import Data
from langflow.base.compressors.model import LCCompressorComponent
from langflow.field_typing import BaseDocumentCompressor
from langflow.io import DropdownInput, StrInput
from langflow.schema.dotdict import dotdict
from langflow.template.field.base import Output
class NvidiaRerankComponent(LCVectorStoreComponent):
class NvidiaRerankComponent(LCCompressorComponent):
display_name = "NVIDIA Rerank"
description = "Rerank documents using the NVIDIA API and a retriever."
description = "Rerank documents using the NVIDIA API."
icon = "NVIDIA"
inputs = [
MultilineInput(
name="search_query",
display_name="Search Query",
tool_mode=True,
),
*LCCompressorComponent.inputs,
StrInput(
name="base_url",
display_name="Base URL",
@ -33,27 +27,20 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
options=["nv-rerank-qa-mistral-4b:1"],
value="nv-rerank-qa-mistral-4b:1",
),
SecretStrInput(name="api_key", display_name="API Key"),
DataInput(
name="search_results",
display_name="Search Results",
info="Search Results from a Vector Store.",
is_list=True,
),
]
outputs = [
Output(
display_name="Reranked Documents",
name="reranked_documents",
method="rerank_documents",
method="compress_documents",
),
]
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "base_url" and field_value:
try:
build_model = self.build_reranker()
build_model = self.build_compressor()
ids = [model.id for model in build_model.available_models]
build_config["model"]["options"] = ids
build_config["model"]["value"] = ids[0]
@ -62,25 +49,10 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
raise ValueError(msg) from e
return build_config
def build_reranker(self):
def build_compressor(self) -> BaseDocumentCompressor:
try:
from langchain_nvidia_ai_endpoints import NVIDIARerank
except ImportError as e:
msg = "Please install langchain-nvidia-ai-endpoints to use the NVIDIA model."
raise ImportError(msg) from e
return NVIDIARerank(api_key=self.api_key, model=self.model, base_url=self.base_url)
async def rerank_documents(self) -> list[Data]: # type: ignore[override]
reranker = self.build_reranker()
documents = reranker.compress_documents(
query=self.search_query,
documents=[passage.to_lc_document() for passage in self.search_results if isinstance(passage, Data)],
)
data = self.to_data(documents)
self.status = data
return data
@check_cached_vector_store
def build_vector_store(self) -> VectorStore:
msg = "NVIDIA Rerank does not support vector stores."
raise NotImplementedError(msg)
return NVIDIARerank(api_key=self.api_key, model=self.model, base_url=self.base_url, top_n=self.top_n)

View file

@ -4,6 +4,7 @@ from .constants import (
AgentExecutor,
BaseChatMemory,
BaseChatModel,
BaseDocumentCompressor,
BaseLanguageModel,
BaseLLM,
BaseLoader,
@ -61,6 +62,7 @@ __all__ = [
"AgentExecutor",
"BaseChatMemory",
"BaseChatModel",
"BaseDocumentCompressor",
"BaseLLM",
"BaseLanguageModel",
"BaseLoader",

View file

@ -7,6 +7,7 @@ from langchain.memory.chat_memory import BaseChatMemory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from langchain_core.documents.compressor import BaseDocumentCompressor
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel, BaseLLM
from langchain_core.language_models.chat_models import BaseChatModel
@ -68,6 +69,7 @@ LANGCHAIN_BASE_TYPES = {
"BaseChatMemory": BaseChatMemory,
"BaseChatModel": BaseChatModel,
"Memory": Memory,
"BaseDocumentCompressor": BaseDocumentCompressor,
}
# Langchain base types plus Python base types
CUSTOM_COMPONENT_SUPPORTED_TYPES = {
@ -96,6 +98,7 @@ from langchain_core.memory import BaseMemory
from langchain_core.output_parsers import BaseLLMOutputParser, BaseOutputParser
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate, PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents.compressor import BaseDocumentCompressor
from langchain_core.tools import BaseTool, Tool
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
from langchain_text_splitters import TextSplitter

640
uv.lock generated

File diff suppressed because it is too large Load diff