feat: adds NVIDIA components (#2591)
* chore: Update langchain-nvidia-ai-endpoints dependency to version 0.1.2 * feat: Add NVIDIAEmbeddingsComponent for generating embeddings using NVIDIA models * feat: Add NVIDIAModelComponent for generating text using NVIDIA LLMs * feat: Add NvidiaRerankComponent for reranking documents using the NVIDIA API and a retriever * fix: add type ignore * chore: Update NVIDIAEmbeddingsComponent and NVIDIAModelComponent to handle type ignore * chore(poetry.lock): update lock
This commit is contained in:
parent
a6f128c4cf
commit
06464eda46
9 changed files with 371 additions and 31 deletions
|
|
@ -0,0 +1,71 @@
|
|||
from typing import Any
|
||||
|
||||
from langflow.base.embeddings.model import LCEmbeddingsModel
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.inputs.inputs import DropdownInput, SecretStrInput
|
||||
from langflow.io import FloatInput, MessageTextInput
|
||||
from langflow.schema.dotdict import dotdict
|
||||
|
||||
|
||||
class NVIDIAEmbeddingsComponent(LCEmbeddingsModel):
|
||||
display_name: str = "NVIDIA Embeddings"
|
||||
description: str = "Generate embeddings using NVIDIA models."
|
||||
icon = "NVIDIA"
|
||||
|
||||
inputs = [
|
||||
DropdownInput(
|
||||
name="model",
|
||||
display_name="Model",
|
||||
options=[
|
||||
"nvidia/nv-embed-v1",
|
||||
"snowflake/arctic-embed-I",
|
||||
],
|
||||
value="nvidia/nv-embed-v1",
|
||||
),
|
||||
MessageTextInput(
|
||||
name="base_url",
|
||||
display_name="NVIDIA Base URL",
|
||||
refresh_button=True,
|
||||
value="https://integrate.api.nvidia.com/v1",
|
||||
),
|
||||
SecretStrInput(
|
||||
name="nvidia_api_key",
|
||||
display_name="NVIDIA API Key",
|
||||
info="The NVIDIA API Key.",
|
||||
advanced=False,
|
||||
value="NVIDIA_API_KEY",
|
||||
),
|
||||
FloatInput(
|
||||
name="temperature",
|
||||
display_name="Model Temperature",
|
||||
value=0.1,
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "base_url" and field_value:
|
||||
try:
|
||||
build_model = self.build_embeddings()
|
||||
ids = [model.id for model in build_model.available_models] # type: ignore
|
||||
build_config["model"]["options"] = ids
|
||||
build_config["model"]["value"] = ids[0]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error getting model names: {e}")
|
||||
return build_config
|
||||
|
||||
def build_embeddings(self) -> Embeddings:
|
||||
try:
|
||||
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
|
||||
except ImportError:
|
||||
raise ImportError("Please install langchain-nvidia-ai-endpoints to use the Nvidia model.")
|
||||
try:
|
||||
output = NVIDIAEmbeddings(
|
||||
model=self.model,
|
||||
base_url=self.base_url,
|
||||
temperature=self.temperature,
|
||||
nvidia_api_key=self.nvidia_api_key,
|
||||
) # type: ignore
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not connect to NVIDIA API. Error: {e}") from e
|
||||
return output
|
||||
90
src/backend/base/langflow/components/models/NvidiaModel.py
Normal file
90
src/backend/base/langflow/components/models/NvidiaModel.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
from typing import Any
|
||||
|
||||
from langflow.base.constants import STREAM_INFO_TEXT
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.inputs import BoolInput, DropdownInput, FloatInput, IntInput, MessageInput, SecretStrInput, StrInput
|
||||
from langflow.schema.dotdict import dotdict
|
||||
|
||||
|
||||
class NVIDIAModelComponent(LCModelComponent):
|
||||
display_name = "NVIDIA"
|
||||
description = "Generates text using NVIDIA LLMs."
|
||||
icon = "NVIDIA"
|
||||
|
||||
inputs = [
|
||||
MessageInput(name="input_value", display_name="Input"),
|
||||
IntInput(
|
||||
name="max_tokens",
|
||||
display_name="Max Tokens",
|
||||
advanced=True,
|
||||
info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.",
|
||||
),
|
||||
DropdownInput(
|
||||
name="model_name",
|
||||
display_name="Model Name",
|
||||
advanced=False,
|
||||
options=["mistralai/mixtral-8x7b-instruct-v0.1"],
|
||||
value="mistralai/mixtral-8x7b-instruct-v0.1",
|
||||
),
|
||||
StrInput(
|
||||
name="base_url",
|
||||
display_name="NVIDIA Base URL",
|
||||
value="https://integrate.api.nvidia.com/v1",
|
||||
refresh_button=True,
|
||||
info="The base URL of the NVIDIA API. Defaults to https://integrate.api.nvidia.com/v1.",
|
||||
),
|
||||
SecretStrInput(
|
||||
name="nvidia_api_key",
|
||||
display_name="NVIDIA API Key",
|
||||
info="The NVIDIA API Key.",
|
||||
advanced=False,
|
||||
value="NVIDIA_API_KEY",
|
||||
),
|
||||
FloatInput(name="temperature", display_name="Temperature", value=0.1),
|
||||
BoolInput(name="stream", display_name="Stream", info=STREAM_INFO_TEXT, advanced=True),
|
||||
StrInput(
|
||||
name="system_message",
|
||||
display_name="System Message",
|
||||
info="System message to pass to the model.",
|
||||
advanced=True,
|
||||
),
|
||||
IntInput(
|
||||
name="seed",
|
||||
display_name="Seed",
|
||||
info="The seed controls the reproducibility of the job.",
|
||||
advanced=True,
|
||||
value=1,
|
||||
),
|
||||
]
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "base_url" and field_value:
|
||||
try:
|
||||
build_model = self.build_model()
|
||||
ids = [model.id for model in build_model.available_models] # type: ignore
|
||||
build_config["model_name"]["options"] = ids
|
||||
build_config["model_name"]["value"] = ids[0]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error getting model names: {e}")
|
||||
return build_config
|
||||
|
||||
def build_model(self) -> LanguageModel: # type: ignore[type-var]
|
||||
try:
|
||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||
except ImportError:
|
||||
raise ImportError("Please install langchain-nvidia-ai-endpoints to use the NVIDIA model.")
|
||||
nvidia_api_key = self.nvidia_api_key
|
||||
temperature = self.temperature
|
||||
model_name: str = self.model_name
|
||||
max_tokens = self.max_tokens
|
||||
seed = self.seed
|
||||
output = ChatNVIDIA(
|
||||
max_tokens=max_tokens or None,
|
||||
model=model_name,
|
||||
base_url=self.base_url,
|
||||
api_key=nvidia_api_key, # type: ignore
|
||||
temperature=temperature or 0.1,
|
||||
seed=seed,
|
||||
)
|
||||
return output # type: ignore
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
from typing import Any, List, cast
|
||||
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
|
||||
from langflow.base.vectorstores.model import LCVectorStoreComponent
|
||||
from langflow.field_typing import Retriever
|
||||
from langflow.io import DropdownInput, HandleInput, MultilineInput, SecretStrInput, StrInput
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.dotdict import dotdict
|
||||
|
||||
|
||||
class NvidiaRerankComponent(LCVectorStoreComponent):
|
||||
display_name = "NVIDIA Rerank"
|
||||
description = "Rerank documents using the NVIDIA API and a retriever."
|
||||
icon = "NVIDIA"
|
||||
|
||||
inputs = [
|
||||
MultilineInput(
|
||||
name="search_query",
|
||||
display_name="Search Query",
|
||||
),
|
||||
StrInput(
|
||||
name="base_url",
|
||||
display_name="Base URL",
|
||||
value="https://integrate.api.nvidia.com/v1",
|
||||
refresh_button=True,
|
||||
info="The base URL of the NVIDIA API. Defaults to https://integrate.api.nvidia.com/v1.",
|
||||
),
|
||||
DropdownInput(
|
||||
name="model", display_name="Model", options=["nv-rerank-qa-mistral-4b:1"], value="nv-rerank-qa-mistral-4b:1"
|
||||
),
|
||||
SecretStrInput(name="api_key", display_name="API Key"),
|
||||
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]),
|
||||
]
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "base_url" and field_value:
|
||||
try:
|
||||
build_model = self.build_model()
|
||||
ids = [model.id for model in build_model.available_models]
|
||||
build_config["model"]["options"] = ids
|
||||
build_config["model"]["value"] = ids[0]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error getting model names: {e}")
|
||||
return build_config
|
||||
|
||||
def build_model(self):
|
||||
try:
|
||||
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
||||
except ImportError:
|
||||
raise ImportError("Please install langchain-nvidia-ai-endpoints to use the NVIDIA model.")
|
||||
return NVIDIARerank(api_key=self.api_key, model=self.model, base_url=self.base_url)
|
||||
|
||||
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
|
||||
nvidia_reranker = self.build_model()
|
||||
retriever = ContextualCompressionRetriever(base_compressor=nvidia_reranker, base_retriever=self.retriever)
|
||||
return cast(Retriever, retriever)
|
||||
|
||||
async def search_documents(self) -> List[Data]: # type: ignore
|
||||
retriever = self.build_base_retriever()
|
||||
documents = await retriever.ainvoke(self.search_query)
|
||||
data = self.to_data(documents)
|
||||
self.status = data
|
||||
return data
|
||||
Loading…
Add table
Add a link
Reference in a new issue