feat: Restructure Rerankers so both NVIDIA and Cohere work properly (#5933)
* Removing reference to nonexistent method * Restructuring rerankers to inherit from BaseDocumentCompressor. Adding Voyage AI reranker. * Removing Voyage AI component and dependency. * [autofix.ci] apply automated fixes * feat: Add method to compress documents as DataFrame in LCCompressorComponent * Changing description of abstract build_compressor method * [autofix.ci] apply automated fixes * Adding top_n as an argument to the NVIDIA reranker * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
8ab74e037a
commit
8bf7048485
7 changed files with 117 additions and 716 deletions
0
src/backend/base/langflow/base/compressors/__init__.py
Normal file
0
src/backend/base/langflow/base/compressors/__init__.py
Normal file
64
src/backend/base/langflow/base/compressors/model.py
Normal file
64
src/backend/base/langflow/base/compressors/model.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from abc import abstractmethod
|
||||
|
||||
from langflow.custom import Component
|
||||
from langflow.field_typing import BaseDocumentCompressor
|
||||
from langflow.io import DataInput, IntInput, MultilineInput, SecretStrInput
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.dataframe import DataFrame
|
||||
from langflow.template.field.base import Output
|
||||
|
||||
|
||||
class LCCompressorComponent(Component):
|
||||
inputs = [
|
||||
MultilineInput(
|
||||
name="search_query",
|
||||
display_name="Search Query",
|
||||
tool_mode=True,
|
||||
),
|
||||
SecretStrInput(
|
||||
name="api_key",
|
||||
display_name="API Key",
|
||||
),
|
||||
DataInput(
|
||||
name="search_results",
|
||||
display_name="Search Results",
|
||||
info="Search Results from a Vector Store.",
|
||||
is_list=True,
|
||||
),
|
||||
IntInput(name="top_n", display_name="Top N", value=3, advanced=True),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(
|
||||
display_name="Data",
|
||||
name="compressed_documents",
|
||||
method="Compressed Documents",
|
||||
),
|
||||
Output(
|
||||
display_name="DataFrame",
|
||||
name="compressed_documents_as_dataframe",
|
||||
method="Compressed Documents as DataFrame",
|
||||
),
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def build_compressor(self) -> BaseDocumentCompressor:
|
||||
"""Builds the Base Document Compressor object."""
|
||||
msg = "build_compressor method must be implemented."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
async def compress_documents(self) -> list[Data]:
|
||||
"""Compresses the documents retrieved from the vector store."""
|
||||
compressor = self.build_compressor()
|
||||
documents = compressor.compress_documents(
|
||||
query=self.search_query,
|
||||
documents=[passage.to_lc_document() for passage in self.search_results if isinstance(passage, Data)],
|
||||
)
|
||||
data = self.to_data(documents)
|
||||
self.status = data
|
||||
return data
|
||||
|
||||
async def compress_documents_as_dataframe(self) -> DataFrame:
|
||||
"""Compresses the documents retrieved from the vector store and returns a pandas DataFrame."""
|
||||
data_objs = await self.compress_documents()
|
||||
return DataFrame(data=data_objs)
|
||||
|
|
@ -1,37 +1,17 @@
|
|||
from typing import cast
|
||||
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain_cohere import CohereRerank
|
||||
|
||||
from langflow.base.vectorstores.model import (
|
||||
LCVectorStoreComponent,
|
||||
check_cached_vector_store,
|
||||
)
|
||||
from langflow.field_typing import Retriever, VectorStore
|
||||
from langflow.io import (
|
||||
DropdownInput,
|
||||
HandleInput,
|
||||
IntInput,
|
||||
MessageTextInput,
|
||||
MultilineInput,
|
||||
SecretStrInput,
|
||||
)
|
||||
from langflow.schema import Data
|
||||
from langflow.base.compressors.model import LCCompressorComponent
|
||||
from langflow.field_typing import BaseDocumentCompressor
|
||||
from langflow.io import DropdownInput
|
||||
from langflow.template.field.base import Output
|
||||
|
||||
|
||||
class CohereRerankComponent(LCVectorStoreComponent):
|
||||
class CohereRerankComponent(LCCompressorComponent):
|
||||
display_name = "Cohere Rerank"
|
||||
description = "Rerank documents using the Cohere API and a retriever."
|
||||
description = "Rerank documents using the Cohere API."
|
||||
name = "CohereRerank"
|
||||
icon = "Cohere"
|
||||
legacy: bool = True
|
||||
|
||||
inputs = [
|
||||
MultilineInput(
|
||||
name="search_query",
|
||||
display_name="Search Query",
|
||||
),
|
||||
*LCCompressorComponent.inputs,
|
||||
DropdownInput(
|
||||
name="model",
|
||||
display_name="Model",
|
||||
|
|
@ -43,48 +23,24 @@ class CohereRerankComponent(LCVectorStoreComponent):
|
|||
],
|
||||
value="rerank-english-v3.0",
|
||||
),
|
||||
SecretStrInput(name="api_key", display_name="API Key"),
|
||||
IntInput(name="top_n", display_name="Top N", value=3),
|
||||
MessageTextInput(
|
||||
name="user_agent",
|
||||
display_name="User Agent",
|
||||
value="langflow",
|
||||
advanced=True,
|
||||
),
|
||||
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(
|
||||
display_name="Retriever",
|
||||
name="base_retriever",
|
||||
method="build_base_retriever",
|
||||
),
|
||||
Output(
|
||||
display_name="Search Results",
|
||||
name="search_results",
|
||||
method="search_documents",
|
||||
display_name="Reranked Documents",
|
||||
name="reranked_documents",
|
||||
method="compress_documents",
|
||||
),
|
||||
]
|
||||
|
||||
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
|
||||
cohere_reranker = CohereRerank(
|
||||
def build_compressor(self) -> BaseDocumentCompressor: # type: ignore[type-var]
|
||||
try:
|
||||
from langchain_cohere import CohereRerank
|
||||
except ImportError as e:
|
||||
msg = "Please install langchain-cohere to use the Cohere model."
|
||||
raise ImportError(msg) from e
|
||||
return CohereRerank(
|
||||
cohere_api_key=self.api_key,
|
||||
model=self.model,
|
||||
top_n=self.top_n,
|
||||
user_agent=self.user_agent,
|
||||
)
|
||||
retriever = ContextualCompressionRetriever(base_compressor=cohere_reranker, base_retriever=self.retriever)
|
||||
return cast("Retriever", retriever)
|
||||
|
||||
async def search_documents(self) -> list[Data]: # type: ignore[override]
|
||||
retriever = self.build_base_retriever()
|
||||
documents = await retriever.ainvoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()})
|
||||
data = self.to_data(documents)
|
||||
self.status = data
|
||||
return data
|
||||
|
||||
@check_cached_vector_store
|
||||
def build_vector_store(self) -> VectorStore:
|
||||
msg = "Cohere Rerank does not support vector stores."
|
||||
raise NotImplementedError(msg)
|
||||
|
|
|
|||
|
|
@ -1,25 +1,19 @@
|
|||
from typing import Any
|
||||
|
||||
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
|
||||
from langflow.field_typing import VectorStore
|
||||
from langflow.inputs.inputs import DataInput
|
||||
from langflow.io import DropdownInput, MultilineInput, SecretStrInput, StrInput
|
||||
from langflow.schema import Data
|
||||
from langflow.base.compressors.model import LCCompressorComponent
|
||||
from langflow.field_typing import BaseDocumentCompressor
|
||||
from langflow.io import DropdownInput, StrInput
|
||||
from langflow.schema.dotdict import dotdict
|
||||
from langflow.template.field.base import Output
|
||||
|
||||
|
||||
class NvidiaRerankComponent(LCVectorStoreComponent):
|
||||
class NvidiaRerankComponent(LCCompressorComponent):
|
||||
display_name = "NVIDIA Rerank"
|
||||
description = "Rerank documents using the NVIDIA API and a retriever."
|
||||
description = "Rerank documents using the NVIDIA API."
|
||||
icon = "NVIDIA"
|
||||
|
||||
inputs = [
|
||||
MultilineInput(
|
||||
name="search_query",
|
||||
display_name="Search Query",
|
||||
tool_mode=True,
|
||||
),
|
||||
*LCCompressorComponent.inputs,
|
||||
StrInput(
|
||||
name="base_url",
|
||||
display_name="Base URL",
|
||||
|
|
@ -33,27 +27,20 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
|
|||
options=["nv-rerank-qa-mistral-4b:1"],
|
||||
value="nv-rerank-qa-mistral-4b:1",
|
||||
),
|
||||
SecretStrInput(name="api_key", display_name="API Key"),
|
||||
DataInput(
|
||||
name="search_results",
|
||||
display_name="Search Results",
|
||||
info="Search Results from a Vector Store.",
|
||||
is_list=True,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(
|
||||
display_name="Reranked Documents",
|
||||
name="reranked_documents",
|
||||
method="rerank_documents",
|
||||
method="compress_documents",
|
||||
),
|
||||
]
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "base_url" and field_value:
|
||||
try:
|
||||
build_model = self.build_reranker()
|
||||
build_model = self.build_compressor()
|
||||
ids = [model.id for model in build_model.available_models]
|
||||
build_config["model"]["options"] = ids
|
||||
build_config["model"]["value"] = ids[0]
|
||||
|
|
@ -62,25 +49,10 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
|
|||
raise ValueError(msg) from e
|
||||
return build_config
|
||||
|
||||
def build_reranker(self):
|
||||
def build_compressor(self) -> BaseDocumentCompressor:
|
||||
try:
|
||||
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
||||
except ImportError as e:
|
||||
msg = "Please install langchain-nvidia-ai-endpoints to use the NVIDIA model."
|
||||
raise ImportError(msg) from e
|
||||
return NVIDIARerank(api_key=self.api_key, model=self.model, base_url=self.base_url)
|
||||
|
||||
async def rerank_documents(self) -> list[Data]: # type: ignore[override]
|
||||
reranker = self.build_reranker()
|
||||
documents = reranker.compress_documents(
|
||||
query=self.search_query,
|
||||
documents=[passage.to_lc_document() for passage in self.search_results if isinstance(passage, Data)],
|
||||
)
|
||||
data = self.to_data(documents)
|
||||
self.status = data
|
||||
return data
|
||||
|
||||
@check_cached_vector_store
|
||||
def build_vector_store(self) -> VectorStore:
|
||||
msg = "NVIDIA Rerank does not support vector stores."
|
||||
raise NotImplementedError(msg)
|
||||
return NVIDIARerank(api_key=self.api_key, model=self.model, base_url=self.base_url, top_n=self.top_n)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from .constants import (
|
|||
AgentExecutor,
|
||||
BaseChatMemory,
|
||||
BaseChatModel,
|
||||
BaseDocumentCompressor,
|
||||
BaseLanguageModel,
|
||||
BaseLLM,
|
||||
BaseLoader,
|
||||
|
|
@ -61,6 +62,7 @@ __all__ = [
|
|||
"AgentExecutor",
|
||||
"BaseChatMemory",
|
||||
"BaseChatModel",
|
||||
"BaseDocumentCompressor",
|
||||
"BaseLLM",
|
||||
"BaseLanguageModel",
|
||||
"BaseLoader",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from langchain.memory.chat_memory import BaseChatMemory
|
|||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.documents.compressor import BaseDocumentCompressor
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLanguageModel, BaseLLM
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
|
@ -68,6 +69,7 @@ LANGCHAIN_BASE_TYPES = {
|
|||
"BaseChatMemory": BaseChatMemory,
|
||||
"BaseChatModel": BaseChatModel,
|
||||
"Memory": Memory,
|
||||
"BaseDocumentCompressor": BaseDocumentCompressor,
|
||||
}
|
||||
# Langchain base types plus Python base types
|
||||
CUSTOM_COMPONENT_SUPPORTED_TYPES = {
|
||||
|
|
@ -96,6 +98,7 @@ from langchain_core.memory import BaseMemory
|
|||
from langchain_core.output_parsers import BaseLLMOutputParser, BaseOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate, PromptTemplate
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.documents.compressor import BaseDocumentCompressor
|
||||
from langchain_core.tools import BaseTool, Tool
|
||||
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
|
||||
from langchain_text_splitters import TextSplitter
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue