From 8b68bb74fbe5bbab461b5d516eb1cb9bbb9727cf Mon Sep 17 00:00:00 2001 From: Edwin Jose Date: Wed, 15 Jan 2025 19:15:46 -0500 Subject: [PATCH] feat: Add Support for validating Tool Mode Models in NVIDIA LLM Component (#5703) * Update nvidia.py * ruff format --------- Co-authored-by: Gabriel Luiz Freitas Almeida --- .../base/langflow/components/models/nvidia.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/backend/base/langflow/components/models/nvidia.py b/src/backend/base/langflow/components/models/nvidia.py index 58bd6da8c..89132d3c6 100644 --- a/src/backend/base/langflow/components/models/nvidia.py +++ b/src/backend/base/langflow/components/models/nvidia.py @@ -2,7 +2,7 @@ from typing import Any from langflow.base.models.model import LCModelComponent from langflow.field_typing import LanguageModel -from langflow.inputs import DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput +from langflow.inputs import BoolInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput from langflow.schema.dotdict import dotdict @@ -33,6 +33,15 @@ class NVIDIAModelComponent(LCModelComponent): refresh_button=True, info="The base URL of the NVIDIA API. Defaults to https://integrate.api.nvidia.com/v1.", ), + BoolInput( + name="tool_model_enabled", + display_name="Enable Tool Models", + info=( + "Select if you want to use models that can work with tools. If yes, only those models will be shown." + ), + advanced=False, + value=True, + ), SecretStrInput( name="nvidia_api_key", display_name="NVIDIA API Key", @@ -50,11 +59,17 @@ class NVIDIAModelComponent(LCModelComponent): ), ] + def get_models(self, tool_model_enabled: bool | None = None) -> list[str]: + build_model = self.build_model() + if tool_model_enabled: + tool_models = [model for model in build_model.get_available_models() if model.supports_tools] + return [model.id for model in tool_models] + return [model.id for model in build_model.available_models] + def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): if field_name == "base_url" and field_value: try: - build_model = self.build_model() - ids = [model.id for model in build_model.available_models] + ids = self.get_models(self.tool_model_enabled) build_config["model_name"]["options"] = ids build_config["model_name"]["value"] = ids[0] except Exception as e: