From f7bdb711f184e0e06ee45240c612c643bded6547 Mon Sep 17 00:00:00 2001 From: Alexandre Henrique Date: Fri, 9 Jun 2023 15:05:56 -0300 Subject: [PATCH] Merged frontend_node/llm from origin/dev into add_extra_fields_documentloaders --- .../langflow/template/frontend_node/llms.py | 21 ++++++++++++------- .../template/frontend_node/textsplitters.py | 2 +- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/backend/langflow/template/frontend_node/llms.py b/src/backend/langflow/template/frontend_node/llms.py index 39e82422f..ac17cf8ed 100644 --- a/src/backend/langflow/template/frontend_node/llms.py +++ b/src/backend/langflow/template/frontend_node/llms.py @@ -16,20 +16,30 @@ class LLMFrontendNode(FrontendNode): def format_azure_field(field: TemplateField): if field.name == "model_name": field.show = False # Azure uses deployment_name instead of model_name. - if field.name == "openai_api_type": + elif field.name == "openai_api_type": field.show = False field.password = False field.value = "azure" - if field.name == "openai_api_version": + elif field.name == "openai_api_version": field.password = False field.value = "2023-03-15-preview" + @staticmethod + def format_llama_field(field: TemplateField): + field.show = True + field.advanced = not field.required + @staticmethod def format_field(field: TemplateField, name: Optional[str] = None) -> None: display_names_dict = { "huggingfacehub_api_token": "HuggingFace Hub API Token", } FrontendNode.format_field(field, name) + LLMFrontendNode.format_openai_field(field) + if name and "azure" in name.lower(): + LLMFrontendNode.format_azure_field(field) + if name and "llama" in name.lower(): + LLMFrontendNode.format_llama_field(field) SHOW_FIELDS = ["repo_id"] if field.name in SHOW_FIELDS: field.show = True @@ -46,7 +56,8 @@ class LLMFrontendNode(FrontendNode): field.required = True field.show = True field.is_list = True - field.options = ["text-generation", "text2text-generation"] + field.options = ["text-generation", "text2text-generation", "summarization"] + field.value = field.options[0] field.advanced = True if display_name := display_names_dict.get(field.name): @@ -64,7 +75,3 @@ class LLMFrontendNode(FrontendNode): ]: field.advanced = False field.show = True - - LLMFrontendNode.format_openai_field(field) - if "azure" in name.lower(): - LLMFrontendNode.format_azure_field(field) diff --git a/src/backend/langflow/template/frontend_node/textsplitters.py b/src/backend/langflow/template/frontend_node/textsplitters.py index b784618c0..03880379d 100644 --- a/src/backend/langflow/template/frontend_node/textsplitters.py +++ b/src/backend/langflow/template/frontend_node/textsplitters.py @@ -10,7 +10,7 @@ class TextSplittersFrontendNode(FrontendNode): required=True, show=True, name="documents", - ) + ) ) name = "separator" if self.template.type_name == "CharacterTextSplitter":