refac: formatting moved to FrontendNode
This commit is contained in:
parent
43c4fe7dfc
commit
24bdfaa941
1 changed files with 88 additions and 15 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -21,8 +21,6 @@ class TemplateFieldCreator(BaseModel, ABC):
|
|||
content: Union[str, None] = None
|
||||
password: bool = False
|
||||
options: list[str] = []
|
||||
# _name will be used to store the name of the field
|
||||
# in the template
|
||||
name: str = ""
|
||||
|
||||
def to_dict(self):
|
||||
|
|
@ -60,9 +58,9 @@ class TemplateFieldCreator(BaseModel, ABC):
|
|||
_type = _type.replace("Mapping", "dict")
|
||||
|
||||
# Change type from str to Tool
|
||||
self.field_type = "Tool" if key in ["allowed_tools"] else _type
|
||||
self.field_type = "Tool" if key in {"allowed_tools"} else self.field_type
|
||||
|
||||
self.field_type = "int" if key in ["max_value_length"] else self.field_type
|
||||
self.field_type = "int" if key in {"max_value_length"} else self.field_type
|
||||
|
||||
# Show or not field
|
||||
self.show = bool(
|
||||
|
|
@ -73,18 +71,18 @@ class TemplateFieldCreator(BaseModel, ABC):
|
|||
|
||||
# Add password field
|
||||
self.password = any(
|
||||
text in key.lower() for text in ["password", "token", "api", "key"]
|
||||
text in key.lower() for text in {"password", "token", "api", "key"}
|
||||
)
|
||||
|
||||
# Add multline
|
||||
self.multiline = key in [
|
||||
self.multiline = key in {
|
||||
"suffix",
|
||||
"prefix",
|
||||
"template",
|
||||
"examples",
|
||||
"code",
|
||||
"headers",
|
||||
]
|
||||
}
|
||||
|
||||
# Replace dict type with str
|
||||
if "dict" in self.field_type.lower():
|
||||
|
|
@ -120,13 +118,17 @@ class Template(BaseModel):
|
|||
type_name: str
|
||||
fields: list[TemplateField]
|
||||
|
||||
def process_fields(self, name: Optional[str] = None) -> None:
|
||||
for field in self.fields:
|
||||
signature = field.to_dict()
|
||||
field.process_field(field.name, signature, name)
|
||||
def process_fields(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
format_field_func: Union[Callable, None] = None,
|
||||
):
|
||||
if format_field_func:
|
||||
for field in self.fields:
|
||||
format_field_func(field, name)
|
||||
|
||||
def to_dict(self):
|
||||
self.process_fields(self.type_name)
|
||||
def to_dict(self, format_field_func=None):
|
||||
self.process_fields(self.type_name, format_field_func)
|
||||
result = {field.name: field.to_dict() for field in self.fields}
|
||||
result["_type"] = self.type_name # type: ignore
|
||||
return result
|
||||
|
|
@ -141,8 +143,79 @@ class FrontendNode(BaseModel):
|
|||
def to_dict(self):
|
||||
return {
|
||||
self.name: {
|
||||
"template": self.template.to_dict(),
|
||||
"template": self.template.to_dict(self.format_field),
|
||||
"description": self.description,
|
||||
"base_classes": self.base_classes,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
key = field.name
|
||||
value = field.to_dict()
|
||||
_type = value["type"]
|
||||
|
||||
# Remove 'Optional' wrapper
|
||||
if "Optional" in _type:
|
||||
_type = _type.replace("Optional[", "")[:-1]
|
||||
|
||||
# Check for list type
|
||||
if "List" in _type:
|
||||
_type = _type.replace("List[", "")[:-1]
|
||||
self.is_list = True
|
||||
|
||||
# Replace 'Mapping' with 'dict'
|
||||
if "Mapping" in _type:
|
||||
_type = _type.replace("Mapping", "dict")
|
||||
|
||||
# Change type from str to Tool
|
||||
field.field_type = "Tool" if key in {"allowed_tools"} else field.field_type
|
||||
|
||||
field.field_type = "int" if key in {"max_value_length"} else field.field_type
|
||||
|
||||
# Show or not field
|
||||
field.show = bool(
|
||||
(field.required and key not in ["input_variables"])
|
||||
or key in FORCE_SHOW_FIELDS
|
||||
or "api_key" in key
|
||||
)
|
||||
|
||||
# Add password field
|
||||
field.password = any(
|
||||
text in key.lower() for text in {"password", "token", "api", "key"}
|
||||
)
|
||||
|
||||
# Add multline
|
||||
field.multiline = key in {
|
||||
"suffix",
|
||||
"prefix",
|
||||
"template",
|
||||
"examples",
|
||||
"code",
|
||||
"headers",
|
||||
}
|
||||
|
||||
# Replace dict type with str
|
||||
if "dict" in field.field_type.lower():
|
||||
field.field_type = "code"
|
||||
|
||||
if key == "dict_":
|
||||
field.field_type = "file"
|
||||
field.suffixes = [".json", ".yaml", ".yml"]
|
||||
field.file_types = ["json", "yaml", "yml"]
|
||||
|
||||
# Replace default value with actual value
|
||||
if "default" in value:
|
||||
field.value = value["default"]
|
||||
|
||||
if key == "headers":
|
||||
field.value = """{'Authorization':
|
||||
'Bearer <token>'}"""
|
||||
|
||||
# Add options to openai
|
||||
if name == "OpenAI" and key == "model_name":
|
||||
field.options = constants.OPENAI_MODELS
|
||||
field.is_list = True
|
||||
elif name == "OpenAIChat" and key == "model_name":
|
||||
field.options = constants.CHAT_OPENAI_MODELS
|
||||
field.is_list = True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue