🐛 fix(custom_component.py): update import statements and function calls to match changes in utils module

🔀 merge(custom_component.py): merge changes from utils module to handle generic aliases correctly in return type parsing
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-11-22 21:10:40 -03:00
commit 15cc7667a7

View file

@ -7,7 +7,10 @@ from fastapi import HTTPException
from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
from langflow.interface.custom.component import Component
from langflow.interface.custom.directory_reader import DirectoryReader
from langflow.interface.custom.utils import extract_inner_type, extract_union_types
from langflow.interface.custom.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
)
from langflow.services.database.models.flow import Flow
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
@ -137,16 +140,19 @@ class CustomComponent(Component):
if not return_type:
return []
# If list or List is in the return type, then we remove it and return the inner type
if return_type.startswith("list") or return_type.startswith("List"):
return_type = extract_inner_type(return_type)
if hasattr(return_type, "__origin__") and return_type.__origin__ in [list, List]:
return_type = extract_inner_type_from_generic_alias(return_type)
# If the return type is not a Union, then we just return it as a list
if "Union" not in return_type:
return [return_type] if return_type in self.return_type_valid_list else []
if not hasattr(return_type, "__origin__") or return_type.__origin__ != Union:
if isinstance(return_type, list):
return return_type
return [return_type] # if return_type in self.return_type_valid_list else []
# If the return type is a Union, then we need to parse it
return_type = extract_union_types(return_type)
return [item for item in return_type if item in self.return_type_valid_list]
# If the return type is a Union, then we need to parse itx
return_type = extract_union_types_from_generic_alias(return_type)
# return [item for item in return_type if item in self.return_type_valid_list]
return return_type
@property
def get_main_class_name(self):