fix: Add missing inputs for HuggingFace model component and include pytest (#4291)
* fix: Add missing inputs for HuggingFace model component and include pytest - Added missing inputs to the HuggingFace model component. - Implemented pytest to ensure the inputs are correctly handled. * remove print statement
This commit is contained in:
parent
3131c0ce08
commit
a0c42d148d
2 changed files with 71 additions and 1 deletions
|
|
@ -8,7 +8,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed
|
|||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.inputs.inputs import HandleInput
|
||||
from langflow.io import DictInput, DropdownInput, IntInput, SecretStrInput, StrInput
|
||||
from langflow.io import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput
|
||||
|
||||
|
||||
class HuggingFaceEndpointsComponent(LCModelComponent):
|
||||
|
|
@ -20,6 +20,45 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
|
|||
inputs = [
|
||||
*LCModelComponent._base_inputs,
|
||||
StrInput(name="model_id", display_name="Model ID", value="openai-community/gpt2"),
|
||||
IntInput(
|
||||
name="max_new_tokens", display_name="Max New Tokens", value=512, info="Maximum number of generated tokens"
|
||||
),
|
||||
IntInput(
|
||||
name="top_k",
|
||||
display_name="Top K",
|
||||
advanced=True,
|
||||
info="The number of highest probability vocabulary tokens to keep for top-k-filtering",
|
||||
),
|
||||
FloatInput(
|
||||
name="top_p",
|
||||
display_name="Top P",
|
||||
value=0.95,
|
||||
advanced=True,
|
||||
info=(
|
||||
"If set to < 1, only the smallest set of most probable tokens with "
|
||||
"probabilities that add up to `top_p` or higher are kept for generation"
|
||||
),
|
||||
),
|
||||
FloatInput(
|
||||
name="typical_p",
|
||||
display_name="Typical P",
|
||||
value=0.95,
|
||||
advanced=True,
|
||||
info="Typical Decoding mass.",
|
||||
),
|
||||
FloatInput(
|
||||
name="temperature",
|
||||
display_name="Temperature",
|
||||
value=0.8,
|
||||
advanced=True,
|
||||
info="The value used to module the logits distribution",
|
||||
),
|
||||
FloatInput(
|
||||
name="repetition_penalty",
|
||||
display_name="Repetition Penalty",
|
||||
info="The parameter for repetition penalty. 1.0 means no penalty.",
|
||||
advanced=True,
|
||||
),
|
||||
StrInput(
|
||||
name="inference_endpoint",
|
||||
display_name="Inference Endpoint",
|
||||
|
|
|
|||
31
src/backend/tests/unit/components/models/test_huggingface.py
Normal file
31
src/backend/tests/unit/components/models/test_huggingface.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from langflow.inputs.inputs import DictInput, DropdownInput, FloatInput, HandleInput, IntInput, SecretStrInput, StrInput
|
||||
|
||||
from src.backend.base.langflow.components.models.huggingface import HuggingFaceEndpointsComponent
|
||||
|
||||
|
||||
def test_huggingface_inputs():
|
||||
component = HuggingFaceEndpointsComponent()
|
||||
inputs = component.inputs
|
||||
|
||||
# Define expected input types and their names
|
||||
expected_inputs = {
|
||||
"model_id": StrInput,
|
||||
"max_new_tokens": IntInput,
|
||||
"top_k": IntInput,
|
||||
"top_p": FloatInput,
|
||||
"typical_p": FloatInput,
|
||||
"temperature": FloatInput,
|
||||
"repetition_penalty": FloatInput,
|
||||
"inference_endpoint": StrInput,
|
||||
"task": DropdownInput,
|
||||
"huggingfacehub_api_token": SecretStrInput,
|
||||
"model_kwargs": DictInput,
|
||||
"retry_attempts": IntInput,
|
||||
"output_parser": HandleInput,
|
||||
}
|
||||
|
||||
# Check if all expected inputs are present
|
||||
for name, input_type in expected_inputs.items():
|
||||
assert any(
|
||||
isinstance(inp, input_type) and inp.name == name for inp in inputs
|
||||
), f"Missing or incorrect input: {name}"
|
||||
Loading…
Add table
Add a link
Reference in a new issue