Refactor build_template_from_function and build_template_from_class functions

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-06 10:56:42 -03:00
commit 77ea76d37d

View file

@ -15,8 +15,12 @@ def remove_ansi_escape_codes(text):
return re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", text)
def build_template_from_function(name: str, type_to_loader_dict: Dict, add_function: bool = False):
classes = [item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()]
def build_template_from_function(
name: str, type_to_loader_dict: Dict, add_function: bool = False
):
classes = [
item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()
]
# Raise error if name is not in chains
if name not in classes:
@ -37,8 +41,10 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict, add_funct
for name_, value_ in value.__repr_args__():
if name_ == "default_factory":
try:
variables[class_field_items]["default"] = get_default_factory(
module=_class.__base__.__module__, function=value_
variables[class_field_items]["default"] = (
get_default_factory(
module=_class.__base__.__module__, function=value_
)
)
except Exception:
variables[class_field_items]["default"] = None
@ -46,7 +52,9 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict, add_funct
variables[class_field_items][name_] = value_
variables[class_field_items]["placeholder"] = (
docs.params[class_field_items] if class_field_items in docs.params else ""
docs.params[class_field_items]
if class_field_items in docs.params
else ""
)
# Adding function to base classes to allow
# the output to be a function
@ -61,7 +69,9 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict, add_funct
}
def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: bool = False):
def build_template_from_class(
name: str, type_to_cls_dict: Dict, add_function: bool = False
):
classes = [item.__name__ for item in type_to_cls_dict.values()]
# Raise error if name is not in chains
@ -85,8 +95,11 @@ def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: b
for name_, value_ in value.__repr_args__():
if name_ == "default_factory":
try:
variables[class_field_items]["default"] = get_default_factory(
module=_class.__base__.__module__, function=value_
variables[class_field_items]["default"] = (
get_default_factory(
module=_class.__base__.__module__,
function=value_,
)
)
except Exception:
variables[class_field_items]["default"] = None
@ -94,7 +107,9 @@ def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: b
variables[class_field_items][name_] = value_
variables[class_field_items]["placeholder"] = (
docs.params[class_field_items] if class_field_items in docs.params else ""
docs.params[class_field_items]
if class_field_items in docs.params
else ""
)
base_classes = get_base_classes(_class)
# Adding function to base classes to allow
@ -126,7 +141,9 @@ def build_template_from_method(
# Check if the method exists in this class
if not hasattr(_class, method_name):
raise ValueError(f"Method {method_name} not found in class {class_name}")
raise ValueError(
f"Method {method_name} not found in class {class_name}"
)
# Get the method
method = getattr(_class, method_name)
@ -145,8 +162,14 @@ def build_template_from_method(
"_type": _type,
**{
name: {
"default": param.default if param.default != param.empty else None,
"type": param.annotation if param.annotation != param.empty else None,
"default": (
param.default if param.default != param.empty else None
),
"type": (
param.annotation
if param.annotation != param.empty
else None
),
"required": param.default == param.empty,
}
for name, param in params.items()
@ -233,7 +256,9 @@ def sync_to_async(func):
return async_wrapper
def format_dict(dictionary: Dict[str, Any], class_name: Optional[str] = None) -> Dict[str, Any]:
def format_dict(
dictionary: Dict[str, Any], class_name: Optional[str] = None
) -> Dict[str, Any]:
"""
Formats a dictionary by removing certain keys and modifying the
values of other keys.
@ -243,7 +268,7 @@ def format_dict(dictionary: Dict[str, Any], class_name: Optional[str] = None) ->
"""
for key, value in dictionary.items():
if key == "_type":
if key in ["_type"]:
continue
_type: Union[str, type] = get_type(value)
@ -319,7 +344,9 @@ def check_list_type(_type: str, value: Dict[str, Any]) -> str:
The modified type string.
"""
if any(list_type in _type for list_type in ["List", "Sequence", "Set"]):
_type = _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
_type = (
_type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
)
value["list"] = True
else:
value["list"] = False
@ -422,7 +449,9 @@ def set_headers_value(value: Dict[str, Any]) -> None:
value["value"] = """{"Authorization": "Bearer <token>"}"""
def add_options_to_field(value: Dict[str, Any], class_name: Optional[str], key: str) -> None:
def add_options_to_field(
value: Dict[str, Any], class_name: Optional[str], key: str
) -> None:
"""
Adds options to the field based on the class name and key.
"""