feat: update template and custom component to load inputs and outputs

This commit is contained in:
ogabrielluiz 2024-05-30 22:47:29 -03:00
commit 230c4a69ed
8 changed files with 95 additions and 197 deletions

View file

@ -37,6 +37,12 @@ def getattr_return_list_of_str(value):
return []
def getattr_return_list_of_object(value):
if isinstance(value, list):
return value
return []
ATTR_FUNC_MAPPING: dict[str, Callable] = {
"display_name": getattr_return_str,
"description": getattr_return_str,
@ -47,4 +53,6 @@ ATTR_FUNC_MAPPING: dict[str, Callable] = {
"is_input": getattr_return_bool,
"is_output": getattr_return_bool,
"conditional_paths": getattr_return_list_of_str,
"outputs": getattr_return_list_of_object,
"inputs": getattr_return_list_of_object,
}

View file

@ -84,6 +84,10 @@ class Component:
if value is not None:
template_config[attribute] = func(value=value)
for key in template_config.copy():
if key not in ATTR_FUNC_MAPPING.keys():
template_config.pop(key, None)
return template_config
def build(self, *args: Any, **kwargs: Any) -> Any:

View file

@ -83,6 +83,23 @@ class CustomComponent(Component):
inputs: Optional[List[Input]] = None
outputs: Optional[List[Output]] = None
def build_inputs(self, user_id: Optional[Union[str, UUID]] = None):
"""
Builds the inputs for the custom component.
Args:
user_id (Optional[Union[str, UUID]], optional): The user ID. Defaults to None.
Returns:
List[Input]: The list of inputs.
"""
# This function is similar to build_config, but it will process the inputs
# and return them as a dict with keys being the Input.name and values being the Input.model_dump()
if not self.inputs:
return {}
build_config = {_input.name: _input.model_dump(by_alias=True, exclude_none=True) for _input in self.inputs}
return build_config
def update_state(self, name: str, value: Any):
if not self.vertex:
raise ValueError("Vertex is not set")
@ -275,7 +292,7 @@ class CustomComponent(Component):
Returns:
list: The arguments of the function entrypoint.
"""
build_method = self.get_build_method()
build_method = self.get_method(self.function_entrypoint_name)
if not build_method:
return []
@ -287,7 +304,7 @@ class CustomComponent(Component):
return args
@cachedmethod(operator.attrgetter("cache"))
def get_build_method(self):
def get_method(self, method_name: str):
"""
Gets the build method for the custom component.
@ -303,9 +320,7 @@ class CustomComponent(Component):
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
]
build_methods = [method for method in component_class["methods"] if method["name"] == (method_name)]
return build_methods[0] if build_methods else {}
@ -317,7 +332,10 @@ class CustomComponent(Component):
Returns:
List[Any]: The return type of the function entrypoint.
"""
build_method = self.get_build_method()
return self.get_method_return_type(self.function_entrypoint_name)
def get_method_return_type(self, method_name: str):
build_method = self.get_method(method_name)
if not build_method or not build_method.get("has_return"):
return []
return_type = build_method["return_type"]

View file

@ -0,0 +1,13 @@
from typing import Any
def format_type(type_: Any) -> str:
if type_ == str:
type_ = "Text"
elif hasattr(type_, "__name__"):
type_ = type_.__name__
elif hasattr(type_, "__class__"):
type_ = type_.__class__.__name__
else:
type_ = str(type_)
return type_

View file

@ -16,12 +16,9 @@ from langflow.interface.types import get_all_components
from langflow.services.auth.utils import create_super_user
from langflow.services.database.models.flow.model import Flow, FlowCreate
from langflow.services.database.models.folder.model import Folder, FolderCreate
from langflow.services.database.models.user.crud import get_user_by_username
from langflow.services.deps import get_settings_service, session_scope
from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist
from langflow.services.deps import get_settings_service, session_scope, get_variable_service
from langflow.services.database.models.user.crud import get_user_by_username
from langflow.services.deps import get_settings_service, get_variable_service, session_scope
STARTER_FOLDER_NAME = "Starter Projects"
STARTER_FOLDER_DESCRIPTION = "Starter projects to help you get started in Langflow."
@ -221,6 +218,7 @@ def _is_valid_uuid(val):
return False
return str(uuid_obj) == val
def load_flows_from_directory():
settings_service = get_settings_service()
flows_path = settings_service.settings.load_flows_path
@ -262,6 +260,7 @@ def load_flows_from_directory():
session.add(flow)
session.commit()
def find_existing_flow(session, flow_id, flow_endpoint_name):
if flow_endpoint_name:
stmt = select(Flow).where(Flow.endpoint_name == flow_endpoint_name)
@ -271,6 +270,8 @@ def find_existing_flow(session, flow_id, flow_endpoint_name):
if existing := session.exec(stmt).first():
return existing
return None
def create_or_update_starter_projects():
components_paths = get_settings_service().settings.components_path
try:

View file

@ -3,12 +3,16 @@ from typing import Any, Callable, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator
from langflow.field_typing.range_spec import RangeSpec
from langflow.helpers.custom import format_type
class Input(BaseModel):
model_config = ConfigDict()
field_type: str = Field(default="str", serialization_alias="type")
field_type: str = Field(
default="str",
serialization_alias="type",
)
"""The type of field this is. Default is a string."""
required: bool = False
@ -102,6 +106,17 @@ class Input(BaseModel):
def serialize_file_path(self, value):
return value if self.field_type == "file" else ""
@field_validator("field_type", mode="before")
def validate_type(cls, v):
# If the user passes CustomComponent as a type insteado of "CustomComponent" we need to convert it to a string
# this should be done for all types
# How to check if v is a type?
if isinstance(v, type):
return format_type(v)
elif not isinstance(v, str):
raise ValueError(f"type must be a string or a type, not {type(v)}")
return v
@field_serializer("field_type")
def serialize_field_type(self, value, _info):
if value == "float" and self.range_spec is None:
@ -131,7 +146,7 @@ class Input(BaseModel):
class Output(BaseModel):
type: list[str] = Field(default=[], serialization_alias="types")
types: Optional[list[str]] = Field(default=[], serialization_alias="types")
"""List of output types for the field."""
selected: Optional[str] = Field(default=None, serialization_alias="selected")
@ -140,5 +155,12 @@ class Output(BaseModel):
name: str = Field(default="", serialization_alias="name")
"""The name of the field."""
method: Optional[str] = Field(default=None, serialization_alias="method")
"""The method to use for the output."""
def to_dict(self):
return self.model_dump(by_alias=True, exclude_none=True)
def add_types(self, _type: list[Any]):
for type_ in _type:
self.types.append(type_)

View file

@ -1,42 +1,10 @@
import re
from collections import defaultdict
from typing import ClassVar, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, field_serializer, model_serializer
from pydantic import BaseModel, field_serializer, model_serializer
from langflow.template.field.base import Input, Output
from langflow.template.frontend_node.constants import FORCE_SHOW_FIELDS
from langflow.template.frontend_node.formatter import field_formatters
from langflow.template.field.base import Output
from langflow.template.template.base import Template
from langflow.utils import constants
class FieldFormatters(BaseModel):
formatters: ClassVar[Dict] = {
"openai_api_key": field_formatters.OpenAIAPIKeyFormatter(),
}
base_formatters: ClassVar[Dict] = {
"kwargs": field_formatters.KwargsFormatter(),
"optional": field_formatters.RemoveOptionalFormatter(),
"list": field_formatters.ListTypeFormatter(),
"dict": field_formatters.DictTypeFormatter(),
"union": field_formatters.UnionTypeFormatter(),
"multiline": field_formatters.MultilineFieldFormatter(),
"show": field_formatters.ShowFieldFormatter(),
"password": field_formatters.PasswordFieldFormatter(),
"default": field_formatters.DefaultValueFormatter(),
"headers": field_formatters.HeadersDefaultValueFormatter(),
"dict_code_file": field_formatters.DictCodeFileFormatter(),
"model_fields": field_formatters.ModelSpecificFieldFormatter(),
}
def format(self, field: Input, name: Optional[str] = None) -> None:
for key, formatter in self.base_formatters.items():
formatter.format(field, name)
for key, formatter in self.formatters.items():
if key == field.name:
formatter.format(field, name)
class FrontendNode(BaseModel):
@ -69,8 +37,6 @@ class FrontendNode(BaseModel):
"""List of output types for the frontend node."""
full_path: Optional[str] = None
"""Full path of the frontend node."""
field_formatters: FieldFormatters = Field(default_factory=FieldFormatters)
"""Field formatters for the frontend node."""
pinned: bool = False
"""Whether the frontend node is pinned."""
conditional_paths: List[str] = []
@ -85,12 +51,6 @@ class FrontendNode(BaseModel):
beta: bool = False
error: Optional[str] = None
# field formatters is an instance attribute but it is not used in the class
# so we need to create a method to get it
@staticmethod
def get_field_formatters() -> FieldFormatters:
return FieldFormatters()
def set_documentation(self, documentation: str) -> None:
"""Sets the documentation of the frontend node."""
self.documentation = documentation
@ -121,7 +81,7 @@ class FrontendNode(BaseModel):
for base_class in result["output_types"]:
output = Output(
name=base_class,
type=[base_class],
types=[base_class],
)
result["outputs"].append(output.model_dump())
@ -155,142 +115,12 @@ class FrontendNode(BaseModel):
elif isinstance(output_type, list):
self.output_types.extend(output_type)
@staticmethod
def format_field(field: Input, name: Optional[str] = None) -> None:
"""Formats a given field based on its attributes and value."""
FrontendNode.get_field_formatters().format(field, name)
@staticmethod
def remove_optional(_type: str) -> str:
"""Removes 'Optional' wrapper from the type if present."""
return re.sub(r"Optional\[(.*)\]", r"\1", _type)
@staticmethod
def check_for_list_type(_type: str) -> tuple:
"""Checks for list type and returns the modified type and a boolean indicating if it's a list."""
is_list = "List" in _type or "Sequence" in _type
if is_list:
_type = re.sub(r"(List|Sequence)\[(.*)\]", r"\2", _type)
return _type, is_list
@staticmethod
def replace_mapping_with_dict(_type: str) -> str:
"""Replaces 'Mapping' with 'dict'."""
return _type.replace("Mapping", "dict")
@staticmethod
def handle_union_type(_type: str) -> str:
"""Simplifies the 'Union' type to the first type in the Union."""
if "Union" in _type:
_type = _type.replace("Union[", "")[:-1]
_type = _type.split(",")[0]
_type = _type.replace("]", "").replace("[", "")
return _type
@staticmethod
def handle_special_field(field, key: str, _type: str, SPECIAL_FIELD_HANDLERS) -> str:
"""Handles special field by using the respective handler if present."""
handler = SPECIAL_FIELD_HANDLERS.get(key)
return handler(field) if handler else _type
@staticmethod
def handle_dict_type(field: Input, _type: str) -> str:
"""Handles 'dict' type by replacing it with 'code' or 'file' based on the field name."""
if "dict" in _type.lower() and field.name == "dict_":
field.field_type = "file"
field.file_types = [".json", ".yaml", ".yml"]
elif _type.startswith("Dict") or _type.startswith("Mapping") or _type.startswith("dict"):
field.field_type = "dict"
return _type
@staticmethod
def replace_default_value(field: Input, value: dict) -> None:
"""Replaces default value with actual value if 'default' is present in value."""
if "default" in value:
field.value = value["default"]
@staticmethod
def handle_specific_field_values(field: Input, key: str, name: Optional[str] = None) -> None:
"""Handles specific field values for certain fields."""
if key == "headers":
field.value = """{"Authorization": "Bearer <token>"}"""
FrontendNode._handle_model_specific_field_values(field, key, name)
FrontendNode._handle_api_key_specific_field_values(field, key, name)
@staticmethod
def _handle_model_specific_field_values(field: Input, key: str, name: Optional[str] = None) -> None:
"""Handles specific field values related to models."""
model_dict = {
"OpenAI": constants.OPENAI_MODELS,
"ChatOpenAI": constants.CHAT_OPENAI_MODELS,
"Anthropic": constants.ANTHROPIC_MODELS,
"ChatAnthropic": constants.ANTHROPIC_MODELS,
}
if name in model_dict and key == "model_name":
field.options = model_dict[name]
field.is_list = True
@staticmethod
def _handle_api_key_specific_field_values(field: Input, key: str, name: Optional[str] = None) -> None:
"""Handles specific field values related to API keys."""
if "api_key" in key and "OpenAI" in str(name):
field.display_name = "OpenAI API Key"
field.required = False
if field.value is None:
field.value = ""
@staticmethod
def handle_kwargs_field(field: Input) -> None:
"""Handles kwargs field by setting certain attributes."""
if "kwargs" in (field.name or "").lower():
field.advanced = True
field.required = False
field.show = False
@staticmethod
def handle_api_key_field(field: Input, key: str) -> None:
"""Handles api key field by setting certain attributes."""
if "api" in key.lower() and "key" in key.lower():
field.required = False
field.advanced = False
field.display_name = key.replace("_", " ").title()
field.display_name = field.display_name.replace("Api", "API")
@staticmethod
def should_show_field(key: str, required: bool) -> bool:
"""Determines whether the field should be shown."""
return (
(required and key not in ["input_variables"])
or key in FORCE_SHOW_FIELDS
or "api" in key
or ("key" in key and "input" not in key and "output" not in key)
)
@staticmethod
def should_be_password(key: str, show: bool) -> bool:
"""Determines whether the field should be a password field."""
return any(text in key.lower() for text in {"password", "token", "api", "key"}) and show
@staticmethod
def should_be_multiline(key: str) -> bool:
"""Determines whether the field should be multiline."""
return key in {
"suffix",
"prefix",
"template",
"examples",
"code",
"headers",
"description",
}
@staticmethod
def set_field_default_value(field: Input, value: dict, key: str) -> None:
"""Sets the field value with the default value if present."""
if "default" in value:
field.value = value["default"]
if key == "headers":
field.value = """{"Authorization": "Bearer <token>"}"""
@classmethod
def from_inputs(cls, **kwargs):
"""Create a frontend node from inputs."""
if "inputs" not in kwargs:
raise ValueError("Missing 'inputs' argument.")
inputs = kwargs.pop("inputs")
template = Template(type_name="CustomComponent", fields=inputs)
kwargs["template"] = template
return cls(**kwargs)

View file

@ -1,6 +1,7 @@
from datetime import datetime
from pathlib import Path
import pytest
from sqlmodel import select
from langflow.initial_setup.setup import (
@ -41,7 +42,8 @@ def test_get_project_data():
assert isinstance(project_icon_bg_color, str) or project_icon_bg_color is None
def test_create_or_update_starter_projects(client):
@pytest.mark.asyncio
async def test_create_or_update_starter_projects(client):
with session_scope() as session:
# Run the function to create or update projects
create_or_update_starter_projects()