diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index 0475fd82b..055766bb0 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -7,7 +7,10 @@ from fastapi import HTTPException from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES from langflow.interface.custom.component import Component from langflow.interface.custom.directory_reader import DirectoryReader -from langflow.interface.custom.utils import extract_inner_type, extract_union_types +from langflow.interface.custom.utils import ( + extract_inner_type_from_generic_alias, + extract_union_types_from_generic_alias, +) from langflow.services.database.models.flow import Flow from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service @@ -137,16 +140,19 @@ class CustomComponent(Component): if not return_type: return [] # If list or List is in the return type, then we remove it and return the inner type - if return_type.startswith("list") or return_type.startswith("List"): - return_type = extract_inner_type(return_type) + if hasattr(return_type, "__origin__") and return_type.__origin__ in [list, List]: + return_type = extract_inner_type_from_generic_alias(return_type) # If the return type is not a Union, then we just return it as a list - if "Union" not in return_type: - return [return_type] if return_type in self.return_type_valid_list else [] + if not hasattr(return_type, "__origin__") or return_type.__origin__ != Union: + if isinstance(return_type, list): + return return_type + return [return_type] # if return_type in self.return_type_valid_list else [] - # If the return type is a Union, then we need to parse it - return_type = extract_union_types(return_type) - return [item for item in return_type if item in self.return_type_valid_list] + # If the return type is a Union, then we need to parse itx + return_type = extract_union_types_from_generic_alias(return_type) + # return [item for item in return_type if item in self.return_type_valid_list] + return return_type @property def get_main_class_name(self):