From 2276d050fe483cc0cb883dff749336468e5805bf Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Fri, 14 Jun 2024 11:52:50 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20(component.py):=20Add=20input=20?= =?UTF-8?q?validation=20to=20Component?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 📝 (component.py): Update imports and type annotations for better readability and maintainability 🔧 (component.py): Refactor map_inputs method to accept InputTypes and validate inputs in Component class ♻️ (component.py): Refactor _validate_inputs method to check if input is a class method 🔧 (util.py): Add is_class_method function to check if a function is a class method --- .../custom/custom_component/component.py | 42 ++++++++++++------- src/backend/base/langflow/utils/util.py | 7 ++++ 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index f13d8a0a2..74b5d190f 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -1,25 +1,17 @@ import inspect -from typing import ( - TYPE_CHECKING, - AsyncIterator, - Awaitable, - Callable, - ClassVar, - Generator, - Iterator, - List, - Optional, - Union, -) +from typing import AsyncIterator, Awaitable, Callable, ClassVar, Generator, Iterator, List, Optional, Union from uuid import UUID import yaml +from git import TYPE_CHECKING from loguru import logger from pydantic import BaseModel +from langflow.inputs.inputs import InputTypes from langflow.schema.artifact import get_artifact_type, post_process_raw from langflow.schema.data import Data from langflow.template.field.base import UNDEFINED, Input, Output +from langflow.utils.util import is_class_method from .custom_component import CustomComponent @@ -36,7 +28,6 @@ def recursive_serialize_or_str(obj): elif isinstance(obj, BaseModel): return {k: recursive_serialize_or_str(v) for k, v in obj.model_dump().items()} elif isinstance(obj, (AsyncIterator, Generator, Iterator)): - # Turn it into something readable that does not # contain memory addresses # without consuming the iterator # return list(obj) consumes the iterator @@ -49,13 +40,35 @@ def recursive_serialize_or_str(obj): class Component(CustomComponent): - inputs: Optional[List[Input]] = None + inputs: Optional[List[InputTypes]] = None outputs: Optional[List[Output]] = None code_class_base_inheritance: ClassVar[str] = "Component" _results: dict = {} _arguments: dict = {} + _inputs: dict[str, InputTypes] = {} + + def __init__(self, **data): + super().__init__(**data) + if self.inputs is not None: + self.map_inputs(self.inputs) + + def map_inputs(self, inputs: List[Input]): + self.inputs = inputs + for input_ in inputs: + self._inputs[input_.name] = input_ + + def _validate_inputs(self, params: dict): + # Params keys are the `name` attribute of the Input objects + for key, value in params.items(): + if key not in self._inputs: + raise ValueError(f"Input {key} not found in arguments") + input_ = self._inputs[key] + # validate_inputs must be a classmethod + if hasattr(input_, "validate_value") and is_class_method(func=input_.validate_value, cls=input_): + input_.validate_value(value) def set_attributes(self, params: dict): + self._validate_inputs(params) for key, value in params.items(): if key in self.__dict__: raise ValueError(f"Key {key} already exists in {self.__class__.__name__}") @@ -149,3 +162,4 @@ class Component(CustomComponent): return [field.name for field in inputs] except KeyError: return [] + return [] diff --git a/src/backend/base/langflow/utils/util.py b/src/backend/base/langflow/utils/util.py index cbf1d393b..d58fdc4f8 100644 --- a/src/backend/base/langflow/utils/util.py +++ b/src/backend/base/langflow/utils/util.py @@ -449,3 +449,10 @@ def update_settings( if not store: logger.debug("Setting store to False") settings_service.settings.update_settings(store=False) + + +def is_class_method(func, cls): + """ + Check if a function is a class method. + """ + return inspect.ismethod(func) and func.__self__ is cls.__class__