refactor: Validate name overlap and attributes in FrontendNode class

This refactor adds a new method `validate` to the `FrontendNode` class in the `frontend_node/base.py` file. The `validate` method now calls two separate validation methods: `validate_name_overlap` and `validate_attributes`. The `validate_name_overlap` method checks for any overlap between input and output names, while the `validate_attributes` method checks for reserved attributes in the input and output names. These changes improve the reliability and maintainability of the code.
This commit is contained in:
ogabrielluiz 2024-06-19 01:03:19 -03:00
commit 2b3ca8ecc6
2 changed files with 34 additions and 1 deletions

View file

@ -374,7 +374,7 @@ def build_custom_component_template_from_inputs(
output.add_types(return_types)
output.set_selected()
# Validate that there is not name overlap between inputs and outputs
frontend_node.validate_name_overlap()
frontend_node.validate()
# ! This should be removed when we have a better way to handle this
frontend_node.get_base_classes_from_outputs()
reorder_fields(frontend_node, custom_component._get_field_order())

View file

@ -102,6 +102,10 @@ class FrontendNode(BaseModel):
def get_base_classes_from_outputs(self) -> list[str]:
self.base_classes = [output_type for output in self.outputs for output_type in output.types]
def validate(self) -> None:
self.validate_name_overlap()
self.validate_attributes()
def validate_name_overlap(self) -> None:
# Check if any of the output names overlap with the any of the inputs
output_names = [output.name for output in self.outputs]
@ -113,6 +117,35 @@ class FrontendNode(BaseModel):
f"There should be no overlap between input and output names. Names {overlap} are duplicated."
)
def validate_attributes(self) -> None:
# None of inputs, outputs, _artifacts, _results, logs, status, vertex, graph, display_name, description, documentation, icon
# should be present in outputs or input names
output_names = [output.name for output in self.outputs]
input_names = [input_.name for input_ in self.template.fields]
attributes = [
"inputs",
"outputs",
"_artifacts",
"_results",
"logs",
"status",
"vertex",
"graph",
"display_name",
"description",
"documentation",
"icon",
]
output_overlap = set(output_names).intersection(attributes)
input_overlap = set(input_names).intersection(attributes)
error_message = ""
if output_overlap:
output_overlap = ", ".join(map(lambda x: f"'{x}'", output_overlap))
error_message += f"Output names {output_overlap} are reserved attributes.\n"
if input_overlap:
input_overlap = ", ".join(map(lambda x: f"'{x}'", input_overlap))
error_message += f"Input names {input_overlap} are reserved attributes."
def add_base_class(self, base_class: Union[str, List[str]]) -> None:
"""Adds a base class to the frontend node."""
if isinstance(base_class, str):