Refactor field serialization and add model

serialization in Template and FrontendNode classes
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-12-09 18:11:51 -03:00
commit 3c955de5e3
3 changed files with 69 additions and 48 deletions

View file

@ -1,11 +1,11 @@
from abc import ABC
from typing import Any, Optional, Union
from pydantic import BaseModel
from pydantic import BaseModel, Field, field_serializer, model_serializer
class TemplateFieldCreator(BaseModel, ABC):
field_type: str = "str"
field_type: str = Field(default="str", alias="type")
"""The type of field this is. Default is a string."""
required: bool = False
@ -14,7 +14,7 @@ class TemplateFieldCreator(BaseModel, ABC):
placeholder: str = ""
"""A placeholder string for the field. Default is an empty string."""
is_list: bool = False
is_list: bool = Field(default=False, alias="list")
"""Defines if the field is a list. Default is False."""
show: bool = True
@ -26,7 +26,7 @@ class TemplateFieldCreator(BaseModel, ABC):
value: Any = None
"""The value of the field. Default is None."""
file_types: list[str] = []
file_types: list[str] = Field(default=[], alias="fileTypes")
"""List of file types associated with the field. Default is an empty list. (duplicate)"""
file_path: Union[str, None] = None
@ -35,7 +35,7 @@ class TemplateFieldCreator(BaseModel, ABC):
password: bool = False
"""Specifies if the field is a password. Defaults to False."""
options: list[str] = []
options: list[str] = None
"""List of options for the field. Only used when is_list=True. Default is an empty list."""
name: str = ""
@ -47,7 +47,7 @@ class TemplateFieldCreator(BaseModel, ABC):
advanced: bool = False
"""Specifies if the field will an advanced parameter (hidden). Defaults to False."""
input_types: list[str] = []
input_types: Optional[list[str]] = None
"""List of input types for the handle when the field has more than one type. Default is an empty list."""
dynamic: bool = False
@ -59,22 +59,31 @@ class TemplateFieldCreator(BaseModel, ABC):
refresh: Optional[bool] = None
"""Specifies if the field should be refreshed. Defaults to False."""
def to_dict(self):
result = self.model_dump()
# Remove key if it is None
for key in list(result.keys()):
if result[key] is None or result[key] == [] and key != "value":
del result[key]
result["type"] = result.pop("field_type")
result["list"] = result.pop("is_list")
if result.get("file_types"):
result["fileTypes"] = result.pop("file_types")
if self.field_type == "file":
result["file_path"] = self.file_path
@model_serializer(mode="wrap")
def serialize_model(self, handler):
# This will be the result of model_dump or dict()
# so we need to build a dict to return
result = handler(self)
result["value"] = self.value
return result
# for key in list(result.keys()):
# if result[key] is None or result[key] == [] and key != "value":
# del result[key]
# return result
def to_dict(self):
return self.model_dump(by_alias=True, exclude_none=True)
@field_serializer("file_path")
def serialize_file_path(self, value):
if self.field_type == "file":
return value
return None
class TemplateField(TemplateFieldCreator):
pass

View file

@ -3,11 +3,12 @@ from collections import defaultdict
from typing import ClassVar, Dict, List, Optional
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.constants import CLASSES_TO_REMOVE, FORCE_SHOW_FIELDS
from langflow.template.frontend_node.constants import (CLASSES_TO_REMOVE,
FORCE_SHOW_FIELDS)
from langflow.template.frontend_node.formatter import field_formatters
from langflow.template.template.base import Template
from langflow.utils import constants
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_serializer, model_serializer
class FieldFormatters(BaseModel):
@ -63,26 +64,31 @@ class FrontendNode(BaseModel):
"""Sets the documentation of the frontend node."""
self.documentation = documentation
def process_base_classes(self) -> None:
@field_serializer("base_classes")
def process_base_classes(self, base_classes: List[str]) -> List[str]:
"""Removes unwanted base classes from the list of base classes."""
self.base_classes = [base_class for base_class in self.base_classes if base_class not in CLASSES_TO_REMOVE]
return [base_class for base_class in base_classes if base_class not in CLASSES_TO_REMOVE]
@field_serializer("display_name")
def process_display_name(self, display_name: str) -> str:
"""Sets the display name of the frontend node."""
return display_name or self.name
@model_serializer(mode="wrap")
def serialize(self, handler):
result = handler(self)
result["template"] = self.template.to_dict(self.format_field)
name = result.pop("name")
return {name: result}
def to_dict(self) -> dict:
"""Returns a dict representation of the frontend node."""
self.process_base_classes()
return {
self.name: {
"template": self.template.to_dict(self.format_field),
"description": self.description,
"base_classes": self.base_classes,
"display_name": self.display_name or self.name,
"custom_fields": self.custom_fields,
"output_types": self.output_types,
"documentation": self.documentation,
"beta": self.beta,
"error": self.error,
},
}
return self.model_dump(by_alias=True, exclude_none=True)
def add_extra_fields(self) -> None:
pass

View file

@ -1,9 +1,8 @@
from typing import Callable, Optional, Union
from pydantic import BaseModel
from typing import Callable, Union
from langflow.template.field.base import TemplateField
from langflow.utils.constants import DIRECT_TYPES
from pydantic import BaseModel, model_serializer
class Template(BaseModel):
@ -12,12 +11,11 @@ class Template(BaseModel):
def process_fields(
self,
name: Optional[str] = None,
format_field_func: Union[Callable, None] = None,
):
if format_field_func:
for field in self.fields:
format_field_func(field, name)
format_field_func(field, self.type_name)
def sort_fields(self):
# first sort alphabetically
@ -25,12 +23,20 @@ class Template(BaseModel):
self.fields.sort(key=lambda x: x.name)
self.fields.sort(key=lambda x: x.field_type in DIRECT_TYPES, reverse=False)
def to_dict(self, format_field_func=None):
self.process_fields(self.type_name, format_field_func)
self.sort_fields()
result = {field.name: field.to_dict() for field in self.fields}
result["_type"] = self.type_name # type: ignore
@model_serializer(mode="wrap")
def serialize_model(self, handler):
result = handler(self)
for field in self.fields:
result[field.name] = field.to_dict()
result["_type"] = result.pop("type_name")
return result
def to_dict(self, format_field_func=None):
self.process_fields(format_field_func)
self.sort_fields()
# result = {field.name: field.to_dict() for field in self.fields}
# result["_type"] = self.type_name # type: ignore
return self.model_dump(by_alias=True, exclude_none=True, exclude={"fields"})
def add_field(self, field: TemplateField) -> None:
self.fields.append(field)