Refactor build_inputs method to add extra fields in ComponentFrontendNode

This commit is contained in:
ogabrielluiz 2024-06-06 15:34:00 -03:00
commit 45e2691598
3 changed files with 32 additions and 15 deletions

View file

@ -110,6 +110,7 @@ class Component(CustomComponent):
"""
# This function is similar to build_config, but it will process the inputs
# and return them as a dict with keys being the Input.name and values being the Input.model_dump()
self.inputs = self.template_config.get("inputs", [])
if not self.inputs:
return {}
build_config = {_input.name: _input.model_dump(by_alias=True, exclude_none=True) for _input in self.inputs}

View file

@ -103,6 +103,8 @@ def extract_type_from_optional(field_type):
Returns:
str: The extracted type, or an empty string if no type was found.
"""
if "optional" not in field_type.lower():
return field_type
match = re.search(r"\[(.*?)\]$", field_type)
return match[1] if match else field_type
@ -249,10 +251,16 @@ def get_field_dict(field: Union[Input, dict]):
return field
def run_build_inputs(custom_component: Component, user_id: Optional[Union[str, UUID]] = None):
def run_build_inputs(
frontend_node: ComponentFrontendNode,
custom_component: Component,
user_id: Optional[Union[str, UUID]] = None,
):
"""Run the build inputs of a custom component."""
try:
return custom_component.build_inputs(user_id=user_id)
field_config = custom_component.build_inputs(user_id=user_id)
add_extra_fields(frontend_node, field_config, field_config.values())
return field_config
except Exception as exc:
logger.error(f"Error running build inputs: {exc}")
raise HTTPException(status_code=500, detail=str(exc)) from exc
@ -326,14 +334,18 @@ def build_custom_component_template_from_inputs(
custom_component: Component, user_id: Optional[Union[str, UUID]] = None
):
# The List of Inputs fills the role of the build_config and the entrypoint_args
frontend_node = ComponentFrontendNode.from_inputs(**custom_component.template_config)
field_config = custom_component.template_config
frontend_node = ComponentFrontendNode.from_inputs(**field_config)
field_config = run_build_inputs(
custom_component,
frontend_node=frontend_node,
custom_component=custom_component,
user_id=user_id,
)
frontend_node = add_code_field(frontend_node, custom_component.code, field_config.get("code", {}))
# But we now need to calculate the return_type of the methods in the outputs
for output in frontend_node.outputs:
if output.types:
continue
return_types = custom_component.get_method_return_type(output.method)
return_types = [format_type(return_type) for return_type in return_types]
output.add_types(return_types)

View file

@ -1,18 +1,16 @@
from types import GenericAlias
from typing import Any, Callable, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator
from langflow.field_typing import Text
from langflow.field_typing.range_spec import RangeSpec
from langflow.helpers.custom import format_type
class Input(BaseModel):
model_config = ConfigDict()
model_config = ConfigDict(arbitrary_types_allowed=True)
field_type: str = Field(
default="str",
serialization_alias="type",
)
field_type: str | type | None = Field(default=str, serialization_alias="type")
"""The type of field this is. Default is a string."""
required: bool = False
@ -86,10 +84,10 @@ class Input(BaseModel):
def serialize_model(self, handler):
result = handler(self)
# If the field is str, we add the Text input type
if self.field_type in ["str", "Text"]:
if self.field_type in [str, Text]:
if "input_types" not in result:
result["input_types"] = ["Text"]
if self.field_type == "Text":
if self.field_type == Text:
result["type"] = "str"
else:
result["type"] = self.field_type
@ -111,15 +109,15 @@ class Input(BaseModel):
# If the user passes CustomComponent as a type insteado of "CustomComponent" we need to convert it to a string
# this should be done for all types
# How to check if v is a type?
if isinstance(v, type):
return format_type(v)
if isinstance(v, (type, GenericAlias)):
return str(v)
elif not isinstance(v, str):
raise ValueError(f"type must be a string or a type, not {type(v)}")
return v
@field_serializer("field_type")
def serialize_field_type(self, value, _info):
if value == "float" and self.range_spec is None:
if value == float and self.range_spec is None:
self.range_spec = RangeSpec()
return value
@ -180,3 +178,9 @@ class Output(BaseModel):
else:
raise ValueError("If display_name is not set, name must be set")
return v
@model_serializer(mode="wrap")
def serialize_model(self, handler):
result = handler(self)
return result