Fix CustomComponent get_build_method return value

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-12-14 09:39:19 -03:00
commit 5de717c770

View file

@ -76,7 +76,7 @@ class CustomComponent(Component):
@property
def tree(self):
return self.get_code_tree(self.code)
return self.get_code_tree(self.code or "")
@property
def get_function_entrypoint_args(self) -> list:
@ -104,11 +104,11 @@ class CustomComponent(Component):
@cachedmethod(operator.attrgetter("cache"))
def get_build_method(self):
if not self.code:
return []
return {}
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
if not component_classes:
return []
return {}
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
@ -116,19 +116,13 @@ class CustomComponent(Component):
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
]
if not build_methods:
return []
return build_methods[0]
return build_methods[0] if build_methods else {}
@property
def get_function_entrypoint_return_type(self) -> List[Any]:
build_method = self.get_build_method()
if not build_method:
if not build_method or not build_method.get("has_return"):
return []
elif not build_method["has_return"]:
return []
return_type = build_method["return_type"]
# If list or List is in the return type, then we remove it and return the inner type
@ -137,10 +131,7 @@ class CustomComponent(Component):
# If the return type is not a Union, then we just return it as a list
if not hasattr(return_type, "__origin__") or return_type.__origin__ != Union:
if isinstance(return_type, list):
return return_type
return [return_type]
return return_type if isinstance(return_type, list) else [return_type]
# If the return type is a Union, then we need to parse itx
return_type = extract_union_types_from_generic_alias(return_type)
return return_type
@ -206,9 +197,7 @@ class CustomComponent(Component):
"""Returns a function that returns the value at the given index in the iterable."""
def get_index(iterable: List[Any]):
if iterable:
return iterable[value]
return iterable
return iterable[value] if iterable else iterable
return get_index