refactor: Update utils.py and custom.py to improve type handling and code readability

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-20 12:03:08 -03:00
commit 80cf34ac98
2 changed files with 15 additions and 4 deletions

View file

@ -3,6 +3,7 @@ import contextlib
import re
import traceback
import warnings
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import UUID
@ -21,7 +22,7 @@ from langflow.custom.directory_reader.utils import (
from langflow.custom.eval import eval_custom_component_code
from langflow.custom.schema import MissingDefault
from langflow.field_typing.range_spec import RangeSpec
from langflow.helpers.custom import format_type
from langflow.helpers.custom import get_all_types_from_type
from langflow.schema import dotdict
from langflow.template.field.base import Input
from langflow.template.frontend_node.custom_components import ComponentFrontendNode, CustomComponentFrontendNode
@ -370,8 +371,8 @@ def build_custom_component_template_from_inputs(
if output.types:
continue
return_types = custom_component.get_method_return_type(output.method)
return_types = [format_type(return_type) for return_type in return_types]
output.add_types(return_types)
all_types = [get_all_types_from_type(return_type) for return_type in return_types]
output.add_types(chain.from_iterable(all_types))
output.set_selected()
# Validate that there is not name overlap between inputs and outputs
frontend_node.validate()

View file

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, get_args
def format_type(type_: Any) -> str:
@ -11,3 +11,13 @@ def format_type(type_: Any) -> str:
else:
type_ = str(type_)
return type_
def get_all_types_from_type(type_: Any) -> str:
args = get_args(type_)
if args:
formatted_types = [format_type(arg) for arg in args]
formatted_types.insert(0, format_type(type_))
return formatted_types
else:
return [format_type(type_)]