From 45e2691598e51c4797f3d7cf8da786c7f514b1c9 Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Thu, 6 Jun 2024 15:34:00 -0300 Subject: [PATCH] Refactor build_inputs method to add extra fields in ComponentFrontendNode --- .../custom/custom_component/component.py | 1 + src/backend/base/langflow/custom/utils.py | 20 +++++++++++--- .../base/langflow/template/field/base.py | 26 +++++++++++-------- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index 6c88e51be..767f5fc7b 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -110,6 +110,7 @@ class Component(CustomComponent): """ # This function is similar to build_config, but it will process the inputs # and return them as a dict with keys being the Input.name and values being the Input.model_dump() + self.inputs = self.template_config.get("inputs", []) if not self.inputs: return {} build_config = {_input.name: _input.model_dump(by_alias=True, exclude_none=True) for _input in self.inputs} diff --git a/src/backend/base/langflow/custom/utils.py b/src/backend/base/langflow/custom/utils.py index 4a6ca2bd2..7b0220024 100644 --- a/src/backend/base/langflow/custom/utils.py +++ b/src/backend/base/langflow/custom/utils.py @@ -103,6 +103,8 @@ def extract_type_from_optional(field_type): Returns: str: The extracted type, or an empty string if no type was found. """ + if "optional" not in field_type.lower(): + return field_type match = re.search(r"\[(.*?)\]$", field_type) return match[1] if match else field_type @@ -249,10 +251,16 @@ def get_field_dict(field: Union[Input, dict]): return field -def run_build_inputs(custom_component: Component, user_id: Optional[Union[str, UUID]] = None): +def run_build_inputs( + frontend_node: ComponentFrontendNode, + custom_component: Component, + user_id: Optional[Union[str, UUID]] = None, +): """Run the build inputs of a custom component.""" try: - return custom_component.build_inputs(user_id=user_id) + field_config = custom_component.build_inputs(user_id=user_id) + add_extra_fields(frontend_node, field_config, field_config.values()) + return field_config except Exception as exc: logger.error(f"Error running build inputs: {exc}") raise HTTPException(status_code=500, detail=str(exc)) from exc @@ -326,14 +334,18 @@ def build_custom_component_template_from_inputs( custom_component: Component, user_id: Optional[Union[str, UUID]] = None ): # The List of Inputs fills the role of the build_config and the entrypoint_args - frontend_node = ComponentFrontendNode.from_inputs(**custom_component.template_config) + field_config = custom_component.template_config + frontend_node = ComponentFrontendNode.from_inputs(**field_config) field_config = run_build_inputs( - custom_component, + frontend_node=frontend_node, + custom_component=custom_component, user_id=user_id, ) frontend_node = add_code_field(frontend_node, custom_component.code, field_config.get("code", {})) # But we now need to calculate the return_type of the methods in the outputs for output in frontend_node.outputs: + if output.types: + continue return_types = custom_component.get_method_return_type(output.method) return_types = [format_type(return_type) for return_type in return_types] output.add_types(return_types) diff --git a/src/backend/base/langflow/template/field/base.py b/src/backend/base/langflow/template/field/base.py index 4e9f059fa..a073f8c93 100644 --- a/src/backend/base/langflow/template/field/base.py +++ b/src/backend/base/langflow/template/field/base.py @@ -1,18 +1,16 @@ +from types import GenericAlias from typing import Any, Callable, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator +from langflow.field_typing import Text from langflow.field_typing.range_spec import RangeSpec -from langflow.helpers.custom import format_type class Input(BaseModel): - model_config = ConfigDict() + model_config = ConfigDict(arbitrary_types_allowed=True) - field_type: str = Field( - default="str", - serialization_alias="type", - ) + field_type: str | type | None = Field(default=str, serialization_alias="type") """The type of field this is. Default is a string.""" required: bool = False @@ -86,10 +84,10 @@ class Input(BaseModel): def serialize_model(self, handler): result = handler(self) # If the field is str, we add the Text input type - if self.field_type in ["str", "Text"]: + if self.field_type in [str, Text]: if "input_types" not in result: result["input_types"] = ["Text"] - if self.field_type == "Text": + if self.field_type == Text: result["type"] = "str" else: result["type"] = self.field_type @@ -111,15 +109,15 @@ class Input(BaseModel): # If the user passes CustomComponent as a type insteado of "CustomComponent" we need to convert it to a string # this should be done for all types # How to check if v is a type? - if isinstance(v, type): - return format_type(v) + if isinstance(v, (type, GenericAlias)): + return str(v) elif not isinstance(v, str): raise ValueError(f"type must be a string or a type, not {type(v)}") return v @field_serializer("field_type") def serialize_field_type(self, value, _info): - if value == "float" and self.range_spec is None: + if value == float and self.range_spec is None: self.range_spec = RangeSpec() return value @@ -180,3 +178,9 @@ class Output(BaseModel): else: raise ValueError("If display_name is not set, name must be set") return v + + @model_serializer(mode="wrap") + def serialize_model(self, handler): + result = handler(self) + + return result