📝 (base.py): Refactor type processing logic to use the newly created post_process_type function for better readability and maintainability

📝 (custom_component.py): Move type extraction functions to type_extraction module for better organization and separation of concerns

📝 (utils.py): Update import statements to reflect the move of type extraction functions to type_extraction module

📝 (type_extraction.py): Add functions to extract inner types and union types from generic aliases for type extraction operations
This commit is contained in:
ogabrielluiz 2024-06-12 10:44:36 -03:00
commit f2c30be721
5 changed files with 66 additions and 42 deletions

View file

@ -7,10 +7,6 @@ from cachetools import TTLCache
from langchain_core.documents import Document
from pydantic import BaseModel
from langflow.custom.code_parser.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
)
from langflow.custom.custom_component.base_component import BaseComponent
from langflow.helpers.flow import list_flows, load_flow, run_flow
from langflow.schema import Record
@ -20,6 +16,10 @@ from langflow.schema.message import Message
from langflow.schema.schema import Log
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
from langflow.services.storage.service import StorageService
from langflow.type_extraction.type_extraction import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
)
from langflow.utils import validate
if TYPE_CHECKING:
@ -327,7 +327,6 @@ class CustomComponent(BaseComponent):
return []
return_type = build_method["return_type"]
# If list or List is in the return type, then we remove it and return the inner type
if hasattr(return_type, "__origin__") and return_type.__origin__ in [
list,
List,

View file

@ -11,7 +11,6 @@ from loguru import logger
from pydantic import BaseModel
from langflow.custom import CustomComponent
from langflow.custom.code_parser.utils import extract_inner_type
from langflow.custom.custom_component.component import Component
from langflow.custom.directory_reader.utils import (
abuild_custom_component_list_from_path,
@ -26,6 +25,7 @@ from langflow.helpers.custom import format_type
from langflow.schema import dotdict
from langflow.template.field.base import Input
from langflow.template.frontend_node.custom_components import ComponentFrontendNode, CustomComponentFrontendNode
from langflow.type_extraction.type_extraction import extract_inner_type
from langflow.utils import validate
from langflow.utils.util import get_base_classes

View file

@ -1,11 +1,13 @@
from enum import Enum
from types import GenericAlias
from typing import Any, Callable, Optional, Union, _GenericAlias, _UnionGenericAlias, get_args, get_origin
from typing import Any, Callable, GenericAlias, Optional, Union, _GenericAlias, _UnionGenericAlias
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator
from langflow.field_typing import Text
from langflow.field_typing.range_spec import RangeSpec
from langflow.helpers.custom import format_type
from langflow.type_extraction.type_extraction import post_process_type
class UndefinedType(Enum):
@ -118,21 +120,8 @@ class Input(BaseModel):
# this should be done for all types
# How to check if v is a type?
if isinstance(v, (type, _GenericAlias, GenericAlias, _UnionGenericAlias)):
if isinstance(v, type):
v = v.__name__
else:
origin = get_origin(v)
args = get_args(v)
if origin and args:
v = f"{origin.__name__}[{', '.join(arg.__name__ if isinstance(arg, type) else str(arg) for arg in args)}]"
# if v is union with None (e.g Union[someType, NoneType]) we need to remove NoneType
# we can return Optional[someType] instead of Union[someType, NoneType]
if "NoneType" in v:
v = v.replace(", NoneType", "")
v = v.replace("Union[", "Optional[")
else:
v = str(v)
v = post_process_type(v)[0]
v = format_type(v)
elif not isinstance(v, str):
raise ValueError(f"type must be a string or a type, not {type(v)}")
return v
@ -196,15 +185,6 @@ class Output(BaseModel):
if not self.selected:
self.selected = self.types[0]
@field_validator("display_name", mode="before")
def validate_display_name(cls, v, info):
if not v:
if info.data.get("name"):
return info.data["name"]
else:
raise ValueError("If display_name is not set, name must be set")
return v
@model_serializer(mode="wrap")
def serialize_model(self, handler):
result = handler(self)
@ -217,4 +197,8 @@ class Output(BaseModel):
def validate_model(self):
if self.value == UNDEFINED.value:
self.value = UNDEFINED
if self.name is None:
raise ValueError("name must be set")
if self.display_name is None:
self.display_name = self.name
return self

View file

@ -1,15 +1,6 @@
import re
from types import GenericAlias
from typing import Any
def extract_inner_type(return_type: str) -> str:
"""
Extracts the inner type from a type hint that is a list.
"""
if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE):
return match[1]
return return_type
from typing import Any, List, Union
def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any:
@ -21,6 +12,15 @@ def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any:
return return_type
def extract_inner_type(return_type: str) -> str:
"""
Extracts the inner type from a type hint that is a list.
"""
if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE):
return match[1]
return return_type
def extract_union_types(return_type: str) -> list[str]:
"""
Extracts the inner type from a type hint that is a list.
@ -31,6 +31,47 @@ def extract_union_types(return_type: str) -> list[str]:
return [item.strip() for item in return_types]
def extract_uniont_types_from_generic_alias(return_type: GenericAlias) -> list:
"""
Extracts the inner type from a type hint that is a Union.
"""
if isinstance(return_type, list):
return [
_inner_arg
for _type in return_type
for _inner_arg in _type.__args__
if _inner_arg not in set((Any, type(None), type(Any)))
]
return list(return_type.__args__)
def post_process_type(_type):
"""
Process the return type of a function.
Args:
_type (Any): The return type of the function.
Returns:
Union[List[Any], Any]: The processed return type.
"""
if hasattr(_type, "__origin__") and _type.__origin__ in [
list,
List,
]:
_type = extract_inner_type_from_generic_alias(_type)
# If the return type is not a Union, then we just return it as a list
inner_type = _type[0] if isinstance(_type, list) else _type
if not hasattr(inner_type, "__origin__") or inner_type.__origin__ != Union:
return _type if isinstance(_type, list) else [_type]
# If the return type is a Union, then we need to parse it
_type = extract_union_types_from_generic_alias(_type)
return _type
def extract_union_types_from_generic_alias(return_type: GenericAlias) -> list:
"""
Extracts the inner type from a type hint that is a Union.