Refactor build_inputs method to add extra fields in ComponentFrontendNode
This commit is contained in:
parent
24b41eb59d
commit
45e2691598
3 changed files with 32 additions and 15 deletions
|
|
@ -110,6 +110,7 @@ class Component(CustomComponent):
|
|||
"""
|
||||
# This function is similar to build_config, but it will process the inputs
|
||||
# and return them as a dict with keys being the Input.name and values being the Input.model_dump()
|
||||
self.inputs = self.template_config.get("inputs", [])
|
||||
if not self.inputs:
|
||||
return {}
|
||||
build_config = {_input.name: _input.model_dump(by_alias=True, exclude_none=True) for _input in self.inputs}
|
||||
|
|
|
|||
|
|
@ -103,6 +103,8 @@ def extract_type_from_optional(field_type):
|
|||
Returns:
|
||||
str: The extracted type, or an empty string if no type was found.
|
||||
"""
|
||||
if "optional" not in field_type.lower():
|
||||
return field_type
|
||||
match = re.search(r"\[(.*?)\]$", field_type)
|
||||
return match[1] if match else field_type
|
||||
|
||||
|
|
@ -249,10 +251,16 @@ def get_field_dict(field: Union[Input, dict]):
|
|||
return field
|
||||
|
||||
|
||||
def run_build_inputs(custom_component: Component, user_id: Optional[Union[str, UUID]] = None):
|
||||
def run_build_inputs(
|
||||
frontend_node: ComponentFrontendNode,
|
||||
custom_component: Component,
|
||||
user_id: Optional[Union[str, UUID]] = None,
|
||||
):
|
||||
"""Run the build inputs of a custom component."""
|
||||
try:
|
||||
return custom_component.build_inputs(user_id=user_id)
|
||||
field_config = custom_component.build_inputs(user_id=user_id)
|
||||
add_extra_fields(frontend_node, field_config, field_config.values())
|
||||
return field_config
|
||||
except Exception as exc:
|
||||
logger.error(f"Error running build inputs: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
|
@ -326,14 +334,18 @@ def build_custom_component_template_from_inputs(
|
|||
custom_component: Component, user_id: Optional[Union[str, UUID]] = None
|
||||
):
|
||||
# The List of Inputs fills the role of the build_config and the entrypoint_args
|
||||
frontend_node = ComponentFrontendNode.from_inputs(**custom_component.template_config)
|
||||
field_config = custom_component.template_config
|
||||
frontend_node = ComponentFrontendNode.from_inputs(**field_config)
|
||||
field_config = run_build_inputs(
|
||||
custom_component,
|
||||
frontend_node=frontend_node,
|
||||
custom_component=custom_component,
|
||||
user_id=user_id,
|
||||
)
|
||||
frontend_node = add_code_field(frontend_node, custom_component.code, field_config.get("code", {}))
|
||||
# But we now need to calculate the return_type of the methods in the outputs
|
||||
for output in frontend_node.outputs:
|
||||
if output.types:
|
||||
continue
|
||||
return_types = custom_component.get_method_return_type(output.method)
|
||||
return_types = [format_type(return_type) for return_type in return_types]
|
||||
output.add_types(return_types)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,16 @@
|
|||
from types import GenericAlias
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator
|
||||
|
||||
from langflow.field_typing import Text
|
||||
from langflow.field_typing.range_spec import RangeSpec
|
||||
from langflow.helpers.custom import format_type
|
||||
|
||||
|
||||
class Input(BaseModel):
|
||||
model_config = ConfigDict()
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
field_type: str = Field(
|
||||
default="str",
|
||||
serialization_alias="type",
|
||||
)
|
||||
field_type: str | type | None = Field(default=str, serialization_alias="type")
|
||||
"""The type of field this is. Default is a string."""
|
||||
|
||||
required: bool = False
|
||||
|
|
@ -86,10 +84,10 @@ class Input(BaseModel):
|
|||
def serialize_model(self, handler):
|
||||
result = handler(self)
|
||||
# If the field is str, we add the Text input type
|
||||
if self.field_type in ["str", "Text"]:
|
||||
if self.field_type in [str, Text]:
|
||||
if "input_types" not in result:
|
||||
result["input_types"] = ["Text"]
|
||||
if self.field_type == "Text":
|
||||
if self.field_type == Text:
|
||||
result["type"] = "str"
|
||||
else:
|
||||
result["type"] = self.field_type
|
||||
|
|
@ -111,15 +109,15 @@ class Input(BaseModel):
|
|||
# If the user passes CustomComponent as a type insteado of "CustomComponent" we need to convert it to a string
|
||||
# this should be done for all types
|
||||
# How to check if v is a type?
|
||||
if isinstance(v, type):
|
||||
return format_type(v)
|
||||
if isinstance(v, (type, GenericAlias)):
|
||||
return str(v)
|
||||
elif not isinstance(v, str):
|
||||
raise ValueError(f"type must be a string or a type, not {type(v)}")
|
||||
return v
|
||||
|
||||
@field_serializer("field_type")
|
||||
def serialize_field_type(self, value, _info):
|
||||
if value == "float" and self.range_spec is None:
|
||||
if value == float and self.range_spec is None:
|
||||
self.range_spec = RangeSpec()
|
||||
return value
|
||||
|
||||
|
|
@ -180,3 +178,9 @@ class Output(BaseModel):
|
|||
else:
|
||||
raise ValueError("If display_name is not set, name must be set")
|
||||
return v
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize_model(self, handler):
|
||||
result = handler(self)
|
||||
|
||||
return result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue