diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index 4a5218b71..548ad1dd7 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -7,10 +7,6 @@ from cachetools import TTLCache from langchain_core.documents import Document from pydantic import BaseModel -from langflow.custom.code_parser.utils import ( - extract_inner_type_from_generic_alias, - extract_union_types_from_generic_alias, -) from langflow.custom.custom_component.base_component import BaseComponent from langflow.helpers.flow import list_flows, load_flow, run_flow from langflow.schema import Record @@ -20,6 +16,10 @@ from langflow.schema.message import Message from langflow.schema.schema import Log from langflow.services.deps import get_storage_service, get_variable_service, session_scope from langflow.services.storage.service import StorageService +from langflow.type_extraction.type_extraction import ( + extract_inner_type_from_generic_alias, + extract_union_types_from_generic_alias, +) from langflow.utils import validate if TYPE_CHECKING: @@ -327,7 +327,6 @@ class CustomComponent(BaseComponent): return [] return_type = build_method["return_type"] - # If list or List is in the return type, then we remove it and return the inner type if hasattr(return_type, "__origin__") and return_type.__origin__ in [ list, List, diff --git a/src/backend/base/langflow/custom/utils.py b/src/backend/base/langflow/custom/utils.py index d6f80ce7e..19943d499 100644 --- a/src/backend/base/langflow/custom/utils.py +++ b/src/backend/base/langflow/custom/utils.py @@ -11,7 +11,6 @@ from loguru import logger from pydantic import BaseModel from langflow.custom import CustomComponent -from langflow.custom.code_parser.utils import extract_inner_type from langflow.custom.custom_component.component import Component from langflow.custom.directory_reader.utils import ( abuild_custom_component_list_from_path, @@ -26,6 +25,7 @@ from langflow.helpers.custom import format_type from langflow.schema import dotdict from langflow.template.field.base import Input from langflow.template.frontend_node.custom_components import ComponentFrontendNode, CustomComponentFrontendNode +from langflow.type_extraction.type_extraction import extract_inner_type from langflow.utils import validate from langflow.utils.util import get_base_classes diff --git a/src/backend/base/langflow/template/field/base.py b/src/backend/base/langflow/template/field/base.py index fe260c890..22f86b3c8 100644 --- a/src/backend/base/langflow/template/field/base.py +++ b/src/backend/base/langflow/template/field/base.py @@ -1,11 +1,13 @@ from enum import Enum -from types import GenericAlias -from typing import Any, Callable, Optional, Union, _GenericAlias, _UnionGenericAlias, get_args, get_origin +from typing import Any, Callable, GenericAlias, Optional, Union, _GenericAlias, _UnionGenericAlias + from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator from langflow.field_typing import Text from langflow.field_typing.range_spec import RangeSpec +from langflow.helpers.custom import format_type +from langflow.type_extraction.type_extraction import post_process_type class UndefinedType(Enum): @@ -118,21 +120,8 @@ class Input(BaseModel): # this should be done for all types # How to check if v is a type? if isinstance(v, (type, _GenericAlias, GenericAlias, _UnionGenericAlias)): - if isinstance(v, type): - v = v.__name__ - else: - origin = get_origin(v) - args = get_args(v) - if origin and args: - v = f"{origin.__name__}[{', '.join(arg.__name__ if isinstance(arg, type) else str(arg) for arg in args)}]" - # if v is union with None (e.g Union[someType, NoneType]) we need to remove NoneType - # we can return Optional[someType] instead of Union[someType, NoneType] - if "NoneType" in v: - v = v.replace(", NoneType", "") - v = v.replace("Union[", "Optional[") - - else: - v = str(v) + v = post_process_type(v)[0] + v = format_type(v) elif not isinstance(v, str): raise ValueError(f"type must be a string or a type, not {type(v)}") return v @@ -196,15 +185,6 @@ class Output(BaseModel): if not self.selected: self.selected = self.types[0] - @field_validator("display_name", mode="before") - def validate_display_name(cls, v, info): - if not v: - if info.data.get("name"): - return info.data["name"] - else: - raise ValueError("If display_name is not set, name must be set") - return v - @model_serializer(mode="wrap") def serialize_model(self, handler): result = handler(self) @@ -217,4 +197,8 @@ class Output(BaseModel): def validate_model(self): if self.value == UNDEFINED.value: self.value = UNDEFINED + if self.name is None: + raise ValueError("name must be set") + if self.display_name is None: + self.display_name = self.name return self diff --git a/src/backend/base/langflow/type_extraction/__init__.py b/src/backend/base/langflow/type_extraction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/base/langflow/custom/code_parser/utils.py b/src/backend/base/langflow/type_extraction/type_extraction.py similarity index 51% rename from src/backend/base/langflow/custom/code_parser/utils.py rename to src/backend/base/langflow/type_extraction/type_extraction.py index 0f97b4c7b..f22a8eebe 100644 --- a/src/backend/base/langflow/custom/code_parser/utils.py +++ b/src/backend/base/langflow/type_extraction/type_extraction.py @@ -1,15 +1,6 @@ import re from types import GenericAlias -from typing import Any - - -def extract_inner_type(return_type: str) -> str: - """ - Extracts the inner type from a type hint that is a list. - """ - if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE): - return match[1] - return return_type +from typing import Any, List, Union def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any: @@ -21,6 +12,15 @@ def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any: return return_type +def extract_inner_type(return_type: str) -> str: + """ + Extracts the inner type from a type hint that is a list. + """ + if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE): + return match[1] + return return_type + + def extract_union_types(return_type: str) -> list[str]: """ Extracts the inner type from a type hint that is a list. @@ -31,6 +31,47 @@ def extract_union_types(return_type: str) -> list[str]: return [item.strip() for item in return_types] +def extract_uniont_types_from_generic_alias(return_type: GenericAlias) -> list: + """ + Extracts the inner type from a type hint that is a Union. + """ + if isinstance(return_type, list): + return [ + _inner_arg + for _type in return_type + for _inner_arg in _type.__args__ + if _inner_arg not in set((Any, type(None), type(Any))) + ] + + return list(return_type.__args__) + + +def post_process_type(_type): + """ + Process the return type of a function. + + Args: + _type (Any): The return type of the function. + + Returns: + Union[List[Any], Any]: The processed return type. + + """ + if hasattr(_type, "__origin__") and _type.__origin__ in [ + list, + List, + ]: + _type = extract_inner_type_from_generic_alias(_type) + + # If the return type is not a Union, then we just return it as a list + inner_type = _type[0] if isinstance(_type, list) else _type + if not hasattr(inner_type, "__origin__") or inner_type.__origin__ != Union: + return _type if isinstance(_type, list) else [_type] + # If the return type is a Union, then we need to parse it + _type = extract_union_types_from_generic_alias(_type) + return _type + + def extract_union_types_from_generic_alias(return_type: GenericAlias) -> list: """ Extracts the inner type from a type hint that is a Union.