feat: support bool type variable frontend (#24437)

Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
Joel 2025-08-26 18:16:05 +08:00 committed by GitHub
commit dac72b078d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
126 changed files with 3832 additions and 512 deletions

11
api/child_class.py Normal file
View file

@ -0,0 +1,11 @@
from tests.integration_tests.utils.parent_class import ParentClass
class ChildClass(ParentClass):
"""Test child class for module import helper tests"""
def __init__(self, name):
super().__init__(name)
def get_name(self):
return f"Child: {self.name}"

View file

@ -3,6 +3,17 @@ import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.external_data_tool.factory import ExternalDataToolFactory
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
[
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.EXTERNAL_DATA_TOOL,
VariableEntityType.CHECKBOX,
]
)
class BasicVariablesConfigManager:
@classmethod
@ -47,6 +58,7 @@ class BasicVariablesConfigManager:
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.SELECT,
VariableEntityType.CHECKBOX,
}:
variable = variables[variable_type]
variable_entities.append(
@ -96,8 +108,17 @@ class BasicVariablesConfigManager:
variables = []
for item in config["user_input_form"]:
key = list(item.keys())[0]
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
# if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
if key not in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.EXTERNAL_DATA_TOOL,
VariableEntityType.CHECKBOX,
}:
allowed_keys = ", ".join(i.value for i in _ALLOWED_VARIABLE_ENTITY_TYPE)
raise ValueError(f"Keys in user_input_form list can only be {allowed_keys}")
form_item = item[key]
if "label" not in form_item:

View file

@ -97,6 +97,7 @@ class VariableEntityType(StrEnum):
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
class VariableEntity(BaseModel):

View file

@ -103,18 +103,23 @@ class BaseAppGenerator:
f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string"
)
if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str):
# handle empty string case
if not value.strip():
return None
# may raise ValueError if user_input_value is not a valid number
try:
if "." in value:
return float(value)
else:
return int(value)
except ValueError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
if variable_entity.type == VariableEntityType.NUMBER:
if isinstance(value, (int, float)):
return value
elif isinstance(value, str):
# handle empty string case
if not value.strip():
return None
# may raise ValueError if user_input_value is not a valid number
try:
if "." in value:
return float(value)
else:
return int(value)
except ValueError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
else:
raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}")
match variable_entity.type:
case VariableEntityType.SELECT:
@ -144,6 +149,11 @@ class BaseAppGenerator:
raise ValueError(
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
)
case VariableEntityType.CHECKBOX:
if not isinstance(value, bool):
raise ValueError(f"{variable_entity.variable} in input form must be a valid boolean value")
case _:
raise AssertionError("this statement should be unreachable.")
return value

View file

@ -151,6 +151,11 @@ class FileSegment(Segment):
return ""
class BooleanSegment(Segment):
value_type: SegmentType = SegmentType.BOOLEAN
value: bool
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any]
@ -198,6 +203,11 @@ class ArrayFileSegment(ArraySegment):
return ""
class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool]
def get_segment_discriminator(v: Any) -> SegmentType | None:
if isinstance(v, Segment):
return v.value_type
@ -231,11 +241,13 @@ SegmentUnion: TypeAlias = Annotated[
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
| Annotated[FileSegment, Tag(SegmentType.FILE)]
| Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
),
Discriminator(get_segment_discriminator),
]

View file

@ -6,7 +6,12 @@ from core.file.models import File
class ArrayValidation(StrEnum):
"""Strategy for validating array elements"""
"""Strategy for validating array elements.
Note:
The `NONE` and `FIRST` strategies are primarily for compatibility purposes.
Avoid using them in new code whenever possible.
"""
# Skip element validation (only check array container)
NONE = "none"
@ -27,12 +32,14 @@ class SegmentType(StrEnum):
SECRET = "secret"
FILE = "file"
BOOLEAN = "boolean"
ARRAY_ANY = "array[any]"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILE = "array[file]"
ARRAY_BOOLEAN = "array[boolean]"
NONE = "none"
@ -76,12 +83,18 @@ class SegmentType(StrEnum):
return SegmentType.ARRAY_FILE
case SegmentType.NONE:
return SegmentType.ARRAY_ANY
case SegmentType.BOOLEAN:
return SegmentType.ARRAY_BOOLEAN
case _:
# This should be unreachable.
raise ValueError(f"not supported value {value}")
if value is None:
return SegmentType.NONE
elif isinstance(value, int) and not isinstance(value, bool):
# Important: The check for `bool` must precede the check for `int`,
# as `bool` is a subclass of `int` in Python's type hierarchy.
elif isinstance(value, bool):
return SegmentType.BOOLEAN
elif isinstance(value, int):
return SegmentType.INTEGER
elif isinstance(value, float):
return SegmentType.FLOAT
@ -111,7 +124,7 @@ class SegmentType(StrEnum):
else:
return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value)
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.ALL) -> bool:
"""
Check if a value matches the segment type.
Users of `SegmentType` should call this method, instead of using
@ -126,6 +139,10 @@ class SegmentType(StrEnum):
"""
if self.is_array_type():
return self._validate_array(value, array_validation)
# Important: The check for `bool` must precede the check for `int`,
# as `bool` is a subclass of `int` in Python's type hierarchy.
elif self == SegmentType.BOOLEAN:
return isinstance(value, bool)
elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]:
return isinstance(value, (int, float))
elif self == SegmentType.STRING:
@ -141,6 +158,27 @@ class SegmentType(StrEnum):
else:
raise AssertionError("this statement should be unreachable.")
@staticmethod
def cast_value(value: Any, type_: "SegmentType") -> Any:
# Cast Python's `bool` type to `int` when the runtime type requires
# an integer or number.
#
# This ensures compatibility with existing workflows that may use `bool` as
# `int`, since in Python's type system, `bool` is a subtype of `int`.
#
# This function exists solely to maintain compatibility with existing workflows.
# It should not be used to compromise the integrity of the runtime type system.
# No additional casting rules should be introduced to this function.
if type_ in (
SegmentType.INTEGER,
SegmentType.NUMBER,
) and isinstance(value, bool):
return int(value)
if type_ == SegmentType.ARRAY_NUMBER and all(isinstance(i, bool) for i in value):
return [int(i) for i in value]
return value
def exposed_type(self) -> "SegmentType":
"""Returns the type exposed to the frontend.
@ -150,6 +188,20 @@ class SegmentType(StrEnum):
return SegmentType.NUMBER
return self
def element_type(self) -> "SegmentType | None":
"""Return the element type of the current segment type, or `None` if the element type is undefined.
Raises:
ValueError: If the current segment type is not an array type.
Note:
For certain array types, such as `SegmentType.ARRAY_ANY`, their element types are not defined
by the runtime system. In such cases, this method will return `None`.
"""
if not self.is_array_type():
raise ValueError(f"element_type is only supported by array type, got {self}")
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
# ARRAY_ANY does not have corresponding element type.
@ -157,6 +209,7 @@ _ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
SegmentType.ARRAY_FILE: SegmentType.FILE,
SegmentType.ARRAY_BOOLEAN: SegmentType.BOOLEAN,
}
_ARRAY_TYPES = frozenset(

View file

@ -8,11 +8,13 @@ from core.helper import encrypter
from .segments import (
ArrayAnySegment,
ArrayBooleanSegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArraySegment,
ArrayStringSegment,
BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
@ -96,10 +98,18 @@ class FileVariable(FileSegment, Variable):
pass
class BooleanVariable(BooleanSegment, Variable):
pass
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
pass
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
# Use `Variable` for type hinting when serialization is not required.
#
@ -114,11 +124,13 @@ VariableUnion: TypeAlias = Annotated[
| Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
| Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
| Annotated[FileVariable, Tag(SegmentType.FILE)]
| Annotated[BooleanVariable, Tag(SegmentType.BOOLEAN)]
| Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
),
Discriminator(get_segment_discriminator),

View file

@ -8,6 +8,7 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@ -119,6 +120,14 @@ class CodeNode(BaseNode):
return value.replace("\x00", "")
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
if value is None:
return None
if not isinstance(value, bool):
raise OutputValidationError(f"Output variable `{variable}` must be a boolean")
return value
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
"""
Check number
@ -173,6 +182,8 @@ class CodeNode(BaseNode):
prefix=f"{prefix}.{output_name}" if prefix else output_name,
depth=depth + 1,
)
elif isinstance(output_value, bool):
self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name)
elif isinstance(output_value, int | float):
self._check_number(
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
@ -232,7 +243,7 @@ class CodeNode(BaseNode):
if output_name not in result:
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
if output_config.type == "object":
if output_config.type == SegmentType.OBJECT:
# check if output is object
if not isinstance(result.get(output_name), dict):
if result[output_name] is None:
@ -249,18 +260,28 @@ class CodeNode(BaseNode):
prefix=f"{prefix}.{output_name}",
depth=depth + 1,
)
elif output_config.type == "number":
elif output_config.type == SegmentType.NUMBER:
# check if number available
transformed_result[output_name] = self._check_number(
value=result[output_name], variable=f"{prefix}{dot}{output_name}"
)
elif output_config.type == "string":
checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}")
# If the output is a boolean and the output schema specifies a NUMBER type,
# convert the boolean value to an integer.
#
# This ensures compatibility with existing workflows that may use
# `True` and `False` as values for NUMBER type outputs.
transformed_result[output_name] = self._convert_boolean_to_int(checked)
elif output_config.type == SegmentType.STRING:
# check if string available
transformed_result[output_name] = self._check_string(
value=result[output_name],
variable=f"{prefix}{dot}{output_name}",
)
elif output_config.type == "array[number]":
elif output_config.type == SegmentType.BOOLEAN:
transformed_result[output_name] = self._check_boolean(
value=result[output_name],
variable=f"{prefix}{dot}{output_name}",
)
elif output_config.type == SegmentType.ARRAY_NUMBER:
# check if array of number available
if not isinstance(result[output_name], list):
if result[output_name] is None:
@ -278,10 +299,17 @@ class CodeNode(BaseNode):
)
transformed_result[output_name] = [
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
# If the element is a boolean and the output schema specifies a `array[number]` type,
# convert the boolean value to an integer.
#
# This ensures compatibility with existing workflows that may use
# `True` and `False` as values for NUMBER type outputs.
self._convert_boolean_to_int(
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"),
)
for i, value in enumerate(result[output_name])
]
elif output_config.type == "array[string]":
elif output_config.type == SegmentType.ARRAY_STRING:
# check if array of string available
if not isinstance(result[output_name], list):
if result[output_name] is None:
@ -302,7 +330,7 @@ class CodeNode(BaseNode):
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
for i, value in enumerate(result[output_name])
]
elif output_config.type == "array[object]":
elif output_config.type == SegmentType.ARRAY_OBJECT:
# check if array of object available
if not isinstance(result[output_name], list):
if result[output_name] is None:
@ -340,6 +368,22 @@ class CodeNode(BaseNode):
)
for i, value in enumerate(result[output_name])
]
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
# check if array of object available
if not isinstance(result[output_name], list):
if result[output_name] is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
transformed_result[output_name] = [
self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
for i, value in enumerate(result[output_name])
]
else:
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
@ -374,3 +418,16 @@ class CodeNode(BaseNode):
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
@staticmethod
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
"""This function convert boolean to integers when the output schema specifies a NUMBER type.
This ensures compatibility with existing workflows that may use
`True` and `False` as values for NUMBER type outputs.
"""
if value is None:
return None
if isinstance(value, bool):
return int(value)
return value

View file

@ -1,11 +1,31 @@
from typing import Literal, Optional
from typing import Annotated, Literal, Optional
from pydantic import BaseModel
from pydantic import AfterValidator, BaseModel
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.types import SegmentType
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
[
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
]
)
def _validate_type(segment_type: SegmentType) -> SegmentType:
if segment_type not in _ALLOWED_OUTPUT_FROM_CODE:
raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}")
return segment_type
class CodeNodeData(BaseNodeData):
"""
@ -13,7 +33,7 @@ class CodeNodeData(BaseNodeData):
"""
class Output(BaseModel):
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: Optional[dict[str, "CodeNodeData.Output"]] = None
class Dependency(BaseModel):

View file

@ -1,36 +1,43 @@
from collections.abc import Sequence
from typing import Literal
from enum import StrEnum
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
_Condition = Literal[
class FilterOperator(StrEnum):
# string conditions
"contains",
"start with",
"end with",
"is",
"in",
"empty",
"not contains",
"is not",
"not in",
"not empty",
CONTAINS = "contains"
START_WITH = "start with"
END_WITH = "end with"
IS = "is"
IN = "in"
EMPTY = "empty"
NOT_CONTAINS = "not contains"
IS_NOT = "is not"
NOT_IN = "not in"
NOT_EMPTY = "not empty"
# number conditions
"=",
"",
"<",
">",
"",
"",
]
EQUAL = "="
NOT_EQUAL = ""
LESS_THAN = "<"
GREATER_THAN = ">"
GREATER_THAN_OR_EQUAL = ""
LESS_THAN_OR_EQUAL = ""
class Order(StrEnum):
ASC = "asc"
DESC = "desc"
class FilterCondition(BaseModel):
key: str = ""
comparison_operator: _Condition = "contains"
value: str | Sequence[str] = ""
comparison_operator: FilterOperator = FilterOperator.CONTAINS
# the value is bool if the filter operator is comparing with
# a boolean constant.
value: str | Sequence[str] | bool = ""
class FilterBy(BaseModel):
@ -38,10 +45,10 @@ class FilterBy(BaseModel):
conditions: Sequence[FilterCondition] = Field(default_factory=list)
class OrderBy(BaseModel):
class OrderByConfig(BaseModel):
enabled: bool = False
key: str = ""
value: Literal["asc", "desc"] = "asc"
value: Order = Order.ASC
class Limit(BaseModel):
@ -57,6 +64,6 @@ class ExtractConfig(BaseModel):
class ListOperatorNodeData(BaseNodeData):
variable: Sequence[str] = Field(default_factory=list)
filter_by: FilterBy
order_by: OrderBy
order_by: OrderByConfig
limit: Limit
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)

View file

@ -1,18 +1,40 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal, Optional, Union
from typing import Any, Optional, TypeAlias, TypeVar
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .entities import ListOperatorNodeData
from .entities import FilterOperator, ListOperatorNodeData, Order
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
_SUPPORTED_TYPES_TUPLE = (
ArrayFileSegment,
ArrayNumberSegment,
ArrayStringSegment,
ArrayBooleanSegment,
)
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
_T = TypeVar("_T")
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
"""Returns the negation of a given filter function. If the original filter
returns `True` for a value, the negated filter will return `False`, and vice versa.
"""
def wrapper(value: _T) -> bool:
return not filter_(value)
return wrapper
class ListOperatorNode(BaseNode):
_node_type = NodeType.LIST_OPERATOR
@ -69,11 +91,8 @@ class ListOperatorNode(BaseNode):
process_data=process_data,
outputs=outputs,
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
)
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
@ -122,9 +141,7 @@ class ListOperatorNode(BaseNode):
outputs=outputs,
)
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self._node_data.filter_by.conditions:
@ -154,33 +171,35 @@ class ListOperatorNode(BaseNode):
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayBooleanSegment):
if not isinstance(condition.value, bool):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statment should be unreachable.")
return variable
def _apply_order(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statement should be unreachable")
return variable
def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
result = variable.value[: self._node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
@ -232,11 +251,11 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
case "empty":
return lambda x: x == ""
case "not contains":
return lambda x: not _contains(value)(x)
return _negation(_contains(value))
case "is not":
return lambda x: not _is(value)(x)
return _negation(_is(value))
case "not in":
return lambda x: not _in(value)(x)
return _negation(_in(value))
case "not empty":
return lambda x: x != ""
case _:
@ -248,7 +267,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
case "in":
return _in(value)
case "not in":
return lambda x: not _in(value)(x)
return _negation(_in(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
@ -271,6 +290,16 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
match condition:
case FilterOperator.IS:
return _is(value)
case FilterOperator.IS_NOT:
return _negation(_is(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
@ -298,7 +327,7 @@ def _endswith(value: str) -> Callable[[str], bool]:
return lambda x: x.endswith(value)
def _is(value: str) -> Callable[[str], bool]:
def _is(value: _T) -> Callable[[_T], bool]:
return lambda x: x == value
@ -330,21 +359,13 @@ def _ge(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x >= value
def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
extract_func: Callable[[File], Any]
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
elif order_by == "size":
extract_func = _get_file_extract_number_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
else:
raise InvalidKeyError(f"Invalid order key: {order_by}")

View file

@ -3,7 +3,7 @@ import io
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
@ -55,7 +55,6 @@ from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
@ -90,6 +89,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent
logger = logging.getLogger(__name__)
@ -161,7 +161,7 @@ class LLMNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
node_inputs: Optional[dict[str, Any]] = None
process_data = None
result_text = ""

View file

@ -12,9 +12,11 @@ _VALID_VAR_TYPE = frozenset(
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
]
)

View file

@ -404,11 +404,11 @@ class LoopNode(BaseNode):
for node_id in loop_graph.node_ids:
variable_pool.remove([node_id])
_outputs = {}
_outputs: dict[str, Segment | int | None] = {}
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
_loop_variable_segment = variable_pool.get(loop_variable_selector)
if _loop_variable_segment:
_outputs[loop_variable_key] = _loop_variable_segment.value
_outputs[loop_variable_key] = _loop_variable_segment
else:
_outputs[loop_variable_key] = None
@ -522,21 +522,30 @@ class LoopNode(BaseNode):
return variable_mapping
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
if var_type in ["array[string]", "array[number]", "array[object]"]:
if value and isinstance(value, str):
value = json.loads(value)
if var_type in [
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
]:
if original_value and isinstance(original_value, str):
value = json.loads(original_value)
else:
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
value = []
elif var_type == SegmentType.ARRAY_BOOLEAN:
value = original_value
else:
raise AssertionError("this statement should be unreachable.")
try:
return build_segment_with_type(var_type, value)
return build_segment_with_type(var_type, value=value)
except TypeMismatchError as type_exc:
# Attempt to parse the value as a JSON-encoded string, if applicable.
if not isinstance(value, str):
if not isinstance(original_value, str):
raise
try:
value = json.loads(value)
value = json.loads(original_value)
except ValueError:
raise type_exc
return build_segment_with_type(var_type, value)

View file

@ -1,10 +1,46 @@
from typing import Any, Literal, Optional
from typing import Annotated, Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
from pydantic import (
BaseModel,
BeforeValidator,
Field,
field_validator,
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
_OLD_BOOL_TYPE_NAME = "bool"
_OLD_SELECT_TYPE_NAME = "select"
_VALID_PARAMETER_TYPES = frozenset(
[
SegmentType.STRING, # "string",
SegmentType.NUMBER, # "number",
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
]
)
def _validate_type(parameter_type: str) -> SegmentType:
if not isinstance(parameter_type, str):
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
if parameter_type not in _VALID_PARAMETER_TYPES:
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
if parameter_type == _OLD_BOOL_TYPE_NAME:
return SegmentType.BOOLEAN
elif parameter_type == _OLD_SELECT_TYPE_NAME:
return SegmentType.STRING
return SegmentType(parameter_type)
class _ParameterConfigError(Exception):
@ -17,7 +53,7 @@ class ParameterConfig(BaseModel):
"""
name: str
type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"]
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
options: Optional[list[str]] = None
description: str
required: bool
@ -32,17 +68,20 @@ class ParameterConfig(BaseModel):
return str(value)
def is_array_type(self) -> bool:
return self.type in ("array[string]", "array[number]", "array[object]")
return self.type.is_array_type()
def element_type(self) -> Literal["string", "number", "object"]:
if self.type == "array[number]":
return "number"
elif self.type == "array[string]":
return "string"
elif self.type == "array[object]":
return "object"
else:
raise _ParameterConfigError(f"{self.type} is not array type.")
def element_type(self) -> SegmentType:
"""Return the element type of the parameter.
Raises a ValueError if the parameter's type is not an array type.
"""
element_type = self.type.element_type()
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
#
# See: _VALID_PARAMETER_TYPES for reference.
assert element_type is not None, f"the element type should not be None, {self.type=}"
return element_type
class ParameterExtractorNodeData(BaseNodeData):
@ -74,16 +113,18 @@ class ParameterExtractorNodeData(BaseNodeData):
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}
if parameter.type in {"string", "select"}:
if parameter.type == SegmentType.STRING:
parameter_schema["type"] = "string"
elif parameter.type.startswith("array"):
elif parameter.type.is_array_type():
parameter_schema["type"] = "array"
nested_type = parameter.type[6:-1]
parameter_schema["items"] = {"type": nested_type}
element_type = parameter.type.element_type()
if element_type is None:
raise AssertionError("element type should not be None.")
parameter_schema["items"] = {"type": element_type.value}
else:
parameter_schema["type"] = parameter.type
if parameter.type == "select":
if parameter.options:
parameter_schema["enum"] = parameter.options
parameters["properties"][parameter.name] = parameter_schema

View file

@ -1,3 +1,8 @@
from typing import Any
from core.variables.types import SegmentType
class ParameterExtractorNodeError(ValueError):
"""Base error for ParameterExtractorNode."""
@ -48,3 +53,23 @@ class InvalidArrayValueError(ParameterExtractorNodeError):
class InvalidModelModeError(ParameterExtractorNodeError):
"""Raised when the model mode is invalid."""
class InvalidValueTypeError(ParameterExtractorNodeError):
def __init__(
self,
/,
parameter_name: str,
expected_type: SegmentType,
actual_type: SegmentType | None,
value: Any,
) -> None:
message = (
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
)
super().__init__(message)
self.parameter_name = parameter_name
self.expected_type = expected_type
self.actual_type = actual_type
self.value = value

View file

@ -26,7 +26,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import SegmentType
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -39,16 +39,13 @@ from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
from .exc import (
InvalidArrayValueError,
InvalidBoolValueError,
InvalidInvokeResultError,
InvalidModelModeError,
InvalidModelTypeError,
InvalidNumberOfParametersError,
InvalidNumberValueError,
InvalidSelectValueError,
InvalidStringValueError,
InvalidTextContentTypeError,
InvalidValueTypeError,
ModelSchemaNotFoundError,
ParameterExtractorNodeError,
RequiredParameterMissingError,
@ -549,9 +546,6 @@ class ParameterExtractorNode(BaseNode):
return prompt_messages
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
Validate result.
"""
if len(data.parameters) != len(result):
raise InvalidNumberOfParametersError("Invalid number of parameters")
@ -559,101 +553,106 @@ class ParameterExtractorNode(BaseNode):
if parameter.required and parameter.name not in result:
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith("array"):
parameters = result.get(parameter.name)
if not isinstance(parameters, list):
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
nested_type = parameter.type[6:-1]
for item in parameters:
if nested_type == "number" and not isinstance(item, int | float):
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
if nested_type == "string" and not isinstance(item, str):
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
if nested_type == "object" and not isinstance(item, dict):
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
param_value = result.get(parameter.name)
if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL):
inferred_type = SegmentType.infer_segment_type(param_value)
raise InvalidValueTypeError(
parameter_name=parameter.name,
expected_type=parameter.type,
actual_type=inferred_type,
value=param_value,
)
if parameter.type == SegmentType.STRING and parameter.options:
if param_value not in parameter.options:
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
return result
@staticmethod
def _transform_number(value: int | float | str | bool) -> int | float | None:
"""
Attempts to transform the input into an integer or float.
Returns:
int or float: The transformed number if the conversion is successful.
None: If the transformation fails.
Note:
Boolean values `True` and `False` are converted to integers `1` and `0`, respectively.
This behavior ensures compatibility with existing workflows that may use boolean types as integers.
"""
if isinstance(value, bool):
return int(value)
elif isinstance(value, (int, float)):
return value
elif not isinstance(value, str):
return None
if "." in value:
try:
return float(value)
except ValueError:
return None
else:
try:
return int(value)
except ValueError:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
Transform result into standard format.
"""
transformed_result = {}
transformed_result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.name in result:
param_value = result[parameter.name]
# transform value
if parameter.type == "number":
if isinstance(result[parameter.name], int | float):
transformed_result[parameter.name] = result[parameter.name]
elif isinstance(result[parameter.name], str):
try:
if "." in result[parameter.name]:
result[parameter.name] = float(result[parameter.name])
else:
result[parameter.name] = int(result[parameter.name])
except ValueError:
pass
else:
pass
# TODO: bool is not supported in the current version
# elif parameter.type == 'bool':
# if isinstance(result[parameter.name], bool):
# transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ['true', 'false']:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
# elif isinstance(result[parameter.name], int):
# transformed_result[parameter.name] = bool(result[parameter.name])
elif parameter.type in {"string", "select"}:
if isinstance(result[parameter.name], str):
transformed_result[parameter.name] = result[parameter.name]
if parameter.type == SegmentType.NUMBER:
transformed = self._transform_number(param_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
elif parameter.type == SegmentType.BOOLEAN:
if isinstance(result[parameter.name], (bool, int)):
transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ["true", "false"]:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true")
elif parameter.type == SegmentType.STRING:
if isinstance(param_value, str):
transformed_result[parameter.name] = param_value
elif parameter.is_array_type():
if isinstance(result[parameter.name], list):
if isinstance(param_value, list):
nested_type = parameter.element_type()
assert nested_type is not None
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
transformed_result[parameter.name] = segment_value
for item in result[parameter.name]:
if nested_type == "number":
if isinstance(item, int | float):
segment_value.value.append(item)
elif isinstance(item, str):
try:
if "." in item:
segment_value.value.append(float(item))
else:
segment_value.value.append(int(item))
except ValueError:
pass
elif nested_type == "string":
for item in param_value:
if nested_type == SegmentType.NUMBER:
transformed = self._transform_number(item)
if transformed is not None:
segment_value.value.append(transformed)
elif nested_type == SegmentType.STRING:
if isinstance(item, str):
segment_value.value.append(item)
elif nested_type == "object":
elif nested_type == SegmentType.OBJECT:
if isinstance(item, dict):
segment_value.value.append(item)
elif nested_type == SegmentType.BOOLEAN:
if isinstance(item, bool):
segment_value.value.append(item)
if parameter.name not in transformed_result:
if parameter.type == "number":
transformed_result[parameter.name] = 0
elif parameter.type == "bool":
transformed_result[parameter.name] = False
elif parameter.type in {"string", "select"}:
transformed_result[parameter.name] = ""
elif parameter.type.startswith("array"):
if parameter.type.is_array_type():
transformed_result[parameter.name] = build_segment_with_type(
segment_type=SegmentType(parameter.type), value=[]
)
elif parameter.type in (SegmentType.STRING, SegmentType.SECRET):
transformed_result[parameter.name] = ""
elif parameter.type == SegmentType.NUMBER:
transformed_result[parameter.name] = 0
elif parameter.type == SegmentType.BOOLEAN:
transformed_result[parameter.name] = False
else:
raise AssertionError("this statement should be unreachable.")
return transformed_result

View file

@ -2,6 +2,7 @@ from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
from core.variables import SegmentType, Variable
from core.variables.segments import BooleanSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
@ -158,8 +159,8 @@ class VariableAssignerNode(BaseNode):
def get_zero_value(t: SegmentType):
# TODO(QuantumGhost): this should be a method of `SegmentType`.
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return variable_factory.build_segment([])
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
return variable_factory.build_segment_with_type(t, [])
case SegmentType.OBJECT:
return variable_factory.build_segment({})
case SegmentType.STRING:
@ -170,5 +171,7 @@ def get_zero_value(t: SegmentType):
return variable_factory.build_segment(0.0)
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case SegmentType.BOOLEAN:
return BooleanSegment(value=False)
case _:
raise VariableOperatorNodeError(f"unsupported variable type: {t}")

View file

@ -4,9 +4,11 @@ from core.variables import SegmentType
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,
SegmentType.BOOLEAN: False,
SegmentType.OBJECT: {},
SegmentType.ARRAY_ANY: [],
SegmentType.ARRAY_STRING: [],
SegmentType.ARRAY_NUMBER: [],
SegmentType.ARRAY_OBJECT: [],
SegmentType.ARRAY_BOOLEAN: [],
}

View file

@ -16,28 +16,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
SegmentType.NUMBER,
SegmentType.INTEGER,
SegmentType.FLOAT,
SegmentType.BOOLEAN,
}
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
# Only number variable can be added, subtracted, multiplied or divided
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
case Operation.APPEND | Operation.EXTEND:
case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
# Only array variable can be appended or extended
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
case Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
# Only array variable can have elements removed
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
return variable_type.is_array_type()
case _:
return False
@ -50,7 +37,7 @@ def is_variable_input_supported(*, operation: Operation):
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
match variable_type:
case SegmentType.STRING | SegmentType.OBJECT:
case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN:
return operation in {Operation.OVER_WRITE, Operation.SET}
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return operation in {
@ -72,6 +59,9 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
case SegmentType.STRING:
return isinstance(value, str)
case SegmentType.BOOLEAN:
return isinstance(value, bool)
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
if not isinstance(value, int | float):
return False
@ -91,6 +81,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
return isinstance(value, int | float)
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
return isinstance(value, dict)
case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND:
return isinstance(value, bool)
# Array & Extend / Overwrite
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
@ -101,6 +93,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, bool) for item in value)
case _:
return False

View file

@ -45,5 +45,5 @@ class SubVariableCondition(BaseModel):
class Condition(BaseModel):
variable_selector: list[str]
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None = None
value: str | Sequence[str] | bool | None = None
sub_variable_condition: SubVariableCondition | None = None

View file

@ -1,13 +1,27 @@
import json
from collections.abc import Sequence
from typing import Any, Literal
from typing import Any, Literal, Union
from core.file import FileAttribute, file_manager
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
from core.workflow.entities.variable_pool import VariablePool
from .entities import Condition, SubCondition, SupportedComparisonOperator
def _convert_to_bool(value: Any) -> bool:
if isinstance(value, int):
return bool(value)
if isinstance(value, str):
loaded = json.loads(value)
if isinstance(loaded, (int, bool)):
return bool(loaded)
raise TypeError(f"unexpected value: type={type(value)}, value={value}")
class ConditionProcessor:
def process_conditions(
self,
@ -48,9 +62,16 @@ class ConditionProcessor:
)
else:
actual_value = variable.value if variable else None
expected_value = condition.value
expected_value: str | Sequence[str] | bool | list[bool] | None = condition.value
if isinstance(expected_value, str):
expected_value = variable_pool.convert_template(expected_value).text
# Here we need to explicit convet the input string to boolean.
if isinstance(variable, (BooleanSegment, ArrayBooleanSegment)) and expected_value is not None:
# The following two lines is for compatibility with existing workflows.
if isinstance(expected_value, list):
expected_value = [_convert_to_bool(i) for i in expected_value]
else:
expected_value = _convert_to_bool(expected_value)
input_conditions.append(
{
"actual_value": actual_value,
@ -77,7 +98,7 @@ def _evaluate_condition(
*,
operator: SupportedComparisonOperator,
value: Any,
expected: str | Sequence[str] | None,
expected: Union[str, Sequence[str], bool | Sequence[bool], None],
) -> bool:
match operator:
case "contains":
@ -130,7 +151,7 @@ def _assert_contains(*, value: Any, expected: Any) -> bool:
if not value:
return False
if not isinstance(value, str | list):
if not isinstance(value, (str, list)):
raise ValueError("Invalid actual value type: string or array")
if expected not in value:
@ -142,7 +163,7 @@ def _assert_not_contains(*, value: Any, expected: Any) -> bool:
if not value:
return True
if not isinstance(value, str | list):
if not isinstance(value, (str, list)):
raise ValueError("Invalid actual value type: string or array")
if expected in value:
@ -178,8 +199,8 @@ def _assert_is(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")
if not isinstance(value, (str, bool)):
raise ValueError("Invalid actual value type: string or boolean")
if value != expected:
return False
@ -190,8 +211,8 @@ def _assert_is_not(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")
if not isinstance(value, (str, bool)):
raise ValueError("Invalid actual value type: string or boolean")
if value == expected:
return False
@ -214,10 +235,13 @@ def _assert_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
raise ValueError("Invalid actual value type: number")
if not isinstance(value, (int, float, bool)):
raise ValueError("Invalid actual value type: number or boolean")
if isinstance(value, int):
# Handle boolean comparison
if isinstance(value, bool):
expected = bool(expected)
elif isinstance(value, int):
expected = int(expected)
else:
expected = float(expected)
@ -231,10 +255,13 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
raise ValueError("Invalid actual value type: number")
if not isinstance(value, (int, float, bool)):
raise ValueError("Invalid actual value type: number or boolean")
if isinstance(value, int):
# Handle boolean comparison
if isinstance(value, bool):
expected = bool(expected)
elif isinstance(value, int):
expected = int(expected)
else:
expected = float(expected)
@ -248,7 +275,7 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
@ -265,7 +292,7 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
@ -282,7 +309,7 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
@ -299,7 +326,7 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):

View file

@ -7,11 +7,13 @@ from core.file import File
from core.variables.exc import VariableError
from core.variables.segments import (
ArrayAnySegment,
ArrayBooleanSegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArraySegment,
ArrayStringSegment,
BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
@ -23,10 +25,12 @@ from core.variables.segments import (
from core.variables.types import SegmentType
from core.variables.variables import (
ArrayAnyVariable,
ArrayBooleanVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
BooleanVariable,
FileVariable,
FloatVariable,
IntegerVariable,
@ -49,17 +53,19 @@ class TypeMismatchError(Exception):
# Define the constant
SEGMENT_TO_VARIABLE_MAP = {
StringSegment: StringVariable,
IntegerSegment: IntegerVariable,
FloatSegment: FloatVariable,
ObjectSegment: ObjectVariable,
FileSegment: FileVariable,
ArrayStringSegment: ArrayStringVariable,
ArrayAnySegment: ArrayAnyVariable,
ArrayBooleanSegment: ArrayBooleanVariable,
ArrayFileSegment: ArrayFileVariable,
ArrayNumberSegment: ArrayNumberVariable,
ArrayObjectSegment: ArrayObjectVariable,
ArrayFileSegment: ArrayFileVariable,
ArrayAnySegment: ArrayAnyVariable,
ArrayStringSegment: ArrayStringVariable,
BooleanSegment: BooleanVariable,
FileSegment: FileVariable,
FloatSegment: FloatVariable,
IntegerSegment: IntegerVariable,
NoneSegment: NoneVariable,
ObjectSegment: ObjectVariable,
StringSegment: StringVariable,
}
@ -99,6 +105,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
mapping = dict(mapping)
mapping["value_type"] = SegmentType.FLOAT
result = FloatVariable.model_validate(mapping)
case SegmentType.BOOLEAN:
result = BooleanVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f"invalid number value {value}")
case SegmentType.OBJECT if isinstance(value, dict):
@ -109,6 +117,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_BOOLEAN if isinstance(value, list):
result = ArrayBooleanVariable.model_validate(mapping)
case _:
raise VariableError(f"not supported value type {value_type}")
if result.size > dify_config.MAX_VARIABLE_SIZE:
@ -129,6 +139,8 @@ def build_segment(value: Any, /) -> Segment:
return NoneSegment()
if isinstance(value, str):
return StringSegment(value=value)
if isinstance(value, bool):
return BooleanSegment(value=value)
if isinstance(value, int):
return IntegerSegment(value=value)
if isinstance(value, float):
@ -152,6 +164,8 @@ def build_segment(value: Any, /) -> Segment:
return ArrayStringSegment(value=value)
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return ArrayNumberSegment(value=value)
case SegmentType.BOOLEAN:
return ArrayBooleanSegment(value=value)
case SegmentType.OBJECT:
return ArrayObjectSegment(value=value)
case SegmentType.FILE:
@ -170,6 +184,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.INTEGER: IntegerSegment,
SegmentType.FLOAT: FloatSegment,
SegmentType.FILE: FileSegment,
SegmentType.BOOLEAN: BooleanSegment,
SegmentType.OBJECT: ObjectSegment,
# Array types
SegmentType.ARRAY_ANY: ArrayAnySegment,
@ -177,6 +192,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
SegmentType.ARRAY_FILE: ArrayFileSegment,
SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
}
@ -225,6 +241,8 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
return ArrayAnySegment(value=value)
elif segment_type == SegmentType.ARRAY_STRING:
return ArrayStringSegment(value=value)
elif segment_type == SegmentType.ARRAY_BOOLEAN:
return ArrayBooleanSegment(value=value)
elif segment_type == SegmentType.ARRAY_NUMBER:
return ArrayNumberSegment(value=value)
elif segment_type == SegmentType.ARRAY_OBJECT:

11
api/lazy_load_class.py Normal file
View file

@ -0,0 +1,11 @@
from tests.integration_tests.utils.parent_class import ParentClass
class LazyLoadChildClass(ParentClass):
"""Test lazy load child class for module import helper tests"""
def __init__(self, name):
super().__init__(name)
def get_name(self):
return self.name

View file

@ -20,3 +20,6 @@ ignore_missing_imports=True
[mypy-flask_restx.inputs]
ignore_missing_imports=True
[mypy-google.cloud.storage]
ignore_missing_imports=True

View file

@ -23,6 +23,7 @@ class TestSegmentTypeIsArrayType:
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
]
expected_non_array_types = [
SegmentType.INTEGER,
@ -34,6 +35,7 @@ class TestSegmentTypeIsArrayType:
SegmentType.FILE,
SegmentType.NONE,
SegmentType.GROUP,
SegmentType.BOOLEAN,
]
for seg_type in expected_array_types:

View file

@ -0,0 +1,729 @@
"""
Comprehensive unit tests for SegmentType.is_valid and SegmentType._validate_array methods.
This module provides thorough testing of the validation logic for all SegmentType values,
including edge cases, error conditions, and different ArrayValidation strategies.
"""
from dataclasses import dataclass
from typing import Any
import pytest
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
from core.variables.types import ArrayValidation, SegmentType
def create_test_file(
file_type: FileType = FileType.DOCUMENT,
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
filename: str = "test.txt",
extension: str = ".txt",
mime_type: str = "text/plain",
size: int = 1024,
) -> File:
"""Factory function to create File objects for testing."""
return File(
tenant_id="test-tenant",
type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
storage_key="test-storage-key",
)
@dataclass
class ValidationTestCase:
"""Test case data structure for validation tests."""
segment_type: SegmentType
value: Any
expected: bool
description: str
def get_id(self):
return self.description
@dataclass
class ArrayValidationTestCase:
"""Test case data structure for array validation tests."""
segment_type: SegmentType
value: Any
array_validation: ArrayValidation
expected: bool
description: str
def get_id(self):
return self.description
# Test data construction functions
def get_boolean_cases() -> list[ValidationTestCase]:
return [
# valid values
ValidationTestCase(SegmentType.BOOLEAN, True, True, "True boolean"),
ValidationTestCase(SegmentType.BOOLEAN, False, True, "False boolean"),
# Invalid values
ValidationTestCase(SegmentType.BOOLEAN, 1, False, "Integer 1 (not boolean)"),
ValidationTestCase(SegmentType.BOOLEAN, 0, False, "Integer 0 (not boolean)"),
ValidationTestCase(SegmentType.BOOLEAN, "true", False, "String 'true'"),
ValidationTestCase(SegmentType.BOOLEAN, "false", False, "String 'false'"),
ValidationTestCase(SegmentType.BOOLEAN, None, False, "None value"),
ValidationTestCase(SegmentType.BOOLEAN, [], False, "Empty list"),
ValidationTestCase(SegmentType.BOOLEAN, {}, False, "Empty dict"),
]
def get_number_cases() -> list[ValidationTestCase]:
"""Get test cases for valid number values."""
return [
# valid values
ValidationTestCase(SegmentType.NUMBER, 42, True, "Positive integer"),
ValidationTestCase(SegmentType.NUMBER, -42, True, "Negative integer"),
ValidationTestCase(SegmentType.NUMBER, 0, True, "Zero integer"),
ValidationTestCase(SegmentType.NUMBER, 3.14, True, "Positive float"),
ValidationTestCase(SegmentType.NUMBER, -3.14, True, "Negative float"),
ValidationTestCase(SegmentType.NUMBER, 0.0, True, "Zero float"),
ValidationTestCase(SegmentType.NUMBER, float("inf"), True, "Positive infinity"),
ValidationTestCase(SegmentType.NUMBER, float("-inf"), True, "Negative infinity"),
ValidationTestCase(SegmentType.NUMBER, float("nan"), True, "float(NaN)"),
# invalid number values
ValidationTestCase(SegmentType.NUMBER, "42", False, "String number"),
ValidationTestCase(SegmentType.NUMBER, None, False, "None value"),
ValidationTestCase(SegmentType.NUMBER, [], False, "Empty list"),
ValidationTestCase(SegmentType.NUMBER, {}, False, "Empty dict"),
ValidationTestCase(SegmentType.NUMBER, "3.14", False, "String float"),
]
def get_string_cases() -> list[ValidationTestCase]:
"""Get test cases for valid string values."""
return [
# valid values
ValidationTestCase(SegmentType.STRING, "", True, "Empty string"),
ValidationTestCase(SegmentType.STRING, "hello", True, "Simple string"),
ValidationTestCase(SegmentType.STRING, "🚀", True, "Unicode emoji"),
ValidationTestCase(SegmentType.STRING, "line1\nline2", True, "Multiline string"),
# invalid values
ValidationTestCase(SegmentType.STRING, 123, False, "Integer"),
ValidationTestCase(SegmentType.STRING, 3.14, False, "Float"),
ValidationTestCase(SegmentType.STRING, True, False, "Boolean"),
ValidationTestCase(SegmentType.STRING, None, False, "None value"),
ValidationTestCase(SegmentType.STRING, [], False, "Empty list"),
ValidationTestCase(SegmentType.STRING, {}, False, "Empty dict"),
]
def get_object_cases() -> list[ValidationTestCase]:
"""Get test cases for valid object values."""
return [
# valid cases
ValidationTestCase(SegmentType.OBJECT, {}, True, "Empty dict"),
ValidationTestCase(SegmentType.OBJECT, {"key": "value"}, True, "Simple dict"),
ValidationTestCase(SegmentType.OBJECT, {"a": 1, "b": 2}, True, "Dict with numbers"),
ValidationTestCase(SegmentType.OBJECT, {"nested": {"key": "value"}}, True, "Nested dict"),
ValidationTestCase(SegmentType.OBJECT, {"list": [1, 2, 3]}, True, "Dict with list"),
ValidationTestCase(SegmentType.OBJECT, {"mixed": [1, "two", {"three": 3}]}, True, "Complex dict"),
# invalid cases
ValidationTestCase(SegmentType.OBJECT, "not a dict", False, "String"),
ValidationTestCase(SegmentType.OBJECT, 123, False, "Integer"),
ValidationTestCase(SegmentType.OBJECT, 3.14, False, "Float"),
ValidationTestCase(SegmentType.OBJECT, True, False, "Boolean"),
ValidationTestCase(SegmentType.OBJECT, None, False, "None value"),
ValidationTestCase(SegmentType.OBJECT, [], False, "Empty list"),
ValidationTestCase(SegmentType.OBJECT, [1, 2, 3], False, "List with values"),
]
def get_secret_cases() -> list[ValidationTestCase]:
"""Get test cases for valid secret values."""
return [
# valid cases
ValidationTestCase(SegmentType.SECRET, "", True, "Empty secret"),
ValidationTestCase(SegmentType.SECRET, "secret", True, "Simple secret"),
ValidationTestCase(SegmentType.SECRET, "api_key_123", True, "API key format"),
ValidationTestCase(SegmentType.SECRET, "very_long_secret_key_with_special_chars!@#", True, "Complex secret"),
# invalid cases
ValidationTestCase(SegmentType.SECRET, 123, False, "Integer"),
ValidationTestCase(SegmentType.SECRET, 3.14, False, "Float"),
ValidationTestCase(SegmentType.SECRET, True, False, "Boolean"),
ValidationTestCase(SegmentType.SECRET, None, False, "None value"),
ValidationTestCase(SegmentType.SECRET, [], False, "Empty list"),
ValidationTestCase(SegmentType.SECRET, {}, False, "Empty dict"),
]
def get_file_cases() -> list[ValidationTestCase]:
"""Get test cases for valid file values."""
test_file = create_test_file()
image_file = create_test_file(
file_type=FileType.IMAGE, filename="image.jpg", extension=".jpg", mime_type="image/jpeg"
)
remote_file = create_test_file(
transfer_method=FileTransferMethod.REMOTE_URL, filename="remote.pdf", extension=".pdf"
)
return [
# valid cases
ValidationTestCase(SegmentType.FILE, test_file, True, "Document file"),
ValidationTestCase(SegmentType.FILE, image_file, True, "Image file"),
ValidationTestCase(SegmentType.FILE, remote_file, True, "Remote file"),
# invalid cases
ValidationTestCase(SegmentType.FILE, "not a file", False, "String"),
ValidationTestCase(SegmentType.FILE, 123, False, "Integer"),
ValidationTestCase(SegmentType.FILE, {"filename": "test.txt"}, False, "Dict resembling file"),
ValidationTestCase(SegmentType.FILE, None, False, "None value"),
ValidationTestCase(SegmentType.FILE, [], False, "Empty list"),
ValidationTestCase(SegmentType.FILE, True, False, "Boolean"),
]
def get_none_cases() -> list[ValidationTestCase]:
"""Get test cases for valid none values."""
return [
# valid cases
ValidationTestCase(SegmentType.NONE, None, True, "None value"),
# invalid cases
ValidationTestCase(SegmentType.NONE, "", False, "Empty string"),
ValidationTestCase(SegmentType.NONE, 0, False, "Zero integer"),
ValidationTestCase(SegmentType.NONE, 0.0, False, "Zero float"),
ValidationTestCase(SegmentType.NONE, False, False, "False boolean"),
ValidationTestCase(SegmentType.NONE, [], False, "Empty list"),
ValidationTestCase(SegmentType.NONE, {}, False, "Empty dict"),
ValidationTestCase(SegmentType.NONE, "null", False, "String 'null'"),
]
def get_array_any_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_ANY validation."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_ANY,
[1, "string", 3.14, {"key": "value"}, True],
ArrayValidation.NONE,
True,
"Mixed types with NONE validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_ANY,
[1, "string", 3.14, {"key": "value"}, True],
ArrayValidation.FIRST,
True,
"Mixed types with FIRST validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_ANY,
[1, "string", 3.14, {"key": "value"}, True],
ArrayValidation.ALL,
True,
"Mixed types with ALL validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_ANY, [None, None, None], ArrayValidation.ALL, True, "All None values"
),
]
def get_array_string_validation_none_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_STRING validation with NONE strategy."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
["hello", "world"],
ArrayValidation.NONE,
True,
"Valid strings with NONE validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
[123, 456],
ArrayValidation.NONE,
True,
"Invalid elements with NONE validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
["valid", 123, True],
ArrayValidation.NONE,
True,
"Mixed types with NONE validation",
),
]
def get_array_string_validation_first_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_STRING validation with FIRST strategy."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["hello", "world"], ArrayValidation.FIRST, True, "All valid strings"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
["hello", 123, True],
ArrayValidation.FIRST,
True,
"First valid, others invalid",
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
[123, "hello", "world"],
ArrayValidation.FIRST,
False,
"First invalid, others valid",
),
ArrayValidationTestCase(SegmentType.ARRAY_STRING, [None, "hello"], ArrayValidation.FIRST, False, "First None"),
]
def get_array_string_validation_all_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_STRING validation with ALL strategy."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["hello", "world", "test"], ArrayValidation.ALL, True, "All valid strings"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["hello", 123, "world"], ArrayValidation.ALL, False, "One invalid element"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, [123, 456, 789], ArrayValidation.ALL, False, "All invalid elements"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["valid", None, "also_valid"], ArrayValidation.ALL, False, "Contains None"
),
]
def get_array_number_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_NUMBER validation with different strategies."""
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [1, 2.5, 3], ArrayValidation.NONE, True, "Valid numbers with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, ["not", "numbers"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [42, "not a number"], ArrayValidation.FIRST, True, "First valid number"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, ["not a number", 42], ArrayValidation.FIRST, False, "First invalid"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [3.14, 2.71, 1.41], ArrayValidation.FIRST, True, "All valid floats"
),
# ALL strategy
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [1, 2, 3, 4.5], ArrayValidation.ALL, True, "All valid numbers"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [1, "invalid", 3], ArrayValidation.ALL, False, "One invalid element"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER,
[float("inf"), float("-inf"), float("nan")],
ArrayValidation.ALL,
True,
"Special float values",
),
]
def get_array_object_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_OBJECT validation with different strategies."""
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT, [{}, {"key": "value"}], ArrayValidation.NONE, True, "Valid objects with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT, ["not", "objects"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
[{"valid": "object"}, "not an object"],
ArrayValidation.FIRST,
True,
"First valid object",
),
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
["not an object", {"valid": "object"}],
ArrayValidation.FIRST,
False,
"First invalid",
),
# ALL strategy
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
[{}, {"a": 1}, {"nested": {"key": "value"}}],
ArrayValidation.ALL,
True,
"All valid objects",
),
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
[{"valid": "object"}, "invalid", {"another": "object"}],
ArrayValidation.ALL,
False,
"One invalid element",
),
]
def get_array_file_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_FILE validation with different strategies."""
file1 = create_test_file(filename="file1.txt")
file2 = create_test_file(filename="file2.txt")
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.NONE, True, "Valid files with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, ["not", "files"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, [file1, "not a file"], ArrayValidation.FIRST, True, "First valid file"
),
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, ["not a file", file1], ArrayValidation.FIRST, False, "First invalid"
),
# ALL strategy
ArrayValidationTestCase(SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.ALL, True, "All valid files"),
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, [file1, "invalid", file2], ArrayValidation.ALL, False, "One invalid element"
),
]
def get_array_boolean_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_BOOLEAN validation with different strategies."""
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, False, True], ArrayValidation.NONE, True, "Valid booleans with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [1, 0, "true"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, 1, 0], ArrayValidation.FIRST, True, "First valid boolean"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [1, True, False], ArrayValidation.FIRST, False, "First invalid (integer 1)"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [0, True, False], ArrayValidation.FIRST, False, "First invalid (integer 0)"
),
# ALL strategy
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, False, True, False], ArrayValidation.ALL, True, "All valid booleans"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, 1, False], ArrayValidation.ALL, False, "One invalid element (integer)"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN,
[True, "false", False],
ArrayValidation.ALL,
False,
"One invalid element (string)",
),
]
class TestSegmentTypeIsValid:
"""Test suite for SegmentType.is_valid method covering all non-array types."""
@pytest.mark.parametrize("case", get_boolean_cases(), ids=lambda case: case.description)
def test_boolean_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_number_cases(), ids=lambda case: case.description)
def test_number_validation(self, case: ValidationTestCase):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_string_cases(), ids=lambda case: case.description)
def test_string_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_object_cases(), ids=lambda case: case.description)
def test_object_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_secret_cases(), ids=lambda case: case.description)
def test_secret_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_file_cases(), ids=lambda case: case.description)
def test_file_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_none_cases(), ids=lambda case: case.description)
def test_none_validation_valid_cases(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
def test_unsupported_segment_type_raises_assertion_error(self):
"""Test that unsupported SegmentType values raise AssertionError."""
# GROUP is not handled in is_valid method
with pytest.raises(AssertionError, match="this statement should be unreachable"):
SegmentType.GROUP.is_valid("any value")
class TestSegmentTypeArrayValidation:
"""Test suite for SegmentType._validate_array method and array type validation."""
def test_array_validation_non_list_values(self):
"""Test that non-list values return False for all array types."""
array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
]
non_list_values = [
"not a list",
123,
3.14,
True,
None,
{"key": "value"},
create_test_file(),
]
for array_type in array_types:
for value in non_list_values:
assert array_type.is_valid(value) is False, f"{array_type} should reject {type(value).__name__}"
def test_empty_array_validation(self):
"""Test that empty arrays are valid for all array types regardless of validation strategy."""
array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
]
validation_strategies = [ArrayValidation.NONE, ArrayValidation.FIRST, ArrayValidation.ALL]
for array_type in array_types:
for strategy in validation_strategies:
assert array_type.is_valid([], strategy) is True, (
f"{array_type} should accept empty array with {strategy}"
)
@pytest.mark.parametrize("case", get_array_any_validation_cases(), ids=lambda case: case.description)
def test_array_any_validation(self, case):
"""Test ARRAY_ANY validation accepts any list regardless of content."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_string_validation_none_cases(), ids=lambda case: case.description)
def test_array_string_validation_with_none_strategy(self, case):
"""Test ARRAY_STRING validation with NONE strategy (no element validation)."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_string_validation_first_cases(), ids=lambda case: case.description)
def test_array_string_validation_with_first_strategy(self, case):
"""Test ARRAY_STRING validation with FIRST strategy (validate first element only)."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_string_validation_all_cases(), ids=lambda case: case.description)
def test_array_string_validation_with_all_strategy(self, case):
"""Test ARRAY_STRING validation with ALL strategy (validate all elements)."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_number_validation_cases(), ids=lambda case: case.description)
def test_array_number_validation_with_different_strategies(self, case):
"""Test ARRAY_NUMBER validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_object_validation_cases(), ids=lambda case: case.description)
def test_array_object_validation_with_different_strategies(self, case):
"""Test ARRAY_OBJECT validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_file_validation_cases(), ids=lambda case: case.description)
def test_array_file_validation_with_different_strategies(self, case):
"""Test ARRAY_FILE validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_boolean_validation_cases(), ids=lambda case: case.description)
def test_array_boolean_validation_with_different_strategies(self, case):
"""Test ARRAY_BOOLEAN validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
def test_default_array_validation_strategy(self):
"""Test that default array validation strategy is FIRST."""
# When no array_validation parameter is provided, it should default to FIRST
assert SegmentType.ARRAY_STRING.is_valid(["valid", 123]) is False # First element valid
assert SegmentType.ARRAY_STRING.is_valid([123, "valid"]) is False # First element invalid
assert SegmentType.ARRAY_NUMBER.is_valid([42, "invalid"]) is False # First element valid
assert SegmentType.ARRAY_NUMBER.is_valid(["invalid", 42]) is False # First element invalid
def test_array_validation_edge_cases(self):
"""Test edge cases for array validation."""
# Test with nested arrays (should be invalid for specific array types)
nested_array = [["nested", "array"], ["another", "nested"]]
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.FIRST) is False
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.ALL) is False
assert SegmentType.ARRAY_ANY.is_valid(nested_array, ArrayValidation.ALL) is True
# Test with very large arrays (performance consideration)
large_valid_array = ["string"] * 1000
large_mixed_array = ["string"] * 999 + [123] # Last element invalid
assert SegmentType.ARRAY_STRING.is_valid(large_valid_array, ArrayValidation.ALL) is True
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.ALL) is False
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.FIRST) is True
class TestSegmentTypeValidationIntegration:
"""Integration tests for SegmentType validation covering interactions between methods."""
def test_non_array_types_ignore_array_validation_parameter(self):
"""Test that non-array types ignore the array_validation parameter."""
non_array_types = [
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
SegmentType.OBJECT,
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
]
for segment_type in non_array_types:
# Create appropriate valid value for each type
valid_value: Any
if segment_type == SegmentType.STRING:
valid_value = "test"
elif segment_type == SegmentType.NUMBER:
valid_value = 42
elif segment_type == SegmentType.BOOLEAN:
valid_value = True
elif segment_type == SegmentType.OBJECT:
valid_value = {"key": "value"}
elif segment_type == SegmentType.SECRET:
valid_value = "secret"
elif segment_type == SegmentType.FILE:
valid_value = create_test_file()
elif segment_type == SegmentType.NONE:
valid_value = None
else:
continue # Skip unsupported types
# All array validation strategies should give the same result
result_none = segment_type.is_valid(valid_value, ArrayValidation.NONE)
result_first = segment_type.is_valid(valid_value, ArrayValidation.FIRST)
result_all = segment_type.is_valid(valid_value, ArrayValidation.ALL)
assert result_none == result_first == result_all == True, (
f"{segment_type} should ignore array_validation parameter"
)
def test_comprehensive_type_coverage(self):
"""Test that all SegmentType enum values are covered in validation tests."""
all_segment_types = set(SegmentType)
# Types that should be handled by is_valid method
handled_types = {
# Non-array types
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
SegmentType.OBJECT,
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
# Array types
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
}
# Types that are not handled by is_valid (should raise AssertionError)
unhandled_types = {
SegmentType.GROUP,
SegmentType.INTEGER, # Handled by NUMBER validation logic
SegmentType.FLOAT, # Handled by NUMBER validation logic
}
# Verify all types are accounted for
assert handled_types | unhandled_types == all_segment_types, "All SegmentType values should be categorized"
# Test that handled types work correctly
for segment_type in handled_types:
if segment_type.is_array_type():
# Test with empty array (should always be valid)
assert segment_type.is_valid([]) is True, f"{segment_type} should accept empty array"
else:
# Test with appropriate valid value
if segment_type == SegmentType.STRING:
assert segment_type.is_valid("test") is True
elif segment_type == SegmentType.NUMBER:
assert segment_type.is_valid(42) is True
elif segment_type == SegmentType.BOOLEAN:
assert segment_type.is_valid(True) is True
elif segment_type == SegmentType.OBJECT:
assert segment_type.is_valid({}) is True
elif segment_type == SegmentType.SECRET:
assert segment_type.is_valid("secret") is True
elif segment_type == SegmentType.FILE:
assert segment_type.is_valid(create_test_file()) is True
elif segment_type == SegmentType.NONE:
assert segment_type.is_valid(None) is True
def test_boolean_vs_integer_type_distinction(self):
"""Test the important distinction between boolean and integer types in validation."""
# This tests the comment in the code about bool being a subclass of int
# Boolean type should only accept actual booleans, not integers
assert SegmentType.BOOLEAN.is_valid(True) is True
assert SegmentType.BOOLEAN.is_valid(False) is True
assert SegmentType.BOOLEAN.is_valid(1) is False # Integer 1, not boolean
assert SegmentType.BOOLEAN.is_valid(0) is False # Integer 0, not boolean
# Number type should accept both integers and floats, including booleans (since bool is subclass of int)
assert SegmentType.NUMBER.is_valid(42) is True
assert SegmentType.NUMBER.is_valid(3.14) is True
assert SegmentType.NUMBER.is_valid(True) is True # bool is subclass of int
assert SegmentType.NUMBER.is_valid(False) is True # bool is subclass of int
def test_array_validation_recursive_behavior(self):
"""Test that array validation correctly handles recursive validation calls."""
# When validating array elements, _validate_array calls is_valid recursively
# with ArrayValidation.NONE to avoid infinite recursion
# Test nested validation doesn't cause issues
nested_arrays = [["inner", "array"], ["another", "inner"]]
# ARRAY_ANY should accept nested arrays
assert SegmentType.ARRAY_ANY.is_valid(nested_arrays, ArrayValidation.ALL) is True
# ARRAY_STRING should reject nested arrays (first element is not a string)
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.FIRST) is False
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.ALL) is False

View file

@ -0,0 +1,27 @@
from core.variables.types import SegmentType
from core.workflow.nodes.parameter_extractor.entities import ParameterConfig
class TestParameterConfig:
def test_select_type(self):
data = {
"name": "yes_or_no",
"type": "select",
"options": ["yes", "no"],
"description": "a simple select made of `yes` and `no`",
"required": True,
}
pc = ParameterConfig.model_validate(data)
assert pc.type == SegmentType.STRING
assert pc.options == data["options"]
def test_validate_bool_type(self):
data = {
"name": "boolean",
"type": "bool",
"description": "a simple boolean parameter",
"required": True,
}
pc = ParameterConfig.model_validate(data)
assert pc.type == SegmentType.BOOLEAN

View file

@ -0,0 +1,567 @@
"""
Test cases for ParameterExtractorNode._validate_result and _transform_result methods.
"""
from dataclasses import dataclass
from typing import Any
import pytest
from core.model_runtime.entities import LLMMode
from core.variables.types import SegmentType
from core.workflow.nodes.llm import ModelConfig, VisionConfig
from core.workflow.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData
from core.workflow.nodes.parameter_extractor.exc import (
InvalidNumberOfParametersError,
InvalidSelectValueError,
InvalidValueTypeError,
RequiredParameterMissingError,
)
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from factories.variable_factory import build_segment_with_type
@dataclass
class ValidTestCase:
"""Test case data for valid scenarios."""
name: str
parameters: list[ParameterConfig]
result: dict[str, Any]
def get_name(self) -> str:
return self.name
@dataclass
class ErrorTestCase:
"""Test case data for error scenarios."""
name: str
parameters: list[ParameterConfig]
result: dict[str, Any]
expected_exception: type[Exception]
expected_message: str
def get_name(self) -> str:
return self.name
@dataclass
class TransformTestCase:
"""Test case data for transformation scenarios."""
name: str
parameters: list[ParameterConfig]
input_result: dict[str, Any]
expected_result: dict[str, Any]
def get_name(self) -> str:
return self.name
class TestParameterExtractorNodeMethods:
"""Test helper class that provides access to the methods under test."""
def validate_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]:
"""Wrapper to call _validate_result method."""
node = ParameterExtractorNode.__new__(ParameterExtractorNode)
return node._validate_result(data=data, result=result)
def transform_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]:
"""Wrapper to call _transform_result method."""
node = ParameterExtractorNode.__new__(ParameterExtractorNode)
return node._transform_result(data=data, result=result)
class TestValidateResult:
"""Test cases for _validate_result method."""
@staticmethod
def get_valid_test_cases() -> list[ValidTestCase]:
"""Get test cases that should pass validation."""
return [
ValidTestCase(
name="single_string_parameter",
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
result={"name": "John"},
),
ValidTestCase(
name="single_number_parameter_int",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
result={"age": 25},
),
ValidTestCase(
name="single_number_parameter_float",
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
result={"price": 19.99},
),
ValidTestCase(
name="single_bool_parameter_true",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
result={"active": True},
),
ValidTestCase(
name="single_bool_parameter_true",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
result={"active": True},
),
ValidTestCase(
name="single_bool_parameter_false",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
result={"active": False},
),
ValidTestCase(
name="select_parameter_valid_option",
parameters=[
ParameterConfig(
name="status",
type="select", # pyright: ignore[reportArgumentType]
description="Status",
required=True,
options=["active", "inactive"],
)
],
result={"status": "active"},
),
ValidTestCase(
name="array_string_parameter",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
result={"tags": ["tag1", "tag2", "tag3"]},
),
ValidTestCase(
name="array_number_parameter",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
result={"scores": [85, 92.5, 78]},
),
ValidTestCase(
name="array_object_parameter",
parameters=[
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
],
result={"items": [{"name": "item1"}, {"name": "item2"}]},
),
ValidTestCase(
name="multiple_parameters",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True),
],
result={"name": "John", "age": 25, "active": True},
),
ValidTestCase(
name="optional_parameter_present",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="nickname", type=SegmentType.STRING, description="Nickname", required=False),
],
result={"name": "John", "nickname": "Johnny"},
),
ValidTestCase(
name="empty_array_parameter",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
result={"tags": []},
),
]
@staticmethod
def get_error_test_cases() -> list[ErrorTestCase]:
"""Get test cases that should raise exceptions."""
return [
ErrorTestCase(
name="invalid_number_of_parameters_too_few",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
],
result={"name": "John"},
expected_exception=InvalidNumberOfParametersError,
expected_message="Invalid number of parameters",
),
ErrorTestCase(
name="invalid_number_of_parameters_too_many",
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
result={"name": "John", "age": 25},
expected_exception=InvalidNumberOfParametersError,
expected_message="Invalid number of parameters",
),
ErrorTestCase(
name="invalid_string_value_none",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
],
result={"name": None}, # Parameter present but None value, will trigger type check first
expected_exception=InvalidValueTypeError,
expected_message="Invalid value for parameter name, expected segment type: string, actual_type: none",
),
ErrorTestCase(
name="invalid_select_value",
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
description="Status",
required=True,
options=["active", "inactive"],
)
],
result={"status": "pending"},
expected_exception=InvalidSelectValueError,
expected_message="Invalid `select` value for parameter status",
),
ErrorTestCase(
name="invalid_number_value_string",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
result={"age": "twenty-five"},
expected_exception=InvalidValueTypeError,
expected_message="Invalid value for parameter age, expected segment type: number, actual_type: string",
),
ErrorTestCase(
name="invalid_bool_value_string",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
result={"active": "yes"},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter active, expected segment type: boolean, actual_type: string"
),
),
ErrorTestCase(
name="invalid_string_value_number",
parameters=[
ParameterConfig(
name="description", type=SegmentType.STRING, description="Description", required=True
)
],
result={"description": 123},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter description, expected segment type: string, actual_type: integer"
),
),
ErrorTestCase(
name="invalid_array_value_not_list",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
result={"tags": "tag1,tag2,tag3"},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter tags, expected segment type: array[string], actual_type: string"
),
),
ErrorTestCase(
name="invalid_array_number_wrong_element_type",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
result={"scores": [85, "ninety-two", 78]},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter scores, expected segment type: array[number], actual_type: array[any]"
),
),
ErrorTestCase(
name="invalid_array_string_wrong_element_type",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
result={"tags": ["tag1", 123, "tag3"]},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter tags, expected segment type: array[string], actual_type: array[any]"
),
),
ErrorTestCase(
name="invalid_array_object_wrong_element_type",
parameters=[
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
],
result={"items": [{"name": "item1"}, "item2"]},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter items, expected segment type: array[object], actual_type: array[any]"
),
),
ErrorTestCase(
name="required_parameter_missing",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=False),
],
result={"age": 25, "other": "value"}, # Missing required 'name' parameter, but has correct count
expected_exception=RequiredParameterMissingError,
expected_message="Parameter name is required",
),
]
@pytest.mark.parametrize("test_case", get_valid_test_cases(), ids=ValidTestCase.get_name)
def test_validate_result_valid_cases(self, test_case):
"""Test _validate_result with valid inputs."""
helper = TestParameterExtractorNodeMethods()
node_data = ParameterExtractorNodeData(
title="Test Node",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
query=["test_query"],
parameters=test_case.parameters,
reasoning_mode="function_call",
vision=VisionConfig(),
)
result = helper.validate_result(data=node_data, result=test_case.result)
assert result == test_case.result, f"Failed for case: {test_case.name}"
@pytest.mark.parametrize("test_case", get_error_test_cases(), ids=ErrorTestCase.get_name)
def test_validate_result_error_cases(self, test_case):
"""Test _validate_result with invalid inputs that should raise exceptions."""
helper = TestParameterExtractorNodeMethods()
node_data = ParameterExtractorNodeData(
title="Test Node",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
query=["test_query"],
parameters=test_case.parameters,
reasoning_mode="function_call",
vision=VisionConfig(),
)
with pytest.raises(test_case.expected_exception) as exc_info:
helper.validate_result(data=node_data, result=test_case.result)
assert test_case.expected_message in str(exc_info.value), f"Failed for case: {test_case.name}"
class TestTransformResult:
"""Test cases for _transform_result method."""
@staticmethod
def get_transform_test_cases() -> list[TransformTestCase]:
"""Get test cases for result transformation."""
return [
# String parameter transformation
TransformTestCase(
name="string_parameter_present",
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
input_result={"name": "John"},
expected_result={"name": "John"},
),
TransformTestCase(
name="string_parameter_missing",
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
input_result={},
expected_result={"name": ""},
),
# Number parameter transformation
TransformTestCase(
name="number_parameter_int_present",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={"age": 25},
expected_result={"age": 25},
),
TransformTestCase(
name="number_parameter_float_present",
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
input_result={"price": 19.99},
expected_result={"price": 19.99},
),
TransformTestCase(
name="number_parameter_missing",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={},
expected_result={"age": 0},
),
# Bool parameter transformation
TransformTestCase(
name="bool_parameter_missing",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
input_result={},
expected_result={"active": False},
),
# Select parameter transformation
TransformTestCase(
name="select_parameter_present",
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
description="Status",
required=True,
options=["active", "inactive"],
)
],
input_result={"status": "active"},
expected_result={"status": "active"},
),
TransformTestCase(
name="select_parameter_missing",
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
description="Status",
required=True,
options=["active", "inactive"],
)
],
input_result={},
expected_result={"status": ""},
),
# Array parameter transformation - present cases
TransformTestCase(
name="array_string_parameter_present",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
input_result={"tags": ["tag1", "tag2"]},
expected_result={
"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=["tag1", "tag2"])
},
),
TransformTestCase(
name="array_number_parameter_present",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
input_result={"scores": [85, 92.5]},
expected_result={
"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5])
},
),
TransformTestCase(
name="array_number_parameter_with_string_conversion",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
input_result={"scores": [85, "92.5", "78"]},
expected_result={
"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5, 78])
},
),
TransformTestCase(
name="array_object_parameter_present",
parameters=[
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
],
input_result={"items": [{"name": "item1"}, {"name": "item2"}]},
expected_result={
"items": build_segment_with_type(
segment_type=SegmentType.ARRAY_OBJECT, value=[{"name": "item1"}, {"name": "item2"}]
)
},
),
# Array parameter transformation - missing cases
TransformTestCase(
name="array_string_parameter_missing",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
input_result={},
expected_result={"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[])},
),
TransformTestCase(
name="array_number_parameter_missing",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
input_result={},
expected_result={"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[])},
),
TransformTestCase(
name="array_object_parameter_missing",
parameters=[
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
],
input_result={},
expected_result={"items": build_segment_with_type(segment_type=SegmentType.ARRAY_OBJECT, value=[])},
),
# Multiple parameters transformation
TransformTestCase(
name="multiple_parameters_mixed",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True),
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True),
],
input_result={"name": "John", "age": 25},
expected_result={
"name": "John",
"age": 25,
"active": False,
"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[]),
},
),
# Number parameter transformation with string conversion
TransformTestCase(
name="number_parameter_string_to_float",
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
input_result={"price": "19.99"},
expected_result={"price": 19.99}, # String not converted, falls back to default
),
TransformTestCase(
name="number_parameter_string_to_int",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={"age": "25"},
expected_result={"age": 25}, # String not converted, falls back to default
),
TransformTestCase(
name="number_parameter_invalid_string",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={"age": "invalid_number"},
expected_result={"age": 0}, # Invalid string conversion fails, falls back to default
),
TransformTestCase(
name="number_parameter_non_string_non_number",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={"age": ["not_a_number"]}, # Non-string, non-number value
expected_result={"age": 0}, # Falls back to default
),
TransformTestCase(
name="array_number_parameter_with_invalid_string_conversion",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
input_result={"scores": [85, "invalid", "78"]},
expected_result={
"scores": build_segment_with_type(
segment_type=SegmentType.ARRAY_NUMBER, value=[85, 78]
) # Invalid string skipped
},
),
]
@pytest.mark.parametrize("test_case", get_transform_test_cases(), ids=TransformTestCase.get_name)
def test_transform_result_cases(self, test_case):
"""Test _transform_result with various inputs."""
helper = TestParameterExtractorNodeMethods()
node_data = ParameterExtractorNodeData(
title="Test Node",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
query=["test_query"],
parameters=test_case.parameters,
reasoning_mode="function_call",
vision=VisionConfig(),
)
result = helper.transform_result(data=node_data, result=test_case.input_result)
assert result == test_case.expected_result, (
f"Failed for case: {test_case.name}. Expected: {test_case.expected_result}, Got: {result}"
)

View file

@ -2,6 +2,8 @@ import time
import uuid
from unittest.mock import MagicMock, Mock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
@ -272,3 +274,220 @@ def test_array_file_contains_file_name():
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is True
def _get_test_conditions() -> list:
conditions = [
# Test boolean "is" operator
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "true"},
# Test boolean "is not" operator
{"comparison_operator": "is not", "variable_selector": ["start", "bool_false"], "value": "true"},
# Test boolean "=" operator
{"comparison_operator": "=", "variable_selector": ["start", "bool_true"], "value": "1"},
# Test boolean "≠" operator
{"comparison_operator": "", "variable_selector": ["start", "bool_false"], "value": "1"},
# Test boolean "not null" operator
{"comparison_operator": "not null", "variable_selector": ["start", "bool_true"]},
# Test boolean array "contains" operator
{"comparison_operator": "contains", "variable_selector": ["start", "bool_array"], "value": "true"},
# Test boolean "in" operator
{
"comparison_operator": "in",
"variable_selector": ["start", "bool_true"],
"value": ["true", "false"],
},
]
return [Condition.model_validate(i) for i in conditions]
def _get_condition_test_id(c: Condition):
return c.comparison_operator
@pytest.mark.parametrize("condition", _get_test_conditions(), ids=_get_condition_test_id)
def test_execute_if_else_boolean_conditions(condition: Condition):
"""Test IfElseNode with boolean conditions using various operators"""
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool with boolean values
pool = VariablePool(
system_variables=SystemVariable(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
pool.add(["start", "bool_false"], False)
pool.add(["start", "bool_array"], [True, False, True])
pool.add(["start", "mixed_array"], [True, "false", 1, 0])
node_data = {
"title": "Boolean Test",
"type": "if-else",
"logical_operator": "and",
"conditions": [condition.model_dump()],
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={"id": "if-else", "data": node_data},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is True
def test_execute_if_else_boolean_false_conditions():
"""Test IfElseNode with boolean conditions that should evaluate to false"""
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool with boolean values
pool = VariablePool(
system_variables=SystemVariable(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
pool.add(["start", "bool_false"], False)
pool.add(["start", "bool_array"], [True, False, True])
node_data = {
"title": "Boolean False Test",
"type": "if-else",
"logical_operator": "or",
"conditions": [
# Test boolean "is" operator (should be false)
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "false"},
# Test boolean "=" operator (should be false)
{"comparison_operator": "=", "variable_selector": ["start", "bool_false"], "value": "1"},
# Test boolean "not contains" operator (should be false)
{
"comparison_operator": "not contains",
"variable_selector": ["start", "bool_array"],
"value": "true",
},
],
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "if-else",
"data": node_data,
},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is False
def test_execute_if_else_boolean_cases_structure():
"""Test IfElseNode with boolean conditions using the new cases structure"""
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool with boolean values
pool = VariablePool(
system_variables=SystemVariable(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
pool.add(["start", "bool_false"], False)
node_data = {
"title": "Boolean Cases Test",
"type": "if-else",
"cases": [
{
"case_id": "true",
"logical_operator": "and",
"conditions": [
{
"comparison_operator": "is",
"variable_selector": ["start", "bool_true"],
"value": "true",
},
{
"comparison_operator": "is not",
"variable_selector": ["start", "bool_false"],
"value": "true",
},
],
}
],
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={"id": "if-else", "data": node_data},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is True
assert result.outputs["selected_case_id"] == "true"

View file

@ -11,7 +11,8 @@ from core.workflow.nodes.list_operator.entities import (
FilterCondition,
Limit,
ListOperatorNodeData,
OrderBy,
Order,
OrderByConfig,
)
from core.workflow.nodes.list_operator.exc import InvalidKeyError
from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func
@ -27,7 +28,7 @@ def list_operator_node():
FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT])
],
),
"order_by": OrderBy(enabled=False, value="asc"),
"order_by": OrderByConfig(enabled=False, value=Order.ASC),
"limit": Limit(enabled=False, size=0),
"extract_by": ExtractConfig(enabled=False, serial="1"),
"title": "Test Title",

View file

@ -24,16 +24,18 @@ from core.variables.segments import (
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from core.variables.types import SegmentType
from factories import variable_factory
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type
def test_string_variable():
@ -139,6 +141,26 @@ def test_array_number_variable():
assert isinstance(variable.value[1], float)
def test_build_segment_scalar_values():
@dataclass
class TestCase:
value: Any
expected: Segment
description: str
cases = [
TestCase(
value=True,
expected=BooleanSegment(value=True),
description="build_segment with boolean should yield BooleanSegment",
)
]
for idx, c in enumerate(cases, 1):
seg = build_segment(c.value)
assert seg == c.expected, f"Test case {idx} failed: {c.description}"
def test_array_object_variable():
mapping = {
"id": str(uuid4()),
@ -847,15 +869,22 @@ class TestBuildSegmentValueErrors:
f"but got: {error_message}"
)
def test_build_segment_boolean_type_note(self):
"""Note: Boolean values are actually handled as integers in Python, so they don't raise ValueError."""
# Boolean values in Python are subclasses of int, so they get processed as integers
# True becomes IntegerSegment(value=1) and False becomes IntegerSegment(value=0)
def test_build_segment_boolean_type(self):
"""Test that Boolean values are correctly handled as boolean type, not integers."""
# Boolean values should now be processed as BooleanSegment, not IntegerSegment
# This is because the bool check now comes before the int check in build_segment
true_segment = variable_factory.build_segment(True)
false_segment = variable_factory.build_segment(False)
# Verify they are processed as integers, not as errors
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
assert true_segment.value_type == SegmentType.INTEGER
assert false_segment.value_type == SegmentType.INTEGER
# Verify they are processed as booleans, not integers
assert true_segment.value is True, "Test case 1 (boolean_true): Expected True to be processed as boolean True"
assert false_segment.value is False, (
"Test case 2 (boolean_false): Expected False to be processed as boolean False"
)
assert true_segment.value_type == SegmentType.BOOLEAN
assert false_segment.value_type == SegmentType.BOOLEAN
# Test array of booleans
bool_array_segment = variable_factory.build_segment([True, False, True])
assert bool_array_segment.value_type == SegmentType.ARRAY_BOOLEAN
assert bool_array_segment.value == [True, False, True]