From 15cc7667a7ae4229dde8d8f412f0961ca9aecd80 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 22 Nov 2023 21:10:40 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(custom=5Fcomponent.py):=20up?= =?UTF-8?q?date=20import=20statements=20and=20function=20calls=20to=20matc?= =?UTF-8?q?h=20changes=20in=20utils=20module=20=F0=9F=94=80=20merge(custom?= =?UTF-8?q?=5Fcomponent.py):=20merge=20changes=20from=20utils=20module=20t?= =?UTF-8?q?o=20handle=20generic=20aliases=20correctly=20in=20return=20type?= =?UTF-8?q?=20parsing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../interface/custom/custom_component.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index 0475fd82b..055766bb0 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -7,7 +7,10 @@ from fastapi import HTTPException from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES from langflow.interface.custom.component import Component from langflow.interface.custom.directory_reader import DirectoryReader -from langflow.interface.custom.utils import extract_inner_type, extract_union_types +from langflow.interface.custom.utils import ( + extract_inner_type_from_generic_alias, + extract_union_types_from_generic_alias, +) from langflow.services.database.models.flow import Flow from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service @@ -137,16 +140,19 @@ class CustomComponent(Component): if not return_type: return [] # If list or List is in the return type, then we remove it and return the inner type - if return_type.startswith("list") or return_type.startswith("List"): - return_type = extract_inner_type(return_type) + if hasattr(return_type, "__origin__") and return_type.__origin__ in [list, List]: + return_type = extract_inner_type_from_generic_alias(return_type) # If the return type is not a Union, then we just return it as a list - if "Union" not in return_type: - return [return_type] if return_type in self.return_type_valid_list else [] + if not hasattr(return_type, "__origin__") or return_type.__origin__ != Union: + if isinstance(return_type, list): + return return_type + return [return_type] # if return_type in self.return_type_valid_list else [] - # If the return type is a Union, then we need to parse it - return_type = extract_union_types(return_type) - return [item for item in return_type if item in self.return_type_valid_list] + # If the return type is a Union, then we need to parse itx + return_type = extract_union_types_from_generic_alias(return_type) + # return [item for item in return_type if item in self.return_type_valid_list] + return return_type @property def get_main_class_name(self):