diff --git a/src/backend/base/langflow/components/models/huggingface.py b/src/backend/base/langflow/components/models/huggingface.py index 96eaa89b1..88b8651cf 100644 --- a/src/backend/base/langflow/components/models/huggingface.py +++ b/src/backend/base/langflow/components/models/huggingface.py @@ -3,11 +3,16 @@ from typing import Any from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint from tenacity import retry, stop_after_attempt, wait_fixed -# TODO: langchain_community.llms.huggingface_endpoint is depreciated. -# Need to update to langchain_huggingface, but have dependency with langchain_core 0.3.0 from langflow.base.models.model import LCModelComponent from langflow.field_typing import LanguageModel -from langflow.io import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput +from langflow.field_typing.range_spec import RangeSpec +from langflow.io import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, SliderInput, StrInput + +# TODO: langchain_community.llms.huggingface_endpoint is depreciated. +# Need to update to langchain_huggingface, but have dependency with langchain_core 0.3.0 + +# Constants +DEFAULT_MODEL = "meta-llama/Llama-3.3-70B-Instruct" class HuggingFaceEndpointsComponent(LCModelComponent): @@ -18,7 +23,32 @@ class HuggingFaceEndpointsComponent(LCModelComponent): inputs = [ *LCModelComponent._base_inputs, - StrInput(name="model_id", display_name="Model ID", value="openai-community/gpt2"), + DropdownInput( + name="model_id", + display_name="Model ID", + info="Select a model from HuggingFace Hub", + options=[ + DEFAULT_MODEL, + "mistralai/Mixtral-8x7B-Instruct-v0.1", + "mistralai/Mistral-7B-Instruct-v0.3", + "meta-llama/Llama-3.1-8B-Instruct", + "Qwen/Qwen2.5-Coder-32B-Instruct", + "Qwen/QwQ-32B-Preview", + "openai-community/gpt2", + "custom", + ], + value=DEFAULT_MODEL, + required=True, + real_time_refresh=True, + ), + StrInput( + name="custom_model", + display_name="Custom Model ID", + info="Enter a custom model ID from HuggingFace Hub", + value="", + show=False, + required=True, + ), IntInput( name="max_new_tokens", display_name="Max New Tokens", value=512, info="Maximum number of generated tokens" ), @@ -45,12 +75,13 @@ class HuggingFaceEndpointsComponent(LCModelComponent): advanced=True, info="Typical Decoding mass.", ), - FloatInput( + SliderInput( name="temperature", display_name="Temperature", value=0.8, - advanced=True, + range_spec=RangeSpec(min=0, max=2, step=0.01), info="The value used to module the logits distribution", + advanced=True, ), FloatInput( name="repetition_penalty", @@ -63,24 +94,47 @@ class HuggingFaceEndpointsComponent(LCModelComponent): display_name="Inference Endpoint", value="https://api-inference.huggingface.co/models/", info="Custom inference endpoint URL.", + required=True, ), DropdownInput( name="task", display_name="Task", options=["text2text-generation", "text-generation", "summarization", "translation"], + value="text-generation", advanced=True, info="The task to call the model with. Should be a task that returns `generated_text` or `summary_text`.", ), - SecretStrInput(name="huggingfacehub_api_token", display_name="API Token", password=True), + SecretStrInput(name="huggingfacehub_api_token", display_name="API Token", password=True, required=True), DictInput(name="model_kwargs", display_name="Model Keyword Arguments", advanced=True), IntInput(name="retry_attempts", display_name="Retry Attempts", value=1, advanced=True), ] def get_api_url(self) -> str: if "huggingface" in self.inference_endpoint.lower(): + if self.model_id == "custom": + if not self.custom_model: + error_msg = "Custom model ID is required when 'custom' is selected" + raise ValueError(error_msg) + return f"{self.inference_endpoint}{self.custom_model}" return f"{self.inference_endpoint}{self.model_id}" return self.inference_endpoint + async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None) -> dict: + """Update build configuration based on field updates.""" + try: + if field_name is None or field_name == "model_id": + # If model_id is custom, show custom model field + if field_value == "custom": + build_config["custom_model"]["show"] = True + build_config["custom_model"]["required"] = True + else: + build_config["custom_model"]["show"] = False + build_config["custom_model"]["value"] = "" + + except (KeyError, AttributeError) as e: + self.log(f"Error updating build config: {e!s}") + return build_config + def create_huggingface_endpoint( self, task: str | None, diff --git a/src/backend/tests/unit/components/models/test_huggingface.py b/src/backend/tests/unit/components/models/test_huggingface.py index e3e3073c8..17023cf54 100644 --- a/src/backend/tests/unit/components/models/test_huggingface.py +++ b/src/backend/tests/unit/components/models/test_huggingface.py @@ -1,5 +1,5 @@ -from langflow.components.models.huggingface import HuggingFaceEndpointsComponent -from langflow.inputs.inputs import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput +from langflow.components.models.huggingface import DEFAULT_MODEL, HuggingFaceEndpointsComponent +from langflow.inputs.inputs import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, SliderInput, StrInput def test_huggingface_inputs(): @@ -8,12 +8,13 @@ def test_huggingface_inputs(): # Define expected input types and their names expected_inputs = { - "model_id": StrInput, + "model_id": DropdownInput, + "custom_model": StrInput, "max_new_tokens": IntInput, "top_k": IntInput, "top_p": FloatInput, "typical_p": FloatInput, - "temperature": FloatInput, + "temperature": SliderInput, "repetition_penalty": FloatInput, "inference_endpoint": StrInput, "task": DropdownInput, @@ -22,8 +23,24 @@ def test_huggingface_inputs(): "retry_attempts": IntInput, } - # Check if all expected inputs are present + # Check if all expected inputs are present and have correct type for name, input_type in expected_inputs.items(): - assert any(isinstance(inp, input_type) and inp.name == name for inp in inputs), ( - f"Missing or incorrect input: {name}" - ) + matching_inputs = [inp for inp in inputs if isinstance(inp, input_type) and inp.name == name] + assert matching_inputs, f"Missing or incorrect input: {name}" + + if name == "model_id": + input_field = matching_inputs[0] + assert input_field.value == DEFAULT_MODEL + assert "custom" in input_field.options + assert input_field.required is True + assert input_field.real_time_refresh is True + elif name == "custom_model": + input_field = matching_inputs[0] + assert input_field.show is False + assert input_field.required is True + elif name == "temperature": + input_field = matching_inputs[0] + assert input_field.value == 0.8 + assert input_field.range_spec.min == 0 + assert input_field.range_spec.max == 2 + assert input_field.range_spec.step == 0.01