From 230c4a69ed6b7509ff78d9afee2424809dc9632c Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Thu, 30 May 2024 22:47:29 -0300 Subject: [PATCH] feat: update template and custom component to load inputs and outputs --- .../base/langflow/custom/attributes.py | 8 + .../custom/custom_component/component.py | 4 + .../custom_component/custom_component.py | 30 ++- src/backend/base/langflow/helpers/custom.py | 13 ++ .../base/langflow/initial_setup/setup.py | 11 +- .../base/langflow/template/field/base.py | 26 ++- .../langflow/template/frontend_node/base.py | 196 ++---------------- tests/test_initial_setup.py | 4 +- 8 files changed, 95 insertions(+), 197 deletions(-) create mode 100644 src/backend/base/langflow/helpers/custom.py diff --git a/src/backend/base/langflow/custom/attributes.py b/src/backend/base/langflow/custom/attributes.py index 1fc7e8c1f..d96500c78 100644 --- a/src/backend/base/langflow/custom/attributes.py +++ b/src/backend/base/langflow/custom/attributes.py @@ -37,6 +37,12 @@ def getattr_return_list_of_str(value): return [] +def getattr_return_list_of_object(value): + if isinstance(value, list): + return value + return [] + + ATTR_FUNC_MAPPING: dict[str, Callable] = { "display_name": getattr_return_str, "description": getattr_return_str, @@ -47,4 +53,6 @@ ATTR_FUNC_MAPPING: dict[str, Callable] = { "is_input": getattr_return_bool, "is_output": getattr_return_bool, "conditional_paths": getattr_return_list_of_str, + "outputs": getattr_return_list_of_object, + "inputs": getattr_return_list_of_object, } diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index d45b5daed..ba4472986 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -84,6 +84,10 @@ class Component: if value is not None: template_config[attribute] = func(value=value) + for key in template_config.copy(): + if key not in ATTR_FUNC_MAPPING.keys(): + template_config.pop(key, None) + return template_config def build(self, *args: Any, **kwargs: Any) -> Any: diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index 75cbabfe8..535bf0cf3 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -83,6 +83,23 @@ class CustomComponent(Component): inputs: Optional[List[Input]] = None outputs: Optional[List[Output]] = None + def build_inputs(self, user_id: Optional[Union[str, UUID]] = None): + """ + Builds the inputs for the custom component. + + Args: + user_id (Optional[Union[str, UUID]], optional): The user ID. Defaults to None. + + Returns: + List[Input]: The list of inputs. + """ + # This function is similar to build_config, but it will process the inputs + # and return them as a dict with keys being the Input.name and values being the Input.model_dump() + if not self.inputs: + return {} + build_config = {_input.name: _input.model_dump(by_alias=True, exclude_none=True) for _input in self.inputs} + return build_config + def update_state(self, name: str, value: Any): if not self.vertex: raise ValueError("Vertex is not set") @@ -275,7 +292,7 @@ class CustomComponent(Component): Returns: list: The arguments of the function entrypoint. """ - build_method = self.get_build_method() + build_method = self.get_method(self.function_entrypoint_name) if not build_method: return [] @@ -287,7 +304,7 @@ class CustomComponent(Component): return args @cachedmethod(operator.attrgetter("cache")) - def get_build_method(self): + def get_method(self, method_name: str): """ Gets the build method for the custom component. @@ -303,9 +320,7 @@ class CustomComponent(Component): # Assume the first Component class is the one we're interested in component_class = component_classes[0] - build_methods = [ - method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name - ] + build_methods = [method for method in component_class["methods"] if method["name"] == (method_name)] return build_methods[0] if build_methods else {} @@ -317,7 +332,10 @@ class CustomComponent(Component): Returns: List[Any]: The return type of the function entrypoint. """ - build_method = self.get_build_method() + return self.get_method_return_type(self.function_entrypoint_name) + + def get_method_return_type(self, method_name: str): + build_method = self.get_method(method_name) if not build_method or not build_method.get("has_return"): return [] return_type = build_method["return_type"] diff --git a/src/backend/base/langflow/helpers/custom.py b/src/backend/base/langflow/helpers/custom.py new file mode 100644 index 000000000..bdbb128f4 --- /dev/null +++ b/src/backend/base/langflow/helpers/custom.py @@ -0,0 +1,13 @@ +from typing import Any + + +def format_type(type_: Any) -> str: + if type_ == str: + type_ = "Text" + elif hasattr(type_, "__name__"): + type_ = type_.__name__ + elif hasattr(type_, "__class__"): + type_ = type_.__class__.__name__ + else: + type_ = str(type_) + return type_ diff --git a/src/backend/base/langflow/initial_setup/setup.py b/src/backend/base/langflow/initial_setup/setup.py index 27574950c..62997fbb1 100644 --- a/src/backend/base/langflow/initial_setup/setup.py +++ b/src/backend/base/langflow/initial_setup/setup.py @@ -16,12 +16,9 @@ from langflow.interface.types import get_all_components from langflow.services.auth.utils import create_super_user from langflow.services.database.models.flow.model import Flow, FlowCreate from langflow.services.database.models.folder.model import Folder, FolderCreate -from langflow.services.database.models.user.crud import get_user_by_username -from langflow.services.deps import get_settings_service, session_scope - from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist -from langflow.services.deps import get_settings_service, session_scope, get_variable_service - +from langflow.services.database.models.user.crud import get_user_by_username +from langflow.services.deps import get_settings_service, get_variable_service, session_scope STARTER_FOLDER_NAME = "Starter Projects" STARTER_FOLDER_DESCRIPTION = "Starter projects to help you get started in Langflow." @@ -221,6 +218,7 @@ def _is_valid_uuid(val): return False return str(uuid_obj) == val + def load_flows_from_directory(): settings_service = get_settings_service() flows_path = settings_service.settings.load_flows_path @@ -262,6 +260,7 @@ def load_flows_from_directory(): session.add(flow) session.commit() + def find_existing_flow(session, flow_id, flow_endpoint_name): if flow_endpoint_name: stmt = select(Flow).where(Flow.endpoint_name == flow_endpoint_name) @@ -271,6 +270,8 @@ def find_existing_flow(session, flow_id, flow_endpoint_name): if existing := session.exec(stmt).first(): return existing return None + + def create_or_update_starter_projects(): components_paths = get_settings_service().settings.components_path try: diff --git a/src/backend/base/langflow/template/field/base.py b/src/backend/base/langflow/template/field/base.py index 0ba947a28..59d56f1a3 100644 --- a/src/backend/base/langflow/template/field/base.py +++ b/src/backend/base/langflow/template/field/base.py @@ -3,12 +3,16 @@ from typing import Any, Callable, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator from langflow.field_typing.range_spec import RangeSpec +from langflow.helpers.custom import format_type class Input(BaseModel): model_config = ConfigDict() - field_type: str = Field(default="str", serialization_alias="type") + field_type: str = Field( + default="str", + serialization_alias="type", + ) """The type of field this is. Default is a string.""" required: bool = False @@ -102,6 +106,17 @@ class Input(BaseModel): def serialize_file_path(self, value): return value if self.field_type == "file" else "" + @field_validator("field_type", mode="before") + def validate_type(cls, v): + # If the user passes CustomComponent as a type insteado of "CustomComponent" we need to convert it to a string + # this should be done for all types + # How to check if v is a type? + if isinstance(v, type): + return format_type(v) + elif not isinstance(v, str): + raise ValueError(f"type must be a string or a type, not {type(v)}") + return v + @field_serializer("field_type") def serialize_field_type(self, value, _info): if value == "float" and self.range_spec is None: @@ -131,7 +146,7 @@ class Input(BaseModel): class Output(BaseModel): - type: list[str] = Field(default=[], serialization_alias="types") + types: Optional[list[str]] = Field(default=[], serialization_alias="types") """List of output types for the field.""" selected: Optional[str] = Field(default=None, serialization_alias="selected") @@ -140,5 +155,12 @@ class Output(BaseModel): name: str = Field(default="", serialization_alias="name") """The name of the field.""" + method: Optional[str] = Field(default=None, serialization_alias="method") + """The method to use for the output.""" + def to_dict(self): return self.model_dump(by_alias=True, exclude_none=True) + + def add_types(self, _type: list[Any]): + for type_ in _type: + self.types.append(type_) diff --git a/src/backend/base/langflow/template/frontend_node/base.py b/src/backend/base/langflow/template/frontend_node/base.py index 2ac4acbab..7f4ad437c 100644 --- a/src/backend/base/langflow/template/frontend_node/base.py +++ b/src/backend/base/langflow/template/frontend_node/base.py @@ -1,42 +1,10 @@ -import re from collections import defaultdict -from typing import ClassVar, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union -from pydantic import BaseModel, Field, field_serializer, model_serializer +from pydantic import BaseModel, field_serializer, model_serializer -from langflow.template.field.base import Input, Output -from langflow.template.frontend_node.constants import FORCE_SHOW_FIELDS -from langflow.template.frontend_node.formatter import field_formatters +from langflow.template.field.base import Output from langflow.template.template.base import Template -from langflow.utils import constants - - -class FieldFormatters(BaseModel): - formatters: ClassVar[Dict] = { - "openai_api_key": field_formatters.OpenAIAPIKeyFormatter(), - } - base_formatters: ClassVar[Dict] = { - "kwargs": field_formatters.KwargsFormatter(), - "optional": field_formatters.RemoveOptionalFormatter(), - "list": field_formatters.ListTypeFormatter(), - "dict": field_formatters.DictTypeFormatter(), - "union": field_formatters.UnionTypeFormatter(), - "multiline": field_formatters.MultilineFieldFormatter(), - "show": field_formatters.ShowFieldFormatter(), - "password": field_formatters.PasswordFieldFormatter(), - "default": field_formatters.DefaultValueFormatter(), - "headers": field_formatters.HeadersDefaultValueFormatter(), - "dict_code_file": field_formatters.DictCodeFileFormatter(), - "model_fields": field_formatters.ModelSpecificFieldFormatter(), - } - - def format(self, field: Input, name: Optional[str] = None) -> None: - for key, formatter in self.base_formatters.items(): - formatter.format(field, name) - - for key, formatter in self.formatters.items(): - if key == field.name: - formatter.format(field, name) class FrontendNode(BaseModel): @@ -69,8 +37,6 @@ class FrontendNode(BaseModel): """List of output types for the frontend node.""" full_path: Optional[str] = None """Full path of the frontend node.""" - field_formatters: FieldFormatters = Field(default_factory=FieldFormatters) - """Field formatters for the frontend node.""" pinned: bool = False """Whether the frontend node is pinned.""" conditional_paths: List[str] = [] @@ -85,12 +51,6 @@ class FrontendNode(BaseModel): beta: bool = False error: Optional[str] = None - # field formatters is an instance attribute but it is not used in the class - # so we need to create a method to get it - @staticmethod - def get_field_formatters() -> FieldFormatters: - return FieldFormatters() - def set_documentation(self, documentation: str) -> None: """Sets the documentation of the frontend node.""" self.documentation = documentation @@ -121,7 +81,7 @@ class FrontendNode(BaseModel): for base_class in result["output_types"]: output = Output( name=base_class, - type=[base_class], + types=[base_class], ) result["outputs"].append(output.model_dump()) @@ -155,142 +115,12 @@ class FrontendNode(BaseModel): elif isinstance(output_type, list): self.output_types.extend(output_type) - @staticmethod - def format_field(field: Input, name: Optional[str] = None) -> None: - """Formats a given field based on its attributes and value.""" - - FrontendNode.get_field_formatters().format(field, name) - - @staticmethod - def remove_optional(_type: str) -> str: - """Removes 'Optional' wrapper from the type if present.""" - return re.sub(r"Optional\[(.*)\]", r"\1", _type) - - @staticmethod - def check_for_list_type(_type: str) -> tuple: - """Checks for list type and returns the modified type and a boolean indicating if it's a list.""" - is_list = "List" in _type or "Sequence" in _type - if is_list: - _type = re.sub(r"(List|Sequence)\[(.*)\]", r"\2", _type) - return _type, is_list - - @staticmethod - def replace_mapping_with_dict(_type: str) -> str: - """Replaces 'Mapping' with 'dict'.""" - return _type.replace("Mapping", "dict") - - @staticmethod - def handle_union_type(_type: str) -> str: - """Simplifies the 'Union' type to the first type in the Union.""" - if "Union" in _type: - _type = _type.replace("Union[", "")[:-1] - _type = _type.split(",")[0] - _type = _type.replace("]", "").replace("[", "") - return _type - - @staticmethod - def handle_special_field(field, key: str, _type: str, SPECIAL_FIELD_HANDLERS) -> str: - """Handles special field by using the respective handler if present.""" - handler = SPECIAL_FIELD_HANDLERS.get(key) - return handler(field) if handler else _type - - @staticmethod - def handle_dict_type(field: Input, _type: str) -> str: - """Handles 'dict' type by replacing it with 'code' or 'file' based on the field name.""" - if "dict" in _type.lower() and field.name == "dict_": - field.field_type = "file" - field.file_types = [".json", ".yaml", ".yml"] - elif _type.startswith("Dict") or _type.startswith("Mapping") or _type.startswith("dict"): - field.field_type = "dict" - return _type - - @staticmethod - def replace_default_value(field: Input, value: dict) -> None: - """Replaces default value with actual value if 'default' is present in value.""" - if "default" in value: - field.value = value["default"] - - @staticmethod - def handle_specific_field_values(field: Input, key: str, name: Optional[str] = None) -> None: - """Handles specific field values for certain fields.""" - if key == "headers": - field.value = """{"Authorization": "Bearer "}""" - FrontendNode._handle_model_specific_field_values(field, key, name) - FrontendNode._handle_api_key_specific_field_values(field, key, name) - - @staticmethod - def _handle_model_specific_field_values(field: Input, key: str, name: Optional[str] = None) -> None: - """Handles specific field values related to models.""" - model_dict = { - "OpenAI": constants.OPENAI_MODELS, - "ChatOpenAI": constants.CHAT_OPENAI_MODELS, - "Anthropic": constants.ANTHROPIC_MODELS, - "ChatAnthropic": constants.ANTHROPIC_MODELS, - } - if name in model_dict and key == "model_name": - field.options = model_dict[name] - field.is_list = True - - @staticmethod - def _handle_api_key_specific_field_values(field: Input, key: str, name: Optional[str] = None) -> None: - """Handles specific field values related to API keys.""" - if "api_key" in key and "OpenAI" in str(name): - field.display_name = "OpenAI API Key" - field.required = False - if field.value is None: - field.value = "" - - @staticmethod - def handle_kwargs_field(field: Input) -> None: - """Handles kwargs field by setting certain attributes.""" - - if "kwargs" in (field.name or "").lower(): - field.advanced = True - field.required = False - field.show = False - - @staticmethod - def handle_api_key_field(field: Input, key: str) -> None: - """Handles api key field by setting certain attributes.""" - if "api" in key.lower() and "key" in key.lower(): - field.required = False - field.advanced = False - - field.display_name = key.replace("_", " ").title() - field.display_name = field.display_name.replace("Api", "API") - - @staticmethod - def should_show_field(key: str, required: bool) -> bool: - """Determines whether the field should be shown.""" - return ( - (required and key not in ["input_variables"]) - or key in FORCE_SHOW_FIELDS - or "api" in key - or ("key" in key and "input" not in key and "output" not in key) - ) - - @staticmethod - def should_be_password(key: str, show: bool) -> bool: - """Determines whether the field should be a password field.""" - return any(text in key.lower() for text in {"password", "token", "api", "key"}) and show - - @staticmethod - def should_be_multiline(key: str) -> bool: - """Determines whether the field should be multiline.""" - return key in { - "suffix", - "prefix", - "template", - "examples", - "code", - "headers", - "description", - } - - @staticmethod - def set_field_default_value(field: Input, value: dict, key: str) -> None: - """Sets the field value with the default value if present.""" - if "default" in value: - field.value = value["default"] - if key == "headers": - field.value = """{"Authorization": "Bearer "}""" + @classmethod + def from_inputs(cls, **kwargs): + """Create a frontend node from inputs.""" + if "inputs" not in kwargs: + raise ValueError("Missing 'inputs' argument.") + inputs = kwargs.pop("inputs") + template = Template(type_name="CustomComponent", fields=inputs) + kwargs["template"] = template + return cls(**kwargs) diff --git a/tests/test_initial_setup.py b/tests/test_initial_setup.py index d4f86a73f..9773b9ca4 100644 --- a/tests/test_initial_setup.py +++ b/tests/test_initial_setup.py @@ -1,6 +1,7 @@ from datetime import datetime from pathlib import Path +import pytest from sqlmodel import select from langflow.initial_setup.setup import ( @@ -41,7 +42,8 @@ def test_get_project_data(): assert isinstance(project_icon_bg_color, str) or project_icon_bg_color is None -def test_create_or_update_starter_projects(client): +@pytest.mark.asyncio +async def test_create_or_update_starter_projects(client): with session_scope() as session: # Run the function to create or update projects create_or_update_starter_projects()