Merged frontend_node/llm from origin/dev into add_extra_fields_documentloaders

This commit is contained in:
Alexandre Henrique 2023-06-09 15:05:56 -03:00
commit f7bdb711f1
2 changed files with 15 additions and 8 deletions

View file

@ -16,20 +16,30 @@ class LLMFrontendNode(FrontendNode):
def format_azure_field(field: TemplateField):
if field.name == "model_name":
field.show = False # Azure uses deployment_name instead of model_name.
if field.name == "openai_api_type":
elif field.name == "openai_api_type":
field.show = False
field.password = False
field.value = "azure"
if field.name == "openai_api_version":
elif field.name == "openai_api_version":
field.password = False
field.value = "2023-03-15-preview"
@staticmethod
def format_llama_field(field: TemplateField):
field.show = True
field.advanced = not field.required
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
display_names_dict = {
"huggingfacehub_api_token": "HuggingFace Hub API Token",
}
FrontendNode.format_field(field, name)
LLMFrontendNode.format_openai_field(field)
if name and "azure" in name.lower():
LLMFrontendNode.format_azure_field(field)
if name and "llama" in name.lower():
LLMFrontendNode.format_llama_field(field)
SHOW_FIELDS = ["repo_id"]
if field.name in SHOW_FIELDS:
field.show = True
@ -46,7 +56,8 @@ class LLMFrontendNode(FrontendNode):
field.required = True
field.show = True
field.is_list = True
field.options = ["text-generation", "text2text-generation"]
field.options = ["text-generation", "text2text-generation", "summarization"]
field.value = field.options[0]
field.advanced = True
if display_name := display_names_dict.get(field.name):
@ -64,7 +75,3 @@ class LLMFrontendNode(FrontendNode):
]:
field.advanced = False
field.show = True
LLMFrontendNode.format_openai_field(field)
if "azure" in name.lower():
LLMFrontendNode.format_azure_field(field)

View file

@ -10,7 +10,7 @@ class TextSplittersFrontendNode(FrontendNode):
required=True,
show=True,
name="documents",
)
)
)
name = "separator"
if self.template.type_name == "CharacterTextSplitter":