fix: Add missing inputs for HuggingFace model component and include pytest (#4291)
* fix: Add missing inputs for HuggingFace model component and include pytest - Added missing inputs to the HuggingFace model component. - Implemented pytest to ensure the inputs are correctly handled. * remove print statement
This commit is contained in:
parent
3131c0ce08
commit
a0c42d148d
2 changed files with 71 additions and 1 deletions
|
|
@ -8,7 +8,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed
|
|||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.inputs.inputs import HandleInput
|
||||
from langflow.io import DictInput, DropdownInput, IntInput, SecretStrInput, StrInput
|
||||
from langflow.io import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput
|
||||
|
||||
|
||||
class HuggingFaceEndpointsComponent(LCModelComponent):
|
||||
|
|
@ -20,6 +20,45 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
|
|||
inputs = [
|
||||
*LCModelComponent._base_inputs,
|
||||
StrInput(name="model_id", display_name="Model ID", value="openai-community/gpt2"),
|
||||
IntInput(
|
||||
name="max_new_tokens", display_name="Max New Tokens", value=512, info="Maximum number of generated tokens"
|
||||
),
|
||||
IntInput(
|
||||
name="top_k",
|
||||
display_name="Top K",
|
||||
advanced=True,
|
||||
info="The number of highest probability vocabulary tokens to keep for top-k-filtering",
|
||||
),
|
||||
FloatInput(
|
||||
name="top_p",
|
||||
display_name="Top P",
|
||||
value=0.95,
|
||||
advanced=True,
|
||||
info=(
|
||||
"If set to < 1, only the smallest set of most probable tokens with "
|
||||
"probabilities that add up to `top_p` or higher are kept for generation"
|
||||
),
|
||||
),
|
||||
FloatInput(
|
||||
name="typical_p",
|
||||
display_name="Typical P",
|
||||
value=0.95,
|
||||
advanced=True,
|
||||
info="Typical Decoding mass.",
|
||||
),
|
||||
FloatInput(
|
||||
name="temperature",
|
||||
display_name="Temperature",
|
||||
value=0.8,
|
||||
advanced=True,
|
||||
info="The value used to module the logits distribution",
|
||||
),
|
||||
FloatInput(
|
||||
name="repetition_penalty",
|
||||
display_name="Repetition Penalty",
|
||||
info="The parameter for repetition penalty. 1.0 means no penalty.",
|
||||
advanced=True,
|
||||
),
|
||||
StrInput(
|
||||
name="inference_endpoint",
|
||||
display_name="Inference Endpoint",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue