diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index ce8956660..8c0b2537a 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -117,7 +117,15 @@ class CustomComponent(Component, extra=Extra.allow): return "" build_method = build_methods[0] - + return_type = build_method["return_type"] + # It could be a type or a Union[type1, type2] + if "Union" in return_type: + return_type = ( + return_type.replace("Union", "").replace("[", "").replace("]", "") + ) + return_type = return_type.split(",") + return_type = [item.strip() for item in return_type] + return [item for item in return_type if item in self.return_type_valid_list] return build_method["return_type"] @property diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index 76dc144a0..950f227b4 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -1,6 +1,6 @@ import ast import contextlib -from typing import Any +from typing import Any, List from langflow.api.utils import merge_nested_dicts_with_renaming from langflow.interface.agents.base import agent_creator from langflow.interface.chains.base import chain_creator @@ -257,26 +257,27 @@ def get_field_properties(extra_field): return field_name, field_type, field_value, field_required -def add_base_classes(frontend_node, return_type): +def add_base_classes(frontend_node, return_types: List[str]): """Add base classes to the frontend node""" - if return_type not in CUSTOM_COMPONENT_SUPPORTED_TYPES or return_type is None: - raise HTTPException( - status_code=400, - detail={ - "error": ( - "Invalid return type should be one of: " - f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}" - ), - "traceback": traceback.format_exc(), - }, - ) + for return_type in return_types: + if return_type not in CUSTOM_COMPONENT_SUPPORTED_TYPES or return_type is None: + raise HTTPException( + status_code=400, + detail={ + "error": ( + "Invalid return type should be one of: " + f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}" + ), + "traceback": traceback.format_exc(), + }, + ) - return_type_instance = CUSTOM_COMPONENT_SUPPORTED_TYPES.get(return_type) - base_classes = get_base_classes(return_type_instance) + return_type_instance = CUSTOM_COMPONENT_SUPPORTED_TYPES.get(return_type) + base_classes = get_base_classes(return_type_instance) - for base_class in base_classes: - if base_class not in CLASSES_TO_REMOVE: - frontend_node.get("base_classes").append(base_class) + for base_class in base_classes: + if base_class not in CLASSES_TO_REMOVE: + frontend_node.get("base_classes").append(base_class) def build_langchain_template_custom_component(custom_component: CustomComponent):