fix: Add missing inputs for HuggingFace model component and include pytest (#4291)

* fix: Add missing inputs for HuggingFace model component and include pytest

- Added missing inputs to the HuggingFace model component.
- Implemented pytest to ensure the inputs are correctly handled.

* remove print statement
This commit is contained in:
Edwin Jose 2024-10-26 15:41:32 -04:00 committed by GitHub
commit a0c42d148d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 71 additions and 1 deletions

View file

@ -8,7 +8,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.inputs.inputs import HandleInput
from langflow.io import DictInput, DropdownInput, IntInput, SecretStrInput, StrInput
from langflow.io import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput
class HuggingFaceEndpointsComponent(LCModelComponent):
@ -20,6 +20,45 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
inputs = [
*LCModelComponent._base_inputs,
StrInput(name="model_id", display_name="Model ID", value="openai-community/gpt2"),
IntInput(
name="max_new_tokens", display_name="Max New Tokens", value=512, info="Maximum number of generated tokens"
),
IntInput(
name="top_k",
display_name="Top K",
advanced=True,
info="The number of highest probability vocabulary tokens to keep for top-k-filtering",
),
FloatInput(
name="top_p",
display_name="Top P",
value=0.95,
advanced=True,
info=(
"If set to < 1, only the smallest set of most probable tokens with "
"probabilities that add up to `top_p` or higher are kept for generation"
),
),
FloatInput(
name="typical_p",
display_name="Typical P",
value=0.95,
advanced=True,
info="Typical Decoding mass.",
),
FloatInput(
name="temperature",
display_name="Temperature",
value=0.8,
advanced=True,
info="The value used to module the logits distribution",
),
FloatInput(
name="repetition_penalty",
display_name="Repetition Penalty",
info="The parameter for repetition penalty. 1.0 means no penalty.",
advanced=True,
),
StrInput(
name="inference_endpoint",
display_name="Inference Endpoint",

View file

@ -0,0 +1,31 @@
from langflow.inputs.inputs import DictInput, DropdownInput, FloatInput, HandleInput, IntInput, SecretStrInput, StrInput
from src.backend.base.langflow.components.models.huggingface import HuggingFaceEndpointsComponent
def test_huggingface_inputs():
component = HuggingFaceEndpointsComponent()
inputs = component.inputs
# Define expected input types and their names
expected_inputs = {
"model_id": StrInput,
"max_new_tokens": IntInput,
"top_k": IntInput,
"top_p": FloatInput,
"typical_p": FloatInput,
"temperature": FloatInput,
"repetition_penalty": FloatInput,
"inference_endpoint": StrInput,
"task": DropdownInput,
"huggingfacehub_api_token": SecretStrInput,
"model_kwargs": DictInput,
"retry_attempts": IntInput,
"output_parser": HandleInput,
}
# Check if all expected inputs are present
for name, input_type in expected_inputs.items():
assert any(
isinstance(inp, input_type) and inp.name == name for inp in inputs
), f"Missing or incorrect input: {name}"