feat: adds NVIDIA components (#2591)
* chore: Update langchain-nvidia-ai-endpoints dependency to version 0.1.2 * feat: Add NVIDIAEmbeddingsComponent for generating embeddings using NVIDIA models * feat: Add NVIDIAModelComponent for generating text using NVIDIA LLMs * feat: Add NvidiaRerankComponent for reranking documents using the NVIDIA API and a retriever * fix: add type ignore * chore: Update NVIDIAEmbeddingsComponent and NVIDIAModelComponent to handle type ignore * chore(poetry.lock): update lock
This commit is contained in:
parent
a6f128c4cf
commit
06464eda46
9 changed files with 371 additions and 31 deletions
81
poetry.lock
generated
81
poetry.lock
generated
|
|
@ -483,17 +483,17 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.34.140"
|
||||
version = "1.34.141"
|
||||
description = "The AWS SDK for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "boto3-1.34.140-py3-none-any.whl", hash = "sha256:23ca8d8f7a30c3bbd989808056b5fc5d68ff5121c02c722c6167b6b1bb7f8726"},
|
||||
{file = "boto3-1.34.140.tar.gz", hash = "sha256:578bbd5e356005719b6b610d03edff7ea1b0824d078afe62d3fb8bea72f83a87"},
|
||||
{file = "boto3-1.34.141-py3-none-any.whl", hash = "sha256:f906c797a02d37a3b88fe4c97e4d72b387e19ab6f3096d2f573578f020fd9bf4"},
|
||||
{file = "boto3-1.34.141.tar.gz", hash = "sha256:947c7a94ac3a2131142914a53afc3b1c5a572d6a79515bf2f0473188817cfcd6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
botocore = ">=1.34.140,<1.35.0"
|
||||
botocore = ">=1.34.141,<1.35.0"
|
||||
jmespath = ">=0.7.1,<2.0.0"
|
||||
s3transfer = ">=0.10.0,<0.11.0"
|
||||
|
||||
|
|
@ -502,13 +502,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
|
|||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.34.140"
|
||||
version = "1.34.141"
|
||||
description = "Low-level, data-driven core of boto 3."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "botocore-1.34.140-py3-none-any.whl", hash = "sha256:43940d3a67d946ba3301631ba4078476a75f1015d4fb0fb0272d0b754b2cf9de"},
|
||||
{file = "botocore-1.34.140.tar.gz", hash = "sha256:86302b2226c743b9eec7915a4c6cfaffd338ae03989cd9ee181078ef39d1ab39"},
|
||||
{file = "botocore-1.34.141-py3-none-any.whl", hash = "sha256:0e661a452c0489b6d62a9c91fed3320d5690a524489a7e50afc8efadb994dba8"},
|
||||
{file = "botocore-1.34.141.tar.gz", hash = "sha256:d2815c09037039a287461eddc07af895d798bc897e6ba4b08f5a12eaa9886ff1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -2615,8 +2615,8 @@ files = [
|
|||
[package.dependencies]
|
||||
cffi = {version = ">=1.12.2", markers = "platform_python_implementation == \"CPython\" and sys_platform == \"win32\""}
|
||||
greenlet = [
|
||||
{version = ">=2.0.0", markers = "platform_python_implementation == \"CPython\" and python_version < \"3.11\""},
|
||||
{version = ">=3.0rc3", markers = "platform_python_implementation == \"CPython\" and python_version >= \"3.11\""},
|
||||
{version = ">=2.0.0", markers = "platform_python_implementation == \"CPython\" and python_version < \"3.11\""},
|
||||
]
|
||||
"zope.event" = "*"
|
||||
"zope.interface" = "*"
|
||||
|
|
@ -2775,12 +2775,12 @@ files = [
|
|||
google-auth = ">=2.14.1,<3.0.dev0"
|
||||
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
|
||||
grpcio = [
|
||||
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
]
|
||||
grpcio-status = [
|
||||
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
]
|
||||
proto-plus = ">=1.22.3,<2.0.0dev"
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
|
||||
|
|
@ -2946,13 +2946,13 @@ grpc = ["grpcio (>=1.38.0,<2.0dev)", "grpcio-status (>=1.38.0,<2.0.dev0)"]
|
|||
|
||||
[[package]]
|
||||
name = "google-cloud-resource-manager"
|
||||
version = "1.12.3"
|
||||
version = "1.12.4"
|
||||
description = "Google Cloud Resource Manager API client library"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "google-cloud-resource-manager-1.12.3.tar.gz", hash = "sha256:809851824119834e4f2310b2c4f38621c1d16b2bb14d5b9f132e69c79d355e7f"},
|
||||
{file = "google_cloud_resource_manager-1.12.3-py2.py3-none-any.whl", hash = "sha256:92be7d6959927b76d90eafc4028985c37975a46ded5466a018f02e8649e113d4"},
|
||||
{file = "google-cloud-resource-manager-1.12.4.tar.gz", hash = "sha256:3eda914a925e92465ef80faaab7e0f7a9312d486dd4e123d2c76e04bac688ff0"},
|
||||
{file = "google_cloud_resource_manager-1.12.4-py2.py3-none-any.whl", hash = "sha256:0b6663585f7f862166c0fb4c55fdda721fce4dc2dc1d5b52d03ee4bf2653a85f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -2960,7 +2960,7 @@ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extr
|
|||
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
|
||||
grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev"
|
||||
proto-plus = ">=1.22.3,<2.0.0dev"
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
|
||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
|
||||
|
||||
[[package]]
|
||||
name = "google-cloud-storage"
|
||||
|
|
@ -4271,13 +4271,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "jsonschema"
|
||||
version = "4.22.0"
|
||||
version = "4.23.0"
|
||||
description = "An implementation of JSON Schema validation for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "jsonschema-4.22.0-py3-none-any.whl", hash = "sha256:ff4cfd6b1367a40e7bc6411caec72effadd3db0bbe5017de188f2d6108335802"},
|
||||
{file = "jsonschema-4.22.0.tar.gz", hash = "sha256:5b22d434a45935119af990552c862e5d6d564e8f6601206b305a61fdf661a2b7"},
|
||||
{file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"},
|
||||
{file = "jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -4288,7 +4288,7 @@ rpds-py = ">=0.7.1"
|
|||
|
||||
[package.extras]
|
||||
format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"]
|
||||
format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"]
|
||||
format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "jsonschema-specifications"
|
||||
|
|
@ -4662,6 +4662,22 @@ langchain-core = ">=0.1.46,<0.3"
|
|||
numpy = ">=1,<2"
|
||||
pymongo = ">=4.6.1,<5.0"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-nvidia-ai-endpoints"
|
||||
version = "0.1.2"
|
||||
description = "An integration package connecting NVIDIA AI Endpoints and LangChain"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "langchain_nvidia_ai_endpoints-0.1.2-py3-none-any.whl", hash = "sha256:cf40deea1aa6ba642b8e8d5071244536c550effc4d107c794fed14666ef0a468"},
|
||||
{file = "langchain_nvidia_ai_endpoints-0.1.2.tar.gz", hash = "sha256:ffba5e8c09dfe77cc9cfa25d9ba17cda2685b7d73ca38816ba712b1efd762fdd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = ">=3.9.1,<4.0.0"
|
||||
langchain-core = ">=0.1.27,<0.3"
|
||||
pillow = ">=10.0.0,<11.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-openai"
|
||||
version = "0.1.14"
|
||||
|
|
@ -4849,13 +4865,13 @@ requests = ">=2,<3"
|
|||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.41.11"
|
||||
version = "1.41.12"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
optional = false
|
||||
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
|
||||
files = [
|
||||
{file = "litellm-1.41.11-py3-none-any.whl", hash = "sha256:1b13bf8893f5c9c5e35213a7094848e2b3933f665002016246c4058d2189c99c"},
|
||||
{file = "litellm-1.41.11.tar.gz", hash = "sha256:649ea04234d8d4d2a4cd8f4ea915d991f5fbb8ce98e19719da2af6c7ab8848a9"},
|
||||
{file = "litellm-1.41.12-py3-none-any.whl", hash = "sha256:9af8b65ca48f0aa5b8ef10a63c21b00553843aa1e498f2c9308738b97d4a50f3"},
|
||||
{file = "litellm-1.41.12.tar.gz", hash = "sha256:f94b5ac8857ea8b98b87f7d3071dbd10e24fe9e0d7831969adafb549688ce0ce"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -4920,8 +4936,8 @@ psutil = ">=5.9.1"
|
|||
pywin32 = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
pyzmq = ">=25.0.0"
|
||||
requests = [
|
||||
{version = ">=2.26.0", markers = "python_version <= \"3.11\""},
|
||||
{version = ">=2.32.2", markers = "python_version > \"3.11\""},
|
||||
{version = ">=2.26.0", markers = "python_version <= \"3.11\""},
|
||||
]
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
typing-extensions = {version = ">=4.6.0", markers = "python_version < \"3.11\""}
|
||||
|
|
@ -6455,9 +6471,9 @@ files = [
|
|||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.22.4,<2", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.23.2,<2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
|
|
@ -8214,13 +8230,13 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""}
|
|||
|
||||
[[package]]
|
||||
name = "qdrant-client"
|
||||
version = "1.10.0"
|
||||
version = "1.10.1"
|
||||
description = "Client library for the Qdrant vector search engine"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "qdrant_client-1.10.0-py3-none-any.whl", hash = "sha256:423c2586709ccf3db20850cd85c3d18954692a8faff98367dfa9dc82ab7f91d9"},
|
||||
{file = "qdrant_client-1.10.0.tar.gz", hash = "sha256:47c4f7abfab152fb7e5e4902ab0e2e9e33483c49ea5e80128ccd0295f342cf9b"},
|
||||
{file = "qdrant_client-1.10.1-py3-none-any.whl", hash = "sha256:b9fb8fe50dd168d92b2998be7c6135d5a229b3a3258ad158cc69c8adf9ff1810"},
|
||||
{file = "qdrant_client-1.10.1.tar.gz", hash = "sha256:2284c8c5bb1defb0d9dbacb07d16f344972f395f4f2ed062318476a7951fd84c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -9429,17 +9445,20 @@ httpx = ">=0.24,<0.28"
|
|||
|
||||
[[package]]
|
||||
name = "sympy"
|
||||
version = "1.12.1"
|
||||
version = "1.13.0"
|
||||
description = "Computer algebra system (CAS) in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"},
|
||||
{file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"},
|
||||
{file = "sympy-1.13.0-py3-none-any.whl", hash = "sha256:6b0b32a4673fb91bd3cac3b55406c8e01d53ae22780be467301cc452f6680c92"},
|
||||
{file = "sympy-1.13.0.tar.gz", hash = "sha256:3b6af8f4d008b9a1a6a4268b335b984b23835f26d1d60b0526ebc71d48a25f57"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
mpmath = ">=1.1.0,<1.4.0"
|
||||
mpmath = ">=1.1.0,<1.4"
|
||||
|
||||
[package.extras]
|
||||
dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "tabulate"
|
||||
|
|
@ -11258,4 +11277,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "9d2906342777f3b41e880756c5116b1c4414bb427a403e1efb7a9bdcf19e64ad"
|
||||
content-hash = "32ecfa0b1cf950dbb8b7d964ca8ec21ac3f7fa9652b60fe5f230fe10110c4c52"
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ langchain-aws = "^0.1.6"
|
|||
langchain-mongodb = "^0.1.6"
|
||||
kubernetes = "^30.1.0"
|
||||
firecrawl-py = "^0.0.16"
|
||||
langchain-nvidia-ai-endpoints = "^0.1.2"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
from typing import Any
|
||||
|
||||
from langflow.base.embeddings.model import LCEmbeddingsModel
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.inputs.inputs import DropdownInput, SecretStrInput
|
||||
from langflow.io import FloatInput, MessageTextInput
|
||||
from langflow.schema.dotdict import dotdict
|
||||
|
||||
|
||||
class NVIDIAEmbeddingsComponent(LCEmbeddingsModel):
|
||||
display_name: str = "NVIDIA Embeddings"
|
||||
description: str = "Generate embeddings using NVIDIA models."
|
||||
icon = "NVIDIA"
|
||||
|
||||
inputs = [
|
||||
DropdownInput(
|
||||
name="model",
|
||||
display_name="Model",
|
||||
options=[
|
||||
"nvidia/nv-embed-v1",
|
||||
"snowflake/arctic-embed-I",
|
||||
],
|
||||
value="nvidia/nv-embed-v1",
|
||||
),
|
||||
MessageTextInput(
|
||||
name="base_url",
|
||||
display_name="NVIDIA Base URL",
|
||||
refresh_button=True,
|
||||
value="https://integrate.api.nvidia.com/v1",
|
||||
),
|
||||
SecretStrInput(
|
||||
name="nvidia_api_key",
|
||||
display_name="NVIDIA API Key",
|
||||
info="The NVIDIA API Key.",
|
||||
advanced=False,
|
||||
value="NVIDIA_API_KEY",
|
||||
),
|
||||
FloatInput(
|
||||
name="temperature",
|
||||
display_name="Model Temperature",
|
||||
value=0.1,
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "base_url" and field_value:
|
||||
try:
|
||||
build_model = self.build_embeddings()
|
||||
ids = [model.id for model in build_model.available_models] # type: ignore
|
||||
build_config["model"]["options"] = ids
|
||||
build_config["model"]["value"] = ids[0]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error getting model names: {e}")
|
||||
return build_config
|
||||
|
||||
def build_embeddings(self) -> Embeddings:
|
||||
try:
|
||||
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
|
||||
except ImportError:
|
||||
raise ImportError("Please install langchain-nvidia-ai-endpoints to use the Nvidia model.")
|
||||
try:
|
||||
output = NVIDIAEmbeddings(
|
||||
model=self.model,
|
||||
base_url=self.base_url,
|
||||
temperature=self.temperature,
|
||||
nvidia_api_key=self.nvidia_api_key,
|
||||
) # type: ignore
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not connect to NVIDIA API. Error: {e}") from e
|
||||
return output
|
||||
90
src/backend/base/langflow/components/models/NvidiaModel.py
Normal file
90
src/backend/base/langflow/components/models/NvidiaModel.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
from typing import Any
|
||||
|
||||
from langflow.base.constants import STREAM_INFO_TEXT
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.inputs import BoolInput, DropdownInput, FloatInput, IntInput, MessageInput, SecretStrInput, StrInput
|
||||
from langflow.schema.dotdict import dotdict
|
||||
|
||||
|
||||
class NVIDIAModelComponent(LCModelComponent):
|
||||
display_name = "NVIDIA"
|
||||
description = "Generates text using NVIDIA LLMs."
|
||||
icon = "NVIDIA"
|
||||
|
||||
inputs = [
|
||||
MessageInput(name="input_value", display_name="Input"),
|
||||
IntInput(
|
||||
name="max_tokens",
|
||||
display_name="Max Tokens",
|
||||
advanced=True,
|
||||
info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.",
|
||||
),
|
||||
DropdownInput(
|
||||
name="model_name",
|
||||
display_name="Model Name",
|
||||
advanced=False,
|
||||
options=["mistralai/mixtral-8x7b-instruct-v0.1"],
|
||||
value="mistralai/mixtral-8x7b-instruct-v0.1",
|
||||
),
|
||||
StrInput(
|
||||
name="base_url",
|
||||
display_name="NVIDIA Base URL",
|
||||
value="https://integrate.api.nvidia.com/v1",
|
||||
refresh_button=True,
|
||||
info="The base URL of the NVIDIA API. Defaults to https://integrate.api.nvidia.com/v1.",
|
||||
),
|
||||
SecretStrInput(
|
||||
name="nvidia_api_key",
|
||||
display_name="NVIDIA API Key",
|
||||
info="The NVIDIA API Key.",
|
||||
advanced=False,
|
||||
value="NVIDIA_API_KEY",
|
||||
),
|
||||
FloatInput(name="temperature", display_name="Temperature", value=0.1),
|
||||
BoolInput(name="stream", display_name="Stream", info=STREAM_INFO_TEXT, advanced=True),
|
||||
StrInput(
|
||||
name="system_message",
|
||||
display_name="System Message",
|
||||
info="System message to pass to the model.",
|
||||
advanced=True,
|
||||
),
|
||||
IntInput(
|
||||
name="seed",
|
||||
display_name="Seed",
|
||||
info="The seed controls the reproducibility of the job.",
|
||||
advanced=True,
|
||||
value=1,
|
||||
),
|
||||
]
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "base_url" and field_value:
|
||||
try:
|
||||
build_model = self.build_model()
|
||||
ids = [model.id for model in build_model.available_models] # type: ignore
|
||||
build_config["model_name"]["options"] = ids
|
||||
build_config["model_name"]["value"] = ids[0]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error getting model names: {e}")
|
||||
return build_config
|
||||
|
||||
def build_model(self) -> LanguageModel: # type: ignore[type-var]
|
||||
try:
|
||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||
except ImportError:
|
||||
raise ImportError("Please install langchain-nvidia-ai-endpoints to use the NVIDIA model.")
|
||||
nvidia_api_key = self.nvidia_api_key
|
||||
temperature = self.temperature
|
||||
model_name: str = self.model_name
|
||||
max_tokens = self.max_tokens
|
||||
seed = self.seed
|
||||
output = ChatNVIDIA(
|
||||
max_tokens=max_tokens or None,
|
||||
model=model_name,
|
||||
base_url=self.base_url,
|
||||
api_key=nvidia_api_key, # type: ignore
|
||||
temperature=temperature or 0.1,
|
||||
seed=seed,
|
||||
)
|
||||
return output # type: ignore
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
from typing import Any, List, cast
|
||||
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
|
||||
from langflow.base.vectorstores.model import LCVectorStoreComponent
|
||||
from langflow.field_typing import Retriever
|
||||
from langflow.io import DropdownInput, HandleInput, MultilineInput, SecretStrInput, StrInput
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.dotdict import dotdict
|
||||
|
||||
|
||||
class NvidiaRerankComponent(LCVectorStoreComponent):
|
||||
display_name = "NVIDIA Rerank"
|
||||
description = "Rerank documents using the NVIDIA API and a retriever."
|
||||
icon = "NVIDIA"
|
||||
|
||||
inputs = [
|
||||
MultilineInput(
|
||||
name="search_query",
|
||||
display_name="Search Query",
|
||||
),
|
||||
StrInput(
|
||||
name="base_url",
|
||||
display_name="Base URL",
|
||||
value="https://integrate.api.nvidia.com/v1",
|
||||
refresh_button=True,
|
||||
info="The base URL of the NVIDIA API. Defaults to https://integrate.api.nvidia.com/v1.",
|
||||
),
|
||||
DropdownInput(
|
||||
name="model", display_name="Model", options=["nv-rerank-qa-mistral-4b:1"], value="nv-rerank-qa-mistral-4b:1"
|
||||
),
|
||||
SecretStrInput(name="api_key", display_name="API Key"),
|
||||
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]),
|
||||
]
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "base_url" and field_value:
|
||||
try:
|
||||
build_model = self.build_model()
|
||||
ids = [model.id for model in build_model.available_models]
|
||||
build_config["model"]["options"] = ids
|
||||
build_config["model"]["value"] = ids[0]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error getting model names: {e}")
|
||||
return build_config
|
||||
|
||||
def build_model(self):
|
||||
try:
|
||||
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
||||
except ImportError:
|
||||
raise ImportError("Please install langchain-nvidia-ai-endpoints to use the NVIDIA model.")
|
||||
return NVIDIARerank(api_key=self.api_key, model=self.model, base_url=self.base_url)
|
||||
|
||||
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
|
||||
nvidia_reranker = self.build_model()
|
||||
retriever = ContextualCompressionRetriever(base_compressor=nvidia_reranker, base_retriever=self.retriever)
|
||||
return cast(Retriever, retriever)
|
||||
|
||||
async def search_documents(self) -> List[Data]: # type: ignore
|
||||
retriever = self.build_base_retriever()
|
||||
documents = await retriever.ainvoke(self.search_query)
|
||||
data = self.to_data(documents)
|
||||
self.status = data
|
||||
return data
|
||||
9
src/frontend/src/icons/Nvidia/index.tsx
Normal file
9
src/frontend/src/icons/Nvidia/index.tsx
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
import React, { forwardRef } from "react";
|
||||
import NvidiaSVG from "./nvidia";
|
||||
|
||||
export const NvidiaIcon = forwardRef<
|
||||
SVGSVGElement,
|
||||
React.PropsWithChildren<{}>
|
||||
>((props, ref) => {
|
||||
return <NvidiaSVG ref={ref} {...props} />;
|
||||
});
|
||||
52
src/frontend/src/icons/Nvidia/nvidia.jsx
Normal file
52
src/frontend/src/icons/Nvidia/nvidia.jsx
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
const NvidiaSVG = (props) => (
|
||||
<svg
|
||||
version="1.1"
|
||||
id="svg2"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
x="0px"
|
||||
y="0px"
|
||||
width="351.46px"
|
||||
height="258.785px"
|
||||
viewBox="35.188 31.512 351.46 258.785"
|
||||
enable-background="new 35.188 31.512 351.46 258.785"
|
||||
{...props}
|
||||
>
|
||||
<title id="title4">
|
||||
generated by pstoedit version:3.44 from NVBadge_2D.eps
|
||||
</title>
|
||||
<path
|
||||
id="path17"
|
||||
d="M384.195,282.109c0,3.771-2.769,6.302-6.047,6.302v-0.023c-3.371,0.023-6.089-2.508-6.089-6.278
|
||||
c0-3.769,2.718-6.293,6.089-6.293C381.427,275.816,384.195,278.34,384.195,282.109z M386.648,282.109c0-5.175-4.02-8.179-8.5-8.179
|
||||
c-4.511,0-8.531,3.004-8.531,8.179c0,5.172,4.021,8.188,8.531,8.188C382.629,290.297,386.648,287.281,386.648,282.109
|
||||
M376.738,282.801h0.91l2.109,3.703h2.316l-2.336-3.859c1.207-0.086,2.2-0.661,2.2-2.286c0-2.019-1.392-2.668-3.75-2.668h-3.411
|
||||
v8.813h1.961V282.801 M376.738,281.309v-2.122h1.364c0.742,0,1.753,0.06,1.753,0.965c0,0.985-0.523,1.157-1.398,1.157H376.738"
|
||||
/>
|
||||
<path
|
||||
id="path19"
|
||||
d="M329.406,237.027l10.598,28.993H318.48L329.406,237.027z M318.056,225.738l-24.423,61.88h17.246l3.863-10.934
|
||||
h28.903l3.656,10.934h18.722l-24.605-61.888L318.056,225.738z M269.023,287.641h17.497v-61.922l-17.5-0.004L269.023,287.641z
|
||||
M147.556,225.715l-14.598,49.078l-13.984-49.074l-18.879-0.004l19.972,61.926h25.207l20.133-61.926H147.556z M218.281,239.199h7.52
|
||||
c10.91,0,17.966,4.898,17.966,17.609c0,12.714-7.056,17.613-17.966,17.613h-7.52V239.199z M200.931,225.715v61.926h28.366
|
||||
c15.113,0,20.048-2.512,25.384-8.148c3.769-3.957,6.207-12.641,6.207-22.134c0-8.707-2.063-16.468-5.66-21.304
|
||||
c-6.481-8.649-15.817-10.34-29.75-10.34H200.931z M35.188,225.629v62.012h17.645v-47.086l13.672,0.004
|
||||
c4.527,0,7.754,1.128,9.934,3.457c2.765,2.945,3.894,7.699,3.894,16.395v27.23h17.098v-34.262c0-24.453-15.586-27.75-30.836-27.75
|
||||
H35.188z M172.771,225.715l0.007,61.926h17.489v-61.926H172.771z"
|
||||
/>
|
||||
<path
|
||||
id="path21"
|
||||
fill="#77B900"
|
||||
d="M82.211,102.414c0,0,22.504-33.203,67.437-36.638V53.73
|
||||
c-49.769,3.997-92.867,46.149-92.867,46.149s24.41,70.565,92.867,77.026v-12.804C99.411,157.781,82.211,102.414,82.211,102.414z
|
||||
M149.648,138.637v11.726c-37.968-6.769-48.507-46.237-48.507-46.237s18.23-20.195,48.507-23.47v12.867
|
||||
c-0.023,0-0.039-0.007-0.058-0.007c-15.891-1.907-28.305,12.938-28.305,12.938S128.243,131.445,149.648,138.637 M149.648,31.512
|
||||
V53.73c1.461-0.112,2.922-0.207,4.391-0.257c56.582-1.907,93.449,46.406,93.449,46.406s-42.343,51.488-86.457,51.488
|
||||
c-4.043,0-7.828-0.375-11.383-1.005v13.739c3.04,0.386,6.192,0.613,9.481,0.613c41.051,0,70.738-20.965,99.484-45.778
|
||||
c4.766,3.817,24.278,13.103,28.289,17.168c-27.332,22.883-91.031,41.329-127.144,41.329c-3.481,0-6.824-0.211-10.11-0.528v19.306
|
||||
h156.032V31.512H149.648z M149.648,80.656V65.777c1.446-0.101,2.903-0.179,4.391-0.226c40.688-1.278,67.382,34.965,67.382,34.965
|
||||
s-28.832,40.043-59.746,40.043c-4.449,0-8.438-0.715-12.028-1.922V93.523c15.84,1.914,19.028,8.911,28.551,24.786l21.18-17.859
|
||||
c0,0-15.461-20.277-41.524-20.277C155.021,80.172,152.31,80.371,149.648,80.656"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default NvidiaSVG;
|
||||
32
src/frontend/src/icons/Nvidia/nvidia.svg
Normal file
32
src/frontend/src/icons/Nvidia/nvidia.svg
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- Generator: Adobe Illustrator 16.0.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg version="1.1" id="svg2" xmlns:svg="http://www.w3.org/2000/svg"
|
||||
xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" width="351.46px"
|
||||
height="258.785px" viewBox="35.188 31.512 351.46 258.785" enable-background="new 35.188 31.512 351.46 258.785"
|
||||
xml:space="preserve">
|
||||
<title id="title4">generated by pstoedit version:3.44 from NVBadge_2D.eps</title>
|
||||
<path id="path17" d="M384.195,282.109c0,3.771-2.769,6.302-6.047,6.302v-0.023c-3.371,0.023-6.089-2.508-6.089-6.278
|
||||
c0-3.769,2.718-6.293,6.089-6.293C381.427,275.816,384.195,278.34,384.195,282.109z M386.648,282.109c0-5.175-4.02-8.179-8.5-8.179
|
||||
c-4.511,0-8.531,3.004-8.531,8.179c0,5.172,4.021,8.188,8.531,8.188C382.629,290.297,386.648,287.281,386.648,282.109
|
||||
M376.738,282.801h0.91l2.109,3.703h2.316l-2.336-3.859c1.207-0.086,2.2-0.661,2.2-2.286c0-2.019-1.392-2.668-3.75-2.668h-3.411
|
||||
v8.813h1.961V282.801 M376.738,281.309v-2.122h1.364c0.742,0,1.753,0.06,1.753,0.965c0,0.985-0.523,1.157-1.398,1.157H376.738"/>
|
||||
<path id="path19" d="M329.406,237.027l10.598,28.993H318.48L329.406,237.027z M318.056,225.738l-24.423,61.88h17.246l3.863-10.934
|
||||
h28.903l3.656,10.934h18.722l-24.605-61.888L318.056,225.738z M269.023,287.641h17.497v-61.922l-17.5-0.004L269.023,287.641z
|
||||
M147.556,225.715l-14.598,49.078l-13.984-49.074l-18.879-0.004l19.972,61.926h25.207l20.133-61.926H147.556z M218.281,239.199h7.52
|
||||
c10.91,0,17.966,4.898,17.966,17.609c0,12.714-7.056,17.613-17.966,17.613h-7.52V239.199z M200.931,225.715v61.926h28.366
|
||||
c15.113,0,20.048-2.512,25.384-8.148c3.769-3.957,6.207-12.641,6.207-22.134c0-8.707-2.063-16.468-5.66-21.304
|
||||
c-6.481-8.649-15.817-10.34-29.75-10.34H200.931z M35.188,225.629v62.012h17.645v-47.086l13.672,0.004
|
||||
c4.527,0,7.754,1.128,9.934,3.457c2.765,2.945,3.894,7.699,3.894,16.395v27.23h17.098v-34.262c0-24.453-15.586-27.75-30.836-27.75
|
||||
H35.188z M172.771,225.715l0.007,61.926h17.489v-61.926H172.771z"/>
|
||||
<path id="path21" fill="#77B900" d="M82.211,102.414c0,0,22.504-33.203,67.437-36.638V53.73
|
||||
c-49.769,3.997-92.867,46.149-92.867,46.149s24.41,70.565,92.867,77.026v-12.804C99.411,157.781,82.211,102.414,82.211,102.414z
|
||||
M149.648,138.637v11.726c-37.968-6.769-48.507-46.237-48.507-46.237s18.23-20.195,48.507-23.47v12.867
|
||||
c-0.023,0-0.039-0.007-0.058-0.007c-15.891-1.907-28.305,12.938-28.305,12.938S128.243,131.445,149.648,138.637 M149.648,31.512
|
||||
V53.73c1.461-0.112,2.922-0.207,4.391-0.257c56.582-1.907,93.449,46.406,93.449,46.406s-42.343,51.488-86.457,51.488
|
||||
c-4.043,0-7.828-0.375-11.383-1.005v13.739c3.04,0.386,6.192,0.613,9.481,0.613c41.051,0,70.738-20.965,99.484-45.778
|
||||
c4.766,3.817,24.278,13.103,28.289,17.168c-27.332,22.883-91.031,41.329-127.144,41.329c-3.481,0-6.824-0.211-10.11-0.528v19.306
|
||||
h156.032V31.512H149.648z M149.648,80.656V65.777c1.446-0.101,2.903-0.179,4.391-0.226c40.688-1.278,67.382,34.965,67.382,34.965
|
||||
s-28.832,40.043-59.746,40.043c-4.449,0-8.438-0.715-12.028-1.922V93.523c15.84,1.914,19.028,8.911,28.551,24.786l21.18-17.859
|
||||
c0,0-15.461-20.277-41.524-20.277C155.021,80.172,152.31,80.371,149.648,80.656"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.2 KiB |
|
|
@ -188,6 +188,7 @@ import { MetaIcon } from "../icons/Meta";
|
|||
import { MidjourneyIcon } from "../icons/Midjorney";
|
||||
import { MongoDBIcon } from "../icons/MongoDB";
|
||||
import { NotionIcon } from "../icons/Notion";
|
||||
import { NvidiaIcon } from "../icons/Nvidia";
|
||||
import { OllamaIcon } from "../icons/Ollama";
|
||||
import { OpenAiIcon } from "../icons/OpenAi";
|
||||
import { PineconeIcon } from "../icons/Pinecone";
|
||||
|
|
@ -384,6 +385,7 @@ export const nodeIconsLucide: iconsType = {
|
|||
MongoDB: MongoDBIcon,
|
||||
MongoDBChatMessageHistory: MongoDBIcon,
|
||||
NotionDirectoryLoader: NotionIcon,
|
||||
NVIDIA: NvidiaIcon,
|
||||
ChatOpenAI: OpenAiIcon,
|
||||
AzureChatOpenAI: OpenAiIcon,
|
||||
OpenAI: OpenAiIcon,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue