🔨 refactor(base.py): refactor FrontendNode.format_field() method to improve readability and maintainability
This commit refactors the FrontendNode.format_field() method to improve its readability and maintainability. The method now uses helper methods to handle specific field types and values, and to determine whether a field should be shown, be a password field, or be multiline. The method also uses a dictionary to handle special fields and their respective handlers.
This commit is contained in:
parent
373b599a1a
commit
342c2eaec7
1 changed files with 137 additions and 63 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from abc import ABC
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -139,10 +140,10 @@ class Template(BaseModel):
|
|||
class FrontendNode(BaseModel):
|
||||
template: Template
|
||||
description: str
|
||||
base_classes: list
|
||||
base_classes: List[str]
|
||||
name: str = ""
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
self.name: {
|
||||
"template": self.template.to_dict(self.format_field),
|
||||
|
|
@ -153,53 +154,145 @@ class FrontendNode(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
"""Formats a given field based on its attributes and value."""
|
||||
SPECIAL_FIELD_HANDLERS = {
|
||||
"allowed_tools": lambda field: "Tool",
|
||||
"max_value_length": lambda field: "int",
|
||||
}
|
||||
|
||||
key = field.name
|
||||
value = field.to_dict()
|
||||
_type = value["type"]
|
||||
|
||||
# Remove 'Optional' wrapper
|
||||
if "Optional" in _type:
|
||||
_type = _type.replace("Optional[", "")[:-1]
|
||||
_type = FrontendNode.remove_optional(_type)
|
||||
_type, is_list = FrontendNode.check_for_list_type(_type)
|
||||
field.is_list = is_list or field.is_list
|
||||
_type = FrontendNode.replace_mapping_with_dict(_type)
|
||||
_type = FrontendNode.handle_union_type(_type)
|
||||
|
||||
# Check for list type
|
||||
if "List" in _type or "Sequence" in _type:
|
||||
_type = _type.replace("List[", "")
|
||||
_type = _type.replace("Sequence[", "")[:-1]
|
||||
field.is_list = True
|
||||
field.field_type = FrontendNode.handle_special_field(
|
||||
field, key, _type, SPECIAL_FIELD_HANDLERS
|
||||
)
|
||||
field.field_type = FrontendNode.handle_dict_type(field, _type)
|
||||
field.show = FrontendNode.should_show_field(key, field.required)
|
||||
field.password = FrontendNode.should_be_password(key, field.show)
|
||||
field.multiline = FrontendNode.should_be_multiline(key)
|
||||
|
||||
# Replace 'Mapping' with 'dict'
|
||||
if "Mapping" in _type:
|
||||
_type = _type.replace("Mapping", "dict")
|
||||
FrontendNode.replace_default_value(field, value)
|
||||
FrontendNode.handle_specific_field_values(field, key, name)
|
||||
FrontendNode.handle_kwargs_field(field)
|
||||
FrontendNode.handle_api_key_field(field, key)
|
||||
|
||||
# {'type': 'Union[float, Tuple[float, float], NoneType]'} != {'type': 'float'}
|
||||
@staticmethod
|
||||
def remove_optional(_type: str) -> str:
|
||||
"""Removes 'Optional' wrapper from the type if present."""
|
||||
return re.sub(r"Optional\[(.*)\]", r"\1", _type)
|
||||
|
||||
@staticmethod
|
||||
def check_for_list_type(_type: str) -> tuple:
|
||||
"""Checks for list type and returns the modified type and a boolean indicating if it's a list."""
|
||||
is_list = "List" in _type or "Sequence" in _type
|
||||
if is_list:
|
||||
_type = re.sub(r"(List|Sequence)\[(.*)\]", r"\2", _type)
|
||||
return _type, is_list
|
||||
|
||||
@staticmethod
|
||||
def replace_mapping_with_dict(_type: str) -> str:
|
||||
"""Replaces 'Mapping' with 'dict'."""
|
||||
return _type.replace("Mapping", "dict")
|
||||
|
||||
@staticmethod
|
||||
def handle_union_type(_type: str) -> str:
|
||||
"""Simplifies the 'Union' type to the first type in the Union."""
|
||||
if "Union" in _type:
|
||||
_type = _type.replace("Union[", "")[:-1]
|
||||
_type = _type.split(",")[0]
|
||||
_type = _type.replace("]", "").replace("[", "")
|
||||
return _type
|
||||
|
||||
field.field_type = _type
|
||||
@staticmethod
|
||||
def handle_special_field(
|
||||
field, key: str, _type: str, SPECIAL_FIELD_HANDLERS
|
||||
) -> str:
|
||||
"""Handles special field by using the respective handler if present."""
|
||||
handler = SPECIAL_FIELD_HANDLERS.get(key)
|
||||
return handler(field) if handler else _type
|
||||
|
||||
# Change type from str to Tool
|
||||
field.field_type = "Tool" if key in {"allowed_tools"} else field.field_type
|
||||
@staticmethod
|
||||
def handle_dict_type(field: TemplateField, _type: str) -> str:
|
||||
"""Handles 'dict' type by replacing it with 'code' or 'file' based on the field name."""
|
||||
if "dict" in _type.lower():
|
||||
if field.name == "dict_":
|
||||
field.field_type = "file"
|
||||
field.suffixes = [".json", ".yaml", ".yml"]
|
||||
field.file_types = ["json", "yaml", "yml"]
|
||||
else:
|
||||
field.field_type = "code"
|
||||
return _type
|
||||
|
||||
field.field_type = "int" if key in {"max_value_length"} else field.field_type
|
||||
@staticmethod
|
||||
def replace_default_value(field: TemplateField, value: dict) -> None:
|
||||
"""Replaces default value with actual value if 'default' is present in value."""
|
||||
if "default" in value:
|
||||
field.value = value["default"]
|
||||
|
||||
# Show or not field
|
||||
field.show = bool(
|
||||
(field.required and key not in ["input_variables"])
|
||||
@staticmethod
|
||||
def handle_specific_field_values(
|
||||
field: TemplateField, key: str, name: Optional[str] = None
|
||||
) -> None:
|
||||
"""Handles specific field values for certain fields."""
|
||||
if key == "headers":
|
||||
field.value = """{'Authorization':
|
||||
'Bearer <token>'}"""
|
||||
if name == "OpenAI" and key == "model_name":
|
||||
field.options = constants.OPENAI_MODELS
|
||||
field.is_list = True
|
||||
elif name == "ChatOpenAI" and key == "model_name":
|
||||
field.options = constants.CHAT_OPENAI_MODELS
|
||||
field.is_list = True
|
||||
if "api_key" in key and "OpenAI" in str(name):
|
||||
field.display_name = "OpenAI API Key"
|
||||
field.required = False
|
||||
if field.value is None:
|
||||
field.value = ""
|
||||
|
||||
@staticmethod
|
||||
def handle_kwargs_field(field: TemplateField) -> None:
|
||||
"""Handles kwargs field by setting certain attributes."""
|
||||
if "kwargs" in field.name.lower():
|
||||
field.advanced = True
|
||||
field.required = False
|
||||
field.show = False
|
||||
|
||||
@staticmethod
|
||||
def handle_api_key_field(field: TemplateField, key: str) -> None:
|
||||
"""Handles api key field by setting certain attributes."""
|
||||
if "api" in key.lower() and "key" in key.lower():
|
||||
field.required = False
|
||||
field.advanced = False
|
||||
|
||||
@staticmethod
|
||||
def should_show_field(key: str, required: bool) -> bool:
|
||||
"""Determines whether the field should be shown."""
|
||||
return (
|
||||
(required and key not in ["input_variables"])
|
||||
or key in FORCE_SHOW_FIELDS
|
||||
or "api" in key
|
||||
or ("key" in key and "input" not in key and "output" not in key)
|
||||
)
|
||||
|
||||
# Add password field
|
||||
field.password = (
|
||||
@staticmethod
|
||||
def should_be_password(key: str, show: bool) -> bool:
|
||||
"""Determines whether the field should be a password field."""
|
||||
return (
|
||||
any(text in key.lower() for text in {"password", "token", "api", "key"})
|
||||
and field.show
|
||||
and show
|
||||
)
|
||||
|
||||
# Add multline
|
||||
field.multiline = key in {
|
||||
@staticmethod
|
||||
def should_be_multiline(key: str) -> bool:
|
||||
"""Determines whether the field should be multiline."""
|
||||
return key in {
|
||||
"suffix",
|
||||
"prefix",
|
||||
"template",
|
||||
|
|
@ -209,43 +302,24 @@ class FrontendNode(BaseModel):
|
|||
"description",
|
||||
}
|
||||
|
||||
# Replace dict type with str
|
||||
if "dict" in field.field_type.lower():
|
||||
field.field_type = "code"
|
||||
@staticmethod
|
||||
def replace_dict_with_code_or_file(
|
||||
field: TemplateField, _type: str, key: str
|
||||
) -> str:
|
||||
"""Replaces 'dict' type with 'code' or 'file'."""
|
||||
if "dict" in _type.lower():
|
||||
if key == "dict_":
|
||||
field.field_type = "file"
|
||||
field.suffixes = [".json", ".yaml", ".yml"]
|
||||
field.file_types = ["json", "yaml", "yml"]
|
||||
else:
|
||||
field.field_type = "code"
|
||||
return field.field_type
|
||||
|
||||
if key == "dict_":
|
||||
field.field_type = "file"
|
||||
field.suffixes = [".json", ".yaml", ".yml"]
|
||||
field.file_types = ["json", "yaml", "yml"]
|
||||
|
||||
# Replace default value with actual value
|
||||
@staticmethod
|
||||
def set_field_default_value(field: TemplateField, value: dict, key: str) -> None:
|
||||
"""Sets the field value with the default value if present."""
|
||||
if "default" in value:
|
||||
field.value = value["default"]
|
||||
|
||||
if key == "headers":
|
||||
field.value = """{'Authorization':
|
||||
'Bearer <token>'}"""
|
||||
|
||||
# Add options to openai
|
||||
if name == "OpenAI" and key == "model_name":
|
||||
field.options = constants.OPENAI_MODELS
|
||||
field.is_list = True
|
||||
elif name == "ChatOpenAI":
|
||||
if key == "model_name":
|
||||
field.options = constants.CHAT_OPENAI_MODELS
|
||||
field.is_list = True
|
||||
if "api_key" in key and "OpenAI" in str(name):
|
||||
field.display_name = "OpenAI API Key"
|
||||
field.required = False
|
||||
if field.value is None:
|
||||
field.value = ""
|
||||
|
||||
if "kwargs" in field.name.lower():
|
||||
field.advanced = True
|
||||
field.required = False
|
||||
field.show = False
|
||||
# If the field.name contains api or api and key, then it might be an api key
|
||||
# other conditions are to make sure that it is not an input or output variable
|
||||
if "api" in key.lower() and "key" in key.lower():
|
||||
field.required = False
|
||||
field.advanced = False
|
||||
field.value = """{'Authorization': 'Bearer <token>'}"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue