feat: update template and custom component to load inputs and outputs
This commit is contained in:
parent
5dae23bcb1
commit
230c4a69ed
8 changed files with 95 additions and 197 deletions
|
|
@ -37,6 +37,12 @@ def getattr_return_list_of_str(value):
|
|||
return []
|
||||
|
||||
|
||||
def getattr_return_list_of_object(value):
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
return []
|
||||
|
||||
|
||||
ATTR_FUNC_MAPPING: dict[str, Callable] = {
|
||||
"display_name": getattr_return_str,
|
||||
"description": getattr_return_str,
|
||||
|
|
@ -47,4 +53,6 @@ ATTR_FUNC_MAPPING: dict[str, Callable] = {
|
|||
"is_input": getattr_return_bool,
|
||||
"is_output": getattr_return_bool,
|
||||
"conditional_paths": getattr_return_list_of_str,
|
||||
"outputs": getattr_return_list_of_object,
|
||||
"inputs": getattr_return_list_of_object,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -84,6 +84,10 @@ class Component:
|
|||
if value is not None:
|
||||
template_config[attribute] = func(value=value)
|
||||
|
||||
for key in template_config.copy():
|
||||
if key not in ATTR_FUNC_MAPPING.keys():
|
||||
template_config.pop(key, None)
|
||||
|
||||
return template_config
|
||||
|
||||
def build(self, *args: Any, **kwargs: Any) -> Any:
|
||||
|
|
|
|||
|
|
@ -83,6 +83,23 @@ class CustomComponent(Component):
|
|||
inputs: Optional[List[Input]] = None
|
||||
outputs: Optional[List[Output]] = None
|
||||
|
||||
def build_inputs(self, user_id: Optional[Union[str, UUID]] = None):
|
||||
"""
|
||||
Builds the inputs for the custom component.
|
||||
|
||||
Args:
|
||||
user_id (Optional[Union[str, UUID]], optional): The user ID. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Input]: The list of inputs.
|
||||
"""
|
||||
# This function is similar to build_config, but it will process the inputs
|
||||
# and return them as a dict with keys being the Input.name and values being the Input.model_dump()
|
||||
if not self.inputs:
|
||||
return {}
|
||||
build_config = {_input.name: _input.model_dump(by_alias=True, exclude_none=True) for _input in self.inputs}
|
||||
return build_config
|
||||
|
||||
def update_state(self, name: str, value: Any):
|
||||
if not self.vertex:
|
||||
raise ValueError("Vertex is not set")
|
||||
|
|
@ -275,7 +292,7 @@ class CustomComponent(Component):
|
|||
Returns:
|
||||
list: The arguments of the function entrypoint.
|
||||
"""
|
||||
build_method = self.get_build_method()
|
||||
build_method = self.get_method(self.function_entrypoint_name)
|
||||
if not build_method:
|
||||
return []
|
||||
|
||||
|
|
@ -287,7 +304,7 @@ class CustomComponent(Component):
|
|||
return args
|
||||
|
||||
@cachedmethod(operator.attrgetter("cache"))
|
||||
def get_build_method(self):
|
||||
def get_method(self, method_name: str):
|
||||
"""
|
||||
Gets the build method for the custom component.
|
||||
|
||||
|
|
@ -303,9 +320,7 @@ class CustomComponent(Component):
|
|||
|
||||
# Assume the first Component class is the one we're interested in
|
||||
component_class = component_classes[0]
|
||||
build_methods = [
|
||||
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
|
||||
]
|
||||
build_methods = [method for method in component_class["methods"] if method["name"] == (method_name)]
|
||||
|
||||
return build_methods[0] if build_methods else {}
|
||||
|
||||
|
|
@ -317,7 +332,10 @@ class CustomComponent(Component):
|
|||
Returns:
|
||||
List[Any]: The return type of the function entrypoint.
|
||||
"""
|
||||
build_method = self.get_build_method()
|
||||
return self.get_method_return_type(self.function_entrypoint_name)
|
||||
|
||||
def get_method_return_type(self, method_name: str):
|
||||
build_method = self.get_method(method_name)
|
||||
if not build_method or not build_method.get("has_return"):
|
||||
return []
|
||||
return_type = build_method["return_type"]
|
||||
|
|
|
|||
13
src/backend/base/langflow/helpers/custom.py
Normal file
13
src/backend/base/langflow/helpers/custom.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from typing import Any
|
||||
|
||||
|
||||
def format_type(type_: Any) -> str:
|
||||
if type_ == str:
|
||||
type_ = "Text"
|
||||
elif hasattr(type_, "__name__"):
|
||||
type_ = type_.__name__
|
||||
elif hasattr(type_, "__class__"):
|
||||
type_ = type_.__class__.__name__
|
||||
else:
|
||||
type_ = str(type_)
|
||||
return type_
|
||||
|
|
@ -16,12 +16,9 @@ from langflow.interface.types import get_all_components
|
|||
from langflow.services.auth.utils import create_super_user
|
||||
from langflow.services.database.models.flow.model import Flow, FlowCreate
|
||||
from langflow.services.database.models.folder.model import Folder, FolderCreate
|
||||
from langflow.services.database.models.user.crud import get_user_by_username
|
||||
from langflow.services.deps import get_settings_service, session_scope
|
||||
|
||||
from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist
|
||||
from langflow.services.deps import get_settings_service, session_scope, get_variable_service
|
||||
|
||||
from langflow.services.database.models.user.crud import get_user_by_username
|
||||
from langflow.services.deps import get_settings_service, get_variable_service, session_scope
|
||||
|
||||
STARTER_FOLDER_NAME = "Starter Projects"
|
||||
STARTER_FOLDER_DESCRIPTION = "Starter projects to help you get started in Langflow."
|
||||
|
|
@ -221,6 +218,7 @@ def _is_valid_uuid(val):
|
|||
return False
|
||||
return str(uuid_obj) == val
|
||||
|
||||
|
||||
def load_flows_from_directory():
|
||||
settings_service = get_settings_service()
|
||||
flows_path = settings_service.settings.load_flows_path
|
||||
|
|
@ -262,6 +260,7 @@ def load_flows_from_directory():
|
|||
session.add(flow)
|
||||
session.commit()
|
||||
|
||||
|
||||
def find_existing_flow(session, flow_id, flow_endpoint_name):
|
||||
if flow_endpoint_name:
|
||||
stmt = select(Flow).where(Flow.endpoint_name == flow_endpoint_name)
|
||||
|
|
@ -271,6 +270,8 @@ def find_existing_flow(session, flow_id, flow_endpoint_name):
|
|||
if existing := session.exec(stmt).first():
|
||||
return existing
|
||||
return None
|
||||
|
||||
|
||||
def create_or_update_starter_projects():
|
||||
components_paths = get_settings_service().settings.components_path
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -3,12 +3,16 @@ from typing import Any, Callable, Optional, Union
|
|||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator
|
||||
|
||||
from langflow.field_typing.range_spec import RangeSpec
|
||||
from langflow.helpers.custom import format_type
|
||||
|
||||
|
||||
class Input(BaseModel):
|
||||
model_config = ConfigDict()
|
||||
|
||||
field_type: str = Field(default="str", serialization_alias="type")
|
||||
field_type: str = Field(
|
||||
default="str",
|
||||
serialization_alias="type",
|
||||
)
|
||||
"""The type of field this is. Default is a string."""
|
||||
|
||||
required: bool = False
|
||||
|
|
@ -102,6 +106,17 @@ class Input(BaseModel):
|
|||
def serialize_file_path(self, value):
|
||||
return value if self.field_type == "file" else ""
|
||||
|
||||
@field_validator("field_type", mode="before")
|
||||
def validate_type(cls, v):
|
||||
# If the user passes CustomComponent as a type insteado of "CustomComponent" we need to convert it to a string
|
||||
# this should be done for all types
|
||||
# How to check if v is a type?
|
||||
if isinstance(v, type):
|
||||
return format_type(v)
|
||||
elif not isinstance(v, str):
|
||||
raise ValueError(f"type must be a string or a type, not {type(v)}")
|
||||
return v
|
||||
|
||||
@field_serializer("field_type")
|
||||
def serialize_field_type(self, value, _info):
|
||||
if value == "float" and self.range_spec is None:
|
||||
|
|
@ -131,7 +146,7 @@ class Input(BaseModel):
|
|||
|
||||
|
||||
class Output(BaseModel):
|
||||
type: list[str] = Field(default=[], serialization_alias="types")
|
||||
types: Optional[list[str]] = Field(default=[], serialization_alias="types")
|
||||
"""List of output types for the field."""
|
||||
|
||||
selected: Optional[str] = Field(default=None, serialization_alias="selected")
|
||||
|
|
@ -140,5 +155,12 @@ class Output(BaseModel):
|
|||
name: str = Field(default="", serialization_alias="name")
|
||||
"""The name of the field."""
|
||||
|
||||
method: Optional[str] = Field(default=None, serialization_alias="method")
|
||||
"""The method to use for the output."""
|
||||
|
||||
def to_dict(self):
|
||||
return self.model_dump(by_alias=True, exclude_none=True)
|
||||
|
||||
def add_types(self, _type: list[Any]):
|
||||
for type_ in _type:
|
||||
self.types.append(type_)
|
||||
|
|
|
|||
|
|
@ -1,42 +1,10 @@
|
|||
import re
|
||||
from collections import defaultdict
|
||||
from typing import ClassVar, Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, model_serializer
|
||||
from pydantic import BaseModel, field_serializer, model_serializer
|
||||
|
||||
from langflow.template.field.base import Input, Output
|
||||
from langflow.template.frontend_node.constants import FORCE_SHOW_FIELDS
|
||||
from langflow.template.frontend_node.formatter import field_formatters
|
||||
from langflow.template.field.base import Output
|
||||
from langflow.template.template.base import Template
|
||||
from langflow.utils import constants
|
||||
|
||||
|
||||
class FieldFormatters(BaseModel):
|
||||
formatters: ClassVar[Dict] = {
|
||||
"openai_api_key": field_formatters.OpenAIAPIKeyFormatter(),
|
||||
}
|
||||
base_formatters: ClassVar[Dict] = {
|
||||
"kwargs": field_formatters.KwargsFormatter(),
|
||||
"optional": field_formatters.RemoveOptionalFormatter(),
|
||||
"list": field_formatters.ListTypeFormatter(),
|
||||
"dict": field_formatters.DictTypeFormatter(),
|
||||
"union": field_formatters.UnionTypeFormatter(),
|
||||
"multiline": field_formatters.MultilineFieldFormatter(),
|
||||
"show": field_formatters.ShowFieldFormatter(),
|
||||
"password": field_formatters.PasswordFieldFormatter(),
|
||||
"default": field_formatters.DefaultValueFormatter(),
|
||||
"headers": field_formatters.HeadersDefaultValueFormatter(),
|
||||
"dict_code_file": field_formatters.DictCodeFileFormatter(),
|
||||
"model_fields": field_formatters.ModelSpecificFieldFormatter(),
|
||||
}
|
||||
|
||||
def format(self, field: Input, name: Optional[str] = None) -> None:
|
||||
for key, formatter in self.base_formatters.items():
|
||||
formatter.format(field, name)
|
||||
|
||||
for key, formatter in self.formatters.items():
|
||||
if key == field.name:
|
||||
formatter.format(field, name)
|
||||
|
||||
|
||||
class FrontendNode(BaseModel):
|
||||
|
|
@ -69,8 +37,6 @@ class FrontendNode(BaseModel):
|
|||
"""List of output types for the frontend node."""
|
||||
full_path: Optional[str] = None
|
||||
"""Full path of the frontend node."""
|
||||
field_formatters: FieldFormatters = Field(default_factory=FieldFormatters)
|
||||
"""Field formatters for the frontend node."""
|
||||
pinned: bool = False
|
||||
"""Whether the frontend node is pinned."""
|
||||
conditional_paths: List[str] = []
|
||||
|
|
@ -85,12 +51,6 @@ class FrontendNode(BaseModel):
|
|||
beta: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
# field formatters is an instance attribute but it is not used in the class
|
||||
# so we need to create a method to get it
|
||||
@staticmethod
|
||||
def get_field_formatters() -> FieldFormatters:
|
||||
return FieldFormatters()
|
||||
|
||||
def set_documentation(self, documentation: str) -> None:
|
||||
"""Sets the documentation of the frontend node."""
|
||||
self.documentation = documentation
|
||||
|
|
@ -121,7 +81,7 @@ class FrontendNode(BaseModel):
|
|||
for base_class in result["output_types"]:
|
||||
output = Output(
|
||||
name=base_class,
|
||||
type=[base_class],
|
||||
types=[base_class],
|
||||
)
|
||||
result["outputs"].append(output.model_dump())
|
||||
|
||||
|
|
@ -155,142 +115,12 @@ class FrontendNode(BaseModel):
|
|||
elif isinstance(output_type, list):
|
||||
self.output_types.extend(output_type)
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: Input, name: Optional[str] = None) -> None:
|
||||
"""Formats a given field based on its attributes and value."""
|
||||
|
||||
FrontendNode.get_field_formatters().format(field, name)
|
||||
|
||||
@staticmethod
|
||||
def remove_optional(_type: str) -> str:
|
||||
"""Removes 'Optional' wrapper from the type if present."""
|
||||
return re.sub(r"Optional\[(.*)\]", r"\1", _type)
|
||||
|
||||
@staticmethod
|
||||
def check_for_list_type(_type: str) -> tuple:
|
||||
"""Checks for list type and returns the modified type and a boolean indicating if it's a list."""
|
||||
is_list = "List" in _type or "Sequence" in _type
|
||||
if is_list:
|
||||
_type = re.sub(r"(List|Sequence)\[(.*)\]", r"\2", _type)
|
||||
return _type, is_list
|
||||
|
||||
@staticmethod
|
||||
def replace_mapping_with_dict(_type: str) -> str:
|
||||
"""Replaces 'Mapping' with 'dict'."""
|
||||
return _type.replace("Mapping", "dict")
|
||||
|
||||
@staticmethod
|
||||
def handle_union_type(_type: str) -> str:
|
||||
"""Simplifies the 'Union' type to the first type in the Union."""
|
||||
if "Union" in _type:
|
||||
_type = _type.replace("Union[", "")[:-1]
|
||||
_type = _type.split(",")[0]
|
||||
_type = _type.replace("]", "").replace("[", "")
|
||||
return _type
|
||||
|
||||
@staticmethod
|
||||
def handle_special_field(field, key: str, _type: str, SPECIAL_FIELD_HANDLERS) -> str:
|
||||
"""Handles special field by using the respective handler if present."""
|
||||
handler = SPECIAL_FIELD_HANDLERS.get(key)
|
||||
return handler(field) if handler else _type
|
||||
|
||||
@staticmethod
|
||||
def handle_dict_type(field: Input, _type: str) -> str:
|
||||
"""Handles 'dict' type by replacing it with 'code' or 'file' based on the field name."""
|
||||
if "dict" in _type.lower() and field.name == "dict_":
|
||||
field.field_type = "file"
|
||||
field.file_types = [".json", ".yaml", ".yml"]
|
||||
elif _type.startswith("Dict") or _type.startswith("Mapping") or _type.startswith("dict"):
|
||||
field.field_type = "dict"
|
||||
return _type
|
||||
|
||||
@staticmethod
|
||||
def replace_default_value(field: Input, value: dict) -> None:
|
||||
"""Replaces default value with actual value if 'default' is present in value."""
|
||||
if "default" in value:
|
||||
field.value = value["default"]
|
||||
|
||||
@staticmethod
|
||||
def handle_specific_field_values(field: Input, key: str, name: Optional[str] = None) -> None:
|
||||
"""Handles specific field values for certain fields."""
|
||||
if key == "headers":
|
||||
field.value = """{"Authorization": "Bearer <token>"}"""
|
||||
FrontendNode._handle_model_specific_field_values(field, key, name)
|
||||
FrontendNode._handle_api_key_specific_field_values(field, key, name)
|
||||
|
||||
@staticmethod
|
||||
def _handle_model_specific_field_values(field: Input, key: str, name: Optional[str] = None) -> None:
|
||||
"""Handles specific field values related to models."""
|
||||
model_dict = {
|
||||
"OpenAI": constants.OPENAI_MODELS,
|
||||
"ChatOpenAI": constants.CHAT_OPENAI_MODELS,
|
||||
"Anthropic": constants.ANTHROPIC_MODELS,
|
||||
"ChatAnthropic": constants.ANTHROPIC_MODELS,
|
||||
}
|
||||
if name in model_dict and key == "model_name":
|
||||
field.options = model_dict[name]
|
||||
field.is_list = True
|
||||
|
||||
@staticmethod
|
||||
def _handle_api_key_specific_field_values(field: Input, key: str, name: Optional[str] = None) -> None:
|
||||
"""Handles specific field values related to API keys."""
|
||||
if "api_key" in key and "OpenAI" in str(name):
|
||||
field.display_name = "OpenAI API Key"
|
||||
field.required = False
|
||||
if field.value is None:
|
||||
field.value = ""
|
||||
|
||||
@staticmethod
|
||||
def handle_kwargs_field(field: Input) -> None:
|
||||
"""Handles kwargs field by setting certain attributes."""
|
||||
|
||||
if "kwargs" in (field.name or "").lower():
|
||||
field.advanced = True
|
||||
field.required = False
|
||||
field.show = False
|
||||
|
||||
@staticmethod
|
||||
def handle_api_key_field(field: Input, key: str) -> None:
|
||||
"""Handles api key field by setting certain attributes."""
|
||||
if "api" in key.lower() and "key" in key.lower():
|
||||
field.required = False
|
||||
field.advanced = False
|
||||
|
||||
field.display_name = key.replace("_", " ").title()
|
||||
field.display_name = field.display_name.replace("Api", "API")
|
||||
|
||||
@staticmethod
|
||||
def should_show_field(key: str, required: bool) -> bool:
|
||||
"""Determines whether the field should be shown."""
|
||||
return (
|
||||
(required and key not in ["input_variables"])
|
||||
or key in FORCE_SHOW_FIELDS
|
||||
or "api" in key
|
||||
or ("key" in key and "input" not in key and "output" not in key)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_be_password(key: str, show: bool) -> bool:
|
||||
"""Determines whether the field should be a password field."""
|
||||
return any(text in key.lower() for text in {"password", "token", "api", "key"}) and show
|
||||
|
||||
@staticmethod
|
||||
def should_be_multiline(key: str) -> bool:
|
||||
"""Determines whether the field should be multiline."""
|
||||
return key in {
|
||||
"suffix",
|
||||
"prefix",
|
||||
"template",
|
||||
"examples",
|
||||
"code",
|
||||
"headers",
|
||||
"description",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def set_field_default_value(field: Input, value: dict, key: str) -> None:
|
||||
"""Sets the field value with the default value if present."""
|
||||
if "default" in value:
|
||||
field.value = value["default"]
|
||||
if key == "headers":
|
||||
field.value = """{"Authorization": "Bearer <token>"}"""
|
||||
@classmethod
|
||||
def from_inputs(cls, **kwargs):
|
||||
"""Create a frontend node from inputs."""
|
||||
if "inputs" not in kwargs:
|
||||
raise ValueError("Missing 'inputs' argument.")
|
||||
inputs = kwargs.pop("inputs")
|
||||
template = Template(type_name="CustomComponent", fields=inputs)
|
||||
kwargs["template"] = template
|
||||
return cls(**kwargs)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.initial_setup.setup import (
|
||||
|
|
@ -41,7 +42,8 @@ def test_get_project_data():
|
|||
assert isinstance(project_icon_bg_color, str) or project_icon_bg_color is None
|
||||
|
||||
|
||||
def test_create_or_update_starter_projects(client):
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_or_update_starter_projects(client):
|
||||
with session_scope() as session:
|
||||
# Run the function to create or update projects
|
||||
create_or_update_starter_projects()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue