parent
1718b3c9ea
commit
e7a48d58e6
5 changed files with 9 additions and 9 deletions
|
|
@ -44,7 +44,7 @@ def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: l
|
|||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
)
|
||||
if return_type == str:
|
||||
if return_type is str:
|
||||
return_type = "Text"
|
||||
elif hasattr(return_type, "__name__"):
|
||||
return_type = return_type.__name__
|
||||
|
|
@ -85,7 +85,7 @@ def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: l
|
|||
)
|
||||
|
||||
base_classes = get_base_classes(return_type_instance)
|
||||
if return_type_instance == str:
|
||||
if return_type_instance is str:
|
||||
base_classes.append("Text")
|
||||
|
||||
for base_class in base_classes:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Any
|
|||
|
||||
|
||||
def format_type(type_: Any) -> str:
|
||||
if type_ == str:
|
||||
if type_ is str:
|
||||
type_ = "Text"
|
||||
elif hasattr(type_, "__name__"):
|
||||
type_ = type_.__name__
|
||||
|
|
|
|||
|
|
@ -36,8 +36,8 @@ def is_list_of_any(field: FieldInfo) -> bool:
|
|||
else:
|
||||
union_args = []
|
||||
|
||||
return field.annotation.__origin__ == list or any(
|
||||
arg.__origin__ == list for arg in union_args if hasattr(arg, "__origin__")
|
||||
return field.annotation.__origin__ is list or any(
|
||||
arg.__origin__ is list for arg in union_args if hasattr(arg, "__origin__")
|
||||
)
|
||||
except AttributeError:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ class Input(BaseModel):
|
|||
|
||||
@field_serializer("field_type")
|
||||
def serialize_field_type(self, value, _info):
|
||||
if value == float and self.range_spec is None:
|
||||
if value is float and self.range_spec is None:
|
||||
self.range_spec = RangeSpec()
|
||||
return value
|
||||
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class TestCreateInputSchema:
|
|||
input_instance = FileInput(name="file_field")
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["file_field"]
|
||||
assert field_info.annotation == str
|
||||
assert field_info.annotation is str
|
||||
|
||||
# Inputs with mixed required and optional fields are processed correctly
|
||||
def test_mixed_required_optional_fields_processing(self):
|
||||
|
|
@ -184,7 +184,7 @@ class TestCreateInputSchema:
|
|||
input_instance = StrInput(name="test_field", is_list=True)
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["test_field"]
|
||||
assert field_info.annotation == list[str]
|
||||
assert field_info.annotation == list[str] # type: ignore
|
||||
|
||||
# Converting FieldTypes to corresponding Python types
|
||||
def test_field_types_conversion(self):
|
||||
|
|
@ -194,7 +194,7 @@ class TestCreateInputSchema:
|
|||
input_instance = IntInput(name="int_field")
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["int_field"]
|
||||
assert field_info.annotation == int
|
||||
assert field_info.annotation is int # Use 'is' for type comparison
|
||||
|
||||
# Setting default values for non-required fields
|
||||
def test_default_values_for_non_required_fields(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue