From dea322b024f6a3234e7d3e68c8b6b8a29f50879a Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 25 Jul 2024 14:09:12 -0300 Subject: [PATCH] feat: add InputTypesMap and _instantiate_input function (#2955) * refactor: add _input_type attribute to dump in BaseInputMixin * feat: add InputTypesMap and _instantiate_input function The commit adds the `InputTypesMap` dictionary and `_instantiate_input` function to the `inputs.py` file. The `InputTypesMap` is a dictionary that maps input types to their corresponding classes, and the `_instantiate_input` function is used to instantiate an input object based on its type. This change improves the flexibility and extensibility of the codebase. --- src/backend/base/langflow/inputs/input_mixin.py | 1 + src/backend/base/langflow/inputs/inputs.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/backend/base/langflow/inputs/input_mixin.py b/src/backend/base/langflow/inputs/input_mixin.py index d3a34ee2d..7092b4dee 100644 --- a/src/backend/base/langflow/inputs/input_mixin.py +++ b/src/backend/base/langflow/inputs/input_mixin.py @@ -87,6 +87,7 @@ class BaseInputMixin(BaseModel, validate_assignment=True): dump = handler(self) if "field_type" in dump: dump["type"] = dump.pop("field_type") + dump["_input_type"] = self.__class__.__name__ return dump diff --git a/src/backend/base/langflow/inputs/inputs.py b/src/backend/base/langflow/inputs/inputs.py index d6a5c0423..8ba5d3624 100644 --- a/src/backend/base/langflow/inputs/inputs.py +++ b/src/backend/base/langflow/inputs/inputs.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncIterator, Iterator, Optional, Union +from typing import Any, AsyncIterator, Iterator, Optional, Union, get_args from loguru import logger from pydantic import Field, field_validator @@ -360,3 +360,13 @@ InputTypes = Union[ MessageInput, TableInput, ] + +InputTypesMap: dict[str, type[InputTypes]] = {t.__name__: t for t in get_args(InputTypes)} + + +def _instantiate_input(input_type: str, data: dict) -> InputTypes: + input_type_class = InputTypesMap.get(input_type) + if input_type_class: + return input_type_class(**data) + else: + raise ValueError(f"Invalid input type: {input_type}")