diff --git a/src/backend/base/langflow/inputs/input_mixin.py b/src/backend/base/langflow/inputs/input_mixin.py index fe7f54f56..5b1f112d0 100644 --- a/src/backend/base/langflow/inputs/input_mixin.py +++ b/src/backend/base/langflow/inputs/input_mixin.py @@ -27,9 +27,9 @@ SerializableFieldTypes = Annotated[FieldTypes, PlainSerializer(lambda v: v.value # Base mixin for common input field attributes and methods class BaseInputMixin(BaseModel, validate_assignment=True): - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid", populate_by_name=True) - field_type: Optional[SerializableFieldTypes] = Field(default=FieldTypes.TEXT) + field_type: SerializableFieldTypes = Field(default=FieldTypes.TEXT) required: bool = False """Specifies if the field is required. Defaults to False.""" @@ -78,9 +78,10 @@ class BaseInputMixin(BaseModel, validate_assignment=True): @field_validator("field_type", mode="before") @classmethod def validate_field_type(cls, v): - if v not in FieldTypes: + try: + return FieldTypes(v) + except ValueError: return FieldTypes.OTHER - return FieldTypes(v) @model_serializer(mode="wrap") def serialize_model(self, handler): @@ -101,7 +102,7 @@ class MetadataTraceMixin(BaseModel): # Mixin for input fields that can be listable class ListableInputMixin(BaseModel): - is_list: bool = Field(default=False, serialization_alias="list") + is_list: bool = Field(default=False, alias="list") # Specific mixin for fields needing database interaction @@ -112,7 +113,7 @@ class DatabaseLoadMixin(BaseModel): # Specific mixin for fields needing file interaction class FileMixin(BaseModel): file_path: Optional[str] = Field(default="") - file_types: list[str] = Field(default=[], serialization_alias="fileTypes") + file_types: list[str] = Field(default=[], alias="fileTypes") @field_validator("file_types") @classmethod diff --git a/src/backend/base/langflow/inputs/inputs.py b/src/backend/base/langflow/inputs/inputs.py index 4d616e175..bc244ec3d 100644 --- a/src/backend/base/langflow/inputs/inputs.py +++ b/src/backend/base/langflow/inputs/inputs.py @@ -1,6 +1,6 @@ +import warnings from typing import Any, AsyncIterator, Iterator, Optional, Union, get_args -from loguru import logger from pydantic import Field, field_validator from langflow.inputs.validators import CoalesceBool @@ -24,7 +24,7 @@ from .input_mixin import ( class TableInput(BaseInputMixin, MetadataTraceMixin, TableMixin, ListableInputMixin): - field_type: Optional[SerializableFieldTypes] = FieldTypes.TABLE + field_type: SerializableFieldTypes = FieldTypes.TABLE is_list: bool = True @field_validator("value") @@ -50,11 +50,11 @@ class HandleInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin): Attributes: input_types (list[str]): A list of input types. - field_type (Optional[SerializableFieldTypes]): The field type of the input. + field_type (SerializableFieldTypes): The field type of the input. """ input_types: list[str] = Field(default_factory=list) - field_type: Optional[SerializableFieldTypes] = FieldTypes.OTHER + field_type: SerializableFieldTypes = FieldTypes.OTHER class DataInput(HandleInput, InputTraceMixin): @@ -69,12 +69,12 @@ class DataInput(HandleInput, InputTraceMixin): class PromptInput(BaseInputMixin, ListableInputMixin, InputTraceMixin): - field_type: Optional[SerializableFieldTypes] = FieldTypes.PROMPT + field_type: SerializableFieldTypes = FieldTypes.PROMPT # Applying mixins to a specific input type class StrInput(BaseInputMixin, ListableInputMixin, DatabaseLoadMixin, MetadataTraceMixin): - field_type: Optional[SerializableFieldTypes] = FieldTypes.TEXT + field_type: SerializableFieldTypes = FieldTypes.TEXT load_from_db: CoalesceBool = False """Defines if the field will allow the user to open a text editor. Default is False.""" @@ -94,8 +94,13 @@ class StrInput(BaseInputMixin, ListableInputMixin, DatabaseLoadMixin, MetadataTr ValueError: If the value is not of a valid type or if the input is missing a required key. """ if not isinstance(v, str) and v is not None: + # Keep the warning for now, but we should change it to an error if _info.data.get("input_types") and v.__class__.__name__ not in _info.data.get("input_types"): - logger.warning(f"Invalid value type {type(v)}") + warnings.warn( + f"Invalid value type {type(v)} for input {_info.data.get('name')}. Expected types: {_info.data.get('input_types')}" + ) + else: + warnings.warn(f"Invalid value type {type(v)} for input {_info.data.get('name')}.") return v @field_validator("value") @@ -129,6 +134,8 @@ class MessageInput(StrInput, InputTraceMixin): @staticmethod def _validate_value(v: Any, _info): # If v is a instance of Message, then its fine + if isinstance(v, dict): + return Message(**v) if isinstance(v, Message): return v if isinstance(v, str): @@ -164,6 +171,8 @@ class MessageTextInput(StrInput, MetadataTraceMixin, InputTraceMixin): ValueError: If the value is not of a valid type or if the input is missing a required key. """ value: str | AsyncIterator | Iterator | None = None + if isinstance(v, dict): + v = Message(**v) if isinstance(v, str): value = v elif isinstance(v, Message): @@ -190,11 +199,11 @@ class MultilineInput(MessageTextInput, MultilineMixin, InputTraceMixin): Represents a multiline input field. Attributes: - field_type (Optional[SerializableFieldTypes]): The type of the field. Defaults to FieldTypes.TEXT. + field_type (SerializableFieldTypes): The type of the field. Defaults to FieldTypes.TEXT. multiline (CoalesceBool): Indicates whether the input field should support multiple lines. Defaults to True. """ - field_type: Optional[SerializableFieldTypes] = FieldTypes.TEXT + field_type: SerializableFieldTypes = FieldTypes.TEXT multiline: CoalesceBool = True @@ -203,11 +212,11 @@ class MultilineSecretInput(MessageTextInput, MultilineMixin, InputTraceMixin): Represents a multiline input field. Attributes: - field_type (Optional[SerializableFieldTypes]): The type of the field. Defaults to FieldTypes.TEXT. + field_type (SerializableFieldTypes): The type of the field. Defaults to FieldTypes.TEXT. multiline (CoalesceBool): Indicates whether the input field should support multiple lines. Defaults to True. """ - field_type: Optional[SerializableFieldTypes] = FieldTypes.PASSWORD + field_type: SerializableFieldTypes = FieldTypes.PASSWORD multiline: CoalesceBool = True password: CoalesceBool = Field(default=True) @@ -219,12 +228,12 @@ class SecretStrInput(BaseInputMixin, DatabaseLoadMixin): This class inherits from `BaseInputMixin` and `DatabaseLoadMixin`. Attributes: - field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to `FieldTypes.PASSWORD`. + field_type (SerializableFieldTypes): The field type of the input. Defaults to `FieldTypes.PASSWORD`. password (CoalesceBool): A boolean indicating whether the input is a password. Defaults to `True`. input_types (list[str]): A list of input types associated with this input. Defaults to an empty list. """ - field_type: Optional[SerializableFieldTypes] = FieldTypes.PASSWORD + field_type: SerializableFieldTypes = FieldTypes.PASSWORD password: CoalesceBool = Field(default=True) input_types: list[str] = ["Message"] load_from_db: CoalesceBool = True @@ -275,10 +284,33 @@ class IntInput(BaseInputMixin, ListableInputMixin, RangeMixin, MetadataTraceMixi It inherits from the `BaseInputMixin`, `ListableInputMixin`, and `RangeMixin` classes. Attributes: - field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.INTEGER. + field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.INTEGER. """ - field_type: Optional[SerializableFieldTypes] = FieldTypes.INTEGER + field_type: SerializableFieldTypes = FieldTypes.INTEGER + + @field_validator("value") + @classmethod + def validate_value(cls, v: Any, _info): + """ + Validates the given value and returns the processed value. + + Args: + v (Any): The value to be validated. + _info: Additional information about the input. + + Returns: + The processed value. + + Raises: + ValueError: If the value is not of a valid type or if the input is missing a required key. + """ + + if v and not isinstance(v, (int, float)): + raise ValueError(f"Invalid value type {type(v)} for input {_info.data.get('name')}.") + if isinstance(v, float): + v = int(v) + return v class FloatInput(BaseInputMixin, ListableInputMixin, RangeMixin, MetadataTraceMixin): @@ -289,10 +321,32 @@ class FloatInput(BaseInputMixin, ListableInputMixin, RangeMixin, MetadataTraceMi It inherits from the `BaseInputMixin`, `ListableInputMixin`, and `RangeMixin` classes. Attributes: - field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.FLOAT. + field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.FLOAT. """ - field_type: Optional[SerializableFieldTypes] = FieldTypes.FLOAT + field_type: SerializableFieldTypes = FieldTypes.FLOAT + + @field_validator("value") + @classmethod + def validate_value(cls, v: Any, _info): + """ + Validates the given value and returns the processed value. + + Args: + v (Any): The value to be validated. + _info: Additional information about the input. + + Returns: + The processed value. + + Raises: + ValueError: If the value is not of a valid type or if the input is missing a required key. + """ + if v and not isinstance(v, (int, float)): + raise ValueError(f"Invalid value type {type(v)} for input {_info.data.get('name')}.") + if isinstance(v, int): + v = float(v) + return v class BoolInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin): @@ -303,11 +357,11 @@ class BoolInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin): It inherits from the `BaseInputMixin` and `ListableInputMixin` classes. Attributes: - field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.BOOLEAN. + field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.BOOLEAN. value (CoalesceBool): The value of the boolean input. """ - field_type: Optional[SerializableFieldTypes] = FieldTypes.BOOLEAN + field_type: SerializableFieldTypes = FieldTypes.BOOLEAN value: CoalesceBool = False @@ -319,11 +373,11 @@ class NestedDictInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin, In It inherits from the `BaseInputMixin` and `ListableInputMixin` classes. Attributes: - field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.NESTED_DICT. + field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.NESTED_DICT. value (Optional[dict]): The value of the input. Defaults to an empty dictionary. """ - field_type: Optional[SerializableFieldTypes] = FieldTypes.NESTED_DICT + field_type: SerializableFieldTypes = FieldTypes.NESTED_DICT value: Optional[dict | Data] = {} @@ -335,11 +389,11 @@ class DictInput(BaseInputMixin, ListableInputMixin, InputTraceMixin): It inherits from the `BaseInputMixin` and `ListableInputMixin` classes. Attributes: - field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.DICT. + field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.DICT. value (Optional[dict]): The value of the dictionary input. Defaults to an empty dictionary. """ - field_type: Optional[SerializableFieldTypes] = FieldTypes.DICT + field_type: SerializableFieldTypes = FieldTypes.DICT value: Optional[dict] = {} @@ -351,12 +405,12 @@ class DropdownInput(BaseInputMixin, DropDownMixin, MetadataTraceMixin): It inherits from the `BaseInputMixin` and `DropDownMixin` classes. Attributes: - field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.TEXT. + field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.TEXT. options (Optional[Union[list[str], Callable]]): List of options for the field. Default is None. """ - field_type: Optional[SerializableFieldTypes] = FieldTypes.TEXT + field_type: SerializableFieldTypes = FieldTypes.TEXT options: list[str] = Field(default_factory=list) combobox: CoalesceBool = False @@ -369,12 +423,12 @@ class MultiselectInput(BaseInputMixin, ListableInputMixin, DropDownMixin, Metada It inherits from the `BaseInputMixin`, `ListableInputMixin` and `DropDownMixin` classes. Attributes: - field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.TEXT. + field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.TEXT. options (Optional[Union[list[str], Callable]]): List of options for the field. Only used when is_list=True. Default is None. """ - field_type: Optional[SerializableFieldTypes] = FieldTypes.TEXT + field_type: SerializableFieldTypes = FieldTypes.TEXT options: list[str] = Field(default_factory=list) is_list: bool = Field(default=True, serialization_alias="list") combobox: CoalesceBool = False @@ -399,10 +453,10 @@ class FileInput(BaseInputMixin, ListableInputMixin, FileMixin, MetadataTraceMixi It inherits from the `BaseInputMixin`, `ListableInputMixin`, and `FileMixin` classes. Attributes: - field_type (Optional[SerializableFieldTypes]): The field type of the input. Defaults to FieldTypes.FILE. + field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.FILE. """ - field_type: Optional[SerializableFieldTypes] = FieldTypes.FILE + field_type: SerializableFieldTypes = FieldTypes.FILE InputTypes = Union[ diff --git a/src/backend/base/langflow/schema/data.py b/src/backend/base/langflow/schema/data.py index 38710177c..2becf0e6c 100644 --- a/src/backend/base/langflow/schema/data.py +++ b/src/backend/base/langflow/schema/data.py @@ -25,6 +25,8 @@ class Data(BaseModel): @model_validator(mode="before") def validate_data(cls, values): + if not isinstance(values, dict): + raise ValueError("Data must be a dictionary") if not values.get("data"): values["data"] = {} # Any other keyword should be added to the data dictionary diff --git a/src/backend/base/langflow/schema/message.py b/src/backend/base/langflow/schema/message.py index 765cf0e70..f941a943f 100644 --- a/src/backend/base/langflow/schema/message.py +++ b/src/backend/base/langflow/schema/message.py @@ -101,7 +101,7 @@ class Message(Data): if self.sender == MESSAGE_SENDER_USER or not self.sender: if self.files: contents = [{"type": "text", "text": text}] - contents.extend(self.get_file_content_dicts()) + contents.extend(self.sync_get_file_content_dicts()) human_message = HumanMessage(content=contents) # type: ignore else: human_message = HumanMessage(content=text) @@ -156,6 +156,11 @@ class Message(Data): return "" return value + def sync_get_file_content_dicts(self): + coro = self.aget_file_content_dicts() + loop = asyncio.get_event_loop() + return loop.run_until_complete(coro) + async def get_file_content_dicts(self): content_dicts = [] files = await get_file_paths(self.files) @@ -165,9 +170,7 @@ class Message(Data): content_dicts.append(file.to_content_dict()) else: image_template = ImagePromptTemplate() - image_prompt_value: ImagePromptValue = image_template.invoke( - input={"path": file}, config={"callbacks": self.get_langchain_callbacks()} - ) # type: ignore + image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file}) # type: ignore content_dicts.append({"type": "image_url", "image_url": image_prompt_value.image_url}) return content_dicts diff --git a/src/backend/tests/unit/inputs/test_inputs.py b/src/backend/tests/unit/inputs/test_inputs.py new file mode 100644 index 000000000..535de0cbf --- /dev/null +++ b/src/backend/tests/unit/inputs/test_inputs.py @@ -0,0 +1,233 @@ +import pytest +from pydantic import ValidationError + +from langflow.inputs.inputs import ( + BoolInput, + DataInput, + DictInput, + DropdownInput, + FileInput, + FloatInput, + HandleInput, + InputTypesMap, + IntInput, + MessageTextInput, + MultilineInput, + MultilineSecretInput, + MultiselectInput, + NestedDictInput, + PromptInput, + SecretStrInput, + StrInput, + TableInput, + _instantiate_input, +) +from langflow.schema.message import Message + + +@pytest.fixture +def client(): + pass + + +def test_table_input_valid(): + # Test with a valid list of dictionaries + data = TableInput(value=[{"key": "value"}, {"key2": "value2"}]) + assert data.value == [{"key": "value"}, {"key2": "value2"}] + + +def test_table_input_invalid(): + with pytest.raises(ValidationError): + # Test with an invalid value + TableInput(value="invalid") + + with pytest.raises(ValidationError): + # Test with a list containing invalid item + TableInput(value=[{"key": "value"}, "invalid"]) + + +def test_str_input_valid(): + data = StrInput(value="This is a string") + assert data.value == "This is a string" + + +def test_str_input_invalid(): + with pytest.warns(UserWarning): + # Test with an invalid value + StrInput(value=1234) + + +def test_message_text_input_valid(): + # Test with a valid string + data = MessageTextInput(value="This is a message") + assert data.value == "This is a message" + + # Test with a valid Message object + msg = Message(text="This is a message") + data = MessageTextInput(value=msg) + assert data.value == "This is a message" + + +def test_message_text_input_invalid(): + with pytest.raises(ValidationError): + # Test with an invalid value + MessageTextInput(value=1234) + + +def test_instantiate_input_valid(): + data = {"value": "This is a string"} + input_instance = _instantiate_input("StrInput", data) + assert isinstance(input_instance, StrInput) + assert input_instance.value == "This is a string" + + +def test_instantiate_input_invalid(): + with pytest.raises(ValueError): + # Test with an invalid input type + _instantiate_input("InvalidInput", {"value": "This is a string"}) + + +def test_handle_input_valid(): + data = HandleInput(input_types=["BaseLanguageModel"]) + assert data.input_types == ["BaseLanguageModel"] + + +def test_handle_input_invalid(): + with pytest.raises(ValidationError): + HandleInput(input_types="BaseLanguageModel") # should be a list, not a string + + +def test_data_input_valid(): + data_input = DataInput(input_types=["Data"]) + assert data_input.input_types == ["Data"] + + +def test_prompt_input_valid(): + prompt_input = PromptInput(value="Enter your name") + assert prompt_input.value == "Enter your name" + + +def test_multiline_input_valid(): + multiline_input = MultilineInput(value="This is a\nmultiline input") + assert multiline_input.value == "This is a\nmultiline input" + assert multiline_input.multiline is True + + +def test_multiline_input_invalid(): + with pytest.raises(ValidationError): + MultilineInput(value=1234) # should be a string, not an integer + + +def test_multiline_secret_input_valid(): + multiline_secret_input = MultilineSecretInput(value="secret") + assert multiline_secret_input.value == "secret" + assert multiline_secret_input.password is True + + +def test_multiline_secret_input_invalid(): + with pytest.raises(ValidationError): + MultilineSecretInput(value=1234) # should be a string, not an integer + + +def test_secret_str_input_valid(): + secret_str_input = SecretStrInput(value="supersecret") + assert secret_str_input.value == "supersecret" + assert secret_str_input.password is True + + +def test_secret_str_input_invalid(): + with pytest.raises(ValidationError): + SecretStrInput(value=1234) # should be a string, not an integer + + +def test_int_input_valid(): + int_input = IntInput(value=10) + assert int_input.value == 10 + + +def test_int_input_invalid(): + with pytest.raises(ValidationError): + IntInput(value="not_an_int") # should be an integer, not a string + + +def test_float_input_valid(): + float_input = FloatInput(value=10.5) + assert float_input.value == 10.5 + + +def test_float_input_invalid(): + with pytest.raises(ValidationError): + FloatInput(value="not_a_float") # should be a float, not a string + + +def test_bool_input_valid(): + bool_input = BoolInput(value=True) + assert bool_input.value is True + + +def test_bool_input_invalid(): + with pytest.raises(ValidationError): + BoolInput(value="not_a_bool") # should be a bool, not a string + + +def test_nested_dict_input_valid(): + nested_dict_input = NestedDictInput(value={"key": "value"}) + assert nested_dict_input.value == {"key": "value"} + + +def test_nested_dict_input_invalid(): + with pytest.raises(ValidationError): + NestedDictInput(value="not_a_dict") # should be a dict, not a string + + +def test_dict_input_valid(): + dict_input = DictInput(value={"key": "value"}) + assert dict_input.value == {"key": "value"} + + +def test_dict_input_invalid(): + with pytest.raises(ValidationError): + DictInput(value="not_a_dict") # should be a dict, not a string + + +def test_dropdown_input_valid(): + dropdown_input = DropdownInput(options=["option1", "option2"]) + assert dropdown_input.options == ["option1", "option2"] + + +def test_dropdown_input_invalid(): + with pytest.raises(ValidationError): + DropdownInput(options="option1") # should be a list, not a string + + +def test_multiselect_input_valid(): + multiselect_input = MultiselectInput(value=["option1", "option2"]) + assert multiselect_input.value == ["option1", "option2"] + + +def test_multiselect_input_invalid(): + with pytest.raises(ValidationError): + MultiselectInput(value="option1") # should be a list, not a string + + +def test_file_input_valid(): + file_input = FileInput(value=["/path/to/file"]) + assert file_input.value == ["/path/to/file"] + + +def test_instantiate_input_comprehensive(): + valid_data = { + "StrInput": {"value": "A string"}, + "IntInput": {"value": 10}, + "FloatInput": {"value": 10.5}, + "BoolInput": {"value": True}, + "DictInput": {"value": {"key": "value"}}, + "MultiselectInput": {"value": ["option1", "option2"]}, + } + + for input_type, data in valid_data.items(): + input_instance = _instantiate_input(input_type, data) + assert isinstance(input_instance, InputTypesMap[input_type]) + + with pytest.raises(ValueError): + _instantiate_input("InvalidInput", {"value": "Invalid"}) # Invalid input type