From 9ba00de4d59fb457645b01a09283b93d3d7a9ba0 Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Tue, 18 Jun 2024 22:11:33 -0300 Subject: [PATCH] refactor: Validate name overlap between inputs and outputs in FrontendNode The code changes in `frontend_node/base.py` add a new method `validate_name_overlap` to the `FrontendNode` class. This method checks if any of the output names overlap with any of the input names and raises a `ValueError` if there is a duplication. This refactor improves the consistency and correctness of the code. The commit message follows the established convention of using a prefix to indicate the type of change. --- src/backend/base/langflow/custom/utils.py | 2 ++ .../base/langflow/template/frontend_node/base.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/src/backend/base/langflow/custom/utils.py b/src/backend/base/langflow/custom/utils.py index 93a888c81..ae1827563 100644 --- a/src/backend/base/langflow/custom/utils.py +++ b/src/backend/base/langflow/custom/utils.py @@ -373,6 +373,8 @@ def build_custom_component_template_from_inputs( return_types = [format_type(return_type) for return_type in return_types] output.add_types(return_types) output.set_selected() + # Validate that there is not name overlap between inputs and outputs + frontend_node.validate_name_overlap() # ! This should be removed when we have a better way to handle this frontend_node.get_base_classes_from_outputs() reorder_fields(frontend_node, custom_component._get_field_order()) diff --git a/src/backend/base/langflow/template/frontend_node/base.py b/src/backend/base/langflow/template/frontend_node/base.py index 6d9a1d397..6694d9d2c 100644 --- a/src/backend/base/langflow/template/frontend_node/base.py +++ b/src/backend/base/langflow/template/frontend_node/base.py @@ -102,6 +102,17 @@ class FrontendNode(BaseModel): def get_base_classes_from_outputs(self) -> list[str]: self.base_classes = [output_type for output in self.outputs for output_type in output.types] + def validate_name_overlap(self) -> None: + # Check if any of the output names overlap with the any of the inputs + output_names = [output.name for output in self.outputs] + input_names = [input_.name for input_ in self.template.fields] + overlap = set(output_names).intersection(input_names) + if overlap: + overlap = ", ".join(map(lambda x: f"'{x}'", overlap)) + raise ValueError( + f"There should be no overlap between input and output names. Names {overlap} are duplicated." + ) + def add_base_class(self, base_class: Union[str, List[str]]) -> None: """Adds a base class to the frontend node.""" if isinstance(base_class, str):