diff --git a/src/backend/base/langflow/components/models/huggingface.py b/src/backend/base/langflow/components/models/huggingface.py index 5c3ae585e..e65c5ed45 100644 --- a/src/backend/base/langflow/components/models/huggingface.py +++ b/src/backend/base/langflow/components/models/huggingface.py @@ -8,7 +8,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed from langflow.base.models.model import LCModelComponent from langflow.field_typing import LanguageModel from langflow.inputs.inputs import HandleInput -from langflow.io import DictInput, DropdownInput, IntInput, SecretStrInput, StrInput +from langflow.io import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput class HuggingFaceEndpointsComponent(LCModelComponent): @@ -20,6 +20,45 @@ class HuggingFaceEndpointsComponent(LCModelComponent): inputs = [ *LCModelComponent._base_inputs, StrInput(name="model_id", display_name="Model ID", value="openai-community/gpt2"), + IntInput( + name="max_new_tokens", display_name="Max New Tokens", value=512, info="Maximum number of generated tokens" + ), + IntInput( + name="top_k", + display_name="Top K", + advanced=True, + info="The number of highest probability vocabulary tokens to keep for top-k-filtering", + ), + FloatInput( + name="top_p", + display_name="Top P", + value=0.95, + advanced=True, + info=( + "If set to < 1, only the smallest set of most probable tokens with " + "probabilities that add up to `top_p` or higher are kept for generation" + ), + ), + FloatInput( + name="typical_p", + display_name="Typical P", + value=0.95, + advanced=True, + info="Typical Decoding mass.", + ), + FloatInput( + name="temperature", + display_name="Temperature", + value=0.8, + advanced=True, + info="The value used to module the logits distribution", + ), + FloatInput( + name="repetition_penalty", + display_name="Repetition Penalty", + info="The parameter for repetition penalty. 1.0 means no penalty.", + advanced=True, + ), StrInput( name="inference_endpoint", display_name="Inference Endpoint", diff --git a/src/backend/tests/unit/components/models/test_huggingface.py b/src/backend/tests/unit/components/models/test_huggingface.py new file mode 100644 index 000000000..6bd313e4f --- /dev/null +++ b/src/backend/tests/unit/components/models/test_huggingface.py @@ -0,0 +1,31 @@ +from langflow.inputs.inputs import DictInput, DropdownInput, FloatInput, HandleInput, IntInput, SecretStrInput, StrInput + +from src.backend.base.langflow.components.models.huggingface import HuggingFaceEndpointsComponent + + +def test_huggingface_inputs(): + component = HuggingFaceEndpointsComponent() + inputs = component.inputs + + # Define expected input types and their names + expected_inputs = { + "model_id": StrInput, + "max_new_tokens": IntInput, + "top_k": IntInput, + "top_p": FloatInput, + "typical_p": FloatInput, + "temperature": FloatInput, + "repetition_penalty": FloatInput, + "inference_endpoint": StrInput, + "task": DropdownInput, + "huggingfacehub_api_token": SecretStrInput, + "model_kwargs": DictInput, + "retry_attempts": IntInput, + "output_parser": HandleInput, + } + + # Check if all expected inputs are present + for name, input_type in expected_inputs.items(): + assert any( + isinstance(inp, input_type) and inp.name == name for inp in inputs + ), f"Missing or incorrect input: {name}"