🔧 (component.py): Add input validation to Component
📝 (component.py): Update imports and type annotations for better readability and maintainability 🔧 (component.py): Refactor map_inputs method to accept InputTypes and validate inputs in Component class ♻️ (component.py): Refactor _validate_inputs method to check if input is a class method 🔧 (util.py): Add is_class_method function to check if a function is a class method
This commit is contained in:
parent
6ff8b01e9c
commit
2276d050fe
2 changed files with 35 additions and 14 deletions
|
|
@ -1,25 +1,17 @@
|
|||
import inspect
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from typing import AsyncIterator, Awaitable, Callable, ClassVar, Generator, Iterator, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
import yaml
|
||||
from git import TYPE_CHECKING
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
from langflow.schema.artifact import get_artifact_type, post_process_raw
|
||||
from langflow.schema.data import Data
|
||||
from langflow.template.field.base import UNDEFINED, Input, Output
|
||||
from langflow.utils.util import is_class_method
|
||||
|
||||
from .custom_component import CustomComponent
|
||||
|
||||
|
|
@ -36,7 +28,6 @@ def recursive_serialize_or_str(obj):
|
|||
elif isinstance(obj, BaseModel):
|
||||
return {k: recursive_serialize_or_str(v) for k, v in obj.model_dump().items()}
|
||||
elif isinstance(obj, (AsyncIterator, Generator, Iterator)):
|
||||
# Turn it into something readable that does not
|
||||
# contain memory addresses
|
||||
# without consuming the iterator
|
||||
# return list(obj) consumes the iterator
|
||||
|
|
@ -49,13 +40,35 @@ def recursive_serialize_or_str(obj):
|
|||
|
||||
|
||||
class Component(CustomComponent):
|
||||
inputs: Optional[List[Input]] = None
|
||||
inputs: Optional[List[InputTypes]] = None
|
||||
outputs: Optional[List[Output]] = None
|
||||
code_class_base_inheritance: ClassVar[str] = "Component"
|
||||
_results: dict = {}
|
||||
_arguments: dict = {}
|
||||
_inputs: dict[str, InputTypes] = {}
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
if self.inputs is not None:
|
||||
self.map_inputs(self.inputs)
|
||||
|
||||
def map_inputs(self, inputs: List[Input]):
|
||||
self.inputs = inputs
|
||||
for input_ in inputs:
|
||||
self._inputs[input_.name] = input_
|
||||
|
||||
def _validate_inputs(self, params: dict):
|
||||
# Params keys are the `name` attribute of the Input objects
|
||||
for key, value in params.items():
|
||||
if key not in self._inputs:
|
||||
raise ValueError(f"Input {key} not found in arguments")
|
||||
input_ = self._inputs[key]
|
||||
# validate_inputs must be a classmethod
|
||||
if hasattr(input_, "validate_value") and is_class_method(func=input_.validate_value, cls=input_):
|
||||
input_.validate_value(value)
|
||||
|
||||
def set_attributes(self, params: dict):
|
||||
self._validate_inputs(params)
|
||||
for key, value in params.items():
|
||||
if key in self.__dict__:
|
||||
raise ValueError(f"Key {key} already exists in {self.__class__.__name__}")
|
||||
|
|
@ -149,3 +162,4 @@ class Component(CustomComponent):
|
|||
return [field.name for field in inputs]
|
||||
except KeyError:
|
||||
return []
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -449,3 +449,10 @@ def update_settings(
|
|||
if not store:
|
||||
logger.debug("Setting store to False")
|
||||
settings_service.settings.update_settings(store=False)
|
||||
|
||||
|
||||
def is_class_method(func, cls):
|
||||
"""
|
||||
Check if a function is a class method.
|
||||
"""
|
||||
return inspect.ismethod(func) and func.__self__ is cls.__class__
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue