feat: adds NVIDIA components (#2591)

* chore: Update langchain-nvidia-ai-endpoints dependency to version 0.1.2

* feat: Add NVIDIAEmbeddingsComponent for generating embeddings using NVIDIA models

* feat: Add NVIDIAModelComponent for generating text using NVIDIA LLMs

* feat: Add NvidiaRerankComponent for reranking documents using the NVIDIA API and a retriever

* fix: add type ignore

* chore: Update NVIDIAEmbeddingsComponent and NVIDIAModelComponent to handle type ignore

* chore(poetry.lock): update lock
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-07-09 15:26:23 -03:00 committed by GitHub
commit 06464eda46
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 371 additions and 31 deletions

View file

@ -0,0 +1,71 @@
from typing import Any
from langflow.base.embeddings.model import LCEmbeddingsModel
from langflow.field_typing import Embeddings
from langflow.inputs.inputs import DropdownInput, SecretStrInput
from langflow.io import FloatInput, MessageTextInput
from langflow.schema.dotdict import dotdict
class NVIDIAEmbeddingsComponent(LCEmbeddingsModel):
display_name: str = "NVIDIA Embeddings"
description: str = "Generate embeddings using NVIDIA models."
icon = "NVIDIA"
inputs = [
DropdownInput(
name="model",
display_name="Model",
options=[
"nvidia/nv-embed-v1",
"snowflake/arctic-embed-I",
],
value="nvidia/nv-embed-v1",
),
MessageTextInput(
name="base_url",
display_name="NVIDIA Base URL",
refresh_button=True,
value="https://integrate.api.nvidia.com/v1",
),
SecretStrInput(
name="nvidia_api_key",
display_name="NVIDIA API Key",
info="The NVIDIA API Key.",
advanced=False,
value="NVIDIA_API_KEY",
),
FloatInput(
name="temperature",
display_name="Model Temperature",
value=0.1,
advanced=True,
),
]
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "base_url" and field_value:
try:
build_model = self.build_embeddings()
ids = [model.id for model in build_model.available_models] # type: ignore
build_config["model"]["options"] = ids
build_config["model"]["value"] = ids[0]
except Exception as e:
raise ValueError(f"Error getting model names: {e}")
return build_config
def build_embeddings(self) -> Embeddings:
try:
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
except ImportError:
raise ImportError("Please install langchain-nvidia-ai-endpoints to use the Nvidia model.")
try:
output = NVIDIAEmbeddings(
model=self.model,
base_url=self.base_url,
temperature=self.temperature,
nvidia_api_key=self.nvidia_api_key,
) # type: ignore
except Exception as e:
raise ValueError(f"Could not connect to NVIDIA API. Error: {e}") from e
return output

View file

@ -0,0 +1,90 @@
from typing import Any
from langflow.base.constants import STREAM_INFO_TEXT
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.inputs import BoolInput, DropdownInput, FloatInput, IntInput, MessageInput, SecretStrInput, StrInput
from langflow.schema.dotdict import dotdict
class NVIDIAModelComponent(LCModelComponent):
display_name = "NVIDIA"
description = "Generates text using NVIDIA LLMs."
icon = "NVIDIA"
inputs = [
MessageInput(name="input_value", display_name="Input"),
IntInput(
name="max_tokens",
display_name="Max Tokens",
advanced=True,
info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.",
),
DropdownInput(
name="model_name",
display_name="Model Name",
advanced=False,
options=["mistralai/mixtral-8x7b-instruct-v0.1"],
value="mistralai/mixtral-8x7b-instruct-v0.1",
),
StrInput(
name="base_url",
display_name="NVIDIA Base URL",
value="https://integrate.api.nvidia.com/v1",
refresh_button=True,
info="The base URL of the NVIDIA API. Defaults to https://integrate.api.nvidia.com/v1.",
),
SecretStrInput(
name="nvidia_api_key",
display_name="NVIDIA API Key",
info="The NVIDIA API Key.",
advanced=False,
value="NVIDIA_API_KEY",
),
FloatInput(name="temperature", display_name="Temperature", value=0.1),
BoolInput(name="stream", display_name="Stream", info=STREAM_INFO_TEXT, advanced=True),
StrInput(
name="system_message",
display_name="System Message",
info="System message to pass to the model.",
advanced=True,
),
IntInput(
name="seed",
display_name="Seed",
info="The seed controls the reproducibility of the job.",
advanced=True,
value=1,
),
]
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "base_url" and field_value:
try:
build_model = self.build_model()
ids = [model.id for model in build_model.available_models] # type: ignore
build_config["model_name"]["options"] = ids
build_config["model_name"]["value"] = ids[0]
except Exception as e:
raise ValueError(f"Error getting model names: {e}")
return build_config
def build_model(self) -> LanguageModel: # type: ignore[type-var]
try:
from langchain_nvidia_ai_endpoints import ChatNVIDIA
except ImportError:
raise ImportError("Please install langchain-nvidia-ai-endpoints to use the NVIDIA model.")
nvidia_api_key = self.nvidia_api_key
temperature = self.temperature
model_name: str = self.model_name
max_tokens = self.max_tokens
seed = self.seed
output = ChatNVIDIA(
max_tokens=max_tokens or None,
model=model_name,
base_url=self.base_url,
api_key=nvidia_api_key, # type: ignore
temperature=temperature or 0.1,
seed=seed,
)
return output # type: ignore

View file

@ -0,0 +1,64 @@
from typing import Any, List, cast
from langchain.retrievers import ContextualCompressionRetriever
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.field_typing import Retriever
from langflow.io import DropdownInput, HandleInput, MultilineInput, SecretStrInput, StrInput
from langflow.schema import Data
from langflow.schema.dotdict import dotdict
class NvidiaRerankComponent(LCVectorStoreComponent):
display_name = "NVIDIA Rerank"
description = "Rerank documents using the NVIDIA API and a retriever."
icon = "NVIDIA"
inputs = [
MultilineInput(
name="search_query",
display_name="Search Query",
),
StrInput(
name="base_url",
display_name="Base URL",
value="https://integrate.api.nvidia.com/v1",
refresh_button=True,
info="The base URL of the NVIDIA API. Defaults to https://integrate.api.nvidia.com/v1.",
),
DropdownInput(
name="model", display_name="Model", options=["nv-rerank-qa-mistral-4b:1"], value="nv-rerank-qa-mistral-4b:1"
),
SecretStrInput(name="api_key", display_name="API Key"),
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]),
]
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "base_url" and field_value:
try:
build_model = self.build_model()
ids = [model.id for model in build_model.available_models]
build_config["model"]["options"] = ids
build_config["model"]["value"] = ids[0]
except Exception as e:
raise ValueError(f"Error getting model names: {e}")
return build_config
def build_model(self):
try:
from langchain_nvidia_ai_endpoints import NVIDIARerank
except ImportError:
raise ImportError("Please install langchain-nvidia-ai-endpoints to use the NVIDIA model.")
return NVIDIARerank(api_key=self.api_key, model=self.model, base_url=self.base_url)
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
nvidia_reranker = self.build_model()
retriever = ContextualCompressionRetriever(base_compressor=nvidia_reranker, base_retriever=self.retriever)
return cast(Retriever, retriever)
async def search_documents(self) -> List[Data]: # type: ignore
retriever = self.build_base_retriever()
documents = await retriever.ainvoke(self.search_query)
data = self.to_data(documents)
self.status = data
return data

View file

@ -0,0 +1,9 @@
import React, { forwardRef } from "react";
import NvidiaSVG from "./nvidia";
export const NvidiaIcon = forwardRef<
SVGSVGElement,
React.PropsWithChildren<{}>
>((props, ref) => {
return <NvidiaSVG ref={ref} {...props} />;
});

View file

@ -0,0 +1,52 @@
const NvidiaSVG = (props) => (
<svg
version="1.1"
id="svg2"
xmlns="http://www.w3.org/2000/svg"
x="0px"
y="0px"
width="351.46px"
height="258.785px"
viewBox="35.188 31.512 351.46 258.785"
enable-background="new 35.188 31.512 351.46 258.785"
{...props}
>
<title id="title4">
generated by pstoedit version:3.44 from NVBadge_2D.eps
</title>
<path
id="path17"
d="M384.195,282.109c0,3.771-2.769,6.302-6.047,6.302v-0.023c-3.371,0.023-6.089-2.508-6.089-6.278
c0-3.769,2.718-6.293,6.089-6.293C381.427,275.816,384.195,278.34,384.195,282.109z M386.648,282.109c0-5.175-4.02-8.179-8.5-8.179
c-4.511,0-8.531,3.004-8.531,8.179c0,5.172,4.021,8.188,8.531,8.188C382.629,290.297,386.648,287.281,386.648,282.109
M376.738,282.801h0.91l2.109,3.703h2.316l-2.336-3.859c1.207-0.086,2.2-0.661,2.2-2.286c0-2.019-1.392-2.668-3.75-2.668h-3.411
v8.813h1.961V282.801 M376.738,281.309v-2.122h1.364c0.742,0,1.753,0.06,1.753,0.965c0,0.985-0.523,1.157-1.398,1.157H376.738"
/>
<path
id="path19"
d="M329.406,237.027l10.598,28.993H318.48L329.406,237.027z M318.056,225.738l-24.423,61.88h17.246l3.863-10.934
h28.903l3.656,10.934h18.722l-24.605-61.888L318.056,225.738z M269.023,287.641h17.497v-61.922l-17.5-0.004L269.023,287.641z
M147.556,225.715l-14.598,49.078l-13.984-49.074l-18.879-0.004l19.972,61.926h25.207l20.133-61.926H147.556z M218.281,239.199h7.52
c10.91,0,17.966,4.898,17.966,17.609c0,12.714-7.056,17.613-17.966,17.613h-7.52V239.199z M200.931,225.715v61.926h28.366
c15.113,0,20.048-2.512,25.384-8.148c3.769-3.957,6.207-12.641,6.207-22.134c0-8.707-2.063-16.468-5.66-21.304
c-6.481-8.649-15.817-10.34-29.75-10.34H200.931z M35.188,225.629v62.012h17.645v-47.086l13.672,0.004
c4.527,0,7.754,1.128,9.934,3.457c2.765,2.945,3.894,7.699,3.894,16.395v27.23h17.098v-34.262c0-24.453-15.586-27.75-30.836-27.75
H35.188z M172.771,225.715l0.007,61.926h17.489v-61.926H172.771z"
/>
<path
id="path21"
fill="#77B900"
d="M82.211,102.414c0,0,22.504-33.203,67.437-36.638V53.73
c-49.769,3.997-92.867,46.149-92.867,46.149s24.41,70.565,92.867,77.026v-12.804C99.411,157.781,82.211,102.414,82.211,102.414z
M149.648,138.637v11.726c-37.968-6.769-48.507-46.237-48.507-46.237s18.23-20.195,48.507-23.47v12.867
c-0.023,0-0.039-0.007-0.058-0.007c-15.891-1.907-28.305,12.938-28.305,12.938S128.243,131.445,149.648,138.637 M149.648,31.512
V53.73c1.461-0.112,2.922-0.207,4.391-0.257c56.582-1.907,93.449,46.406,93.449,46.406s-42.343,51.488-86.457,51.488
c-4.043,0-7.828-0.375-11.383-1.005v13.739c3.04,0.386,6.192,0.613,9.481,0.613c41.051,0,70.738-20.965,99.484-45.778
c4.766,3.817,24.278,13.103,28.289,17.168c-27.332,22.883-91.031,41.329-127.144,41.329c-3.481,0-6.824-0.211-10.11-0.528v19.306
h156.032V31.512H149.648z M149.648,80.656V65.777c1.446-0.101,2.903-0.179,4.391-0.226c40.688-1.278,67.382,34.965,67.382,34.965
s-28.832,40.043-59.746,40.043c-4.449,0-8.438-0.715-12.028-1.922V93.523c15.84,1.914,19.028,8.911,28.551,24.786l21.18-17.859
c0,0-15.461-20.277-41.524-20.277C155.021,80.172,152.31,80.371,149.648,80.656"
/>
</svg>
);
export default NvidiaSVG;

