diff --git a/src/backend/base/langflow/custom/utils.py b/src/backend/base/langflow/custom/utils.py index 49bb201a8..f2a2450a1 100644 --- a/src/backend/base/langflow/custom/utils.py +++ b/src/backend/base/langflow/custom/utils.py @@ -3,6 +3,7 @@ import contextlib import re import traceback import warnings +from itertools import chain from typing import Any, Dict, List, Optional, Tuple, Union from uuid import UUID @@ -21,7 +22,7 @@ from langflow.custom.directory_reader.utils import ( from langflow.custom.eval import eval_custom_component_code from langflow.custom.schema import MissingDefault from langflow.field_typing.range_spec import RangeSpec -from langflow.helpers.custom import format_type +from langflow.helpers.custom import get_all_types_from_type from langflow.schema import dotdict from langflow.template.field.base import Input from langflow.template.frontend_node.custom_components import ComponentFrontendNode, CustomComponentFrontendNode @@ -370,8 +371,8 @@ def build_custom_component_template_from_inputs( if output.types: continue return_types = custom_component.get_method_return_type(output.method) - return_types = [format_type(return_type) for return_type in return_types] - output.add_types(return_types) + all_types = [get_all_types_from_type(return_type) for return_type in return_types] + output.add_types(chain.from_iterable(all_types)) output.set_selected() # Validate that there is not name overlap between inputs and outputs frontend_node.validate() diff --git a/src/backend/base/langflow/helpers/custom.py b/src/backend/base/langflow/helpers/custom.py index bdbb128f4..61ab4f46b 100644 --- a/src/backend/base/langflow/helpers/custom.py +++ b/src/backend/base/langflow/helpers/custom.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, get_args def format_type(type_: Any) -> str: @@ -11,3 +11,13 @@ def format_type(type_: Any) -> str: else: type_ = str(type_) return type_ + + +def get_all_types_from_type(type_: Any) -> str: + args = get_args(type_) + if args: + formatted_types = [format_type(arg) for arg in args] + formatted_types.insert(0, format_type(type_)) + return formatted_types + else: + return [format_type(type_)]