🔧 (component.py): Add input validation to Component

📝 (component.py): Update imports and type annotations for better readability and maintainability
🔧 (component.py): Refactor map_inputs method to accept InputTypes and validate inputs in Component class
♻️ (component.py): Refactor _validate_inputs method to check if input is a class method
🔧 (util.py): Add is_class_method function to check if a function is a class method
This commit is contained in:
ogabrielluiz 2024-06-14 11:52:50 -03:00
commit 2276d050fe
2 changed files with 35 additions and 14 deletions

View file

@ -1,25 +1,17 @@
import inspect
from typing import (
TYPE_CHECKING,
AsyncIterator,
Awaitable,
Callable,
ClassVar,
Generator,
Iterator,
List,
Optional,
Union,
)
from typing import AsyncIterator, Awaitable, Callable, ClassVar, Generator, Iterator, List, Optional, Union
from uuid import UUID
import yaml
from git import TYPE_CHECKING
from loguru import logger
from pydantic import BaseModel
from langflow.inputs.inputs import InputTypes
from langflow.schema.artifact import get_artifact_type, post_process_raw
from langflow.schema.data import Data
from langflow.template.field.base import UNDEFINED, Input, Output
from langflow.utils.util import is_class_method
from .custom_component import CustomComponent
@ -36,7 +28,6 @@ def recursive_serialize_or_str(obj):
elif isinstance(obj, BaseModel):
return {k: recursive_serialize_or_str(v) for k, v in obj.model_dump().items()}
elif isinstance(obj, (AsyncIterator, Generator, Iterator)):
# Turn it into something readable that does not
# contain memory addresses
# without consuming the iterator
# return list(obj) consumes the iterator
@ -49,13 +40,35 @@ def recursive_serialize_or_str(obj):
class Component(CustomComponent):
inputs: Optional[List[Input]] = None
inputs: Optional[List[InputTypes]] = None
outputs: Optional[List[Output]] = None
code_class_base_inheritance: ClassVar[str] = "Component"
_results: dict = {}
_arguments: dict = {}
_inputs: dict[str, InputTypes] = {}
def __init__(self, **data):
super().__init__(**data)
if self.inputs is not None:
self.map_inputs(self.inputs)
def map_inputs(self, inputs: List[Input]):
self.inputs = inputs
for input_ in inputs:
self._inputs[input_.name] = input_
def _validate_inputs(self, params: dict):
# Params keys are the `name` attribute of the Input objects
for key, value in params.items():
if key not in self._inputs:
raise ValueError(f"Input {key} not found in arguments")
input_ = self._inputs[key]
# validate_inputs must be a classmethod
if hasattr(input_, "validate_value") and is_class_method(func=input_.validate_value, cls=input_):
input_.validate_value(value)
def set_attributes(self, params: dict):
self._validate_inputs(params)
for key, value in params.items():
if key in self.__dict__:
raise ValueError(f"Key {key} already exists in {self.__class__.__name__}")
@ -149,3 +162,4 @@ class Component(CustomComponent):
return [field.name for field in inputs]
except KeyError:
return []
return []

View file

@ -449,3 +449,10 @@ def update_settings(
if not store:
logger.debug("Setting store to False")
settings_service.settings.update_settings(store=False)
def is_class_method(func, cls):
"""
Check if a function is a class method.
"""
return inspect.ismethod(func) and func.__self__ is cls.__class__