Refactor field handling in frontend nodes

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-12-11 12:58:31 -03:00
commit c8f2469c9f
9 changed files with 75 additions and 65 deletions

View file

@ -129,7 +129,8 @@ def add_new_custom_field(
**sanitize_field_config(field_config),
)
frontend_node.template.upsert_field(field_name, new_field)
frontend_node.custom_fields[field_name] = None
if isinstance(frontend_node.custom_fields, dict):
frontend_node.custom_fields[field_name] = None
return frontend_node

View file

@ -37,10 +37,10 @@ class TemplateField(BaseModel):
password: bool = False
"""Specifies if the field is a password. Defaults to False."""
options: Union[list[str], Callable] = None
options: Optional[Union[list[str], Callable]] = None
"""List of options for the field. Only used when is_list=True. Default is an empty list."""
name: str = None
name: Optional[str] = None
"""Name of the field. Default is an empty string."""
display_name: Optional[str] = None
@ -61,7 +61,7 @@ class TemplateField(BaseModel):
refresh: Optional[bool] = None
"""Specifies if the field should be refreshed. Defaults to False."""
range_spec: Optional[RangeSpec] = Field(None, serialization_alias="rangeSpec")
range_spec: Optional[RangeSpec] = Field(default=None, serialization_alias="rangeSpec")
"""Range specification for the field. Defaults to None."""
def to_dict(self):

View file

@ -44,7 +44,7 @@ class FrontendNode(BaseModel):
description: Optional[str] = None
base_classes: List[str]
name: str = ""
display_name: str = ""
display_name: Optional[str] = ""
documentation: str = ""
custom_fields: Optional[Dict] = defaultdict(list)
output_types: List[str] = []
@ -201,7 +201,8 @@ class FrontendNode(BaseModel):
@staticmethod
def handle_kwargs_field(field: TemplateField) -> None:
"""Handles kwargs field by setting certain attributes."""
if "kwargs" in field.name.lower():
if "kwargs" in (field.name or "").lower():
field.advanced = True
field.required = False
field.show = False

View file

@ -48,16 +48,16 @@ class ChainFrontendNode(FrontendNode):
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
if "name" == "RetrievalQA" and field.name == "memory":
key = field.name or ""
if "name" == "RetrievalQA" and key == "memory":
field.show = False
field.required = False
field.advanced = False
if "key" in field.name:
if "key" in key:
field.password = False
field.show = False
if field.name in ["input_key", "output_key"]:
if key in ["input_key", "output_key"]:
field.required = True
field.show = True
field.advanced = True
@ -71,26 +71,26 @@ class ChainFrontendNode(FrontendNode):
# field.value = field.value.template
# Separated for possible future changes
if field.name == "prompt" and field.value is None:
if key == "prompt" and field.value is None:
field.required = True
field.show = True
field.advanced = False
if field.name == "memory":
if key == "memory":
# field.required = False
field.show = True
field.advanced = False
if field.name == "verbose":
if key == "verbose":
field.required = False
field.show = False
field.advanced = True
if field.name == "llm":
if key == "llm":
field.required = True
field.show = True
field.advanced = False
field.field_type = "BaseLanguageModel" # temporary fix
field.is_list = False
if field.name == "return_source_documents":
if key == "return_source_documents":
field.required = False
field.show = True
field.advanced = True

View file