View file

@ -0,0 +1,32 @@
<?xml version="1.0" encoding="utf-8"?>
<!-- Generator: Adobe Illustrator 16.0.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg version="1.1" id="svg2" xmlns:svg="http://www.w3.org/2000/svg"
xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" width="351.46px"
height="258.785px" viewBox="35.188 31.512 351.46 258.785" enable-background="new 35.188 31.512 351.46 258.785"
xml:space="preserve">
<title id="title4">generated by pstoedit version:3.44 from NVBadge_2D.eps</title>
<path id="path17" d="M384.195,282.109c0,3.771-2.769,6.302-6.047,6.302v-0.023c-3.371,0.023-6.089-2.508-6.089-6.278
c0-3.769,2.718-6.293,6.089-6.293C381.427,275.816,384.195,278.34,384.195,282.109z M386.648,282.109c0-5.175-4.02-8.179-8.5-8.179
c-4.511,0-8.531,3.004-8.531,8.179c0,5.172,4.021,8.188,8.531,8.188C382.629,290.297,386.648,287.281,386.648,282.109
M376.738,282.801h0.91l2.109,3.703h2.316l-2.336-3.859c1.207-0.086,2.2-0.661,2.2-2.286c0-2.019-1.392-2.668-3.75-2.668h-3.411
v8.813h1.961V282.801 M376.738,281.309v-2.122h1.364c0.742,0,1.753,0.06,1.753,0.965c0,0.985-0.523,1.157-1.398,1.157H376.738"/>
<path id="path19" d="M329.406,237.027l10.598,28.993H318.48L329.406,237.027z M318.056,225.738l-24.423,61.88h17.246l3.863-10.934
h28.903l3.656,10.934h18.722l-24.605-61.888L318.056,225.738z M269.023,287.641h17.497v-61.922l-17.5-0.004L269.023,287.641z
M147.556,225.715l-14.598,49.078l-13.984-49.074l-18.879-0.004l19.972,61.926h25.207l20.133-61.926H147.556z M218.281,239.199h7.52
c10.91,0,17.966,4.898,17.966,17.609c0,12.714-7.056,17.613-17.966,17.613h-7.52V239.199z M200.931,225.715v61.926h28.366
c15.113,0,20.048-2.512,25.384-8.148c3.769-3.957,6.207-12.641,6.207-22.134c0-8.707-2.063-16.468-5.66-21.304
c-6.481-8.649-15.817-10.34-29.75-10.34H200.931z M35.188,225.629v62.012h17.645v-47.086l13.672,0.004
c4.527,0,7.754,1.128,9.934,3.457c2.765,2.945,3.894,7.699,3.894,16.395v27.23h17.098v-34.262c0-24.453-15.586-27.75-30.836-27.75
H35.188z M172.771,225.715l0.007,61.926h17.489v-61.926H172.771z"/>
<path id="path21" fill="#77B900" d="M82.211,102.414c0,0,22.504-33.203,67.437-36.638V53.73
c-49.769,3.997-92.867,46.149-92.867,46.149s24.41,70.565,92.867,77.026v-12.804C99.411,157.781,82.211,102.414,82.211,102.414z
M149.648,138.637v11.726c-37.968-6.769-48.507-46.237-48.507-46.237s18.23-20.195,48.507-23.47v12.867
c-0.023,0-0.039-0.007-0.058-0.007c-15.891-1.907-28.305,12.938-28.305,12.938S128.243,131.445,149.648,138.637 M149.648,31.512
V53.73c1.461-0.112,2.922-0.207,4.391-0.257c56.582-1.907,93.449,46.406,93.449,46.406s-42.343,51.488-86.457,51.488
c-4.043,0-7.828-0.375-11.383-1.005v13.739c3.04,0.386,6.192,0.613,9.481,0.613c41.051,0,70.738-20.965,99.484-45.778
c4.766,3.817,24.278,13.103,28.289,17.168c-27.332,22.883-91.031,41.329-127.144,41.329c-3.481,0-6.824-0.211-10.11-0.528v19.306
h156.032V31.512H149.648z M149.648,80.656V65.777c1.446-0.101,2.903-0.179,4.391-0.226c40.688-1.278,67.382,34.965,67.382,34.965
s-28.832,40.043-59.746,40.043c-4.449,0-8.438-0.715-12.028-1.922V93.523c15.84,1.914,19.028,8.911,28.551,24.786l21.18-17.859
c0,0-15.461-20.277-41.524-20.277C155.021,80.172,152.31,80.371,149.648,80.656"/>
</svg>

After

Width:  |  Height:  |  Size: 3.2 KiB

View file

@ -188,6 +188,7 @@ import { MetaIcon } from "../icons/Meta";
import { MidjourneyIcon } from "../icons/Midjorney";
import { MongoDBIcon } from "../icons/MongoDB";
import { NotionIcon } from "../icons/Notion";
import { NvidiaIcon } from "../icons/Nvidia";
import { OllamaIcon } from "../icons/Ollama";
import { OpenAiIcon } from "../icons/OpenAi";
import { PineconeIcon } from "../icons/Pinecone";
@ -384,6 +385,7 @@ export const nodeIconsLucide: iconsType = {
MongoDB: MongoDBIcon,
MongoDBChatMessageHistory: MongoDBIcon,
NotionDirectoryLoader: NotionIcon,
NVIDIA: NvidiaIcon,
ChatOpenAI: OpenAiIcon,
AzureChatOpenAI: OpenAiIcon,
OpenAI: OpenAiIcon,