feat: improve model input fields for HuggingFace model component (#5723)
* feat: improve model input fields for huggingface component 1. Make model_id, inference_endpoint and api_token fields required 2. Convert temperature to SliderInput with range 0-2 3. Update temperature info to match HuggingFace's description Co-authored-by: Vinícios Batista da Silva <vinicios.batsi@gmail.com> * feat(huggingface): update model selection and temperature input This commit enhances the HuggingFace component by implementing a fixed list of large language models in the dropdown selection and improving the temperature input with a slider control. Key changes in src/backend/base/langflow/components/models/huggingface.py: - Add DEFAULT_MODEL constant set to "meta-llama/Llama-3.3-70B-Instruct" - Replace model_id StrInput with DropdownInput containing pre-selected models: * meta-llama/Llama-3.3-70B-Instruct (default) * mistralai/Mixtral-8x7B-Instruct-v0.1 * mistralai/Mistral-7B-Instruct-v0.3 * meta-llama/Llama-3.1-8B-Instruct * Qwen/Qwen2.5-Coder-32B-Instruct * Qwen/QwQ-32B-Preview * openai-community/gpt2 * custom option - Add real_time_refresh to model_id dropdown for dynamic updates - Implement custom model input field that shows/hides based on selection - Replace temperature FloatInput with SliderInput for better UX: * Added RangeSpec with min=0, max=2, step=0.01 * Maintains default value of 0.8 - Add build_config update logic to handle custom model visibility - Update API URL generation to support custom model IDs - Import RangeSpec and SliderInput from langflow packages Co-authored-by: Vinícios Batista da Silva <vinicios.batsi@gmail.com> * [autofix.ci] apply automated fixes * test(huggingface): update tests for enhanced model selection and UI controls This commit updates the test suite to verify the new features and changes in the HuggingFace component, ensuring proper functionality of the model selection dropdown and improved temperature control. Key changes in src/backend/tests/unit/components/models/test_huggingface.py: - Update test_huggingface_inputs to verify new DropdownInput for model_id: * Check DEFAULT_MODEL as default value * Verify presence of 'custom' option * Validate required and real_time_refresh settings - Add verification for custom_model field: * Confirm initial custom_model hidden state * Verify required flag - Add specific checks for temperature SliderInput: * Validate default value of 0.8 * Verify RangeSpec configuration (min=0, max=2, step=0.01) - Improve test structure with detailed assertions for field configurations - Update imports to include DEFAULT_MODEL constant Related to previous commit that enhanced the HuggingFace component with fixed model list and slider controls. Co-authored-by: Vinícios Batista da Silva <vinicios.batsi@gmail.com> --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1fdc79b0b2
commit
349f3441c8
2 changed files with 86 additions and 15 deletions
|
|
@ -3,11 +3,16 @@ from typing import Any
|
|||
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
# TODO: langchain_community.llms.huggingface_endpoint is depreciated.
|
||||
# Need to update to langchain_huggingface, but have dependency with langchain_core 0.3.0
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.io import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput
|
||||
from langflow.field_typing.range_spec import RangeSpec
|
||||
from langflow.io import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, SliderInput, StrInput
|
||||
|
||||
# TODO: langchain_community.llms.huggingface_endpoint is depreciated.
|
||||
# Need to update to langchain_huggingface, but have dependency with langchain_core 0.3.0
|
||||
|
||||
# Constants
|
||||
DEFAULT_MODEL = "meta-llama/Llama-3.3-70B-Instruct"
|
||||
|
||||
|
||||
class HuggingFaceEndpointsComponent(LCModelComponent):
|
||||
|
|
@ -18,7 +23,32 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
|
|||
|
||||
inputs = [
|
||||
*LCModelComponent._base_inputs,
|
||||
StrInput(name="model_id", display_name="Model ID", value="openai-community/gpt2"),
|
||||
DropdownInput(
|
||||
name="model_id",
|
||||
display_name="Model ID",
|
||||
info="Select a model from HuggingFace Hub",
|
||||
options=[
|
||||
DEFAULT_MODEL,
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
"Qwen/QwQ-32B-Preview",
|
||||
"openai-community/gpt2",
|
||||
"custom",
|
||||
],
|
||||
value=DEFAULT_MODEL,
|
||||
required=True,
|
||||
real_time_refresh=True,
|
||||
),
|
||||
StrInput(
|
||||
name="custom_model",
|
||||
display_name="Custom Model ID",
|
||||
info="Enter a custom model ID from HuggingFace Hub",
|
||||
value="",
|
||||
show=False,
|
||||
required=True,
|
||||
),
|
||||
IntInput(
|
||||
name="max_new_tokens", display_name="Max New Tokens", value=512, info="Maximum number of generated tokens"
|
||||
),
|
||||
|
|
@ -45,12 +75,13 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
|
|||
advanced=True,
|
||||
info="Typical Decoding mass.",
|
||||
),
|
||||
FloatInput(
|
||||
SliderInput(
|
||||
name="temperature",
|
||||
display_name="Temperature",
|
||||
value=0.8,
|
||||
advanced=True,
|
||||
range_spec=RangeSpec(min=0, max=2, step=0.01),
|
||||
info="The value used to module the logits distribution",
|
||||
advanced=True,
|
||||
),
|
||||
FloatInput(
|
||||
name="repetition_penalty",
|
||||
|
|
@ -63,24 +94,47 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
|
|||
display_name="Inference Endpoint",
|
||||
value="https://api-inference.huggingface.co/models/",
|
||||
info="Custom inference endpoint URL.",
|
||||
required=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="task",
|
||||
display_name="Task",
|
||||
options=["text2text-generation", "text-generation", "summarization", "translation"],
|
||||
value="text-generation",
|
||||
advanced=True,
|
||||
info="The task to call the model with. Should be a task that returns `generated_text` or `summary_text`.",
|
||||
),
|
||||
SecretStrInput(name="huggingfacehub_api_token", display_name="API Token", password=True),
|
||||
SecretStrInput(name="huggingfacehub_api_token", display_name="API Token", password=True, required=True),
|
||||
DictInput(name="model_kwargs", display_name="Model Keyword Arguments", advanced=True),
|
||||
IntInput(name="retry_attempts", display_name="Retry Attempts", value=1, advanced=True),
|
||||
]
|
||||
|
||||
def get_api_url(self) -> str:
|
||||
if "huggingface" in self.inference_endpoint.lower():
|
||||
if self.model_id == "custom":
|
||||
if not self.custom_model:
|
||||
error_msg = "Custom model ID is required when 'custom' is selected"
|
||||
raise ValueError(error_msg)
|
||||
return f"{self.inference_endpoint}{self.custom_model}"
|
||||
return f"{self.inference_endpoint}{self.model_id}"
|
||||
return self.inference_endpoint
|
||||
|
||||
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None) -> dict:
|
||||
"""Update build configuration based on field updates."""
|
||||
try:
|
||||
if field_name is None or field_name == "model_id":
|
||||
# If model_id is custom, show custom model field
|
||||
if field_value == "custom":
|
||||
build_config["custom_model"]["show"] = True
|
||||
build_config["custom_model"]["required"] = True
|
||||
else:
|
||||
build_config["custom_model"]["show"] = False
|
||||
build_config["custom_model"]["value"] = ""
|
||||
|
||||
except (KeyError, AttributeError) as e:
|
||||
self.log(f"Error updating build config: {e!s}")
|
||||
return build_config
|
||||
|
||||
def create_huggingface_endpoint(
|
||||
self,
|
||||
task: str | None,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from langflow.components.models.huggingface import HuggingFaceEndpointsComponent
|
||||
from langflow.inputs.inputs import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput
|
||||
from langflow.components.models.huggingface import DEFAULT_MODEL, HuggingFaceEndpointsComponent
|
||||
from langflow.inputs.inputs import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, SliderInput, StrInput
|
||||
|
||||
|
||||
def test_huggingface_inputs():
|
||||
|
|
@ -8,12 +8,13 @@ def test_huggingface_inputs():
|
|||
|
||||
# Define expected input types and their names
|
||||
expected_inputs = {
|
||||
"model_id": StrInput,
|
||||
"model_id": DropdownInput,
|
||||
"custom_model": StrInput,
|
||||
"max_new_tokens": IntInput,
|
||||
"top_k": IntInput,
|
||||
"top_p": FloatInput,
|
||||
"typical_p": FloatInput,
|
||||
"temperature": FloatInput,
|
||||
"temperature": SliderInput,
|
||||
"repetition_penalty": FloatInput,
|
||||
"inference_endpoint": StrInput,
|
||||
"task": DropdownInput,
|
||||
|
|
@ -22,8 +23,24 @@ def test_huggingface_inputs():
|
|||
"retry_attempts": IntInput,
|
||||
}
|
||||
|
||||
# Check if all expected inputs are present
|
||||
# Check if all expected inputs are present and have correct type
|
||||
for name, input_type in expected_inputs.items():
|
||||
assert any(isinstance(inp, input_type) and inp.name == name for inp in inputs), (
|
||||
f"Missing or incorrect input: {name}"
|
||||
)
|
||||
matching_inputs = [inp for inp in inputs if isinstance(inp, input_type) and inp.name == name]
|
||||
assert matching_inputs, f"Missing or incorrect input: {name}"
|
||||
|
||||
if name == "model_id":
|
||||
input_field = matching_inputs[0]
|
||||
assert input_field.value == DEFAULT_MODEL
|
||||
assert "custom" in input_field.options
|
||||
assert input_field.required is True
|
||||
assert input_field.real_time_refresh is True
|
||||
elif name == "custom_model":
|
||||
input_field = matching_inputs[0]
|
||||
assert input_field.show is False
|
||||
assert input_field.required is True
|
||||
elif name == "temperature":
|
||||
input_field = matching_inputs[0]
|
||||
assert input_field.value == 0.8
|
||||
assert input_field.range_spec.min == 0
|
||||
assert input_field.range_spec.max == 2
|
||||
assert input_field.range_spec.step == 0.01
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue