🐛 fix(custom_component.py): handle return_type as a Union[type1, type2] and add support for multiple return types in add_base_classes function

🐛 fix(types.py): handle multiple return types in add_base_classes function and raise HTTPException with appropriate error message if return type is invalid
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-09 14:36:45 -03:00
commit 72f88e1a16
2 changed files with 28 additions and 19 deletions

View file

@ -117,7 +117,15 @@ class CustomComponent(Component, extra=Extra.allow):
return ""
build_method = build_methods[0]
return_type = build_method["return_type"]
# It could be a type or a Union[type1, type2]
if "Union" in return_type:
return_type = (
return_type.replace("Union", "").replace("[", "").replace("]", "")
)
return_type = return_type.split(",")
return_type = [item.strip() for item in return_type]
return [item for item in return_type if item in self.return_type_valid_list]
return build_method["return_type"]
@property

View file

@ -1,6 +1,6 @@
import ast
import contextlib
from typing import Any
from typing import Any, List
from langflow.api.utils import merge_nested_dicts_with_renaming
from langflow.interface.agents.base import agent_creator
from langflow.interface.chains.base import chain_creator
@ -257,26 +257,27 @@ def get_field_properties(extra_field):
return field_name, field_type, field_value, field_required
def add_base_classes(frontend_node, return_type):
def add_base_classes(frontend_node, return_types: List[str]):
"""Add base classes to the frontend node"""
if return_type not in CUSTOM_COMPONENT_SUPPORTED_TYPES or return_type is None:
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid return type should be one of: "
f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}"
),
"traceback": traceback.format_exc(),
},
)
for return_type in return_types:
if return_type not in CUSTOM_COMPONENT_SUPPORTED_TYPES or return_type is None:
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid return type should be one of: "
f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}"
),
"traceback": traceback.format_exc(),
},
)
return_type_instance = CUSTOM_COMPONENT_SUPPORTED_TYPES.get(return_type)
base_classes = get_base_classes(return_type_instance)
return_type_instance = CUSTOM_COMPONENT_SUPPORTED_TYPES.get(return_type)
base_classes = get_base_classes(return_type_instance)
for base_class in base_classes:
if base_class not in CLASSES_TO_REMOVE:
frontend_node.get("base_classes").append(base_class)
for base_class in base_classes:
if base_class not in CLASSES_TO_REMOVE:
frontend_node.get("base_classes").append(base_class)
def build_langchain_template_custom_component(custom_component: CustomComponent):