diff --git a/src/backend/langflow/interface/llms/base.py b/src/backend/langflow/interface/llms/base.py index 85f9035db..91eefd845 100644 --- a/src/backend/langflow/interface/llms/base.py +++ b/src/backend/langflow/interface/llms/base.py @@ -1,14 +1,19 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Type from langflow.interface.base import LangChainTypeCreator from langflow.interface.custom_lists import llm_type_to_cls_dict from langflow.settings import settings +from langflow.template.nodes import LLMFrontendNode from langflow.utils.util import build_template_from_class class LLMCreator(LangChainTypeCreator): type_name: str = "llms" + @property + def frontend_node_class(self) -> Type[LLMFrontendNode]: + return LLMFrontendNode + @property def type_to_loader_dict(self) -> Dict: if self.type_dict is None: diff --git a/src/backend/langflow/template/nodes.py b/src/backend/langflow/template/nodes.py index c5d9bf23d..99c180c9e 100644 --- a/src/backend/langflow/template/nodes.py +++ b/src/backend/langflow/template/nodes.py @@ -299,3 +299,31 @@ class ChainFrontendNode(FrontendNode): if "key" in field.name: field.password = False field.show = False + + +class LLMFrontendNode(FrontendNode): + @staticmethod + def format_field(field: TemplateField, name: Optional[str] = None) -> None: + display_names_dict = { + "huggingfacehub_api_token": "HuggingFace Hub API Token", + } + FrontendNode.format_field(field, name) + SHOW_FIELDS = ["repo_id", "task", "model_kwargs"] + if field.name in SHOW_FIELDS: + field.show = True + + if "api" in field.name and ("key" in field.name or "token" in field.name): + field.password = True + field.show = True + field.required = True + + if field.name == "task": + field.required = True + field.show = True + field.is_list = True + field.options = ["text-generation", "text2text-generation"] + + if display_name := display_names_dict.get(field.name): + field.display_name = display_name + if field.name == "model_kwargs": + field.field_type = "code"