refactor: input field handling and serialization (#3118)

* refactor: update code references to use _code instead of code

* refactor: add backwards compatible attributes to Component class

* refactor: update Component constructor to pass config params with underscore

Refactored the `Component` class in `component.py` to handle inputs and outputs. Added a new method `map_outputs` to map a list of outputs to the component. Also updated the `__init__` method to properly initialize the inputs, outputs, and other attributes. This change improves the flexibility and extensibility of the `Component` class.

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>

* refactor: change attribute to use underscore

* refactor: update CustomComponent initialization parameters

Refactored the `instantiate_class` function in `loading.py` to update the initialization parameters for the `CustomComponent` class. Changed the parameter names from `user_id`, `parameters`, `vertex`, and `tracing_service` to `_user_id`, `_parameters`, `_vertex`, and `_tracing_service` respectively. This change ensures consistency and improves code readability.

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>

* refactor: update BaseComponent to accept UUID for _user_id

Updated the `BaseComponent` class in `base_component.py` to accept a `UUID` type for the `_user_id` attribute. This change improves the type safety and ensures consistency with the usage of `_user_id` throughout the codebase.

* refactor: import nanoid with type annotation

The `nanoid` import in `component.py` has been updated to include a type annotation `# type: ignore`. This change ensures that the type checker ignores any errors related to the `nanoid` import.

* fix(custom_component.py): convert _user_id to string before passing to functions to ensure compatibility with function signatures

* feat(component.py): add method to set output types based on method return type to improve type checking and validation in custom components

* refactor: extract method to get method return type in CustomComponent

* refactor(utils.py): refactor code to use _user_id instead of user_id for consistency and clarity

perf(utils.py): optimize code by reusing cc_instance instead of calling get_component_instance multiple times

* refactor(utils.py, base.py): change parameter name 'add_name' to 'keep_name' for clarity and consistency in codebase

* [autofix.ci] apply automated fixes

* refactor: update schema.py to include Edge related typres

The `schema.py` file in the `src/backend/base/langflow/graph/edge` directory has been updated to include the `TargetHandle` and `SourceHandle` models. These models define the structure and attributes of the target and source handles used in the edge data. This change improves the clarity and consistency of the codebase.

* refactor: update BaseInputMixin to handle invalid field types gracefully

The `BaseInputMixin` class in `input_mixin.py` has been updated to handle invalid field types gracefully. Instead of raising an exception, it now returns `FieldTypes.OTHER` for any invalid field type. This change improves the robustness and reliability of the codebase.

* refactor: update file_types field alias in FileMixin

The `file_types` field in the `FileMixin` class of `input_mixin.py` has been updated to use the `alias` parameter instead of `serialization_alias`. This change ensures consistency and improves the clarity of the codebase.

* refactor(inputs): update field_type declarations in various input classes to use SerializableFieldTypes enum for better type safety and clarity

* refactor(inputs): convert dict to Message object in _validate_value method

* refactor(inputs): convert dict to Message object in _validate_value method

* refactor(inputs): update model_config in BaseInputMixin to enable populating by name

The `model_config` attribute in the `BaseInputMixin` class of `input_mixin.py` has been updated to include the `populate_by_name=True` parameter. This change allows the model configuration to be populated by name, improving the flexibility and usability of the codebase.

* refactor: update _extract_return_type method in CustomComponent to accept Any type

The _extract_return_type method in CustomComponent has been updated to accept the Any type as the return_type parameter. This change improves the flexibility and compatibility of the method, allowing it to handle a wider range of return types.

* refactor: update BaseComponent to use get_template_config method

Refactored the `BaseComponent` class in `base_component.py` to use the `get_template_config` method instead of duplicating the code. This change improves code readability and reduces redundancy.

* feat: add BaseModel class with model_config attribute

A new `BaseModel` class has been added to the `base_model.py` file. This class extends the `PydanticBaseModel` and includes a `model_config` attribute of type `ConfigDict`. This change improves the codebase by providing a base model with a configuration dictionary for models.

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>

* refactor: update langflow.graph.edge.schema.py

Refactor the `langflow.graph.edge.schema.py` file to include the `TargetHandle` and `SourceHandle` models. This change improves the clarity and consistency of the codebase.

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>

* refactor: update build_custom_component_template to use add_name instead of keep_name

Refactor the `build_custom_component_template` function in `utils.py` to use the `add_name` parameter instead of the deprecated `keep_name` parameter. This change ensures consistency with the updated method signature and improves code clarity.

* feat(component.py): add method to set output types based on method return type to improve type checking and validation in custom components (#3115)

* feat(component.py): add method to set output types based on method return type to improve type checking and validation in custom components

* refactor: extract method to get method return type in CustomComponent

* refactor: update _extract_return_type method in CustomComponent to accept Any type

The _extract_return_type method in CustomComponent has been updated to accept the Any type as the return_type parameter. This change improves the flexibility and compatibility of the method, allowing it to handle a wider range of return types.

* refactor: add _template_config property to BaseComponent

Add a new `_template_config` property to the `BaseComponent` class in `base_component.py`. This property is used to store the template configuration for the custom component. If the `_template_config` property is empty, it is populated by calling the `build_template_config` method. This change improves the efficiency of accessing the template configuration and ensures that it is only built when needed.

* refactor: add type checking for Output types in add_types method

Improve type checking in the `add_types` method of the `Output` class in `base.py`. Check if the `type_` already exists in the `types` list before adding it. This change ensures that duplicate types are not added to the list.

* update starter projects

* refactor: optimize imports in base.py

Optimize imports in the `base.py` file by removing unused imports and organizing the remaining imports. This change improves code readability and reduces unnecessary clutter.

* fix(base.py): fix condition to check if self.types is not None before checking if type_ is in self.types

* refactor: update build_custom_component_template to use add_name instead of keep_name

* feat: update logger warning message for invalid value type in StrInput

* refactor: update logger warning message for invalid value type in StrInput

* refactor: add unit tests for inputs in test_inputs.py

* refactor: update validation for IntInput value type

Improve the validation for the value type in the IntInput class. The updated code ensures that the value is of a valid type (int or float) and converts float values to integers. This change enhances the accuracy and reliability of the input validation process.

* refactor: improve validation for FloatInput value type

Improve the validation for the value type in the FloatInput class. The updated code ensures that the value is of a valid type (int or float) and converts integer values to floats. This change enhances the accuracy and reliability of the input validation process.

* fix(data.py): add validation to check if data is a dictionary before processing to prevent potential errors

* refactor: update test_inputs.py to include comprehensive unit tests for input classes

Add comprehensive unit tests for the input classes in the test_inputs.py file. This change ensures that the input classes are thoroughly tested and functioning correctly. The unit tests cover various scenarios and edge cases to validate the behavior of the input classes. This improvement enhances the reliability and stability of the codebase.

* refactor: remove invalid input tests for DataInput and FileInput in test_inputs.py

Cleaned up unnecessary test cases for invalid input types

* refactor: improve validation for IntInput and FloatInput value types

* refactor: add async version of get_file_content_dicts method in Message class

* refactor: update get_file_content_dicts method in Message class

The `get_file_content_dicts` method in the `Message` class has been updated to use an async version for improved performance. This change enhances the efficiency of retrieving file content dictionaries.

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-08-01 15:29:22 -03:00 committed by GitHub
commit b5180c4d70
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 332 additions and 39 deletions

View file

@ -27,9 +27,9 @@ SerializableFieldTypes = Annotated[FieldTypes, PlainSerializer(lambda v: v.value
# Base mixin for common input field attributes and methods
class BaseInputMixin(BaseModel, validate_assignment=True):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid", populate_by_name=True)
field_type: Optional[SerializableFieldTypes] = Field(default=FieldTypes.TEXT)
field_type: SerializableFieldTypes = Field(default=FieldTypes.TEXT)
required: bool = False
"""Specifies if the field is required. Defaults to False."""
@ -78,9 +78,10 @@ class BaseInputMixin(BaseModel, validate_assignment=True):
@field_validator("field_type", mode="before")
@classmethod
def validate_field_type(cls, v):
if v not in FieldTypes:
try:
return FieldTypes(v)
except ValueError:
return FieldTypes.OTHER
return FieldTypes(v)
@model_serializer(mode="wrap")
def serialize_model(self, handler):
@ -101,7 +102,7 @@ class MetadataTraceMixin(BaseModel):
# Mixin for input fields that can be listable
class ListableInputMixin(BaseModel):
is_list: bool = Field(default=False, serialization_alias="list")
is_list: bool = Field(default=False, alias="list")
# Specific mixin for fields needing database interaction
@ -112,7 +113,7 @@ class DatabaseLoadMixin(BaseModel):
# Specific mixin for fields needing file interaction
class FileMixin(BaseModel):
file_path: Optional[str] = Field(default="")
file_types: list[str] = Field(default=[], serialization_alias="fileTypes")
file_types: list[str] = Field(default=[], alias="fileTypes")
@field_validator("file_types")
@classmethod

View file

@ -1,6 +1,6 @@
import warnings
from typing import Any, AsyncIterator, Iterator, Optional, Union, get_args
from loguru import logger
from pydantic import Field, field_validator
from langflow.inputs.validators import CoalesceBool
@ -24,7 +24,7 @@ from .input_mixin import (
class TableInput(BaseInputMixin, MetadataTraceMixin, TableMixin, ListableInputMixin):
field_type: Optional[SerializableFieldTypes] = FieldTypes.TABLE
field_type: SerializableFieldTypes = FieldTypes.TABLE
is_list: bool = True
@field_validator("value")
@ -50,11 +50,11 @@ class HandleInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin):
Attributes:
input_types (list[str]): A list of input types.
field_type (Optional[SerializableFieldTypes]): The field type of the input.
field_type (SerializableFieldTypes): The field type of the input.
"""
input_types: list[str] = Field(default_factory=list)
field_type: Optional[SerializableFieldTypes] = FieldTypes.OTHER
field_type: SerializableFieldTypes = FieldTypes.OTHER
class DataInput(HandleInput, InputTraceMixin):
@ -69,12 +69,12 @@ class DataInput(HandleInput, InputTraceMixin):
class PromptInput(BaseInputMixin, ListableInputMixin, InputTraceMixin):
field_type: Optional[SerializableFieldTypes] = FieldTypes.PROMPT
field_type: SerializableFieldTypes = FieldTypes.PROMPT
# Applying mixins to a specific input type
class StrInput(BaseInputMixin, ListableInputMixin, DatabaseLoadMixin, MetadataTraceMixin):
field_type: Optional[SerializableFieldTypes] = FieldTypes.TEXT
field_type: SerializableFieldTypes = FieldTypes.TEXT
load_from_db: CoalesceBool = False
"""Defines if the field will allow the user to open a text editor. Default is False."""
@ -94,8 +94,13 @@ class StrInput(BaseInputMixin, ListableInputMixin, DatabaseLoadMixin, MetadataTr
ValueError: If the value is not of a valid type or if the input is missing a required key.
"""
if not isinstance(v, str) and v is not None:
# Keep the warning for now, but we should change it to an error
if _info.data.get("input_types") and v.__class__.__name__ not in _info.data.get("input_types"):
logger.warning(f"Invalid value type {type(v)}")
warnings.warn(
f"Invalid value type {type(v)} for input {_info.data.get('name')}. Expected types: {_info.data.get('input_types')}"
)
else:
warnings.warn(f"Invalid value type {type(v)} for input {_info.data.get('name')}.")
return v
@field_validator("value")
@ -129,6 +134,8 @@ class MessageInput(StrInput, InputTraceMixin):
@staticmethod
def _validate_value(v: Any, _info):
# If v is a instance of Message, then its fine
if isinstance(v, dict):
return Message(**v)
if isinstance(v, Message):
return v
if isinstance(v, str):
@ -164,6 +171,8 @@ class MessageTextInput(StrInput, MetadataTraceMixin, InputTraceMixin):
ValueError: If the value is not of a valid type or if the input is missing a required key.
"""
value: str | AsyncIterator | Iterator | None = None
if isinstance(v, dict):
v = Message(**v)
if isinstance(v, str):
value = v
elif isinstance(v, Message):
@ -190,11 +199,11 @@ class MultilineInput(MessageTextInput, MultilineMixin, InputTraceMixin):
Represents a multiline input field.
Attributes:
field_type (Optional[SerializableFieldTypes]): The type of the field. Defaults to FieldTypes.TEXT.
field_type (SerializableFieldTypes): The type of the field. Defaults to FieldTypes.TEXT.
multiline (CoalesceBool): Indicates whether the input field should support multiple lines. Defaults to True.
"""
field_type: Optional[SerializableFieldTypes] = FieldTypes.TEXT
field_type: SerializableFieldTypes = FieldTypes.TEXT
multiline: CoalesceBool = True
@ -203,11 +212,11 @@ class MultilineSecretInput(MessageTextInput, MultilineMixin, InputTraceMixin):
Represents a multiline input field.
Attributes:
field_type (Optional[SerializableFieldTypes]): The type of the field. Defaults to FieldTypes.TEXT.
field_type (SerializableFieldTypes): The type of the field. Defaults to FieldTypes.TEXT.
multiline (CoalesceBool): Indicates whether the input field should support multiple lines. Defaults to True.
"""
field_type: Optional[SerializableFieldTypes] = FieldTypes.PASSWORD
field_type: SerializableFieldTypes = FieldTypes.PASSWORD
multiline: CoalesceBool = True
password: CoalesceBool = Field(default=True)
@ -219,12 +228,12 @@ class SecretStrInput(BaseInputMixin, DatabaseLoadMixin):
This class inherits from `BaseInputMixin` and `DatabaseLoadMixin`.
Attributes:
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to `FieldTypes.PASSWORD`.
field_type (SerializableFieldTypes): The field type of the input. Defaults to `FieldTypes.PASSWORD`.
password (CoalesceBool): A boolean indicating whether the input is a password. Defaults to `True`.
input_types (list[str]): A list of input types associated with this input. Defaults to an empty list.
"""
field_type: Optional[SerializableFieldTypes] = FieldTypes.PASSWORD
field_type: SerializableFieldTypes = FieldTypes.PASSWORD
password: CoalesceBool = Field(default=True)
input_types: list[str] = ["Message"]
load_from_db: CoalesceBool = True
@ -275,10 +284,33 @@ class IntInput(BaseInputMixin, ListableInputMixin, RangeMixin, MetadataTraceMixi
It inherits from the `BaseInputMixin`, `ListableInputMixin`, and `RangeMixin` classes.
Attributes:
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.INTEGER.
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.INTEGER.
"""
field_type: Optional[SerializableFieldTypes] = FieldTypes.INTEGER
field_type: SerializableFieldTypes = FieldTypes.INTEGER
@field_validator("value")
@classmethod
def validate_value(cls, v: Any, _info):
"""
Validates the given value and returns the processed value.
Args:
v (Any): The value to be validated.
_info: Additional information about the input.
Returns:
The processed value.
Raises:
ValueError: If the value is not of a valid type or if the input is missing a required key.
"""
if v and not isinstance(v, (int, float)):
raise ValueError(f"Invalid value type {type(v)} for input {_info.data.get('name')}.")
if isinstance(v, float):
v = int(v)
return v
class FloatInput(BaseInputMixin, ListableInputMixin, RangeMixin, MetadataTraceMixin):
@ -289,10 +321,32 @@ class FloatInput(BaseInputMixin, ListableInputMixin, RangeMixin, MetadataTraceMi
It inherits from the `BaseInputMixin`, `ListableInputMixin`, and `RangeMixin` classes.
Attributes:
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.FLOAT.
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.FLOAT.
"""
field_type: Optional[SerializableFieldTypes] = FieldTypes.FLOAT
field_type: SerializableFieldTypes = FieldTypes.FLOAT
@field_validator("value")
@classmethod
def validate_value(cls, v: Any, _info):
"""
Validates the given value and returns the processed value.
Args:
v (Any): The value to be validated.
_info: Additional information about the input.
Returns:
The processed value.
Raises:
ValueError: If the value is not of a valid type or if the input is missing a required key.
"""
if v and not isinstance(v, (int, float)):
raise ValueError(f"Invalid value type {type(v)} for input {_info.data.get('name')}.")
if isinstance(v, int):
v = float(v)
return v
class BoolInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin):
@ -303,11 +357,11 @@ class BoolInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin):
It inherits from the `BaseInputMixin` and `ListableInputMixin` classes.
Attributes:
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.BOOLEAN.
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.BOOLEAN.
value (CoalesceBool): The value of the boolean input.
"""
field_type: Optional[SerializableFieldTypes] = FieldTypes.BOOLEAN
field_type: SerializableFieldTypes = FieldTypes.BOOLEAN
value: CoalesceBool = False
@ -319,11 +373,11 @@ class NestedDictInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin, In
It inherits from the `BaseInputMixin` and `ListableInputMixin` classes.
Attributes:
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.NESTED_DICT.
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.NESTED_DICT.
value (Optional[dict]): The value of the input. Defaults to an empty dictionary.
"""
field_type: Optional[SerializableFieldTypes] = FieldTypes.NESTED_DICT
field_type: SerializableFieldTypes = FieldTypes.NESTED_DICT
value: Optional[dict | Data] = {}
@ -335,11 +389,11 @@ class DictInput(BaseInputMixin, ListableInputMixin, InputTraceMixin):
It inherits from the `BaseInputMixin` and `ListableInputMixin` classes.
Attributes:
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.DICT.
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.DICT.
value (Optional[dict]): The value of the dictionary input. Defaults to an empty dictionary.
"""
field_type: Optional[SerializableFieldTypes] = FieldTypes.DICT
field_type: SerializableFieldTypes = FieldTypes.DICT
value: Optional[dict] = {}
@ -351,12 +405,12 @@ class DropdownInput(BaseInputMixin, DropDownMixin, MetadataTraceMixin):
It inherits from the `BaseInputMixin` and `DropDownMixin` classes.
Attributes:
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.TEXT.
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.TEXT.
options (Optional[Union[list[str], Callable]]): List of options for the field.
Default is None.
"""
field_type: Optional[SerializableFieldTypes] = FieldTypes.TEXT
field_type: SerializableFieldTypes = FieldTypes.TEXT
options: list[str] = Field(default_factory=list)
combobox: CoalesceBool = False
@ -369,12 +423,12 @@ class MultiselectInput(BaseInputMixin, ListableInputMixin, DropDownMixin, Metada
It inherits from the `BaseInputMixin`, `ListableInputMixin` and `DropDownMixin` classes.
Attributes:
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.TEXT.
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.TEXT.
options (Optional[Union[list[str], Callable]]): List of options for the field. Only used when is_list=True.
Default is None.
"""
field_type: Optional[SerializableFieldTypes] = FieldTypes.TEXT
field_type: SerializableFieldTypes = FieldTypes.TEXT
options: list[str] = Field(default_factory=list)
is_list: bool = Field(default=True, serialization_alias="list")
combobox: CoalesceBool = False
@ -399,10 +453,10 @@ class FileInput(BaseInputMixin, ListableInputMixin, FileMixin, MetadataTraceMixi
It inherits from the `BaseInputMixin`, `ListableInputMixin`, and `FileMixin` classes.
Attributes:
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.FILE.
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.FILE.
"""
field_type: Optional[SerializableFieldTypes] = FieldTypes.FILE
field_type: SerializableFieldTypes = FieldTypes.FILE
InputTypes = Union[

View file

@ -25,6 +25,8 @@ class Data(BaseModel):
@model_validator(mode="before")
def validate_data(cls, values):
if not isinstance(values, dict):
raise ValueError("Data must be a dictionary")
if not values.get("data"):
values["data"] = {}
# Any other keyword should be added to the data dictionary

View file

@ -101,7 +101,7 @@ class Message(Data):
if self.sender == MESSAGE_SENDER_USER or not self.sender:
if self.files:
contents = [{"type": "text", "text": text}]
contents.extend(self.get_file_content_dicts())
contents.extend(self.sync_get_file_content_dicts())
human_message = HumanMessage(content=contents) # type: ignore
else:
human_message = HumanMessage(content=text)
@ -156,6 +156,11 @@ class Message(Data):
return ""
return value
def sync_get_file_content_dicts(self):
coro = self.aget_file_content_dicts()
loop = asyncio.get_event_loop()
return loop.run_until_complete(coro)
async def get_file_content_dicts(self):
content_dicts = []
files = await get_file_paths(self.files)
@ -165,9 +170,7 @@ class Message(Data):
content_dicts.append(file.to_content_dict())
else:
image_template = ImagePromptTemplate()
image_prompt_value: ImagePromptValue = image_template.invoke(
input={"path": file}, config={"callbacks": self.get_langchain_callbacks()}
) # type: ignore
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file}) # type: ignore
content_dicts.append({"type": "image_url", "image_url": image_prompt_value.image_url})
return content_dicts

View file

@ -0,0 +1,233 @@
import pytest
from pydantic import ValidationError
from langflow.inputs.inputs import (
BoolInput,
DataInput,
DictInput,
DropdownInput,
FileInput,
FloatInput,
HandleInput,
InputTypesMap,
IntInput,
MessageTextInput,
MultilineInput,
MultilineSecretInput,
MultiselectInput,
NestedDictInput,
PromptInput,
SecretStrInput,
StrInput,
TableInput,
_instantiate_input,
)
from langflow.schema.message import Message
@pytest.fixture
def client():
pass
def test_table_input_valid():
# Test with a valid list of dictionaries
data = TableInput(value=[{"key": "value"}, {"key2": "value2"}])
assert data.value == [{"key": "value"}, {"key2": "value2"}]
def test_table_input_invalid():
with pytest.raises(ValidationError):
# Test with an invalid value
TableInput(value="invalid")
with pytest.raises(ValidationError):
# Test with a list containing invalid item
TableInput(value=[{"key": "value"}, "invalid"])
def test_str_input_valid():
data = StrInput(value="This is a string")
assert data.value == "This is a string"
def test_str_input_invalid():
with pytest.warns(UserWarning):
# Test with an invalid value
StrInput(value=1234)
def test_message_text_input_valid():
# Test with a valid string
data = MessageTextInput(value="This is a message")
assert data.value == "This is a message"
# Test with a valid Message object
msg = Message(text="This is a message")
data = MessageTextInput(value=msg)
assert data.value == "This is a message"
def test_message_text_input_invalid():
with pytest.raises(ValidationError):
# Test with an invalid value
MessageTextInput(value=1234)
def test_instantiate_input_valid():
data = {"value": "This is a string"}
input_instance = _instantiate_input("StrInput", data)
assert isinstance(input_instance, StrInput)
assert input_instance.value == "This is a string"
def test_instantiate_input_invalid():
with pytest.raises(ValueError):
# Test with an invalid input type
_instantiate_input("InvalidInput", {"value": "This is a string"})
def test_handle_input_valid():
data = HandleInput(input_types=["BaseLanguageModel"])
assert data.input_types == ["BaseLanguageModel"]
def test_handle_input_invalid():
with pytest.raises(ValidationError):
HandleInput(input_types="BaseLanguageModel") # should be a list, not a string
def test_data_input_valid():
data_input = DataInput(input_types=["Data"])
assert data_input.input_types == ["Data"]
def test_prompt_input_valid():
prompt_input = PromptInput(value="Enter your name")
assert prompt_input.value == "Enter your name"
def test_multiline_input_valid():
multiline_input = MultilineInput(value="This is a\nmultiline input")
assert multiline_input.value == "This is a\nmultiline input"
assert multiline_input.multiline is True
def test_multiline_input_invalid():
with pytest.raises(ValidationError):
MultilineInput(value=1234) # should be a string, not an integer
def test_multiline_secret_input_valid():
multiline_secret_input = MultilineSecretInput(value="secret")
assert multiline_secret_input.value == "secret"
assert multiline_secret_input.password is True
def test_multiline_secret_input_invalid():
with pytest.raises(ValidationError):
MultilineSecretInput(value=1234) # should be a string, not an integer
def test_secret_str_input_valid():
secret_str_input = SecretStrInput(value="supersecret")
assert secret_str_input.value == "supersecret"
assert secret_str_input.password is True
def test_secret_str_input_invalid():
with pytest.raises(ValidationError):
SecretStrInput(value=1234) # should be a string, not an integer
def test_int_input_valid():
int_input = IntInput(value=10)
assert int_input.value == 10
def test_int_input_invalid():
with pytest.raises(ValidationError):
IntInput(value="not_an_int") # should be an integer, not a string
def test_float_input_valid():
float_input = FloatInput(value=10.5)
assert float_input.value == 10.5
def test_float_input_invalid():
with pytest.raises(ValidationError):
FloatInput(value="not_a_float") # should be a float, not a string
def test_bool_input_valid():
bool_input = BoolInput(value=True)
assert bool_input.value is True
def test_bool_input_invalid():
with pytest.raises(ValidationError):
BoolInput(value="not_a_bool") # should be a bool, not a string
def test_nested_dict_input_valid():
nested_dict_input = NestedDictInput(value={"key": "value"})
assert nested_dict_input.value == {"key": "value"}
def test_nested_dict_input_invalid():
with pytest.raises(ValidationError):
NestedDictInput(value="not_a_dict") # should be a dict, not a string
def test_dict_input_valid():
dict_input = DictInput(value={"key": "value"})
assert dict_input.value == {"key": "value"}
def test_dict_input_invalid():
with pytest.raises(ValidationError):
DictInput(value="not_a_dict") # should be a dict, not a string
def test_dropdown_input_valid():
dropdown_input = DropdownInput(options=["option1", "option2"])
assert dropdown_input.options == ["option1", "option2"]
def test_dropdown_input_invalid():
with pytest.raises(ValidationError):
DropdownInput(options="option1") # should be a list, not a string
def test_multiselect_input_valid():
multiselect_input = MultiselectInput(value=["option1", "option2"])
assert multiselect_input.value == ["option1", "option2"]
def test_multiselect_input_invalid():
with pytest.raises(ValidationError):
MultiselectInput(value="option1") # should be a list, not a string
def test_file_input_valid():
file_input = FileInput(value=["/path/to/file"])
assert file_input.value == ["/path/to/file"]
def test_instantiate_input_comprehensive():
valid_data = {
"StrInput": {"value": "A string"},
"IntInput": {"value": 10},
"FloatInput": {"value": 10.5},
"BoolInput": {"value": True},
"DictInput": {"value": {"key": "value"}},
"MultiselectInput": {"value": ["option1", "option2"]},
}
for input_type, data in valid_data.items():
input_instance = _instantiate_input(input_type, data)
assert isinstance(input_instance, InputTypesMap[input_type])
with pytest.raises(ValueError):
_instantiate_input("InvalidInput", {"value": "Invalid"}) # Invalid input type