feat: support bool type variable frontend (#24437)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
parent
b5c2756261
commit
dac72b078d
126 changed files with 3832 additions and 512 deletions
11
api/child_class.py
Normal file
11
api/child_class.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
class ChildClass(ParentClass):
|
||||
"""Test child class for module import helper tests"""
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
def get_name(self):
|
||||
return f"Child: {self.name}"
|
||||
|
|
@ -3,6 +3,17 @@ import re
|
|||
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
|
||||
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
[
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
VariableEntityType.CHECKBOX,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class BasicVariablesConfigManager:
|
||||
@classmethod
|
||||
|
|
@ -47,6 +58,7 @@ class BasicVariablesConfigManager:
|
|||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.CHECKBOX,
|
||||
}:
|
||||
variable = variables[variable_type]
|
||||
variable_entities.append(
|
||||
|
|
@ -96,8 +108,17 @@ class BasicVariablesConfigManager:
|
|||
variables = []
|
||||
for item in config["user_input_form"]:
|
||||
key = list(item.keys())[0]
|
||||
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
|
||||
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
|
||||
# if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
|
||||
if key not in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
VariableEntityType.CHECKBOX,
|
||||
}:
|
||||
allowed_keys = ", ".join(i.value for i in _ALLOWED_VARIABLE_ENTITY_TYPE)
|
||||
raise ValueError(f"Keys in user_input_form list can only be {allowed_keys}")
|
||||
|
||||
form_item = item[key]
|
||||
if "label" not in form_item:
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ class VariableEntityType(StrEnum):
|
|||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
FILE = "file"
|
||||
FILE_LIST = "file-list"
|
||||
CHECKBOX = "checkbox"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
|
|
|
|||
|
|
@ -103,18 +103,23 @@ class BaseAppGenerator:
|
|||
f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string"
|
||||
)
|
||||
|
||||
if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str):
|
||||
# handle empty string case
|
||||
if not value.strip():
|
||||
return None
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
|
||||
if variable_entity.type == VariableEntityType.NUMBER:
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
elif isinstance(value, str):
|
||||
# handle empty string case
|
||||
if not value.strip():
|
||||
return None
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
|
||||
else:
|
||||
raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}")
|
||||
|
||||
match variable_entity.type:
|
||||
case VariableEntityType.SELECT:
|
||||
|
|
@ -144,6 +149,11 @@ class BaseAppGenerator:
|
|||
raise ValueError(
|
||||
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
|
||||
)
|
||||
case VariableEntityType.CHECKBOX:
|
||||
if not isinstance(value, bool):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid boolean value")
|
||||
case _:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
return value
|
||||
|
||||
|
|
|
|||
|
|
@ -151,6 +151,11 @@ class FileSegment(Segment):
|
|||
return ""
|
||||
|
||||
|
||||
class BooleanSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.BOOLEAN
|
||||
value: bool
|
||||
|
||||
|
||||
class ArrayAnySegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_ANY
|
||||
value: Sequence[Any]
|
||||
|
|
@ -198,6 +203,11 @@ class ArrayFileSegment(ArraySegment):
|
|||
return ""
|
||||
|
||||
|
||||
class ArrayBooleanSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
|
||||
value: Sequence[bool]
|
||||
|
||||
|
||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
if isinstance(v, Segment):
|
||||
return v.value_type
|
||||
|
|
@ -231,11 +241,13 @@ SegmentUnion: TypeAlias = Annotated[
|
|||
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
|
||||
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
|
||||
| Annotated[FileSegment, Tag(SegmentType.FILE)]
|
||||
| Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
|
||||
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
|
||||
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
|
||||
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
|
||||
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,12 @@ from core.file.models import File
|
|||
|
||||
|
||||
class ArrayValidation(StrEnum):
|
||||
"""Strategy for validating array elements"""
|
||||
"""Strategy for validating array elements.
|
||||
|
||||
Note:
|
||||
The `NONE` and `FIRST` strategies are primarily for compatibility purposes.
|
||||
Avoid using them in new code whenever possible.
|
||||
"""
|
||||
|
||||
# Skip element validation (only check array container)
|
||||
NONE = "none"
|
||||
|
|
@ -27,12 +32,14 @@ class SegmentType(StrEnum):
|
|||
SECRET = "secret"
|
||||
|
||||
FILE = "file"
|
||||
BOOLEAN = "boolean"
|
||||
|
||||
ARRAY_ANY = "array[any]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
|
||||
NONE = "none"
|
||||
|
||||
|
|
@ -76,12 +83,18 @@ class SegmentType(StrEnum):
|
|||
return SegmentType.ARRAY_FILE
|
||||
case SegmentType.NONE:
|
||||
return SegmentType.ARRAY_ANY
|
||||
case SegmentType.BOOLEAN:
|
||||
return SegmentType.ARRAY_BOOLEAN
|
||||
case _:
|
||||
# This should be unreachable.
|
||||
raise ValueError(f"not supported value {value}")
|
||||
if value is None:
|
||||
return SegmentType.NONE
|
||||
elif isinstance(value, int) and not isinstance(value, bool):
|
||||
# Important: The check for `bool` must precede the check for `int`,
|
||||
# as `bool` is a subclass of `int` in Python's type hierarchy.
|
||||
elif isinstance(value, bool):
|
||||
return SegmentType.BOOLEAN
|
||||
elif isinstance(value, int):
|
||||
return SegmentType.INTEGER
|
||||
elif isinstance(value, float):
|
||||
return SegmentType.FLOAT
|
||||
|
|
@ -111,7 +124,7 @@ class SegmentType(StrEnum):
|
|||
else:
|
||||
return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value)
|
||||
|
||||
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
|
||||
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.ALL) -> bool:
|
||||
"""
|
||||
Check if a value matches the segment type.
|
||||
Users of `SegmentType` should call this method, instead of using
|
||||
|
|
@ -126,6 +139,10 @@ class SegmentType(StrEnum):
|
|||
"""
|
||||
if self.is_array_type():
|
||||
return self._validate_array(value, array_validation)
|
||||
# Important: The check for `bool` must precede the check for `int`,
|
||||
# as `bool` is a subclass of `int` in Python's type hierarchy.
|
||||
elif self == SegmentType.BOOLEAN:
|
||||
return isinstance(value, bool)
|
||||
elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]:
|
||||
return isinstance(value, (int, float))
|
||||
elif self == SegmentType.STRING:
|
||||
|
|
@ -141,6 +158,27 @@ class SegmentType(StrEnum):
|
|||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
@staticmethod
|
||||
def cast_value(value: Any, type_: "SegmentType") -> Any:
|
||||
# Cast Python's `bool` type to `int` when the runtime type requires
|
||||
# an integer or number.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use `bool` as
|
||||
# `int`, since in Python's type system, `bool` is a subtype of `int`.
|
||||
#
|
||||
# This function exists solely to maintain compatibility with existing workflows.
|
||||
# It should not be used to compromise the integrity of the runtime type system.
|
||||
# No additional casting rules should be introduced to this function.
|
||||
|
||||
if type_ in (
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.NUMBER,
|
||||
) and isinstance(value, bool):
|
||||
return int(value)
|
||||
if type_ == SegmentType.ARRAY_NUMBER and all(isinstance(i, bool) for i in value):
|
||||
return [int(i) for i in value]
|
||||
return value
|
||||
|
||||
def exposed_type(self) -> "SegmentType":
|
||||
"""Returns the type exposed to the frontend.
|
||||
|
||||
|
|
@ -150,6 +188,20 @@ class SegmentType(StrEnum):
|
|||
return SegmentType.NUMBER
|
||||
return self
|
||||
|
||||
def element_type(self) -> "SegmentType | None":
|
||||
"""Return the element type of the current segment type, or `None` if the element type is undefined.
|
||||
|
||||
Raises:
|
||||
ValueError: If the current segment type is not an array type.
|
||||
|
||||
Note:
|
||||
For certain array types, such as `SegmentType.ARRAY_ANY`, their element types are not defined
|
||||
by the runtime system. In such cases, this method will return `None`.
|
||||
"""
|
||||
if not self.is_array_type():
|
||||
raise ValueError(f"element_type is only supported by array type, got {self}")
|
||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||
|
||||
|
||||
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
||||
# ARRAY_ANY does not have corresponding element type.
|
||||
|
|
@ -157,6 +209,7 @@ _ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
|||
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
|
||||
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
|
||||
SegmentType.ARRAY_FILE: SegmentType.FILE,
|
||||
SegmentType.ARRAY_BOOLEAN: SegmentType.BOOLEAN,
|
||||
}
|
||||
|
||||
_ARRAY_TYPES = frozenset(
|
||||
|
|
|
|||
|
|
@ -8,11 +8,13 @@ from core.helper import encrypter
|
|||
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayBooleanSegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
|
|
@ -96,10 +98,18 @@ class FileVariable(FileSegment, Variable):
|
|||
pass
|
||||
|
||||
|
||||
class BooleanVariable(BooleanSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
|
||||
# Use `Variable` for type hinting when serialization is not required.
|
||||
#
|
||||
|
|
@ -114,11 +124,13 @@ VariableUnion: TypeAlias = Annotated[
|
|||
| Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
|
||||
| Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
|
||||
| Annotated[FileVariable, Tag(SegmentType.FILE)]
|
||||
| Annotated[BooleanVariable, Tag(SegmentType.BOOLEAN)]
|
||||
| Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
|
||||
| Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
|
||||
| Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
|
||||
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
|||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
|
|
@ -119,6 +120,14 @@ class CodeNode(BaseNode):
|
|||
|
||||
return value.replace("\x00", "")
|
||||
|
||||
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, bool):
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a boolean")
|
||||
|
||||
return value
|
||||
|
||||
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
|
||||
"""
|
||||
Check number
|
||||
|
|
@ -173,6 +182,8 @@ class CodeNode(BaseNode):
|
|||
prefix=f"{prefix}.{output_name}" if prefix else output_name,
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif isinstance(output_value, bool):
|
||||
self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name)
|
||||
elif isinstance(output_value, int | float):
|
||||
self._check_number(
|
||||
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
|
||||
|
|
@ -232,7 +243,7 @@ class CodeNode(BaseNode):
|
|||
if output_name not in result:
|
||||
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
|
||||
|
||||
if output_config.type == "object":
|
||||
if output_config.type == SegmentType.OBJECT:
|
||||
# check if output is object
|
||||
if not isinstance(result.get(output_name), dict):
|
||||
if result[output_name] is None:
|
||||
|
|
@ -249,18 +260,28 @@ class CodeNode(BaseNode):
|
|||
prefix=f"{prefix}.{output_name}",
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif output_config.type == "number":
|
||||
elif output_config.type == SegmentType.NUMBER:
|
||||
# check if number available
|
||||
transformed_result[output_name] = self._check_number(
|
||||
value=result[output_name], variable=f"{prefix}{dot}{output_name}"
|
||||
)
|
||||
elif output_config.type == "string":
|
||||
checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}")
|
||||
# If the output is a boolean and the output schema specifies a NUMBER type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
transformed_result[output_name] = self._convert_boolean_to_int(checked)
|
||||
|
||||
elif output_config.type == SegmentType.STRING:
|
||||
# check if string available
|
||||
transformed_result[output_name] = self._check_string(
|
||||
value=result[output_name],
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == "array[number]":
|
||||
elif output_config.type == SegmentType.BOOLEAN:
|
||||
transformed_result[output_name] = self._check_boolean(
|
||||
value=result[output_name],
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == SegmentType.ARRAY_NUMBER:
|
||||
# check if array of number available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
|
|
@ -278,10 +299,17 @@ class CodeNode(BaseNode):
|
|||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
# If the element is a boolean and the output schema specifies a `array[number]` type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
self._convert_boolean_to_int(
|
||||
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"),
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == "array[string]":
|
||||
elif output_config.type == SegmentType.ARRAY_STRING:
|
||||
# check if array of string available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
|
|
@ -302,7 +330,7 @@ class CodeNode(BaseNode):
|
|||
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == "array[object]":
|
||||
elif output_config.type == SegmentType.ARRAY_OBJECT:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
|
|
@ -340,6 +368,22 @@ class CodeNode(BaseNode):
|
|||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
transformed_result[output_name] = [
|
||||
self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
|
||||
else:
|
||||
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
|
||||
|
||||
|
|
@ -374,3 +418,16 @@ class CodeNode(BaseNode):
|
|||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
@staticmethod
|
||||
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
|
||||
"""This function convert boolean to integers when the output schema specifies a NUMBER type.
|
||||
|
||||
This ensures compatibility with existing workflows that may use
|
||||
`True` and `False` as values for NUMBER type outputs.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -1,11 +1,31 @@
|
|||
from typing import Literal, Optional
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_type(segment_type: SegmentType) -> SegmentType:
|
||||
if segment_type not in _ALLOWED_OUTPUT_FROM_CODE:
|
||||
raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}")
|
||||
return segment_type
|
||||
|
||||
|
||||
class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
|
|
@ -13,7 +33,7 @@ class CodeNodeData(BaseNodeData):
|
|||
"""
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: Optional[dict[str, "CodeNodeData.Output"]] = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1,36 +1,43 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
_Condition = Literal[
|
||||
|
||||
class FilterOperator(StrEnum):
|
||||
# string conditions
|
||||
"contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"in",
|
||||
"empty",
|
||||
"not contains",
|
||||
"is not",
|
||||
"not in",
|
||||
"not empty",
|
||||
CONTAINS = "contains"
|
||||
START_WITH = "start with"
|
||||
END_WITH = "end with"
|
||||
IS = "is"
|
||||
IN = "in"
|
||||
EMPTY = "empty"
|
||||
NOT_CONTAINS = "not contains"
|
||||
IS_NOT = "is not"
|
||||
NOT_IN = "not in"
|
||||
NOT_EMPTY = "not empty"
|
||||
# number conditions
|
||||
"=",
|
||||
"≠",
|
||||
"<",
|
||||
">",
|
||||
"≥",
|
||||
"≤",
|
||||
]
|
||||
EQUAL = "="
|
||||
NOT_EQUAL = "≠"
|
||||
LESS_THAN = "<"
|
||||
GREATER_THAN = ">"
|
||||
GREATER_THAN_OR_EQUAL = "≥"
|
||||
LESS_THAN_OR_EQUAL = "≤"
|
||||
|
||||
|
||||
class Order(StrEnum):
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class FilterCondition(BaseModel):
|
||||
key: str = ""
|
||||
comparison_operator: _Condition = "contains"
|
||||
value: str | Sequence[str] = ""
|
||||
comparison_operator: FilterOperator = FilterOperator.CONTAINS
|
||||
# the value is bool if the filter operator is comparing with
|
||||
# a boolean constant.
|
||||
value: str | Sequence[str] | bool = ""
|
||||
|
||||
|
||||
class FilterBy(BaseModel):
|
||||
|
|
@ -38,10 +45,10 @@ class FilterBy(BaseModel):
|
|||
conditions: Sequence[FilterCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OrderBy(BaseModel):
|
||||
class OrderByConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
key: str = ""
|
||||
value: Literal["asc", "desc"] = "asc"
|
||||
value: Order = Order.ASC
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
|
|
@ -57,6 +64,6 @@ class ExtractConfig(BaseModel):
|
|||
class ListOperatorNodeData(BaseNodeData):
|
||||
variable: Sequence[str] = Field(default_factory=list)
|
||||
filter_by: FilterBy
|
||||
order_by: OrderBy
|
||||
order_by: OrderByConfig
|
||||
limit: Limit
|
||||
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,40 @@
|
|||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Optional, TypeAlias, TypeVar
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
from .entities import ListOperatorNodeData
|
||||
from .entities import FilterOperator, ListOperatorNodeData, Order
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
||||
_SUPPORTED_TYPES_TUPLE = (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayStringSegment,
|
||||
ArrayBooleanSegment,
|
||||
)
|
||||
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
|
||||
"""Returns the negation of a given filter function. If the original filter
|
||||
returns `True` for a value, the negated filter will return `False`, and vice versa.
|
||||
"""
|
||||
|
||||
def wrapper(value: _T) -> bool:
|
||||
return not filter_(value)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode):
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
|
|
@ -69,11 +91,8 @@ class ListOperatorNode(BaseNode):
|
|||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
|
||||
error_message = (
|
||||
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
|
||||
"or ArrayStringSegment"
|
||||
)
|
||||
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
|
||||
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
|
|
@ -122,9 +141,7 @@ class ListOperatorNode(BaseNode):
|
|||
outputs=outputs,
|
||||
)
|
||||
|
||||
def _apply_filter(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
filter_func: Callable[[Any], bool]
|
||||
result: list[Any] = []
|
||||
for condition in self._node_data.filter_by.conditions:
|
||||
|
|
@ -154,33 +171,35 @@ class ListOperatorNode(BaseNode):
|
|||
)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayBooleanSegment):
|
||||
if not isinstance(condition.value, bool):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statment should be unreachable.")
|
||||
return variable
|
||||
|
||||
def _apply_order(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
|
||||
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
|
||||
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
result = _order_file(
|
||||
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable")
|
||||
|
||||
return variable
|
||||
|
||||
def _apply_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
result = variable.value[: self._node_data.limit.size]
|
||||
return variable.model_copy(update={"value": result})
|
||||
|
||||
def _extract_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
|
||||
if value < 1:
|
||||
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
|
||||
|
|
@ -232,11 +251,11 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
|
|||
case "empty":
|
||||
return lambda x: x == ""
|
||||
case "not contains":
|
||||
return lambda x: not _contains(value)(x)
|
||||
return _negation(_contains(value))
|
||||
case "is not":
|
||||
return lambda x: not _is(value)(x)
|
||||
return _negation(_is(value))
|
||||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
return _negation(_in(value))
|
||||
case "not empty":
|
||||
return lambda x: x != ""
|
||||
case _:
|
||||
|
|
@ -248,7 +267,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
|
|||
case "in":
|
||||
return _in(value)
|
||||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
return _negation(_in(value))
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
|
@ -271,6 +290,16 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
|
|||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
|
||||
match condition:
|
||||
case FilterOperator.IS:
|
||||
return _is(value)
|
||||
case FilterOperator.IS_NOT:
|
||||
return _negation(_is(value))
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
extract_func: Callable[[File], Any]
|
||||
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
|
||||
|
|
@ -298,7 +327,7 @@ def _endswith(value: str) -> Callable[[str], bool]:
|
|||
return lambda x: x.endswith(value)
|
||||
|
||||
|
||||
def _is(value: str) -> Callable[[str], bool]:
|
||||
def _is(value: _T) -> Callable[[_T], bool]:
|
||||
return lambda x: x == value
|
||||
|
||||
|
||||
|
|
@ -330,21 +359,13 @@ def _ge(value: int | float) -> Callable[[int | float], bool]:
|
|||
return lambda x: x >= value
|
||||
|
||||
|
||||
def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]):
|
||||
return sorted(array, key=lambda x: x, reverse=order == "desc")
|
||||
|
||||
|
||||
def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
|
||||
return sorted(array, key=lambda x: x, reverse=order == "desc")
|
||||
|
||||
|
||||
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
|
||||
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
|
||||
extract_func: Callable[[File], Any]
|
||||
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
|
||||
extract_func = _get_file_extract_string_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
||||
elif order_by == "size":
|
||||
extract_func = _get_file_extract_number_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
||||
else:
|
||||
raise InvalidKeyError(f"Invalid order key: {order_by}")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import io
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
|
|
@ -55,7 +55,6 @@ from core.workflow.entities.variable_entities import VariableSelector
|
|||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
|
@ -90,6 +89,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
|
|||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -161,7 +161,7 @@ class LLMNode(BaseNode):
|
|||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
process_data = None
|
||||
result_text = ""
|
||||
|
|
|
|||
|
|
@ -12,9 +12,11 @@ _VALID_VAR_TYPE = frozenset(
|
|||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -404,11 +404,11 @@ class LoopNode(BaseNode):
|
|||
for node_id in loop_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
_outputs = {}
|
||||
_outputs: dict[str, Segment | int | None] = {}
|
||||
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
|
||||
_loop_variable_segment = variable_pool.get(loop_variable_selector)
|
||||
if _loop_variable_segment:
|
||||
_outputs[loop_variable_key] = _loop_variable_segment.value
|
||||
_outputs[loop_variable_key] = _loop_variable_segment
|
||||
else:
|
||||
_outputs[loop_variable_key] = None
|
||||
|
||||
|
|
@ -522,21 +522,30 @@ class LoopNode(BaseNode):
|
|||
return variable_mapping
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
if var_type in ["array[string]", "array[number]", "array[object]"]:
|
||||
if value and isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
if var_type in [
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
]:
|
||||
if original_value and isinstance(original_value, str):
|
||||
value = json.loads(original_value)
|
||||
else:
|
||||
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
|
||||
value = []
|
||||
elif var_type == SegmentType.ARRAY_BOOLEAN:
|
||||
value = original_value
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
try:
|
||||
return build_segment_with_type(var_type, value)
|
||||
return build_segment_with_type(var_type, value=value)
|
||||
except TypeMismatchError as type_exc:
|
||||
# Attempt to parse the value as a JSON-encoded string, if applicable.
|
||||
if not isinstance(value, str):
|
||||
if not isinstance(original_value, str):
|
||||
raise
|
||||
try:
|
||||
value = json.loads(value)
|
||||
value = json.loads(original_value)
|
||||
except ValueError:
|
||||
raise type_exc
|
||||
return build_segment_with_type(var_type, value)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,46 @@
|
|||
from typing import Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
_OLD_BOOL_TYPE_NAME = "bool"
|
||||
_OLD_SELECT_TYPE_NAME = "select"
|
||||
|
||||
_VALID_PARAMETER_TYPES = frozenset(
|
||||
[
|
||||
SegmentType.STRING, # "string",
|
||||
SegmentType.NUMBER, # "number",
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
|
||||
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_type(parameter_type: str) -> SegmentType:
|
||||
if not isinstance(parameter_type, str):
|
||||
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
|
||||
if parameter_type not in _VALID_PARAMETER_TYPES:
|
||||
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
|
||||
|
||||
if parameter_type == _OLD_BOOL_TYPE_NAME:
|
||||
return SegmentType.BOOLEAN
|
||||
elif parameter_type == _OLD_SELECT_TYPE_NAME:
|
||||
return SegmentType.STRING
|
||||
return SegmentType(parameter_type)
|
||||
|
||||
|
||||
class _ParameterConfigError(Exception):
|
||||
|
|
@ -17,7 +53,7 @@ class ParameterConfig(BaseModel):
|
|||
"""
|
||||
|
||||
name: str
|
||||
type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"]
|
||||
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
|
||||
options: Optional[list[str]] = None
|
||||
description: str
|
||||
required: bool
|
||||
|
|
@ -32,17 +68,20 @@ class ParameterConfig(BaseModel):
|
|||
return str(value)
|
||||
|
||||
def is_array_type(self) -> bool:
|
||||
return self.type in ("array[string]", "array[number]", "array[object]")
|
||||
return self.type.is_array_type()
|
||||
|
||||
def element_type(self) -> Literal["string", "number", "object"]:
|
||||
if self.type == "array[number]":
|
||||
return "number"
|
||||
elif self.type == "array[string]":
|
||||
return "string"
|
||||
elif self.type == "array[object]":
|
||||
return "object"
|
||||
else:
|
||||
raise _ParameterConfigError(f"{self.type} is not array type.")
|
||||
def element_type(self) -> SegmentType:
|
||||
"""Return the element type of the parameter.
|
||||
|
||||
Raises a ValueError if the parameter's type is not an array type.
|
||||
"""
|
||||
element_type = self.type.element_type()
|
||||
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
|
||||
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
|
||||
#
|
||||
# See: _VALID_PARAMETER_TYPES for reference.
|
||||
assert element_type is not None, f"the element type should not be None, {self.type=}"
|
||||
return element_type
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
|
|
@ -74,16 +113,18 @@ class ParameterExtractorNodeData(BaseNodeData):
|
|||
for parameter in self.parameters:
|
||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
||||
|
||||
if parameter.type in {"string", "select"}:
|
||||
if parameter.type == SegmentType.STRING:
|
||||
parameter_schema["type"] = "string"
|
||||
elif parameter.type.startswith("array"):
|
||||
elif parameter.type.is_array_type():
|
||||
parameter_schema["type"] = "array"
|
||||
nested_type = parameter.type[6:-1]
|
||||
parameter_schema["items"] = {"type": nested_type}
|
||||
element_type = parameter.type.element_type()
|
||||
if element_type is None:
|
||||
raise AssertionError("element type should not be None.")
|
||||
parameter_schema["items"] = {"type": element_type.value}
|
||||
else:
|
||||
parameter_schema["type"] = parameter.type
|
||||
|
||||
if parameter.type == "select":
|
||||
if parameter.options:
|
||||
parameter_schema["enum"] = parameter.options
|
||||
|
||||
parameters["properties"][parameter.name] = parameter_schema
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
from typing import Any
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class ParameterExtractorNodeError(ValueError):
|
||||
"""Base error for ParameterExtractorNode."""
|
||||
|
||||
|
|
@ -48,3 +53,23 @@ class InvalidArrayValueError(ParameterExtractorNodeError):
|
|||
|
||||
class InvalidModelModeError(ParameterExtractorNodeError):
|
||||
"""Raised when the model mode is invalid."""
|
||||
|
||||
|
||||
class InvalidValueTypeError(ParameterExtractorNodeError):
|
||||
def __init__(
|
||||
self,
|
||||
/,
|
||||
parameter_name: str,
|
||||
expected_type: SegmentType,
|
||||
actual_type: SegmentType | None,
|
||||
value: Any,
|
||||
) -> None:
|
||||
message = (
|
||||
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
|
||||
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
|
||||
)
|
||||
super().__init__(message)
|
||||
self.parameter_name = parameter_name
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
self.value = value
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
|
@ -39,16 +39,13 @@ from factories.variable_factory import build_segment_with_type
|
|||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
InvalidArrayValueError,
|
||||
InvalidBoolValueError,
|
||||
InvalidInvokeResultError,
|
||||
InvalidModelModeError,
|
||||
InvalidModelTypeError,
|
||||
InvalidNumberOfParametersError,
|
||||
InvalidNumberValueError,
|
||||
InvalidSelectValueError,
|
||||
InvalidStringValueError,
|
||||
InvalidTextContentTypeError,
|
||||
InvalidValueTypeError,
|
||||
ModelSchemaNotFoundError,
|
||||
ParameterExtractorNodeError,
|
||||
RequiredParameterMissingError,
|
||||
|
|
@ -549,9 +546,6 @@ class ParameterExtractorNode(BaseNode):
|
|||
return prompt_messages
|
||||
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
"""
|
||||
Validate result.
|
||||
"""
|
||||
if len(data.parameters) != len(result):
|
||||
raise InvalidNumberOfParametersError("Invalid number of parameters")
|
||||
|
||||
|
|
@ -559,101 +553,106 @@ class ParameterExtractorNode(BaseNode):
|
|||
if parameter.required and parameter.name not in result:
|
||||
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
|
||||
|
||||
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
|
||||
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
|
||||
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
|
||||
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
|
||||
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type.startswith("array"):
|
||||
parameters = result.get(parameter.name)
|
||||
if not isinstance(parameters, list):
|
||||
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||
nested_type = parameter.type[6:-1]
|
||||
for item in parameters:
|
||||
if nested_type == "number" and not isinstance(item, int | float):
|
||||
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
||||
if nested_type == "string" and not isinstance(item, str):
|
||||
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
|
||||
if nested_type == "object" and not isinstance(item, dict):
|
||||
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
|
||||
param_value = result.get(parameter.name)
|
||||
if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL):
|
||||
inferred_type = SegmentType.infer_segment_type(param_value)
|
||||
raise InvalidValueTypeError(
|
||||
parameter_name=parameter.name,
|
||||
expected_type=parameter.type,
|
||||
actual_type=inferred_type,
|
||||
value=param_value,
|
||||
)
|
||||
if parameter.type == SegmentType.STRING and parameter.options:
|
||||
if param_value not in parameter.options:
|
||||
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _transform_number(value: int | float | str | bool) -> int | float | None:
|
||||
"""
|
||||
Attempts to transform the input into an integer or float.
|
||||
|
||||
Returns:
|
||||
int or float: The transformed number if the conversion is successful.
|
||||
None: If the transformation fails.
|
||||
|
||||
Note:
|
||||
Boolean values `True` and `False` are converted to integers `1` and `0`, respectively.
|
||||
This behavior ensures compatibility with existing workflows that may use boolean types as integers.
|
||||
"""
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
elif isinstance(value, (int, float)):
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
return None
|
||||
if "." in value:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
transformed_result = {}
|
||||
transformed_result: dict[str, Any] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.name in result:
|
||||
param_value = result[parameter.name]
|
||||
# transform value
|
||||
if parameter.type == "number":
|
||||
if isinstance(result[parameter.name], int | float):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
elif isinstance(result[parameter.name], str):
|
||||
try:
|
||||
if "." in result[parameter.name]:
|
||||
result[parameter.name] = float(result[parameter.name])
|
||||
else:
|
||||
result[parameter.name] = int(result[parameter.name])
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
# TODO: bool is not supported in the current version
|
||||
# elif parameter.type == 'bool':
|
||||
# if isinstance(result[parameter.name], bool):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
# elif isinstance(result[parameter.name], str):
|
||||
# if result[parameter.name].lower() in ['true', 'false']:
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
|
||||
# elif isinstance(result[parameter.name], int):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
elif parameter.type in {"string", "select"}:
|
||||
if isinstance(result[parameter.name], str):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
if parameter.type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(param_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
if isinstance(result[parameter.name], (bool, int)):
|
||||
transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
# elif isinstance(result[parameter.name], str):
|
||||
# if result[parameter.name].lower() in ["true", "false"]:
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true")
|
||||
elif parameter.type == SegmentType.STRING:
|
||||
if isinstance(param_value, str):
|
||||
transformed_result[parameter.name] = param_value
|
||||
elif parameter.is_array_type():
|
||||
if isinstance(result[parameter.name], list):
|
||||
if isinstance(param_value, list):
|
||||
nested_type = parameter.element_type()
|
||||
assert nested_type is not None
|
||||
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
|
||||
transformed_result[parameter.name] = segment_value
|
||||
for item in result[parameter.name]:
|
||||
if nested_type == "number":
|
||||
if isinstance(item, int | float):
|
||||
segment_value.value.append(item)
|
||||
elif isinstance(item, str):
|
||||
try:
|
||||
if "." in item:
|
||||
segment_value.value.append(float(item))
|
||||
else:
|
||||
segment_value.value.append(int(item))
|
||||
except ValueError:
|
||||
pass
|
||||
elif nested_type == "string":
|
||||
for item in param_value:
|
||||
if nested_type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(item)
|
||||
if transformed is not None:
|
||||
segment_value.value.append(transformed)
|
||||
elif nested_type == SegmentType.STRING:
|
||||
if isinstance(item, str):
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == "object":
|
||||
elif nested_type == SegmentType.OBJECT:
|
||||
if isinstance(item, dict):
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == SegmentType.BOOLEAN:
|
||||
if isinstance(item, bool):
|
||||
segment_value.value.append(item)
|
||||
|
||||
if parameter.name not in transformed_result:
|
||||
if parameter.type == "number":
|
||||
transformed_result[parameter.name] = 0
|
||||
elif parameter.type == "bool":
|
||||
transformed_result[parameter.name] = False
|
||||
elif parameter.type in {"string", "select"}:
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type.startswith("array"):
|
||||
if parameter.type.is_array_type():
|
||||
transformed_result[parameter.name] = build_segment_with_type(
|
||||
segment_type=SegmentType(parameter.type), value=[]
|
||||
)
|
||||
elif parameter.type in (SegmentType.STRING, SegmentType.SECRET):
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type == SegmentType.NUMBER:
|
||||
transformed_result[parameter.name] = 0
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
transformed_result[parameter.name] = False
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
return transformed_result
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from collections.abc import Callable, Mapping, Sequence
|
|||
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.segments import BooleanSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
|
|
@ -158,8 +159,8 @@ class VariableAssignerNode(BaseNode):
|
|||
def get_zero_value(t: SegmentType):
|
||||
# TODO(QuantumGhost): this should be a method of `SegmentType`.
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
return variable_factory.build_segment([])
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
|
||||
return variable_factory.build_segment_with_type(t, [])
|
||||
case SegmentType.OBJECT:
|
||||
return variable_factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
|
|
@ -170,5 +171,7 @@ def get_zero_value(t: SegmentType):
|
|||
return variable_factory.build_segment(0.0)
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.BOOLEAN:
|
||||
return BooleanSegment(value=False)
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported variable type: {t}")
|
||||
|
|
|
|||
|
|
@ -4,9 +4,11 @@ from core.variables import SegmentType
|
|||
EMPTY_VALUE_MAPPING = {
|
||||
SegmentType.STRING: "",
|
||||
SegmentType.NUMBER: 0,
|
||||
SegmentType.BOOLEAN: False,
|
||||
SegmentType.OBJECT: {},
|
||||
SegmentType.ARRAY_ANY: [],
|
||||
SegmentType.ARRAY_STRING: [],
|
||||
SegmentType.ARRAY_NUMBER: [],
|
||||
SegmentType.ARRAY_OBJECT: [],
|
||||
SegmentType.ARRAY_BOOLEAN: [],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,28 +16,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
|||
SegmentType.NUMBER,
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
SegmentType.BOOLEAN,
|
||||
}
|
||||
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
|
||||
# Only number variable can be added, subtracted, multiplied or divided
|
||||
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
|
||||
case Operation.APPEND | Operation.EXTEND:
|
||||
case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
|
||||
# Only array variable can be appended or extended
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
case Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
|
||||
# Only array variable can have elements removed
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
return variable_type.is_array_type()
|
||||
case _:
|
||||
return False
|
||||
|
||||
|
|
@ -50,7 +37,7 @@ def is_variable_input_supported(*, operation: Operation):
|
|||
|
||||
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
match variable_type:
|
||||
case SegmentType.STRING | SegmentType.OBJECT:
|
||||
case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN:
|
||||
return operation in {Operation.OVER_WRITE, Operation.SET}
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
return operation in {
|
||||
|
|
@ -72,6 +59,9 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
|||
case SegmentType.STRING:
|
||||
return isinstance(value, str)
|
||||
|
||||
case SegmentType.BOOLEAN:
|
||||
return isinstance(value, bool)
|
||||
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
if not isinstance(value, int | float):
|
||||
return False
|
||||
|
|
@ -91,6 +81,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
|||
return isinstance(value, int | float)
|
||||
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
|
||||
return isinstance(value, dict)
|
||||
case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND:
|
||||
return isinstance(value, bool)
|
||||
|
||||
# Array & Extend / Overwrite
|
||||
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
|
|
@ -101,6 +93,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
|||
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
|
||||
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
|
||||
case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, bool) for item in value)
|
||||
|
||||
case _:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -45,5 +45,5 @@ class SubVariableCondition(BaseModel):
|
|||
class Condition(BaseModel):
|
||||
variable_selector: list[str]
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None = None
|
||||
value: str | Sequence[str] | bool | None = None
|
||||
sub_variable_condition: SubVariableCondition | None = None
|
||||
|
|
|
|||
|
|
@ -1,13 +1,27 @@
|
|||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from core.file import FileAttribute, file_manager
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from .entities import Condition, SubCondition, SupportedComparisonOperator
|
||||
|
||||
|
||||
def _convert_to_bool(value: Any) -> bool:
|
||||
if isinstance(value, int):
|
||||
return bool(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
loaded = json.loads(value)
|
||||
if isinstance(loaded, (int, bool)):
|
||||
return bool(loaded)
|
||||
|
||||
raise TypeError(f"unexpected value: type={type(value)}, value={value}")
|
||||
|
||||
|
||||
class ConditionProcessor:
|
||||
def process_conditions(
|
||||
self,
|
||||
|
|
@ -48,9 +62,16 @@ class ConditionProcessor:
|
|||
)
|
||||
else:
|
||||
actual_value = variable.value if variable else None
|
||||
expected_value = condition.value
|
||||
expected_value: str | Sequence[str] | bool | list[bool] | None = condition.value
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = variable_pool.convert_template(expected_value).text
|
||||
# Here we need to explicit convet the input string to boolean.
|
||||
if isinstance(variable, (BooleanSegment, ArrayBooleanSegment)) and expected_value is not None:
|
||||
# The following two lines is for compatibility with existing workflows.
|
||||
if isinstance(expected_value, list):
|
||||
expected_value = [_convert_to_bool(i) for i in expected_value]
|
||||
else:
|
||||
expected_value = _convert_to_bool(expected_value)
|
||||
input_conditions.append(
|
||||
{
|
||||
"actual_value": actual_value,
|
||||
|
|
@ -77,7 +98,7 @@ def _evaluate_condition(
|
|||
*,
|
||||
operator: SupportedComparisonOperator,
|
||||
value: Any,
|
||||
expected: str | Sequence[str] | None,
|
||||
expected: Union[str, Sequence[str], bool | Sequence[bool], None],
|
||||
) -> bool:
|
||||
match operator:
|
||||
case "contains":
|
||||
|
|
@ -130,7 +151,7 @@ def _assert_contains(*, value: Any, expected: Any) -> bool:
|
|||
if not value:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str | list):
|
||||
if not isinstance(value, (str, list)):
|
||||
raise ValueError("Invalid actual value type: string or array")
|
||||
|
||||
if expected not in value:
|
||||
|
|
@ -142,7 +163,7 @@ def _assert_not_contains(*, value: Any, expected: Any) -> bool:
|
|||
if not value:
|
||||
return True
|
||||
|
||||
if not isinstance(value, str | list):
|
||||
if not isinstance(value, (str, list)):
|
||||
raise ValueError("Invalid actual value type: string or array")
|
||||
|
||||
if expected in value:
|
||||
|
|
@ -178,8 +199,8 @@ def _assert_is(*, value: Any, expected: Any) -> bool:
|
|||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("Invalid actual value type: string")
|
||||
if not isinstance(value, (str, bool)):
|
||||
raise ValueError("Invalid actual value type: string or boolean")
|
||||
|
||||
if value != expected:
|
||||
return False
|
||||
|
|
@ -190,8 +211,8 @@ def _assert_is_not(*, value: Any, expected: Any) -> bool:
|
|||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("Invalid actual value type: string")
|
||||
if not isinstance(value, (str, bool)):
|
||||
raise ValueError("Invalid actual value type: string or boolean")
|
||||
|
||||
if value == expected:
|
||||
return False
|
||||
|
|
@ -214,10 +235,13 @@ def _assert_equal(*, value: Any, expected: Any) -> bool:
|
|||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
if not isinstance(value, (int, float, bool)):
|
||||
raise ValueError("Invalid actual value type: number or boolean")
|
||||
|
||||
if isinstance(value, int):
|
||||
# Handle boolean comparison
|
||||
if isinstance(value, bool):
|
||||
expected = bool(expected)
|
||||
elif isinstance(value, int):
|
||||
expected = int(expected)
|
||||
else:
|
||||
expected = float(expected)
|
||||
|
|
@ -231,10 +255,13 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool:
|
|||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
if not isinstance(value, (int, float, bool)):
|
||||
raise ValueError("Invalid actual value type: number or boolean")
|
||||
|
||||
if isinstance(value, int):
|
||||
# Handle boolean comparison
|
||||
if isinstance(value, bool):
|
||||
expected = bool(expected)
|
||||
elif isinstance(value, int):
|
||||
expected = int(expected)
|
||||
else:
|
||||
expected = float(expected)
|
||||
|
|
@ -248,7 +275,7 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool:
|
|||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
|
|
@ -265,7 +292,7 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool:
|
|||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
|
|
@ -282,7 +309,7 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
|
|||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
|
|
@ -299,7 +326,7 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
|
|||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
|
|
|
|||
|
|
@ -7,11 +7,13 @@ from core.file import File
|
|||
from core.variables.exc import VariableError
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayBooleanSegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
|
|
@ -23,10 +25,12 @@ from core.variables.segments import (
|
|||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import (
|
||||
ArrayAnyVariable,
|
||||
ArrayBooleanVariable,
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
BooleanVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
|
|
@ -49,17 +53,19 @@ class TypeMismatchError(Exception):
|
|||
|
||||
# Define the constant
|
||||
SEGMENT_TO_VARIABLE_MAP = {
|
||||
StringSegment: StringVariable,
|
||||
IntegerSegment: IntegerVariable,
|
||||
FloatSegment: FloatVariable,
|
||||
ObjectSegment: ObjectVariable,
|
||||
FileSegment: FileVariable,
|
||||
ArrayStringSegment: ArrayStringVariable,
|
||||
ArrayAnySegment: ArrayAnyVariable,
|
||||
ArrayBooleanSegment: ArrayBooleanVariable,
|
||||
ArrayFileSegment: ArrayFileVariable,
|
||||
ArrayNumberSegment: ArrayNumberVariable,
|
||||
ArrayObjectSegment: ArrayObjectVariable,
|
||||
ArrayFileSegment: ArrayFileVariable,
|
||||
ArrayAnySegment: ArrayAnyVariable,
|
||||
ArrayStringSegment: ArrayStringVariable,
|
||||
BooleanSegment: BooleanVariable,
|
||||
FileSegment: FileVariable,
|
||||
FloatSegment: FloatVariable,
|
||||
IntegerSegment: IntegerVariable,
|
||||
NoneSegment: NoneVariable,
|
||||
ObjectSegment: ObjectVariable,
|
||||
StringSegment: StringVariable,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -99,6 +105,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
|||
mapping = dict(mapping)
|
||||
mapping["value_type"] = SegmentType.FLOAT
|
||||
result = FloatVariable.model_validate(mapping)
|
||||
case SegmentType.BOOLEAN:
|
||||
result = BooleanVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||
raise VariableError(f"invalid number value {value}")
|
||||
case SegmentType.OBJECT if isinstance(value, dict):
|
||||
|
|
@ -109,6 +117,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
|||
result = ArrayNumberVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
||||
result = ArrayObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_BOOLEAN if isinstance(value, list):
|
||||
result = ArrayBooleanVariable.model_validate(mapping)
|
||||
case _:
|
||||
raise VariableError(f"not supported value type {value_type}")
|
||||
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
||||
|
|
@ -129,6 +139,8 @@ def build_segment(value: Any, /) -> Segment:
|
|||
return NoneSegment()
|
||||
if isinstance(value, str):
|
||||
return StringSegment(value=value)
|
||||
if isinstance(value, bool):
|
||||
return BooleanSegment(value=value)
|
||||
if isinstance(value, int):
|
||||
return IntegerSegment(value=value)
|
||||
if isinstance(value, float):
|
||||
|
|
@ -152,6 +164,8 @@ def build_segment(value: Any, /) -> Segment:
|
|||
return ArrayStringSegment(value=value)
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
return ArrayNumberSegment(value=value)
|
||||
case SegmentType.BOOLEAN:
|
||||
return ArrayBooleanSegment(value=value)
|
||||
case SegmentType.OBJECT:
|
||||
return ArrayObjectSegment(value=value)
|
||||
case SegmentType.FILE:
|
||||
|
|
@ -170,6 +184,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
|
|||
SegmentType.INTEGER: IntegerSegment,
|
||||
SegmentType.FLOAT: FloatSegment,
|
||||
SegmentType.FILE: FileSegment,
|
||||
SegmentType.BOOLEAN: BooleanSegment,
|
||||
SegmentType.OBJECT: ObjectSegment,
|
||||
# Array types
|
||||
SegmentType.ARRAY_ANY: ArrayAnySegment,
|
||||
|
|
@ -177,6 +192,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
|
|||
SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
|
||||
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
|
||||
SegmentType.ARRAY_FILE: ArrayFileSegment,
|
||||
SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -225,6 +241,8 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
|||
return ArrayAnySegment(value=value)
|
||||
elif segment_type == SegmentType.ARRAY_STRING:
|
||||
return ArrayStringSegment(value=value)
|
||||
elif segment_type == SegmentType.ARRAY_BOOLEAN:
|
||||
return ArrayBooleanSegment(value=value)
|
||||
elif segment_type == SegmentType.ARRAY_NUMBER:
|
||||
return ArrayNumberSegment(value=value)
|
||||
elif segment_type == SegmentType.ARRAY_OBJECT:
|
||||
|
|
|
|||
11
api/lazy_load_class.py
Normal file
11
api/lazy_load_class.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
class LazyLoadChildClass(ParentClass):
|
||||
"""Test lazy load child class for module import helper tests"""
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
|
|
@ -20,3 +20,6 @@ ignore_missing_imports=True
|
|||
|
||||
[mypy-flask_restx.inputs]
|
||||
ignore_missing_imports=True
|
||||
|
||||
[mypy-google.cloud.storage]
|
||||
ignore_missing_imports=True
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ class TestSegmentTypeIsArrayType:
|
|||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
expected_non_array_types = [
|
||||
SegmentType.INTEGER,
|
||||
|
|
@ -34,6 +35,7 @@ class TestSegmentTypeIsArrayType:
|
|||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
SegmentType.GROUP,
|
||||
SegmentType.BOOLEAN,
|
||||
]
|
||||
|
||||
for seg_type in expected_array_types:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,729 @@
|
|||
"""
|
||||
Comprehensive unit tests for SegmentType.is_valid and SegmentType._validate_array methods.
|
||||
|
||||
This module provides thorough testing of the validation logic for all SegmentType values,
|
||||
including edge cases, error conditions, and different ArrayValidation strategies.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
|
||||
|
||||
def create_test_file(
|
||||
file_type: FileType = FileType.DOCUMENT,
|
||||
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
|
||||
filename: str = "test.txt",
|
||||
extension: str = ".txt",
|
||||
mime_type: str = "text/plain",
|
||||
size: int = 1024,
|
||||
) -> File:
|
||||
"""Factory function to create File objects for testing."""
|
||||
return File(
|
||||
tenant_id="test-tenant",
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
|
||||
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
storage_key="test-storage-key",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationTestCase:
|
||||
"""Test case data structure for validation tests."""
|
||||
|
||||
segment_type: SegmentType
|
||||
value: Any
|
||||
expected: bool
|
||||
description: str
|
||||
|
||||
def get_id(self):
|
||||
return self.description
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArrayValidationTestCase:
|
||||
"""Test case data structure for array validation tests."""
|
||||
|
||||
segment_type: SegmentType
|
||||
value: Any
|
||||
array_validation: ArrayValidation
|
||||
expected: bool
|
||||
description: str
|
||||
|
||||
def get_id(self):
|
||||
return self.description
|
||||
|
||||
|
||||
# Test data construction functions
|
||||
def get_boolean_cases() -> list[ValidationTestCase]:
|
||||
return [
|
||||
# valid values
|
||||
ValidationTestCase(SegmentType.BOOLEAN, True, True, "True boolean"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, False, True, "False boolean"),
|
||||
# Invalid values
|
||||
ValidationTestCase(SegmentType.BOOLEAN, 1, False, "Integer 1 (not boolean)"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, 0, False, "Integer 0 (not boolean)"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, "true", False, "String 'true'"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, "false", False, "String 'false'"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, {}, False, "Empty dict"),
|
||||
]
|
||||
|
||||
|
||||
def get_number_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid number values."""
|
||||
return [
|
||||
# valid values
|
||||
ValidationTestCase(SegmentType.NUMBER, 42, True, "Positive integer"),
|
||||
ValidationTestCase(SegmentType.NUMBER, -42, True, "Negative integer"),
|
||||
ValidationTestCase(SegmentType.NUMBER, 0, True, "Zero integer"),
|
||||
ValidationTestCase(SegmentType.NUMBER, 3.14, True, "Positive float"),
|
||||
ValidationTestCase(SegmentType.NUMBER, -3.14, True, "Negative float"),
|
||||
ValidationTestCase(SegmentType.NUMBER, 0.0, True, "Zero float"),
|
||||
ValidationTestCase(SegmentType.NUMBER, float("inf"), True, "Positive infinity"),
|
||||
ValidationTestCase(SegmentType.NUMBER, float("-inf"), True, "Negative infinity"),
|
||||
ValidationTestCase(SegmentType.NUMBER, float("nan"), True, "float(NaN)"),
|
||||
# invalid number values
|
||||
ValidationTestCase(SegmentType.NUMBER, "42", False, "String number"),
|
||||
ValidationTestCase(SegmentType.NUMBER, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.NUMBER, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.NUMBER, {}, False, "Empty dict"),
|
||||
ValidationTestCase(SegmentType.NUMBER, "3.14", False, "String float"),
|
||||
]
|
||||
|
||||
|
||||
def get_string_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid string values."""
|
||||
return [
|
||||
# valid values
|
||||
ValidationTestCase(SegmentType.STRING, "", True, "Empty string"),
|
||||
ValidationTestCase(SegmentType.STRING, "hello", True, "Simple string"),
|
||||
ValidationTestCase(SegmentType.STRING, "🚀", True, "Unicode emoji"),
|
||||
ValidationTestCase(SegmentType.STRING, "line1\nline2", True, "Multiline string"),
|
||||
# invalid values
|
||||
ValidationTestCase(SegmentType.STRING, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.STRING, 3.14, False, "Float"),
|
||||
ValidationTestCase(SegmentType.STRING, True, False, "Boolean"),
|
||||
ValidationTestCase(SegmentType.STRING, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.STRING, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.STRING, {}, False, "Empty dict"),
|
||||
]
|
||||
|
||||
|
||||
def get_object_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid object values."""
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.OBJECT, {}, True, "Empty dict"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"key": "value"}, True, "Simple dict"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"a": 1, "b": 2}, True, "Dict with numbers"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"nested": {"key": "value"}}, True, "Nested dict"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"list": [1, 2, 3]}, True, "Dict with list"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"mixed": [1, "two", {"three": 3}]}, True, "Complex dict"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.OBJECT, "not a dict", False, "String"),
|
||||
ValidationTestCase(SegmentType.OBJECT, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.OBJECT, 3.14, False, "Float"),
|
||||
ValidationTestCase(SegmentType.OBJECT, True, False, "Boolean"),
|
||||
ValidationTestCase(SegmentType.OBJECT, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.OBJECT, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.OBJECT, [1, 2, 3], False, "List with values"),
|
||||
]
|
||||
|
||||
|
||||
def get_secret_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid secret values."""
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.SECRET, "", True, "Empty secret"),
|
||||
ValidationTestCase(SegmentType.SECRET, "secret", True, "Simple secret"),
|
||||
ValidationTestCase(SegmentType.SECRET, "api_key_123", True, "API key format"),
|
||||
ValidationTestCase(SegmentType.SECRET, "very_long_secret_key_with_special_chars!@#", True, "Complex secret"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.SECRET, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.SECRET, 3.14, False, "Float"),
|
||||
ValidationTestCase(SegmentType.SECRET, True, False, "Boolean"),
|
||||
ValidationTestCase(SegmentType.SECRET, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.SECRET, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.SECRET, {}, False, "Empty dict"),
|
||||
]
|
||||
|
||||
|
||||
def get_file_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid file values."""
|
||||
test_file = create_test_file()
|
||||
image_file = create_test_file(
|
||||
file_type=FileType.IMAGE, filename="image.jpg", extension=".jpg", mime_type="image/jpeg"
|
||||
)
|
||||
remote_file = create_test_file(
|
||||
transfer_method=FileTransferMethod.REMOTE_URL, filename="remote.pdf", extension=".pdf"
|
||||
)
|
||||
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.FILE, test_file, True, "Document file"),
|
||||
ValidationTestCase(SegmentType.FILE, image_file, True, "Image file"),
|
||||
ValidationTestCase(SegmentType.FILE, remote_file, True, "Remote file"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.FILE, "not a file", False, "String"),
|
||||
ValidationTestCase(SegmentType.FILE, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.FILE, {"filename": "test.txt"}, False, "Dict resembling file"),
|
||||
ValidationTestCase(SegmentType.FILE, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.FILE, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.FILE, True, False, "Boolean"),
|
||||
]
|
||||
|
||||
|
||||
def get_none_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid none values."""
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.NONE, None, True, "None value"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.NONE, "", False, "Empty string"),
|
||||
ValidationTestCase(SegmentType.NONE, 0, False, "Zero integer"),
|
||||
ValidationTestCase(SegmentType.NONE, 0.0, False, "Zero float"),
|
||||
ValidationTestCase(SegmentType.NONE, False, False, "False boolean"),
|
||||
ValidationTestCase(SegmentType.NONE, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.NONE, {}, False, "Empty dict"),
|
||||
ValidationTestCase(SegmentType.NONE, "null", False, "String 'null'"),
|
||||
]
|
||||
|
||||
|
||||
def get_array_any_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_ANY validation."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY,
|
||||
[1, "string", 3.14, {"key": "value"}, True],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Mixed types with NONE validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY,
|
||||
[1, "string", 3.14, {"key": "value"}, True],
|
||||
ArrayValidation.FIRST,
|
||||
True,
|
||||
"Mixed types with FIRST validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY,
|
||||
[1, "string", 3.14, {"key": "value"}, True],
|
||||
ArrayValidation.ALL,
|
||||
True,
|
||||
"Mixed types with ALL validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY, [None, None, None], ArrayValidation.ALL, True, "All None values"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_string_validation_none_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_STRING validation with NONE strategy."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
["hello", "world"],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Valid strings with NONE validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
[123, 456],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Invalid elements with NONE validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
["valid", 123, True],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Mixed types with NONE validation",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_string_validation_first_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_STRING validation with FIRST strategy."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["hello", "world"], ArrayValidation.FIRST, True, "All valid strings"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
["hello", 123, True],
|
||||
ArrayValidation.FIRST,
|
||||
True,
|
||||
"First valid, others invalid",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
[123, "hello", "world"],
|
||||
ArrayValidation.FIRST,
|
||||
False,
|
||||
"First invalid, others valid",
|
||||
),
|
||||
ArrayValidationTestCase(SegmentType.ARRAY_STRING, [None, "hello"], ArrayValidation.FIRST, False, "First None"),
|
||||
]
|
||||
|
||||
|
||||
def get_array_string_validation_all_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_STRING validation with ALL strategy."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["hello", "world", "test"], ArrayValidation.ALL, True, "All valid strings"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["hello", 123, "world"], ArrayValidation.ALL, False, "One invalid element"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, [123, 456, 789], ArrayValidation.ALL, False, "All invalid elements"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["valid", None, "also_valid"], ArrayValidation.ALL, False, "Contains None"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_number_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_NUMBER validation with different strategies."""
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [1, 2.5, 3], ArrayValidation.NONE, True, "Valid numbers with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, ["not", "numbers"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [42, "not a number"], ArrayValidation.FIRST, True, "First valid number"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, ["not a number", 42], ArrayValidation.FIRST, False, "First invalid"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [3.14, 2.71, 1.41], ArrayValidation.FIRST, True, "All valid floats"
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [1, 2, 3, 4.5], ArrayValidation.ALL, True, "All valid numbers"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [1, "invalid", 3], ArrayValidation.ALL, False, "One invalid element"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
[float("inf"), float("-inf"), float("nan")],
|
||||
ArrayValidation.ALL,
|
||||
True,
|
||||
"Special float values",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_object_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_OBJECT validation with different strategies."""
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT, [{}, {"key": "value"}], ArrayValidation.NONE, True, "Valid objects with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT, ["not", "objects"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
[{"valid": "object"}, "not an object"],
|
||||
ArrayValidation.FIRST,
|
||||
True,
|
||||
"First valid object",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
["not an object", {"valid": "object"}],
|
||||
ArrayValidation.FIRST,
|
||||
False,
|
||||
"First invalid",
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
[{}, {"a": 1}, {"nested": {"key": "value"}}],
|
||||
ArrayValidation.ALL,
|
||||
True,
|
||||
"All valid objects",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
[{"valid": "object"}, "invalid", {"another": "object"}],
|
||||
ArrayValidation.ALL,
|
||||
False,
|
||||
"One invalid element",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_file_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_FILE validation with different strategies."""
|
||||
file1 = create_test_file(filename="file1.txt")
|
||||
file2 = create_test_file(filename="file2.txt")
|
||||
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.NONE, True, "Valid files with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, ["not", "files"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, [file1, "not a file"], ArrayValidation.FIRST, True, "First valid file"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, ["not a file", file1], ArrayValidation.FIRST, False, "First invalid"
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.ALL, True, "All valid files"),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, [file1, "invalid", file2], ArrayValidation.ALL, False, "One invalid element"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_boolean_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_BOOLEAN validation with different strategies."""
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, False, True], ArrayValidation.NONE, True, "Valid booleans with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [1, 0, "true"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, 1, 0], ArrayValidation.FIRST, True, "First valid boolean"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [1, True, False], ArrayValidation.FIRST, False, "First invalid (integer 1)"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [0, True, False], ArrayValidation.FIRST, False, "First invalid (integer 0)"
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, False, True, False], ArrayValidation.ALL, True, "All valid booleans"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, 1, False], ArrayValidation.ALL, False, "One invalid element (integer)"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
[True, "false", False],
|
||||
ArrayValidation.ALL,
|
||||
False,
|
||||
"One invalid element (string)",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestSegmentTypeIsValid:
|
||||
"""Test suite for SegmentType.is_valid method covering all non-array types."""
|
||||
|
||||
@pytest.mark.parametrize("case", get_boolean_cases(), ids=lambda case: case.description)
|
||||
def test_boolean_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_number_cases(), ids=lambda case: case.description)
|
||||
def test_number_validation(self, case: ValidationTestCase):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_string_cases(), ids=lambda case: case.description)
|
||||
def test_string_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_object_cases(), ids=lambda case: case.description)
|
||||
def test_object_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_secret_cases(), ids=lambda case: case.description)
|
||||
def test_secret_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_file_cases(), ids=lambda case: case.description)
|
||||
def test_file_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_none_cases(), ids=lambda case: case.description)
|
||||
def test_none_validation_valid_cases(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
def test_unsupported_segment_type_raises_assertion_error(self):
|
||||
"""Test that unsupported SegmentType values raise AssertionError."""
|
||||
# GROUP is not handled in is_valid method
|
||||
with pytest.raises(AssertionError, match="this statement should be unreachable"):
|
||||
SegmentType.GROUP.is_valid("any value")
|
||||
|
||||
|
||||
class TestSegmentTypeArrayValidation:
|
||||
"""Test suite for SegmentType._validate_array method and array type validation."""
|
||||
|
||||
def test_array_validation_non_list_values(self):
|
||||
"""Test that non-list values return False for all array types."""
|
||||
array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
|
||||
non_list_values = [
|
||||
"not a list",
|
||||
123,
|
||||
3.14,
|
||||
True,
|
||||
None,
|
||||
{"key": "value"},
|
||||
create_test_file(),
|
||||
]
|
||||
|
||||
for array_type in array_types:
|
||||
for value in non_list_values:
|
||||
assert array_type.is_valid(value) is False, f"{array_type} should reject {type(value).__name__}"
|
||||
|
||||
def test_empty_array_validation(self):
|
||||
"""Test that empty arrays are valid for all array types regardless of validation strategy."""
|
||||
array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
|
||||
validation_strategies = [ArrayValidation.NONE, ArrayValidation.FIRST, ArrayValidation.ALL]
|
||||
|
||||
for array_type in array_types:
|
||||
for strategy in validation_strategies:
|
||||
assert array_type.is_valid([], strategy) is True, (
|
||||
f"{array_type} should accept empty array with {strategy}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_any_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_any_validation(self, case):
|
||||
"""Test ARRAY_ANY validation accepts any list regardless of content."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_string_validation_none_cases(), ids=lambda case: case.description)
|
||||
def test_array_string_validation_with_none_strategy(self, case):
|
||||
"""Test ARRAY_STRING validation with NONE strategy (no element validation)."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_string_validation_first_cases(), ids=lambda case: case.description)
|
||||
def test_array_string_validation_with_first_strategy(self, case):
|
||||
"""Test ARRAY_STRING validation with FIRST strategy (validate first element only)."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_string_validation_all_cases(), ids=lambda case: case.description)
|
||||
def test_array_string_validation_with_all_strategy(self, case):
|
||||
"""Test ARRAY_STRING validation with ALL strategy (validate all elements)."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_number_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_number_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_NUMBER validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_object_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_object_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_OBJECT validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_file_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_file_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_FILE validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_boolean_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_boolean_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_BOOLEAN validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
def test_default_array_validation_strategy(self):
|
||||
"""Test that default array validation strategy is FIRST."""
|
||||
# When no array_validation parameter is provided, it should default to FIRST
|
||||
assert SegmentType.ARRAY_STRING.is_valid(["valid", 123]) is False # First element valid
|
||||
assert SegmentType.ARRAY_STRING.is_valid([123, "valid"]) is False # First element invalid
|
||||
|
||||
assert SegmentType.ARRAY_NUMBER.is_valid([42, "invalid"]) is False # First element valid
|
||||
assert SegmentType.ARRAY_NUMBER.is_valid(["invalid", 42]) is False # First element invalid
|
||||
|
||||
def test_array_validation_edge_cases(self):
|
||||
"""Test edge cases for array validation."""
|
||||
# Test with nested arrays (should be invalid for specific array types)
|
||||
nested_array = [["nested", "array"], ["another", "nested"]]
|
||||
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.FIRST) is False
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.ALL) is False
|
||||
assert SegmentType.ARRAY_ANY.is_valid(nested_array, ArrayValidation.ALL) is True
|
||||
|
||||
# Test with very large arrays (performance consideration)
|
||||
large_valid_array = ["string"] * 1000
|
||||
large_mixed_array = ["string"] * 999 + [123] # Last element invalid
|
||||
|
||||
assert SegmentType.ARRAY_STRING.is_valid(large_valid_array, ArrayValidation.ALL) is True
|
||||
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.ALL) is False
|
||||
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.FIRST) is True
|
||||
|
||||
|
||||
class TestSegmentTypeValidationIntegration:
|
||||
"""Integration tests for SegmentType validation covering interactions between methods."""
|
||||
|
||||
def test_non_array_types_ignore_array_validation_parameter(self):
|
||||
"""Test that non-array types ignore the array_validation parameter."""
|
||||
non_array_types = [
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
]
|
||||
|
||||
for segment_type in non_array_types:
|
||||
# Create appropriate valid value for each type
|
||||
valid_value: Any
|
||||
if segment_type == SegmentType.STRING:
|
||||
valid_value = "test"
|
||||
elif segment_type == SegmentType.NUMBER:
|
||||
valid_value = 42
|
||||
elif segment_type == SegmentType.BOOLEAN:
|
||||
valid_value = True
|
||||
elif segment_type == SegmentType.OBJECT:
|
||||
valid_value = {"key": "value"}
|
||||
elif segment_type == SegmentType.SECRET:
|
||||
valid_value = "secret"
|
||||
elif segment_type == SegmentType.FILE:
|
||||
valid_value = create_test_file()
|
||||
elif segment_type == SegmentType.NONE:
|
||||
valid_value = None
|
||||
else:
|
||||
continue # Skip unsupported types
|
||||
|
||||
# All array validation strategies should give the same result
|
||||
result_none = segment_type.is_valid(valid_value, ArrayValidation.NONE)
|
||||
result_first = segment_type.is_valid(valid_value, ArrayValidation.FIRST)
|
||||
result_all = segment_type.is_valid(valid_value, ArrayValidation.ALL)
|
||||
|
||||
assert result_none == result_first == result_all == True, (
|
||||
f"{segment_type} should ignore array_validation parameter"
|
||||
)
|
||||
|
||||
def test_comprehensive_type_coverage(self):
|
||||
"""Test that all SegmentType enum values are covered in validation tests."""
|
||||
all_segment_types = set(SegmentType)
|
||||
|
||||
# Types that should be handled by is_valid method
|
||||
handled_types = {
|
||||
# Non-array types
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
# Array types
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
}
|
||||
|
||||
# Types that are not handled by is_valid (should raise AssertionError)
|
||||
unhandled_types = {
|
||||
SegmentType.GROUP,
|
||||
SegmentType.INTEGER, # Handled by NUMBER validation logic
|
||||
SegmentType.FLOAT, # Handled by NUMBER validation logic
|
||||
}
|
||||
|
||||
# Verify all types are accounted for
|
||||
assert handled_types | unhandled_types == all_segment_types, "All SegmentType values should be categorized"
|
||||
|
||||
# Test that handled types work correctly
|
||||
for segment_type in handled_types:
|
||||
if segment_type.is_array_type():
|
||||
# Test with empty array (should always be valid)
|
||||
assert segment_type.is_valid([]) is True, f"{segment_type} should accept empty array"
|
||||
else:
|
||||
# Test with appropriate valid value
|
||||
if segment_type == SegmentType.STRING:
|
||||
assert segment_type.is_valid("test") is True
|
||||
elif segment_type == SegmentType.NUMBER:
|
||||
assert segment_type.is_valid(42) is True
|
||||
elif segment_type == SegmentType.BOOLEAN:
|
||||
assert segment_type.is_valid(True) is True
|
||||
elif segment_type == SegmentType.OBJECT:
|
||||
assert segment_type.is_valid({}) is True
|
||||
elif segment_type == SegmentType.SECRET:
|
||||
assert segment_type.is_valid("secret") is True
|
||||
elif segment_type == SegmentType.FILE:
|
||||
assert segment_type.is_valid(create_test_file()) is True
|
||||
elif segment_type == SegmentType.NONE:
|
||||
assert segment_type.is_valid(None) is True
|
||||
|
||||
def test_boolean_vs_integer_type_distinction(self):
|
||||
"""Test the important distinction between boolean and integer types in validation."""
|
||||
# This tests the comment in the code about bool being a subclass of int
|
||||
|
||||
# Boolean type should only accept actual booleans, not integers
|
||||
assert SegmentType.BOOLEAN.is_valid(True) is True
|
||||
assert SegmentType.BOOLEAN.is_valid(False) is True
|
||||
assert SegmentType.BOOLEAN.is_valid(1) is False # Integer 1, not boolean
|
||||
assert SegmentType.BOOLEAN.is_valid(0) is False # Integer 0, not boolean
|
||||
|
||||
# Number type should accept both integers and floats, including booleans (since bool is subclass of int)
|
||||
assert SegmentType.NUMBER.is_valid(42) is True
|
||||
assert SegmentType.NUMBER.is_valid(3.14) is True
|
||||
assert SegmentType.NUMBER.is_valid(True) is True # bool is subclass of int
|
||||
assert SegmentType.NUMBER.is_valid(False) is True # bool is subclass of int
|
||||
|
||||
def test_array_validation_recursive_behavior(self):
|
||||
"""Test that array validation correctly handles recursive validation calls."""
|
||||
# When validating array elements, _validate_array calls is_valid recursively
|
||||
# with ArrayValidation.NONE to avoid infinite recursion
|
||||
|
||||
# Test nested validation doesn't cause issues
|
||||
nested_arrays = [["inner", "array"], ["another", "inner"]]
|
||||
|
||||
# ARRAY_ANY should accept nested arrays
|
||||
assert SegmentType.ARRAY_ANY.is_valid(nested_arrays, ArrayValidation.ALL) is True
|
||||
|
||||
# ARRAY_STRING should reject nested arrays (first element is not a string)
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.FIRST) is False
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.ALL) is False
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterConfig
|
||||
|
||||
|
||||
class TestParameterConfig:
|
||||
def test_select_type(self):
|
||||
data = {
|
||||
"name": "yes_or_no",
|
||||
"type": "select",
|
||||
"options": ["yes", "no"],
|
||||
"description": "a simple select made of `yes` and `no`",
|
||||
"required": True,
|
||||
}
|
||||
|
||||
pc = ParameterConfig.model_validate(data)
|
||||
assert pc.type == SegmentType.STRING
|
||||
assert pc.options == data["options"]
|
||||
|
||||
def test_validate_bool_type(self):
|
||||
data = {
|
||||
"name": "boolean",
|
||||
"type": "bool",
|
||||
"description": "a simple boolean parameter",
|
||||
"required": True,
|
||||
}
|
||||
pc = ParameterConfig.model_validate(data)
|
||||
assert pc.type == SegmentType.BOOLEAN
|
||||
|
|
@ -0,0 +1,567 @@
|
|||
"""
|
||||
Test cases for ParameterExtractorNode._validate_result and _transform_result methods.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities import LLMMode
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData
|
||||
from core.workflow.nodes.parameter_extractor.exc import (
|
||||
InvalidNumberOfParametersError,
|
||||
InvalidSelectValueError,
|
||||
InvalidValueTypeError,
|
||||
RequiredParameterMissingError,
|
||||
)
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidTestCase:
|
||||
"""Test case data for valid scenarios."""
|
||||
|
||||
name: str
|
||||
parameters: list[ParameterConfig]
|
||||
result: dict[str, Any]
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorTestCase:
|
||||
"""Test case data for error scenarios."""
|
||||
|
||||
name: str
|
||||
parameters: list[ParameterConfig]
|
||||
result: dict[str, Any]
|
||||
expected_exception: type[Exception]
|
||||
expected_message: str
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformTestCase:
|
||||
"""Test case data for transformation scenarios."""
|
||||
|
||||
name: str
|
||||
parameters: list[ParameterConfig]
|
||||
input_result: dict[str, Any]
|
||||
expected_result: dict[str, Any]
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
class TestParameterExtractorNodeMethods:
|
||||
"""Test helper class that provides access to the methods under test."""
|
||||
|
||||
def validate_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Wrapper to call _validate_result method."""
|
||||
node = ParameterExtractorNode.__new__(ParameterExtractorNode)
|
||||
return node._validate_result(data=data, result=result)
|
||||
|
||||
def transform_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Wrapper to call _transform_result method."""
|
||||
node = ParameterExtractorNode.__new__(ParameterExtractorNode)
|
||||
return node._transform_result(data=data, result=result)
|
||||
|
||||
|
||||
class TestValidateResult:
|
||||
"""Test cases for _validate_result method."""
|
||||
|
||||
@staticmethod
|
||||
def get_valid_test_cases() -> list[ValidTestCase]:
|
||||
"""Get test cases that should pass validation."""
|
||||
return [
|
||||
ValidTestCase(
|
||||
name="single_string_parameter",
|
||||
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
|
||||
result={"name": "John"},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_number_parameter_int",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
result={"age": 25},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_number_parameter_float",
|
||||
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
|
||||
result={"price": 19.99},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_bool_parameter_true",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
result={"active": True},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_bool_parameter_true",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
result={"active": True},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_bool_parameter_false",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
result={"active": False},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="select_parameter_valid_option",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # pyright: ignore[reportArgumentType]
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
)
|
||||
],
|
||||
result={"status": "active"},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="array_string_parameter",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
result={"tags": ["tag1", "tag2", "tag3"]},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="array_number_parameter",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
result={"scores": [85, 92.5, 78]},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="array_object_parameter",
|
||||
parameters=[
|
||||
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
|
||||
],
|
||||
result={"items": [{"name": "item1"}, {"name": "item2"}]},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="multiple_parameters",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True),
|
||||
],
|
||||
result={"name": "John", "age": 25, "active": True},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="optional_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="nickname", type=SegmentType.STRING, description="Nickname", required=False),
|
||||
],
|
||||
result={"name": "John", "nickname": "Johnny"},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="empty_array_parameter",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
result={"tags": []},
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_error_test_cases() -> list[ErrorTestCase]:
|
||||
"""Get test cases that should raise exceptions."""
|
||||
return [
|
||||
ErrorTestCase(
|
||||
name="invalid_number_of_parameters_too_few",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
|
||||
],
|
||||
result={"name": "John"},
|
||||
expected_exception=InvalidNumberOfParametersError,
|
||||
expected_message="Invalid number of parameters",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_number_of_parameters_too_many",
|
||||
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
|
||||
result={"name": "John", "age": 25},
|
||||
expected_exception=InvalidNumberOfParametersError,
|
||||
expected_message="Invalid number of parameters",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_string_value_none",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
],
|
||||
result={"name": None}, # Parameter present but None value, will trigger type check first
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message="Invalid value for parameter name, expected segment type: string, actual_type: none",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_select_value",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
)
|
||||
],
|
||||
result={"status": "pending"},
|
||||
expected_exception=InvalidSelectValueError,
|
||||
expected_message="Invalid `select` value for parameter status",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_number_value_string",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
result={"age": "twenty-five"},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message="Invalid value for parameter age, expected segment type: number, actual_type: string",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_bool_value_string",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
result={"active": "yes"},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter active, expected segment type: boolean, actual_type: string"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_string_value_number",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="description", type=SegmentType.STRING, description="Description", required=True
|
||||
)
|
||||
],
|
||||
result={"description": 123},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter description, expected segment type: string, actual_type: integer"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_array_value_not_list",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
result={"tags": "tag1,tag2,tag3"},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter tags, expected segment type: array[string], actual_type: string"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_array_number_wrong_element_type",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
result={"scores": [85, "ninety-two", 78]},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter scores, expected segment type: array[number], actual_type: array[any]"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_array_string_wrong_element_type",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
result={"tags": ["tag1", 123, "tag3"]},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter tags, expected segment type: array[string], actual_type: array[any]"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_array_object_wrong_element_type",
|
||||
parameters=[
|
||||
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
|
||||
],
|
||||
result={"items": [{"name": "item1"}, "item2"]},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter items, expected segment type: array[object], actual_type: array[any]"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="required_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=False),
|
||||
],
|
||||
result={"age": 25, "other": "value"}, # Missing required 'name' parameter, but has correct count
|
||||
expected_exception=RequiredParameterMissingError,
|
||||
expected_message="Parameter name is required",
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("test_case", get_valid_test_cases(), ids=ValidTestCase.get_name)
|
||||
def test_validate_result_valid_cases(self, test_case):
|
||||
"""Test _validate_result with valid inputs."""
|
||||
helper = TestParameterExtractorNodeMethods()
|
||||
|
||||
node_data = ParameterExtractorNodeData(
|
||||
title="Test Node",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
query=["test_query"],
|
||||
parameters=test_case.parameters,
|
||||
reasoning_mode="function_call",
|
||||
vision=VisionConfig(),
|
||||
)
|
||||
|
||||
result = helper.validate_result(data=node_data, result=test_case.result)
|
||||
assert result == test_case.result, f"Failed for case: {test_case.name}"
|
||||
|
||||
@pytest.mark.parametrize("test_case", get_error_test_cases(), ids=ErrorTestCase.get_name)
|
||||
def test_validate_result_error_cases(self, test_case):
|
||||
"""Test _validate_result with invalid inputs that should raise exceptions."""
|
||||
helper = TestParameterExtractorNodeMethods()
|
||||
|
||||
node_data = ParameterExtractorNodeData(
|
||||
title="Test Node",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
query=["test_query"],
|
||||
parameters=test_case.parameters,
|
||||
reasoning_mode="function_call",
|
||||
vision=VisionConfig(),
|
||||
)
|
||||
|
||||
with pytest.raises(test_case.expected_exception) as exc_info:
|
||||
helper.validate_result(data=node_data, result=test_case.result)
|
||||
|
||||
assert test_case.expected_message in str(exc_info.value), f"Failed for case: {test_case.name}"
|
||||
|
||||
|
||||
class TestTransformResult:
|
||||
"""Test cases for _transform_result method."""
|
||||
|
||||
@staticmethod
|
||||
def get_transform_test_cases() -> list[TransformTestCase]:
|
||||
"""Get test cases for result transformation."""
|
||||
return [
|
||||
# String parameter transformation
|
||||
TransformTestCase(
|
||||
name="string_parameter_present",
|
||||
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
|
||||
input_result={"name": "John"},
|
||||
expected_result={"name": "John"},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="string_parameter_missing",
|
||||
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
|
||||
input_result={},
|
||||
expected_result={"name": ""},
|
||||
),
|
||||
# Number parameter transformation
|
||||
TransformTestCase(
|
||||
name="number_parameter_int_present",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={"age": 25},
|
||||
expected_result={"age": 25},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_float_present",
|
||||
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
|
||||
input_result={"price": 19.99},
|
||||
expected_result={"price": 19.99},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_missing",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={},
|
||||
expected_result={"age": 0},
|
||||
),
|
||||
# Bool parameter transformation
|
||||
TransformTestCase(
|
||||
name="bool_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"active": False},
|
||||
),
|
||||
# Select parameter transformation
|
||||
TransformTestCase(
|
||||
name="select_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
)
|
||||
],
|
||||
input_result={"status": "active"},
|
||||
expected_result={"status": "active"},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="select_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"status": ""},
|
||||
),
|
||||
# Array parameter transformation - present cases
|
||||
TransformTestCase(
|
||||
name="array_string_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
input_result={"tags": ["tag1", "tag2"]},
|
||||
expected_result={
|
||||
"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=["tag1", "tag2"])
|
||||
},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_number_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
input_result={"scores": [85, 92.5]},
|
||||
expected_result={
|
||||
"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5])
|
||||
},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_number_parameter_with_string_conversion",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
input_result={"scores": [85, "92.5", "78"]},
|
||||
expected_result={
|
||||
"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5, 78])
|
||||
},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_object_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
|
||||
],
|
||||
input_result={"items": [{"name": "item1"}, {"name": "item2"}]},
|
||||
expected_result={
|
||||
"items": build_segment_with_type(
|
||||
segment_type=SegmentType.ARRAY_OBJECT, value=[{"name": "item1"}, {"name": "item2"}]
|
||||
)
|
||||
},
|
||||
),
|
||||
# Array parameter transformation - missing cases
|
||||
TransformTestCase(
|
||||
name="array_string_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[])},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_number_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[])},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_object_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"items": build_segment_with_type(segment_type=SegmentType.ARRAY_OBJECT, value=[])},
|
||||
),
|
||||
# Multiple parameters transformation
|
||||
TransformTestCase(
|
||||
name="multiple_parameters_mixed",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True),
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True),
|
||||
],
|
||||
input_result={"name": "John", "age": 25},
|
||||
expected_result={
|
||||
"name": "John",
|
||||
"age": 25,
|
||||
"active": False,
|
||||
"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[]),
|
||||
},
|
||||
),
|
||||
# Number parameter transformation with string conversion
|
||||
TransformTestCase(
|
||||
name="number_parameter_string_to_float",
|
||||
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
|
||||
input_result={"price": "19.99"},
|
||||
expected_result={"price": 19.99}, # String not converted, falls back to default
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_string_to_int",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={"age": "25"},
|
||||
expected_result={"age": 25}, # String not converted, falls back to default
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_invalid_string",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={"age": "invalid_number"},
|
||||
expected_result={"age": 0}, # Invalid string conversion fails, falls back to default
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_non_string_non_number",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={"age": ["not_a_number"]}, # Non-string, non-number value
|
||||
expected_result={"age": 0}, # Falls back to default
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_number_parameter_with_invalid_string_conversion",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
input_result={"scores": [85, "invalid", "78"]},
|
||||
expected_result={
|
||||
"scores": build_segment_with_type(
|
||||
segment_type=SegmentType.ARRAY_NUMBER, value=[85, 78]
|
||||
) # Invalid string skipped
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("test_case", get_transform_test_cases(), ids=TransformTestCase.get_name)
|
||||
def test_transform_result_cases(self, test_case):
|
||||
"""Test _transform_result with various inputs."""
|
||||
helper = TestParameterExtractorNodeMethods()
|
||||
|
||||
node_data = ParameterExtractorNodeData(
|
||||
title="Test Node",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
query=["test_query"],
|
||||
parameters=test_case.parameters,
|
||||
reasoning_mode="function_call",
|
||||
vision=VisionConfig(),
|
||||
)
|
||||
|
||||
result = helper.transform_result(data=node_data, result=test_case.input_result)
|
||||
assert result == test_case.expected_result, (
|
||||
f"Failed for case: {test_case.name}. Expected: {test_case.expected_result}, Got: {result}"
|
||||
)
|
||||
|
|
@ -2,6 +2,8 @@ import time
|
|||
import uuid
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
|
|
@ -272,3 +274,220 @@ def test_array_file_contains_file_name():
|
|||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is True
|
||||
|
||||
|
||||
def _get_test_conditions() -> list:
|
||||
conditions = [
|
||||
# Test boolean "is" operator
|
||||
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "true"},
|
||||
# Test boolean "is not" operator
|
||||
{"comparison_operator": "is not", "variable_selector": ["start", "bool_false"], "value": "true"},
|
||||
# Test boolean "=" operator
|
||||
{"comparison_operator": "=", "variable_selector": ["start", "bool_true"], "value": "1"},
|
||||
# Test boolean "≠" operator
|
||||
{"comparison_operator": "≠", "variable_selector": ["start", "bool_false"], "value": "1"},
|
||||
# Test boolean "not null" operator
|
||||
{"comparison_operator": "not null", "variable_selector": ["start", "bool_true"]},
|
||||
# Test boolean array "contains" operator
|
||||
{"comparison_operator": "contains", "variable_selector": ["start", "bool_array"], "value": "true"},
|
||||
# Test boolean "in" operator
|
||||
{
|
||||
"comparison_operator": "in",
|
||||
"variable_selector": ["start", "bool_true"],
|
||||
"value": ["true", "false"],
|
||||
},
|
||||
]
|
||||
return [Condition.model_validate(i) for i in conditions]
|
||||
|
||||
|
||||
def _get_condition_test_id(c: Condition):
|
||||
return c.comparison_operator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("condition", _get_test_conditions(), ids=_get_condition_test_id)
|
||||
def test_execute_if_else_boolean_conditions(condition: Condition):
|
||||
"""Test IfElseNode with boolean conditions using various operators"""
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
pool.add(["start", "bool_false"], False)
|
||||
pool.add(["start", "bool_array"], [True, False, True])
|
||||
pool.add(["start", "mixed_array"], [True, "false", 1, 0])
|
||||
|
||||
node_data = {
|
||||
"title": "Boolean Test",
|
||||
"type": "if-else",
|
||||
"logical_operator": "and",
|
||||
"conditions": [condition.model_dump()],
|
||||
}
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={"id": "if-else", "data": node_data},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is True
|
||||
|
||||
|
||||
def test_execute_if_else_boolean_false_conditions():
|
||||
"""Test IfElseNode with boolean conditions that should evaluate to false"""
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
pool.add(["start", "bool_false"], False)
|
||||
pool.add(["start", "bool_array"], [True, False, True])
|
||||
|
||||
node_data = {
|
||||
"title": "Boolean False Test",
|
||||
"type": "if-else",
|
||||
"logical_operator": "or",
|
||||
"conditions": [
|
||||
# Test boolean "is" operator (should be false)
|
||||
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "false"},
|
||||
# Test boolean "=" operator (should be false)
|
||||
{"comparison_operator": "=", "variable_selector": ["start", "bool_false"], "value": "1"},
|
||||
# Test boolean "not contains" operator (should be false)
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "bool_array"],
|
||||
"value": "true",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "if-else",
|
||||
"data": node_data,
|
||||
},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is False
|
||||
|
||||
|
||||
def test_execute_if_else_boolean_cases_structure():
|
||||
"""Test IfElseNode with boolean conditions using the new cases structure"""
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
pool.add(["start", "bool_false"], False)
|
||||
|
||||
node_data = {
|
||||
"title": "Boolean Cases Test",
|
||||
"type": "if-else",
|
||||
"cases": [
|
||||
{
|
||||
"case_id": "true",
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "is",
|
||||
"variable_selector": ["start", "bool_true"],
|
||||
"value": "true",
|
||||
},
|
||||
{
|
||||
"comparison_operator": "is not",
|
||||
"variable_selector": ["start", "bool_false"],
|
||||
"value": "true",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={"id": "if-else", "data": node_data},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is True
|
||||
assert result.outputs["selected_case_id"] == "true"
|
||||
|
|
|
|||
|
|
@ -11,7 +11,8 @@ from core.workflow.nodes.list_operator.entities import (
|
|||
FilterCondition,
|
||||
Limit,
|
||||
ListOperatorNodeData,
|
||||
OrderBy,
|
||||
Order,
|
||||
OrderByConfig,
|
||||
)
|
||||
from core.workflow.nodes.list_operator.exc import InvalidKeyError
|
||||
from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func
|
||||
|
|
@ -27,7 +28,7 @@ def list_operator_node():
|
|||
FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT])
|
||||
],
|
||||
),
|
||||
"order_by": OrderBy(enabled=False, value="asc"),
|
||||
"order_by": OrderByConfig(enabled=False, value=Order.ASC),
|
||||
"limit": Limit(enabled=False, size=0),
|
||||
"extract_by": ExtractConfig(enabled=False, serial="1"),
|
||||
"title": "Test Title",
|
||||
|
|
|
|||
|
|
@ -24,16 +24,18 @@ from core.variables.segments import (
|
|||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.types import SegmentType
|
||||
from factories import variable_factory
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type
|
||||
|
||||
|
||||
def test_string_variable():
|
||||
|
|
@ -139,6 +141,26 @@ def test_array_number_variable():
|
|||
assert isinstance(variable.value[1], float)
|
||||
|
||||
|
||||
def test_build_segment_scalar_values():
|
||||
@dataclass
|
||||
class TestCase:
|
||||
value: Any
|
||||
expected: Segment
|
||||
description: str
|
||||
|
||||
cases = [
|
||||
TestCase(
|
||||
value=True,
|
||||
expected=BooleanSegment(value=True),
|
||||
description="build_segment with boolean should yield BooleanSegment",
|
||||
)
|
||||
]
|
||||
|
||||
for idx, c in enumerate(cases, 1):
|
||||
seg = build_segment(c.value)
|
||||
assert seg == c.expected, f"Test case {idx} failed: {c.description}"
|
||||
|
||||
|
||||
def test_array_object_variable():
|
||||
mapping = {
|
||||
"id": str(uuid4()),
|
||||
|
|
@ -847,15 +869,22 @@ class TestBuildSegmentValueErrors:
|
|||
f"but got: {error_message}"
|
||||
)
|
||||
|
||||
def test_build_segment_boolean_type_note(self):
|
||||
"""Note: Boolean values are actually handled as integers in Python, so they don't raise ValueError."""
|
||||
# Boolean values in Python are subclasses of int, so they get processed as integers
|
||||
# True becomes IntegerSegment(value=1) and False becomes IntegerSegment(value=0)
|
||||
def test_build_segment_boolean_type(self):
|
||||
"""Test that Boolean values are correctly handled as boolean type, not integers."""
|
||||
# Boolean values should now be processed as BooleanSegment, not IntegerSegment
|
||||
# This is because the bool check now comes before the int check in build_segment
|
||||
true_segment = variable_factory.build_segment(True)
|
||||
false_segment = variable_factory.build_segment(False)
|
||||
|
||||
# Verify they are processed as integers, not as errors
|
||||
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
|
||||
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
|
||||
assert true_segment.value_type == SegmentType.INTEGER
|
||||
assert false_segment.value_type == SegmentType.INTEGER
|
||||
# Verify they are processed as booleans, not integers
|
||||
assert true_segment.value is True, "Test case 1 (boolean_true): Expected True to be processed as boolean True"
|
||||
assert false_segment.value is False, (
|
||||
"Test case 2 (boolean_false): Expected False to be processed as boolean False"
|
||||
)
|
||||
assert true_segment.value_type == SegmentType.BOOLEAN
|
||||
assert false_segment.value_type == SegmentType.BOOLEAN
|
||||
|
||||
# Test array of booleans
|
||||
bool_array_segment = variable_factory.build_segment([True, False, True])
|
||||
assert bool_array_segment.value_type == SegmentType.ARRAY_BOOLEAN
|
||||
assert bool_array_segment.value == [True, False, True]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue