diff --git a/src/backend/langflow/template/field/base.py b/src/backend/langflow/template/field/base.py index 751b72af0..2c24e75bc 100644 --- a/src/backend/langflow/template/field/base.py +++ b/src/backend/langflow/template/field/base.py @@ -1,11 +1,11 @@ from abc import ABC from typing import Any, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field, field_serializer, model_serializer class TemplateFieldCreator(BaseModel, ABC): - field_type: str = "str" + field_type: str = Field(default="str", alias="type") """The type of field this is. Default is a string.""" required: bool = False @@ -14,7 +14,7 @@ class TemplateFieldCreator(BaseModel, ABC): placeholder: str = "" """A placeholder string for the field. Default is an empty string.""" - is_list: bool = False + is_list: bool = Field(default=False, alias="list") """Defines if the field is a list. Default is False.""" show: bool = True @@ -26,7 +26,7 @@ class TemplateFieldCreator(BaseModel, ABC): value: Any = None """The value of the field. Default is None.""" - file_types: list[str] = [] + file_types: list[str] = Field(default=[], alias="fileTypes") """List of file types associated with the field. Default is an empty list. (duplicate)""" file_path: Union[str, None] = None @@ -35,7 +35,7 @@ class TemplateFieldCreator(BaseModel, ABC): password: bool = False """Specifies if the field is a password. Defaults to False.""" - options: list[str] = [] + options: list[str] = None """List of options for the field. Only used when is_list=True. Default is an empty list.""" name: str = "" @@ -47,7 +47,7 @@ class TemplateFieldCreator(BaseModel, ABC): advanced: bool = False """Specifies if the field will an advanced parameter (hidden). Defaults to False.""" - input_types: list[str] = [] + input_types: Optional[list[str]] = None """List of input types for the handle when the field has more than one type. Default is an empty list.""" dynamic: bool = False @@ -59,22 +59,31 @@ class TemplateFieldCreator(BaseModel, ABC): refresh: Optional[bool] = None """Specifies if the field should be refreshed. Defaults to False.""" - def to_dict(self): - result = self.model_dump() - # Remove key if it is None - for key in list(result.keys()): - if result[key] is None or result[key] == [] and key != "value": - del result[key] - result["type"] = result.pop("field_type") - result["list"] = result.pop("is_list") - - if result.get("file_types"): - result["fileTypes"] = result.pop("file_types") - - if self.field_type == "file": - result["file_path"] = self.file_path + @model_serializer(mode="wrap") + def serialize_model(self, handler): + # This will be the result of model_dump or dict() + # so we need to build a dict to return + result = handler(self) + result["value"] = self.value return result + + # for key in list(result.keys()): + # if result[key] is None or result[key] == [] and key != "value": + # del result[key] + # return result + + def to_dict(self): + return self.model_dump(by_alias=True, exclude_none=True) + + + @field_serializer("file_path") + def serialize_file_path(self, value): + if self.field_type == "file": + return value + return None + + class TemplateField(TemplateFieldCreator): pass diff --git a/src/backend/langflow/template/frontend_node/base.py b/src/backend/langflow/template/frontend_node/base.py index 04467a094..f05260f68 100644 --- a/src/backend/langflow/template/frontend_node/base.py +++ b/src/backend/langflow/template/frontend_node/base.py @@ -3,11 +3,12 @@ from collections import defaultdict from typing import ClassVar, Dict, List, Optional from langflow.template.field.base import TemplateField -from langflow.template.frontend_node.constants import CLASSES_TO_REMOVE, FORCE_SHOW_FIELDS +from langflow.template.frontend_node.constants import (CLASSES_TO_REMOVE, + FORCE_SHOW_FIELDS) from langflow.template.frontend_node.formatter import field_formatters from langflow.template.template.base import Template from langflow.utils import constants -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_serializer, model_serializer class FieldFormatters(BaseModel): @@ -63,26 +64,31 @@ class FrontendNode(BaseModel): """Sets the documentation of the frontend node.""" self.documentation = documentation - def process_base_classes(self) -> None: + @field_serializer("base_classes") + def process_base_classes(self, base_classes: List[str]) -> List[str]: """Removes unwanted base classes from the list of base classes.""" - self.base_classes = [base_class for base_class in self.base_classes if base_class not in CLASSES_TO_REMOVE] + + return [base_class for base_class in base_classes if base_class not in CLASSES_TO_REMOVE] + + @field_serializer("display_name") + def process_display_name(self, display_name: str) -> str: + """Sets the display name of the frontend node.""" + + return display_name or self.name + + + @model_serializer(mode="wrap") + def serialize(self, handler): + result = handler(self) + result["template"] = self.template.to_dict(self.format_field) + name = result.pop("name") + + return {name: result} def to_dict(self) -> dict: """Returns a dict representation of the frontend node.""" - self.process_base_classes() - return { - self.name: { - "template": self.template.to_dict(self.format_field), - "description": self.description, - "base_classes": self.base_classes, - "display_name": self.display_name or self.name, - "custom_fields": self.custom_fields, - "output_types": self.output_types, - "documentation": self.documentation, - "beta": self.beta, - "error": self.error, - }, - } + + return self.model_dump(by_alias=True, exclude_none=True) def add_extra_fields(self) -> None: pass diff --git a/src/backend/langflow/template/template/base.py b/src/backend/langflow/template/template/base.py index c680fd468..6b1fc1b8e 100644 --- a/src/backend/langflow/template/template/base.py +++ b/src/backend/langflow/template/template/base.py @@ -1,9 +1,8 @@ -from typing import Callable, Optional, Union - -from pydantic import BaseModel +from typing import Callable, Union from langflow.template.field.base import TemplateField from langflow.utils.constants import DIRECT_TYPES +from pydantic import BaseModel, model_serializer class Template(BaseModel): @@ -12,12 +11,11 @@ class Template(BaseModel): def process_fields( self, - name: Optional[str] = None, format_field_func: Union[Callable, None] = None, ): if format_field_func: for field in self.fields: - format_field_func(field, name) + format_field_func(field, self.type_name) def sort_fields(self): # first sort alphabetically @@ -25,12 +23,20 @@ class Template(BaseModel): self.fields.sort(key=lambda x: x.name) self.fields.sort(key=lambda x: x.field_type in DIRECT_TYPES, reverse=False) - def to_dict(self, format_field_func=None): - self.process_fields(self.type_name, format_field_func) - self.sort_fields() - result = {field.name: field.to_dict() for field in self.fields} - result["_type"] = self.type_name # type: ignore + @model_serializer(mode="wrap") + def serialize_model(self, handler): + result = handler(self) + for field in self.fields: + result[field.name] = field.to_dict() + result["_type"] = result.pop("type_name") return result + def to_dict(self, format_field_func=None): + self.process_fields(format_field_func) + self.sort_fields() + # result = {field.name: field.to_dict() for field in self.fields} + # result["_type"] = self.type_name # type: ignore + return self.model_dump(by_alias=True, exclude_none=True, exclude={"fields"}) + def add_field(self, field: TemplateField) -> None: self.fields.append(field)