refactor: input field handling and serialization (#3118)
* refactor: update code references to use _code instead of code * refactor: add backwards compatible attributes to Component class * refactor: update Component constructor to pass config params with underscore Refactored the `Component` class in `component.py` to handle inputs and outputs. Added a new method `map_outputs` to map a list of outputs to the component. Also updated the `__init__` method to properly initialize the inputs, outputs, and other attributes. This change improves the flexibility and extensibility of the `Component` class. Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * refactor: change attribute to use underscore * refactor: update CustomComponent initialization parameters Refactored the `instantiate_class` function in `loading.py` to update the initialization parameters for the `CustomComponent` class. Changed the parameter names from `user_id`, `parameters`, `vertex`, and `tracing_service` to `_user_id`, `_parameters`, `_vertex`, and `_tracing_service` respectively. This change ensures consistency and improves code readability. Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * refactor: update BaseComponent to accept UUID for _user_id Updated the `BaseComponent` class in `base_component.py` to accept a `UUID` type for the `_user_id` attribute. This change improves the type safety and ensures consistency with the usage of `_user_id` throughout the codebase. * refactor: import nanoid with type annotation The `nanoid` import in `component.py` has been updated to include a type annotation `# type: ignore`. This change ensures that the type checker ignores any errors related to the `nanoid` import. * fix(custom_component.py): convert _user_id to string before passing to functions to ensure compatibility with function signatures * feat(component.py): add method to set output types based on method return type to improve type checking and validation in custom components * refactor: extract method to get method return type in CustomComponent * refactor(utils.py): refactor code to use _user_id instead of user_id for consistency and clarity perf(utils.py): optimize code by reusing cc_instance instead of calling get_component_instance multiple times * refactor(utils.py, base.py): change parameter name 'add_name' to 'keep_name' for clarity and consistency in codebase * [autofix.ci] apply automated fixes * refactor: update schema.py to include Edge related typres The `schema.py` file in the `src/backend/base/langflow/graph/edge` directory has been updated to include the `TargetHandle` and `SourceHandle` models. These models define the structure and attributes of the target and source handles used in the edge data. This change improves the clarity and consistency of the codebase. * refactor: update BaseInputMixin to handle invalid field types gracefully The `BaseInputMixin` class in `input_mixin.py` has been updated to handle invalid field types gracefully. Instead of raising an exception, it now returns `FieldTypes.OTHER` for any invalid field type. This change improves the robustness and reliability of the codebase. * refactor: update file_types field alias in FileMixin The `file_types` field in the `FileMixin` class of `input_mixin.py` has been updated to use the `alias` parameter instead of `serialization_alias`. This change ensures consistency and improves the clarity of the codebase. * refactor(inputs): update field_type declarations in various input classes to use SerializableFieldTypes enum for better type safety and clarity * refactor(inputs): convert dict to Message object in _validate_value method * refactor(inputs): convert dict to Message object in _validate_value method * refactor(inputs): update model_config in BaseInputMixin to enable populating by name The `model_config` attribute in the `BaseInputMixin` class of `input_mixin.py` has been updated to include the `populate_by_name=True` parameter. This change allows the model configuration to be populated by name, improving the flexibility and usability of the codebase. * refactor: update _extract_return_type method in CustomComponent to accept Any type The _extract_return_type method in CustomComponent has been updated to accept the Any type as the return_type parameter. This change improves the flexibility and compatibility of the method, allowing it to handle a wider range of return types. * refactor: update BaseComponent to use get_template_config method Refactored the `BaseComponent` class in `base_component.py` to use the `get_template_config` method instead of duplicating the code. This change improves code readability and reduces redundancy. * feat: add BaseModel class with model_config attribute A new `BaseModel` class has been added to the `base_model.py` file. This class extends the `PydanticBaseModel` and includes a `model_config` attribute of type `ConfigDict`. This change improves the codebase by providing a base model with a configuration dictionary for models. Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * refactor: update langflow.graph.edge.schema.py Refactor the `langflow.graph.edge.schema.py` file to include the `TargetHandle` and `SourceHandle` models. This change improves the clarity and consistency of the codebase. Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * refactor: update build_custom_component_template to use add_name instead of keep_name Refactor the `build_custom_component_template` function in `utils.py` to use the `add_name` parameter instead of the deprecated `keep_name` parameter. This change ensures consistency with the updated method signature and improves code clarity. * feat(component.py): add method to set output types based on method return type to improve type checking and validation in custom components (#3115) * feat(component.py): add method to set output types based on method return type to improve type checking and validation in custom components * refactor: extract method to get method return type in CustomComponent * refactor: update _extract_return_type method in CustomComponent to accept Any type The _extract_return_type method in CustomComponent has been updated to accept the Any type as the return_type parameter. This change improves the flexibility and compatibility of the method, allowing it to handle a wider range of return types. * refactor: add _template_config property to BaseComponent Add a new `_template_config` property to the `BaseComponent` class in `base_component.py`. This property is used to store the template configuration for the custom component. If the `_template_config` property is empty, it is populated by calling the `build_template_config` method. This change improves the efficiency of accessing the template configuration and ensures that it is only built when needed. * refactor: add type checking for Output types in add_types method Improve type checking in the `add_types` method of the `Output` class in `base.py`. Check if the `type_` already exists in the `types` list before adding it. This change ensures that duplicate types are not added to the list. * update starter projects * refactor: optimize imports in base.py Optimize imports in the `base.py` file by removing unused imports and organizing the remaining imports. This change improves code readability and reduces unnecessary clutter. * fix(base.py): fix condition to check if self.types is not None before checking if type_ is in self.types * refactor: update build_custom_component_template to use add_name instead of keep_name * feat: update logger warning message for invalid value type in StrInput * refactor: update logger warning message for invalid value type in StrInput * refactor: add unit tests for inputs in test_inputs.py * refactor: update validation for IntInput value type Improve the validation for the value type in the IntInput class. The updated code ensures that the value is of a valid type (int or float) and converts float values to integers. This change enhances the accuracy and reliability of the input validation process. * refactor: improve validation for FloatInput value type Improve the validation for the value type in the FloatInput class. The updated code ensures that the value is of a valid type (int or float) and converts integer values to floats. This change enhances the accuracy and reliability of the input validation process. * fix(data.py): add validation to check if data is a dictionary before processing to prevent potential errors * refactor: update test_inputs.py to include comprehensive unit tests for input classes Add comprehensive unit tests for the input classes in the test_inputs.py file. This change ensures that the input classes are thoroughly tested and functioning correctly. The unit tests cover various scenarios and edge cases to validate the behavior of the input classes. This improvement enhances the reliability and stability of the codebase. * refactor: remove invalid input tests for DataInput and FileInput in test_inputs.py Cleaned up unnecessary test cases for invalid input types * refactor: improve validation for IntInput and FloatInput value types * refactor: add async version of get_file_content_dicts method in Message class * refactor: update get_file_content_dicts method in Message class The `get_file_content_dicts` method in the `Message` class has been updated to use an async version for improved performance. This change enhances the efficiency of retrieving file content dictionaries. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
4fb96d6160
commit
b5180c4d70
5 changed files with 332 additions and 39 deletions
|
|
@ -27,9 +27,9 @@ SerializableFieldTypes = Annotated[FieldTypes, PlainSerializer(lambda v: v.value
|
|||
|
||||
# Base mixin for common input field attributes and methods
|
||||
class BaseInputMixin(BaseModel, validate_assignment=True):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid", populate_by_name=True)
|
||||
|
||||
field_type: Optional[SerializableFieldTypes] = Field(default=FieldTypes.TEXT)
|
||||
field_type: SerializableFieldTypes = Field(default=FieldTypes.TEXT)
|
||||
|
||||
required: bool = False
|
||||
"""Specifies if the field is required. Defaults to False."""
|
||||
|
|
@ -78,9 +78,10 @@ class BaseInputMixin(BaseModel, validate_assignment=True):
|
|||
@field_validator("field_type", mode="before")
|
||||
@classmethod
|
||||
def validate_field_type(cls, v):
|
||||
if v not in FieldTypes:
|
||||
try:
|
||||
return FieldTypes(v)
|
||||
except ValueError:
|
||||
return FieldTypes.OTHER
|
||||
return FieldTypes(v)
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize_model(self, handler):
|
||||
|
|
@ -101,7 +102,7 @@ class MetadataTraceMixin(BaseModel):
|
|||
|
||||
# Mixin for input fields that can be listable
|
||||
class ListableInputMixin(BaseModel):
|
||||
is_list: bool = Field(default=False, serialization_alias="list")
|
||||
is_list: bool = Field(default=False, alias="list")
|
||||
|
||||
|
||||
# Specific mixin for fields needing database interaction
|
||||
|
|
@ -112,7 +113,7 @@ class DatabaseLoadMixin(BaseModel):
|
|||
# Specific mixin for fields needing file interaction
|
||||
class FileMixin(BaseModel):
|
||||
file_path: Optional[str] = Field(default="")
|
||||
file_types: list[str] = Field(default=[], serialization_alias="fileTypes")
|
||||
file_types: list[str] = Field(default=[], alias="fileTypes")
|
||||
|
||||
@field_validator("file_types")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import warnings
|
||||
from typing import Any, AsyncIterator, Iterator, Optional, Union, get_args
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from langflow.inputs.validators import CoalesceBool
|
||||
|
|
@ -24,7 +24,7 @@ from .input_mixin import (
|
|||
|
||||
|
||||
class TableInput(BaseInputMixin, MetadataTraceMixin, TableMixin, ListableInputMixin):
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.TABLE
|
||||
field_type: SerializableFieldTypes = FieldTypes.TABLE
|
||||
is_list: bool = True
|
||||
|
||||
@field_validator("value")
|
||||
|
|
@ -50,11 +50,11 @@ class HandleInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin):
|
|||
|
||||
Attributes:
|
||||
input_types (list[str]): A list of input types.
|
||||
field_type (Optional[SerializableFieldTypes]): The field type of the input.
|
||||
field_type (SerializableFieldTypes): The field type of the input.
|
||||
"""
|
||||
|
||||
input_types: list[str] = Field(default_factory=list)
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.OTHER
|
||||
field_type: SerializableFieldTypes = FieldTypes.OTHER
|
||||
|
||||
|
||||
class DataInput(HandleInput, InputTraceMixin):
|
||||
|
|
@ -69,12 +69,12 @@ class DataInput(HandleInput, InputTraceMixin):
|
|||
|
||||
|
||||
class PromptInput(BaseInputMixin, ListableInputMixin, InputTraceMixin):
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.PROMPT
|
||||
field_type: SerializableFieldTypes = FieldTypes.PROMPT
|
||||
|
||||
|
||||
# Applying mixins to a specific input type
|
||||
class StrInput(BaseInputMixin, ListableInputMixin, DatabaseLoadMixin, MetadataTraceMixin):
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.TEXT
|
||||
field_type: SerializableFieldTypes = FieldTypes.TEXT
|
||||
load_from_db: CoalesceBool = False
|
||||
"""Defines if the field will allow the user to open a text editor. Default is False."""
|
||||
|
||||
|
|
@ -94,8 +94,13 @@ class StrInput(BaseInputMixin, ListableInputMixin, DatabaseLoadMixin, MetadataTr
|
|||
ValueError: If the value is not of a valid type or if the input is missing a required key.
|
||||
"""
|
||||
if not isinstance(v, str) and v is not None:
|
||||
# Keep the warning for now, but we should change it to an error
|
||||
if _info.data.get("input_types") and v.__class__.__name__ not in _info.data.get("input_types"):
|
||||
logger.warning(f"Invalid value type {type(v)}")
|
||||
warnings.warn(
|
||||
f"Invalid value type {type(v)} for input {_info.data.get('name')}. Expected types: {_info.data.get('input_types')}"
|
||||
)
|
||||
else:
|
||||
warnings.warn(f"Invalid value type {type(v)} for input {_info.data.get('name')}.")
|
||||
return v
|
||||
|
||||
@field_validator("value")
|
||||
|
|
@ -129,6 +134,8 @@ class MessageInput(StrInput, InputTraceMixin):
|
|||
@staticmethod
|
||||
def _validate_value(v: Any, _info):
|
||||
# If v is a instance of Message, then its fine
|
||||
if isinstance(v, dict):
|
||||
return Message(**v)
|
||||
if isinstance(v, Message):
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
|
|
@ -164,6 +171,8 @@ class MessageTextInput(StrInput, MetadataTraceMixin, InputTraceMixin):
|
|||
ValueError: If the value is not of a valid type or if the input is missing a required key.
|
||||
"""
|
||||
value: str | AsyncIterator | Iterator | None = None
|
||||
if isinstance(v, dict):
|
||||
v = Message(**v)
|
||||
if isinstance(v, str):
|
||||
value = v
|
||||
elif isinstance(v, Message):
|
||||
|
|
@ -190,11 +199,11 @@ class MultilineInput(MessageTextInput, MultilineMixin, InputTraceMixin):
|
|||
Represents a multiline input field.
|
||||
|
||||
Attributes:
|
||||
field_type (Optional[SerializableFieldTypes]): The type of the field. Defaults to FieldTypes.TEXT.
|
||||
field_type (SerializableFieldTypes): The type of the field. Defaults to FieldTypes.TEXT.
|
||||
multiline (CoalesceBool): Indicates whether the input field should support multiple lines. Defaults to True.
|
||||
"""
|
||||
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.TEXT
|
||||
field_type: SerializableFieldTypes = FieldTypes.TEXT
|
||||
multiline: CoalesceBool = True
|
||||
|
||||
|
||||
|
|
@ -203,11 +212,11 @@ class MultilineSecretInput(MessageTextInput, MultilineMixin, InputTraceMixin):
|
|||
Represents a multiline input field.
|
||||
|
||||
Attributes:
|
||||
field_type (Optional[SerializableFieldTypes]): The type of the field. Defaults to FieldTypes.TEXT.
|
||||
field_type (SerializableFieldTypes): The type of the field. Defaults to FieldTypes.TEXT.
|
||||
multiline (CoalesceBool): Indicates whether the input field should support multiple lines. Defaults to True.
|
||||
"""
|
||||
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.PASSWORD
|
||||
field_type: SerializableFieldTypes = FieldTypes.PASSWORD
|
||||
multiline: CoalesceBool = True
|
||||
password: CoalesceBool = Field(default=True)
|
||||
|
||||
|
|
@ -219,12 +228,12 @@ class SecretStrInput(BaseInputMixin, DatabaseLoadMixin):
|
|||
This class inherits from `BaseInputMixin` and `DatabaseLoadMixin`.
|
||||
|
||||
Attributes:
|
||||
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to `FieldTypes.PASSWORD`.
|
||||
field_type (SerializableFieldTypes): The field type of the input. Defaults to `FieldTypes.PASSWORD`.
|
||||
password (CoalesceBool): A boolean indicating whether the input is a password. Defaults to `True`.
|
||||
input_types (list[str]): A list of input types associated with this input. Defaults to an empty list.
|
||||
"""
|
||||
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.PASSWORD
|
||||
field_type: SerializableFieldTypes = FieldTypes.PASSWORD
|
||||
password: CoalesceBool = Field(default=True)
|
||||
input_types: list[str] = ["Message"]
|
||||
load_from_db: CoalesceBool = True
|
||||
|
|
@ -275,10 +284,33 @@ class IntInput(BaseInputMixin, ListableInputMixin, RangeMixin, MetadataTraceMixi
|
|||
It inherits from the `BaseInputMixin`, `ListableInputMixin`, and `RangeMixin` classes.
|
||||
|
||||
Attributes:
|
||||
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.INTEGER.
|
||||
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.INTEGER.
|
||||
"""
|
||||
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.INTEGER
|
||||
field_type: SerializableFieldTypes = FieldTypes.INTEGER
|
||||
|
||||
@field_validator("value")
|
||||
@classmethod
|
||||
def validate_value(cls, v: Any, _info):
|
||||
"""
|
||||
Validates the given value and returns the processed value.
|
||||
|
||||
Args:
|
||||
v (Any): The value to be validated.
|
||||
_info: Additional information about the input.
|
||||
|
||||
Returns:
|
||||
The processed value.
|
||||
|
||||
Raises:
|
||||
ValueError: If the value is not of a valid type or if the input is missing a required key.
|
||||
"""
|
||||
|
||||
if v and not isinstance(v, (int, float)):
|
||||
raise ValueError(f"Invalid value type {type(v)} for input {_info.data.get('name')}.")
|
||||
if isinstance(v, float):
|
||||
v = int(v)
|
||||
return v
|
||||
|
||||
|
||||
class FloatInput(BaseInputMixin, ListableInputMixin, RangeMixin, MetadataTraceMixin):
|
||||
|
|
@ -289,10 +321,32 @@ class FloatInput(BaseInputMixin, ListableInputMixin, RangeMixin, MetadataTraceMi
|
|||
It inherits from the `BaseInputMixin`, `ListableInputMixin`, and `RangeMixin` classes.
|
||||
|
||||
Attributes:
|
||||
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.FLOAT.
|
||||
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.FLOAT.
|
||||
"""
|
||||
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.FLOAT
|
||||
field_type: SerializableFieldTypes = FieldTypes.FLOAT
|
||||
|
||||
@field_validator("value")
|
||||
@classmethod
|
||||
def validate_value(cls, v: Any, _info):
|
||||
"""
|
||||
Validates the given value and returns the processed value.
|
||||
|
||||
Args:
|
||||
v (Any): The value to be validated.
|
||||
_info: Additional information about the input.
|
||||
|
||||
Returns:
|
||||
The processed value.
|
||||
|
||||
Raises:
|
||||
ValueError: If the value is not of a valid type or if the input is missing a required key.
|
||||
"""
|
||||
if v and not isinstance(v, (int, float)):
|
||||
raise ValueError(f"Invalid value type {type(v)} for input {_info.data.get('name')}.")
|
||||
if isinstance(v, int):
|
||||
v = float(v)
|
||||
return v
|
||||
|
||||
|
||||
class BoolInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin):
|
||||
|
|
@ -303,11 +357,11 @@ class BoolInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin):
|
|||
It inherits from the `BaseInputMixin` and `ListableInputMixin` classes.
|
||||
|
||||
Attributes:
|
||||
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.BOOLEAN.
|
||||
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.BOOLEAN.
|
||||
value (CoalesceBool): The value of the boolean input.
|
||||
"""
|
||||
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.BOOLEAN
|
||||
field_type: SerializableFieldTypes = FieldTypes.BOOLEAN
|
||||
value: CoalesceBool = False
|
||||
|
||||
|
||||
|
|
@ -319,11 +373,11 @@ class NestedDictInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin, In
|
|||
It inherits from the `BaseInputMixin` and `ListableInputMixin` classes.
|
||||
|
||||
Attributes:
|
||||
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.NESTED_DICT.
|
||||
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.NESTED_DICT.
|
||||
value (Optional[dict]): The value of the input. Defaults to an empty dictionary.
|
||||
"""
|
||||
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.NESTED_DICT
|
||||
field_type: SerializableFieldTypes = FieldTypes.NESTED_DICT
|
||||
value: Optional[dict | Data] = {}
|
||||
|
||||
|
||||
|
|
@ -335,11 +389,11 @@ class DictInput(BaseInputMixin, ListableInputMixin, InputTraceMixin):
|
|||
It inherits from the `BaseInputMixin` and `ListableInputMixin` classes.
|
||||
|
||||
Attributes:
|
||||
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.DICT.
|
||||
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.DICT.
|
||||
value (Optional[dict]): The value of the dictionary input. Defaults to an empty dictionary.
|
||||
"""
|
||||
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.DICT
|
||||
field_type: SerializableFieldTypes = FieldTypes.DICT
|
||||
value: Optional[dict] = {}
|
||||
|
||||
|
||||
|
|
@ -351,12 +405,12 @@ class DropdownInput(BaseInputMixin, DropDownMixin, MetadataTraceMixin):
|
|||
It inherits from the `BaseInputMixin` and `DropDownMixin` classes.
|
||||
|
||||
Attributes:
|
||||
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.TEXT.
|
||||
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.TEXT.
|
||||
options (Optional[Union[list[str], Callable]]): List of options for the field.
|
||||
Default is None.
|
||||
"""
|
||||
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.TEXT
|
||||
field_type: SerializableFieldTypes = FieldTypes.TEXT
|
||||
options: list[str] = Field(default_factory=list)
|
||||
combobox: CoalesceBool = False
|
||||
|
||||
|
|
@ -369,12 +423,12 @@ class MultiselectInput(BaseInputMixin, ListableInputMixin, DropDownMixin, Metada
|
|||
It inherits from the `BaseInputMixin`, `ListableInputMixin` and `DropDownMixin` classes.
|
||||
|
||||
Attributes:
|
||||
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.TEXT.
|
||||
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.TEXT.
|
||||
options (Optional[Union[list[str], Callable]]): List of options for the field. Only used when is_list=True.
|
||||
Default is None.
|
||||
"""
|
||||
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.TEXT
|
||||
field_type: SerializableFieldTypes = FieldTypes.TEXT
|
||||
options: list[str] = Field(default_factory=list)
|
||||
is_list: bool = Field(default=True, serialization_alias="list")
|
||||
combobox: CoalesceBool = False
|
||||
|
|
@ -399,10 +453,10 @@ class FileInput(BaseInputMixin, ListableInputMixin, FileMixin, MetadataTraceMixi
|
|||
It inherits from the `BaseInputMixin`, `ListableInputMixin`, and `FileMixin` classes.
|
||||
|
||||
Attributes:
|
||||
field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.FILE.
|
||||
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.FILE.
|
||||
"""
|
||||
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.FILE
|
||||
field_type: SerializableFieldTypes = FieldTypes.FILE
|
||||
|
||||
|
||||
InputTypes = Union[
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ class Data(BaseModel):
|
|||
|
||||
@model_validator(mode="before")
|
||||
def validate_data(cls, values):
|
||||
if not isinstance(values, dict):
|
||||
raise ValueError("Data must be a dictionary")
|
||||
if not values.get("data"):
|
||||
values["data"] = {}
|
||||
# Any other keyword should be added to the data dictionary
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class Message(Data):
|
|||
if self.sender == MESSAGE_SENDER_USER or not self.sender:
|
||||
if self.files:
|
||||
contents = [{"type": "text", "text": text}]
|
||||
contents.extend(self.get_file_content_dicts())
|
||||
contents.extend(self.sync_get_file_content_dicts())
|
||||
human_message = HumanMessage(content=contents) # type: ignore
|
||||
else:
|
||||
human_message = HumanMessage(content=text)
|
||||
|
|
@ -156,6 +156,11 @@ class Message(Data):
|
|||
return ""
|
||||
return value
|
||||
|
||||
def sync_get_file_content_dicts(self):
|
||||
coro = self.aget_file_content_dicts()
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(coro)
|
||||
|
||||
async def get_file_content_dicts(self):
|
||||
content_dicts = []
|
||||
files = await get_file_paths(self.files)
|
||||
|
|
@ -165,9 +170,7 @@ class Message(Data):
|
|||
content_dicts.append(file.to_content_dict())
|
||||
else:
|
||||
image_template = ImagePromptTemplate()
|
||||
image_prompt_value: ImagePromptValue = image_template.invoke(
|
||||
input={"path": file}, config={"callbacks": self.get_langchain_callbacks()}
|
||||
) # type: ignore
|
||||
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file}) # type: ignore
|
||||
content_dicts.append({"type": "image_url", "image_url": image_prompt_value.image_url})
|
||||
return content_dicts
|
||||
|
||||
|
|
|
|||
233
src/backend/tests/unit/inputs/test_inputs.py
Normal file
233
src/backend/tests/unit/inputs/test_inputs.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langflow.inputs.inputs import (
|
||||
BoolInput,
|
||||
DataInput,
|
||||
DictInput,
|
||||
DropdownInput,
|
||||
FileInput,
|
||||
FloatInput,
|
||||
HandleInput,
|
||||
InputTypesMap,
|
||||
IntInput,
|
||||
MessageTextInput,
|
||||
MultilineInput,
|
||||
MultilineSecretInput,
|
||||
MultiselectInput,
|
||||
NestedDictInput,
|
||||
PromptInput,
|
||||
SecretStrInput,
|
||||
StrInput,
|
||||
TableInput,
|
||||
_instantiate_input,
|
||||
)
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
pass
|
||||
|
||||
|
||||
def test_table_input_valid():
|
||||
# Test with a valid list of dictionaries
|
||||
data = TableInput(value=[{"key": "value"}, {"key2": "value2"}])
|
||||
assert data.value == [{"key": "value"}, {"key2": "value2"}]
|
||||
|
||||
|
||||
def test_table_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
# Test with an invalid value
|
||||
TableInput(value="invalid")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
# Test with a list containing invalid item
|
||||
TableInput(value=[{"key": "value"}, "invalid"])
|
||||
|
||||
|
||||
def test_str_input_valid():
|
||||
data = StrInput(value="This is a string")
|
||||
assert data.value == "This is a string"
|
||||
|
||||
|
||||
def test_str_input_invalid():
|
||||
with pytest.warns(UserWarning):
|
||||
# Test with an invalid value
|
||||
StrInput(value=1234)
|
||||
|
||||
|
||||
def test_message_text_input_valid():
|
||||
# Test with a valid string
|
||||
data = MessageTextInput(value="This is a message")
|
||||
assert data.value == "This is a message"
|
||||
|
||||
# Test with a valid Message object
|
||||
msg = Message(text="This is a message")
|
||||
data = MessageTextInput(value=msg)
|
||||
assert data.value == "This is a message"
|
||||
|
||||
|
||||
def test_message_text_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
# Test with an invalid value
|
||||
MessageTextInput(value=1234)
|
||||
|
||||
|
||||
def test_instantiate_input_valid():
|
||||
data = {"value": "This is a string"}
|
||||
input_instance = _instantiate_input("StrInput", data)
|
||||
assert isinstance(input_instance, StrInput)
|
||||
assert input_instance.value == "This is a string"
|
||||
|
||||
|
||||
def test_instantiate_input_invalid():
|
||||
with pytest.raises(ValueError):
|
||||
# Test with an invalid input type
|
||||
_instantiate_input("InvalidInput", {"value": "This is a string"})
|
||||
|
||||
|
||||
def test_handle_input_valid():
|
||||
data = HandleInput(input_types=["BaseLanguageModel"])
|
||||
assert data.input_types == ["BaseLanguageModel"]
|
||||
|
||||
|
||||
def test_handle_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
HandleInput(input_types="BaseLanguageModel") # should be a list, not a string
|
||||
|
||||
|
||||
def test_data_input_valid():
|
||||
data_input = DataInput(input_types=["Data"])
|
||||
assert data_input.input_types == ["Data"]
|
||||
|
||||
|
||||
def test_prompt_input_valid():
|
||||
prompt_input = PromptInput(value="Enter your name")
|
||||
assert prompt_input.value == "Enter your name"
|
||||
|
||||
|
||||
def test_multiline_input_valid():
|
||||
multiline_input = MultilineInput(value="This is a\nmultiline input")
|
||||
assert multiline_input.value == "This is a\nmultiline input"
|
||||
assert multiline_input.multiline is True
|
||||
|
||||
|
||||
def test_multiline_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
MultilineInput(value=1234) # should be a string, not an integer
|
||||
|
||||
|
||||
def test_multiline_secret_input_valid():
|
||||
multiline_secret_input = MultilineSecretInput(value="secret")
|
||||
assert multiline_secret_input.value == "secret"
|
||||
assert multiline_secret_input.password is True
|
||||
|
||||
|
||||
def test_multiline_secret_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
MultilineSecretInput(value=1234) # should be a string, not an integer
|
||||
|
||||
|
||||
def test_secret_str_input_valid():
|
||||
secret_str_input = SecretStrInput(value="supersecret")
|
||||
assert secret_str_input.value == "supersecret"
|
||||
assert secret_str_input.password is True
|
||||
|
||||
|
||||
def test_secret_str_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
SecretStrInput(value=1234) # should be a string, not an integer
|
||||
|
||||
|
||||
def test_int_input_valid():
|
||||
int_input = IntInput(value=10)
|
||||
assert int_input.value == 10
|
||||
|
||||
|
||||
def test_int_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
IntInput(value="not_an_int") # should be an integer, not a string
|
||||
|
||||
|
||||
def test_float_input_valid():
|
||||
float_input = FloatInput(value=10.5)
|
||||
assert float_input.value == 10.5
|
||||
|
||||
|
||||
def test_float_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
FloatInput(value="not_a_float") # should be a float, not a string
|
||||
|
||||
|
||||
def test_bool_input_valid():
|
||||
bool_input = BoolInput(value=True)
|
||||
assert bool_input.value is True
|
||||
|
||||
|
||||
def test_bool_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
BoolInput(value="not_a_bool") # should be a bool, not a string
|
||||
|
||||
|
||||
def test_nested_dict_input_valid():
|
||||
nested_dict_input = NestedDictInput(value={"key": "value"})
|
||||
assert nested_dict_input.value == {"key": "value"}
|
||||
|
||||
|
||||
def test_nested_dict_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
NestedDictInput(value="not_a_dict") # should be a dict, not a string
|
||||
|
||||
|
||||
def test_dict_input_valid():
|
||||
dict_input = DictInput(value={"key": "value"})
|
||||
assert dict_input.value == {"key": "value"}
|
||||
|
||||
|
||||
def test_dict_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
DictInput(value="not_a_dict") # should be a dict, not a string
|
||||
|
||||
|
||||
def test_dropdown_input_valid():
|
||||
dropdown_input = DropdownInput(options=["option1", "option2"])
|
||||
assert dropdown_input.options == ["option1", "option2"]
|
||||
|
||||
|
||||
def test_dropdown_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
DropdownInput(options="option1") # should be a list, not a string
|
||||
|
||||
|
||||
def test_multiselect_input_valid():
|
||||
multiselect_input = MultiselectInput(value=["option1", "option2"])
|
||||
assert multiselect_input.value == ["option1", "option2"]
|
||||
|
||||
|
||||
def test_multiselect_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
MultiselectInput(value="option1") # should be a list, not a string
|
||||
|
||||
|
||||
def test_file_input_valid():
|
||||
file_input = FileInput(value=["/path/to/file"])
|
||||
assert file_input.value == ["/path/to/file"]
|
||||
|
||||
|
||||
def test_instantiate_input_comprehensive():
|
||||
valid_data = {
|
||||
"StrInput": {"value": "A string"},
|
||||
"IntInput": {"value": 10},
|
||||
"FloatInput": {"value": 10.5},
|
||||
"BoolInput": {"value": True},
|
||||
"DictInput": {"value": {"key": "value"}},
|
||||
"MultiselectInput": {"value": ["option1", "option2"]},
|
||||
}
|
||||
|
||||
for input_type, data in valid_data.items():
|
||||
input_instance = _instantiate_input(input_type, data)
|
||||
assert isinstance(input_instance, InputTypesMap[input_type])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_instantiate_input("InvalidInput", {"value": "Invalid"}) # Invalid input type
|
||||
Loading…
Add table
Add a link
Reference in a new issue