From 6ff8b01e9c3773afde11cc69dece26dd3fdda3e5 Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Fri, 14 Jun 2024 11:52:08 -0300 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20(inputs/=5F=5Finit=5F=5F.py):=20Add?= =?UTF-8?q?=20TextInput=20class=20to=20support=20text=20input=20type=20in?= =?UTF-8?q?=20langflow=20inputs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 📝 (inputs/inputs.py): Add TextInput class with validation logic for different input types like Data, Message, and Text 📝 (schema/message.py): Add text_key attribute to Message class to specify the key for text data in the message object --- src/backend/base/langflow/inputs/__init__.py | 4 ++- src/backend/base/langflow/inputs/inputs.py | 30 ++++++++++++++++++-- src/backend/base/langflow/schema/message.py | 1 + 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/backend/base/langflow/inputs/__init__.py b/src/backend/base/langflow/inputs/__init__.py index 53927ca0f..6939c470c 100644 --- a/src/backend/base/langflow/inputs/__init__.py +++ b/src/backend/base/langflow/inputs/__init__.py @@ -4,13 +4,14 @@ from .inputs import ( DropdownInput, FileInput, FloatInput, + HandleInput, IntInput, MultilineInput, NestedDictInput, PromptInput, SecretStrInput, StrInput, - HandleInput, + TextInput, ) __all__ = [ @@ -26,4 +27,5 @@ __all__ = [ "PromptInput", "MultilineInput", "HandleInput", + "TextInput", ] diff --git a/src/backend/base/langflow/inputs/inputs.py b/src/backend/base/langflow/inputs/inputs.py index c8146b2c2..17d10d78c 100644 --- a/src/backend/base/langflow/inputs/inputs.py +++ b/src/backend/base/langflow/inputs/inputs.py @@ -1,8 +1,10 @@ -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union -from pydantic import Field, model_validator +from pydantic import Field, field_validator, model_validator from langflow.inputs.validators import StrictBoolean +from langflow.schema.data import Data +from langflow.schema.message import Message from .input_mixin import ( BaseInputMixin, @@ -39,6 +41,30 @@ class StrInput(BaseInputMixin, ListableInputMixin, DatabaseLoadMixin): # noqa: """Defines if the field will allow the user to open a text editor. Default is False.""" +class TextInput(StrInput): + input_types: list[str] = ["Data", "Message", "Text"] + + @field_validator("value") + @classmethod + def validate_value(cls, v: Any, _info): + if isinstance(v, str): + return v + elif isinstance(v, Message): + return v.text + elif isinstance(v, Data): + if v.text_key in v.data: + return v.data[v.text_key] + else: + keys = ", ".join(v.data.keys()) + input_name = _info.data["name"] + raise ValueError( + f"The input to '{input_name}' must contain the key '{v.text_key}'." + f"You can set `text_key` to one of the following keys: {keys} or set the value using another Component." + ) + else: + raise ValueError(f"Invalid input type {type(v)}") + + class MultilineInput(BaseInputMixin): field_type: Optional[SerializableFieldTypes] = FieldTypes.TEXT multiline: StrictBoolean = True diff --git a/src/backend/base/langflow/schema/message.py b/src/backend/base/langflow/schema/message.py index caa7d95bf..706d337aa 100644 --- a/src/backend/base/langflow/schema/message.py +++ b/src/backend/base/langflow/schema/message.py @@ -17,6 +17,7 @@ def _timestamp_to_str(timestamp: datetime) -> str: class Message(Data): model_config = ConfigDict(arbitrary_types_allowed=True) # Helper class to deal with image data + text_key: str = "text" text: Optional[str | AsyncIterator | Iterator] = Field(default="") sender: str sender_name: str