From 72f88e1a168d5122a561f5537df793a6b384e23f Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 9 Aug 2023 14:36:45 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(custom=5Fcomponent.py):=20ha?= =?UTF-8?q?ndle=20return=5Ftype=20as=20a=20Union[type1,=20type2]=20and=20a?= =?UTF-8?q?dd=20support=20for=20multiple=20return=20types=20in=20add=5Fbas?= =?UTF-8?q?e=5Fclasses=20function=20=F0=9F=90=9B=20fix(types.py):=20handle?= =?UTF-8?q?=20multiple=20return=20types=20in=20add=5Fbase=5Fclasses=20func?= =?UTF-8?q?tion=20and=20raise=20HTTPException=20with=20appropriate=20error?= =?UTF-8?q?=20message=20if=20return=20type=20is=20invalid?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../interface/custom/custom_component.py | 10 ++++- src/backend/langflow/interface/types.py | 37 ++++++++++--------- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index ce8956660..8c0b2537a 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -117,7 +117,15 @@ class CustomComponent(Component, extra=Extra.allow): return "" build_method = build_methods[0] - + return_type = build_method["return_type"] + # It could be a type or a Union[type1, type2] + if "Union" in return_type: + return_type = ( + return_type.replace("Union", "").replace("[", "").replace("]", "") + ) + return_type = return_type.split(",") + return_type = [item.strip() for item in return_type] + return [item for item in return_type if item in self.return_type_valid_list] return build_method["return_type"] @property diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index 76dc144a0..950f227b4 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -1,6 +1,6 @@ import ast import contextlib -from typing import Any +from typing import Any, List from langflow.api.utils import merge_nested_dicts_with_renaming from langflow.interface.agents.base import agent_creator from langflow.interface.chains.base import chain_creator @@ -257,26 +257,27 @@ def get_field_properties(extra_field): return field_name, field_type, field_value, field_required -def add_base_classes(frontend_node, return_type): +def add_base_classes(frontend_node, return_types: List[str]): """Add base classes to the frontend node""" - if return_type not in CUSTOM_COMPONENT_SUPPORTED_TYPES or return_type is None: - raise HTTPException( - status_code=400, - detail={ - "error": ( - "Invalid return type should be one of: " - f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}" - ), - "traceback": traceback.format_exc(), - }, - ) + for return_type in return_types: + if return_type not in CUSTOM_COMPONENT_SUPPORTED_TYPES or return_type is None: + raise HTTPException( + status_code=400, + detail={ + "error": ( + "Invalid return type should be one of: " + f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}" + ), + "traceback": traceback.format_exc(), + }, + ) - return_type_instance = CUSTOM_COMPONENT_SUPPORTED_TYPES.get(return_type) - base_classes = get_base_classes(return_type_instance) + return_type_instance = CUSTOM_COMPONENT_SUPPORTED_TYPES.get(return_type) + base_classes = get_base_classes(return_type_instance) - for base_class in base_classes: - if base_class not in CLASSES_TO_REMOVE: - frontend_node.get("base_classes").append(base_class) + for base_class in base_classes: + if base_class not in CLASSES_TO_REMOVE: + frontend_node.get("base_classes").append(base_class) def build_langchain_template_custom_component(custom_component: CustomComponent):