diff --git a/src/backend/base/langflow/inputs/input_mixin.py b/src/backend/base/langflow/inputs/input_mixin.py index d3a34ee2d..7092b4dee 100644 --- a/src/backend/base/langflow/inputs/input_mixin.py +++ b/src/backend/base/langflow/inputs/input_mixin.py @@ -87,6 +87,7 @@ class BaseInputMixin(BaseModel, validate_assignment=True): dump = handler(self) if "field_type" in dump: dump["type"] = dump.pop("field_type") + dump["_input_type"] = self.__class__.__name__ return dump diff --git a/src/backend/base/langflow/inputs/inputs.py b/src/backend/base/langflow/inputs/inputs.py index d6a5c0423..8ba5d3624 100644 --- a/src/backend/base/langflow/inputs/inputs.py +++ b/src/backend/base/langflow/inputs/inputs.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncIterator, Iterator, Optional, Union +from typing import Any, AsyncIterator, Iterator, Optional, Union, get_args from loguru import logger from pydantic import Field, field_validator @@ -360,3 +360,13 @@ InputTypes = Union[ MessageInput, TableInput, ] + +InputTypesMap: dict[str, type[InputTypes]] = {t.__name__: t for t in get_args(InputTypes)} + + +def _instantiate_input(input_type: str, data: dict) -> InputTypes: + input_type_class = InputTypesMap.get(input_type) + if input_type_class: + return input_type_class(**data) + else: + raise ValueError(f"Invalid input type: {input_type}")