@ -22,13 +22,14 @@ class EmbeddingFrontendNode(FrontendNode):
@staticmethod
def format_vertex_field(field: TemplateField, name: str):
if "VertexAI" in name:
key = field.name or ""
advanced_fields = [
"verbose",
"top_p",
"top_k",
"max_output_tokens",
]
if field.name in advanced_fields:
if key in advanced_fields:
field.advanced = True
show_fields = [
"verbose",
@ -42,21 +43,22 @@ class EmbeddingFrontendNode(FrontendNode):
"top_k",
]
if field.name in show_fields:
if key in show_fields:
field.show = True
@staticmethod
def format_jina_fields(field: TemplateField):
if "jina" in field.name:
name = field.name or ""
if "jina" in name:
field.show = True
field.advanced = False
if "auth" in field.name or "token" in field.name:
if "auth" in name or "token" in name:
field.password = True
field.show = True
field.advanced = False
if field.name == "jina_api_url":
if name == "jina_api_url":
field.show = True
field.advanced = True
field.display_name = "Jina API URL"
@ -64,14 +66,15 @@ class EmbeddingFrontendNode(FrontendNode):
@staticmethod
def format_openai_fields(field: TemplateField):
if "openai" in field.name:
name = field.name or ""
if "openai" in name:
field.show = True
field.advanced = True
split_name = field.name.split("_")
split_name = name.split("_")
title_name = " ".join([s.capitalize() for s in split_name])
field.display_name = title_name.replace("Openai", "OpenAI").replace("Api", "API")
if "api_key" in field.name:
if "api_key" in name:
field.password = True
field.show = True
field.advanced = False
@ -83,13 +86,14 @@ class EmbeddingFrontendNode(FrontendNode):
EmbeddingFrontendNode.format_vertex_field(field, name)
field.advanced = not field.required
field.show = True
if field.name == "headers":
key = field.name or ""
if key == "headers":
field.show = False
if field.name == "model_kwargs":
if key == "model_kwargs":
field.field_type = "dict"
field.advanced = True
field.show = True
elif field.name in [
elif key in [
"model_name",
"temperature",
"model_file",
@ -99,9 +103,9 @@ class EmbeddingFrontendNode(FrontendNode):
]:
field.advanced = False
field.show = True
if field.name == "credentials":
if key == "credentials":
field.field_type = "file"
if name == "VertexAI" and field.name not in [
if name == "VertexAI" and key not in [
"callbacks",
"client",
"stop",

View file

@ -9,7 +9,7 @@ from langflow.utils.constants import ANTHROPIC_MODELS, CHAT_OPENAI_MODELS, OPENA
class OpenAIAPIKeyFormatter(FieldFormatter):
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
if "api_key" in field.name and "OpenAI" in str(name):
if field.name and "api_key" in field.name and "OpenAI" in str(name):
field.display_name = "OpenAI API Key"
field.required = False
if field.value is None:
@ -25,14 +25,14 @@ class ModelSpecificFieldFormatter(FieldFormatter):
}
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
if name in self.MODEL_DICT and field.name == "model_name":
if field.name and name in self.MODEL_DICT and field.name == "model_name":
field.options = self.MODEL_DICT[name]
field.is_list = True
class KwargsFormatter(FieldFormatter):
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
if "kwargs" in field.name.lower():
if field.name and "kwargs" in field.name.lower():
field.advanced = True
field.required = False
field.show = False
@ -40,11 +40,11 @@ class KwargsFormatter(FieldFormatter):
class APIKeyFormatter(FieldFormatter):
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
if "api" in field.name.lower() and "key" in field.name.lower():
if field.name and "api" in field.name.lower() and "key" in field.name.lower():
field.required = False
field.advanced = False
field.display_name = field.name.replace("_", " ").title()
field.display_name = (field.name or "").replace("_", " ").title()
field.display_name = field.display_name.replace("Api", "API")
@ -94,7 +94,7 @@ class SpecialFieldFormatter(FieldFormatter):
class ShowFieldFormatter(FieldFormatter):
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
key = field.name
key = field.name or ""
required = field.required
field.show = (
(required and key not in ["input_variables"])
@ -106,7 +106,7 @@ class ShowFieldFormatter(FieldFormatter):
class PasswordFieldFormatter(FieldFormatter):
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
key = field.name
key = field.name or ""
show = field.show
if any(text in key.lower() for text in {"password", "token", "api", "key"}) and show:
field.password = True
@ -114,7 +114,7 @@ class PasswordFieldFormatter(FieldFormatter):
class MultilineFieldFormatter(FieldFormatter):
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
key = field.name
key = field.name or ""
if key in {
"suffix",
"prefix",

View file

@ -1,10 +1,9 @@
from typing import Optional
from langflow.services.database.models.base import orjson_dumps
from langflow.services.database.models.base import orjson_dumps
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
from langflow.template.frontend_node.constants import CTRANSFORMERS_DEFAULT_CONFIG
from langflow.template.frontend_node.constants import OPENAI_API_BASE_INFO
from langflow.template.frontend_node.constants import CTRANSFORMERS_DEFAULT_CONFIG, OPENAI_API_BASE_INFO
class LLMFrontendNode(FrontendNode):
@ -24,6 +23,7 @@ class LLMFrontendNode(FrontendNode):
@staticmethod
def format_vertex_field(field: TemplateField, name: str):
key = field.name or ""
if "VertexAI" in name:
advanced_fields = [
"tuned_model_name",
@ -32,7 +32,7 @@ class LLMFrontendNode(FrontendNode):
"top_k",
"max_output_tokens",
]
if field.name in advanced_fields:
if key in advanced_fields:
field.advanced = True
show_fields = [
"tuned_model_name",
@ -47,20 +47,19 @@ class LLMFrontendNode(FrontendNode):
"top_k",
]
if field.name in show_fields:
if key in show_fields:
field.show = True
@staticmethod
def format_openai_field(field: TemplateField):
if "openai" in field.name.lower():
field.display_name = (field.name.title().replace("Openai", "OpenAI").replace("_", " ")).replace(
"Api", "API"
)
key = field.name or ""
if "openai" in key.lower():
field.display_name = (key.title().replace("Openai", "OpenAI").replace("_", " ")).replace("Api", "API")
if "key" not in field.name.lower() and "token" not in field.name.lower():
if "key" not in key.lower() and "token" not in key.lower():
field.password = False
if field.name == "openai_api_base":
if key == "openai_api_base":
field.info = OPENAI_API_BASE_INFO
def add_extra_base_classes(self) -> None:
@ -69,13 +68,14 @@ class LLMFrontendNode(FrontendNode):
@staticmethod
def format_azure_field(field: TemplateField):
if field.name == "model_name":
key = field.name or ""
if key == "model_name":
field.show = False # Azure uses deployment_name instead of model_name.
elif field.name == "openai_api_type":
elif key == "openai_api_type":
field.show = False
field.password = False
field.value = "azure"
elif field.name == "openai_api_version":
elif key == "openai_api_version":
field.password = False
@staticmethod
@ -85,7 +85,8 @@ class LLMFrontendNode(FrontendNode):
@staticmethod
def format_ctransformers_field(field: TemplateField):
if field.name == "config":
key = field.name or ""
if key == "config":
field.show = True
field.advanced = True
field.value = orjson_dumps(CTRANSFORMERS_DEFAULT_CONFIG, indent_2=True)
@ -105,10 +106,11 @@ class LLMFrontendNode(FrontendNode):
if name and "vertex" in name.lower():
LLMFrontendNode.format_vertex_field(field, name)
SHOW_FIELDS = ["repo_id"]
if field.name in SHOW_FIELDS:
key = field.name or ""
if key in SHOW_FIELDS:
field.show = True
if "api" in field.name and ("key" in field.name or ("token" in field.name and "tokens" not in field.name)):
if "api" in key and ("key" in key or ("token" in key and "tokens" not in key)):
field.password = True
field.show = True
# Required should be False to support
@ -116,7 +118,7 @@ class LLMFrontendNode(FrontendNode):
field.required = False
field.advanced = False
if field.name == "task":
if key == "task":
field.required = True
field.show = True
field.is_list = True
@ -124,13 +126,13 @@ class LLMFrontendNode(FrontendNode):
field.value = field.options[0]
field.advanced = True
if display_name := display_names_dict.get(field.name):
if display_name := display_names_dict.get(key):
field.display_name = display_name
if field.name == "model_kwargs":
if key == "model_kwargs":
field.field_type = "dict"
field.advanced = True
field.show = True
elif field.name in [
elif key in [
"model_name",
"temperature",
"model_file",
@ -140,9 +142,9 @@ class LLMFrontendNode(FrontendNode):
]:
field.advanced = False
field.show = True
if field.name == "credentials":
if key == "credentials":
field.field_type = "file"
if name == "VertexAI" and field.name not in [
if name == "VertexAI" and key not in [
"callbacks",
"client",
"stop",

View file

@ -1,6 +1,7 @@
from typing import Optional
from langchain.agents.mrkl import prompt
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
from langflow.template.frontend_node.constants import DEFAULT_PROMPT, HUMAN_PROMPT, SYSTEM_PROMPT
@ -20,14 +21,15 @@ class PromptFrontendNode(FrontendNode):
"examples",
"format_instructions",
]
key = field.name or ""
if field.field_type == "StringPromptTemplate" and "Message" in str(name):
field.field_type = "prompt"
field.multiline = True
field.value = HUMAN_PROMPT if "Human" in field.name else SYSTEM_PROMPT
if field.name == "template" and field.value == "":
field.value = HUMAN_PROMPT if "Human" in key else SYSTEM_PROMPT
if key == "template" and field.value == "":
field.value = DEFAULT_PROMPT
if field.name in PROMPT_FIELDS:
if key and key in PROMPT_FIELDS:
field.field_type = "prompt"
field.advanced = False
@ -48,7 +50,8 @@ class PromptTemplateNode(FrontendNode):
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
if field.name == "examples":
if (field.name or "") == "examples":
field.advanced = False

View file

@ -3,7 +3,6 @@ from typing import List, Optional
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
BASIC_FIELDS = [
"work_dir",
"collection_name",
@ -313,7 +312,7 @@ class VectorStoreFrontendNode(FrontendNode):
field.show = True
field.advanced = False
field.is_list = True
elif "embedding" in field.name:
elif field.name and "embedding" in field.name:
# for backwards compatibility
field.name = "embedding"
field.required = True