Add update_build_config method that updates any part of the template

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-05 16:29:03 -03:00
commit 240b1900b9
11 changed files with 257 additions and 113 deletions

View file

@ -83,15 +83,13 @@ The CustomComponent class serves as the foundation for creating custom component
| _`file_types: List[str]`_ | This is a requirement if the _`field_type`_ is _file_. Defines which file types will be accepted. For example, _json_, _yaml_ or _yml_. |
| _`range_spec: langflow.field_typing.RangeSpec`_ | This is a requirement if the _`field_type`_ is _`float`_. Defines the range of values accepted and the step size. If none is defined, the default is _`[-1, 1, 0.1]`_. |
| _`title_case: bool`_ | Formats the name of the field when _`display_name`_ is not defined. Set it to False to keep the name as you set it in the _`build`_ method. |
| _`refresh: bool`_ | If set to True a button will appear to the right of the field, and when clicked, it will call the _`update_build_config`_ method which takes in the _`build_config`_, the name of the field (_`field_name`_) and the latest value of the field (_`field_value`_). This is useful when you want to update the _`build_config`_ based on the value of the field. |
<Admonition type="info" label="Tip">
Keys _`options`_ and _`value`_ can receive a method or function that returns a list of strings or a string, respectively. This is useful when you want to dynamically generate the options or the default value of a field. A refresh button will appear next to the field in the component, allowing the user to update the options or the default value.
</Admonition>
<Admonition type="info" label="Tip">
By using the _`update_build_config`_ method, you can update the _`build_config`_ in whatever way you want based on the value of the field or not.
</Admonition>
- The CustomComponent class also provides helpful methods for specific tasks (e.g., to load and use other flows from the Langflow platform):

View file

@ -306,7 +306,10 @@ async def custom_component_update(
component = CustomComponent(code=raw_code.code)
component_node = build_custom_component_template(
component, user_id=user.id, update_field=raw_code.field
component,
user_id=user.id,
update_field=raw_code.field,
update_field_value=raw_code.field_value,
)
# Update the field
return component_node

View file

@ -166,6 +166,7 @@ class StreamData(BaseModel):
class CustomComponentCode(BaseModel):
code: str
field: Optional[str] = None
field_value: Optional[str] = None
frontend_node: Optional[dict] = None

View file

@ -1,4 +1,6 @@
from fastapi import APIRouter, HTTPException
from loguru import logger
from langflow.api.v1.base import (
Code,
CodeValidationResponse,
@ -6,9 +8,8 @@ from langflow.api.v1.base import (
ValidatePromptRequest,
validate_prompt,
)
from langflow.template.field.base import TemplateField
from langflow.utils.validate import PROMPT_INPUT_TYPES, validate_code
from loguru import logger
from langflow.template.field.prompt import DefaultPromptField
from langflow.utils.validate import validate_code
# build router
router = APIRouter(prefix="/validate", tags=["Validate"])
@ -40,7 +41,9 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
add_new_variables_to_template(input_variables, prompt_request)
remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request)
remove_old_variables_from_template(
old_custom_fields, input_variables, prompt_request
)
update_input_variables_field(input_variables, prompt_request)
@ -55,12 +58,19 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
def get_old_custom_fields(prompt_request):
try:
if len(prompt_request.frontend_node.custom_fields) == 1 and prompt_request.name == "":
if (
len(prompt_request.frontend_node.custom_fields) == 1
and prompt_request.name == ""
):
# If there is only one custom field and the name is empty string
# then we are dealing with the first prompt request after the node was created
prompt_request.name = list(prompt_request.frontend_node.custom_fields.keys())[0]
prompt_request.name = list(
prompt_request.frontend_node.custom_fields.keys()
)[0]
old_custom_fields = prompt_request.frontend_node.custom_fields[prompt_request.name]
old_custom_fields = prompt_request.frontend_node.custom_fields[
prompt_request.name
]
if old_custom_fields is None:
old_custom_fields = []
@ -74,38 +84,43 @@ def get_old_custom_fields(prompt_request):
def add_new_variables_to_template(input_variables, prompt_request):
for variable in input_variables:
try:
template_field = TemplateField(
name=variable,
display_name=variable,
field_type="str",
show=True,
advanced=False,
multiline=True,
input_types=PROMPT_INPUT_TYPES,
value="", # Set the value to empty string
)
template_field = DefaultPromptField(name=variable, display_name=variable)
if variable in prompt_request.frontend_node.template:
# Set the new field with the old value
template_field.value = prompt_request.frontend_node.template[variable]["value"]
template_field.value = prompt_request.frontend_node.template[variable][
"value"
]
prompt_request.frontend_node.template[variable] = template_field.to_dict()
# Check if variable is not already in the list before appending
if variable not in prompt_request.frontend_node.custom_fields[prompt_request.name]:
prompt_request.frontend_node.custom_fields[prompt_request.name].append(variable)
if (
variable
not in prompt_request.frontend_node.custom_fields[prompt_request.name]
):
prompt_request.frontend_node.custom_fields[prompt_request.name].append(
variable
)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request):
def remove_old_variables_from_template(
old_custom_fields, input_variables, prompt_request
):
for variable in old_custom_fields:
if variable not in input_variables:
try:
# Remove the variable from custom_fields associated with the given name
if variable in prompt_request.frontend_node.custom_fields[prompt_request.name]:
prompt_request.frontend_node.custom_fields[prompt_request.name].remove(variable)
if (
variable
in prompt_request.frontend_node.custom_fields[prompt_request.name]
):
prompt_request.frontend_node.custom_fields[
prompt_request.name
].remove(variable)
# Remove the variable from the template
prompt_request.frontend_node.template.pop(variable, None)
@ -117,4 +132,6 @@ def remove_old_variables_from_template(old_custom_fields, input_variables, promp
def update_input_variables_field(input_variables, prompt_request):
if "input_variables" in prompt_request.frontend_node.template:
prompt_request.frontend_node.template["input_variables"]["value"] = input_variables
prompt_request.frontend_node.template["input_variables"][
"value"
] = input_variables

View file

@ -95,7 +95,9 @@ class CodeParser:
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
if alias.asname:
self.data["imports"].append((node.module, f"{alias.name} as {alias.asname}"))
self.data["imports"].append(
(node.module, f"{alias.name} as {alias.asname}")
)
else:
self.data["imports"].append((node.module, alias.name))
@ -144,7 +146,9 @@ class CodeParser:
return_type = None
if node.returns:
return_type_str = ast.unparse(node.returns)
eval_env = self.construct_eval_env(return_type_str, tuple(self.data["imports"]))
eval_env = self.construct_eval_env(
return_type_str, tuple(self.data["imports"])
)
try:
return_type = eval(return_type_str, eval_env)
@ -174,7 +178,7 @@ class CodeParser:
args += self.parse_keyword_args(node)
# Commented out because we don't want kwargs
# showing up as fields in the frontend
# args += self.parse_kwargs(node)
args += self.parse_kwargs(node)
return args
@ -186,14 +190,22 @@ class CodeParser:
num_defaults = len(node.args.defaults)
num_missing_defaults = num_args - num_defaults
missing_defaults = [None] * num_missing_defaults
default_values = [ast.unparse(default).strip("'") if default else None for default in node.args.defaults]
default_values = [
ast.unparse(default).strip("'") if default else None
for default in node.args.defaults
]
# Now check all default values to see if there
# are any "None" values in the middle
default_values = [None if value == "None" else value for value in default_values]
default_values = [
None if value == "None" else value for value in default_values
]
defaults = missing_defaults + default_values
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.args, defaults)]
args = [
self.parse_arg(arg, default)
for arg, default in zip(node.args.args, defaults)
]
return args
def parse_varargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
@ -211,11 +223,17 @@ class CodeParser:
"""
Parses the keyword-only arguments of a function or method node.
"""
kw_defaults = [None] * (len(node.args.kwonlyargs) - len(node.args.kw_defaults)) + [
ast.unparse(default) if default else None for default in node.args.kw_defaults
kw_defaults = [None] * (
len(node.args.kwonlyargs) - len(node.args.kw_defaults)
) + [
ast.unparse(default) if default else None
for default in node.args.kw_defaults
]
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.kwonlyargs, kw_defaults)]
args = [
self.parse_arg(arg, default)
for arg, default in zip(node.args.kwonlyargs, kw_defaults)
]
return args
def parse_kwargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
@ -319,7 +337,9 @@ class CodeParser:
Extracts global variables from the code.
"""
global_var = {
"targets": [t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets],
"targets": [
t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets
],
"value": ast.unparse(node.value),
}
self.data["global_vars"].append(global_var)

View file

@ -1,15 +1,7 @@
import operator
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
List,
Optional,
Sequence,
Union,
)
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, List, Optional,
Sequence, Union)
from uuid import UUID
import yaml
@ -20,17 +12,13 @@ from sqlmodel import select
from langflow.interface.custom.code_parser.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
)
extract_union_types_from_generic_alias)
from langflow.interface.custom.custom_component.component import Component
from langflow.schema import Record
from langflow.services.database.models.flow import Flow
from langflow.services.database.utils import session_getter
from langflow.services.deps import (
get_credential_service,
get_db_service,
get_storage_service,
)
from langflow.services.deps import (get_credential_service, get_db_service,
get_storage_service)
from langflow.services.storage.service import StorageService
from langflow.utils import validate
@ -138,6 +126,10 @@ class CustomComponent(Component):
def build_config(self):
return self.field_config
def update_build_config(self, build_config: dict, field_name: str, field_value: Any):
build_config[field_name] = field_value
return build_config
@property
def tree(self):
return self.get_code_tree(self.code or "")

View file

@ -8,6 +8,7 @@ from uuid import UUID
from fastapi import HTTPException
from loguru import logger
from pydantic import BaseModel
from langflow.field_typing.range_spec import RangeSpec
from langflow.interface.custom.attributes import ATTR_FUNC_MAPPING
@ -27,6 +28,10 @@ from langflow.utils import validate
from langflow.utils.util import get_base_classes
class UpdateBuildConfigError(Exception):
pass
def add_output_types(
frontend_node: CustomComponentFrontendNode, return_types: List[str]
):
@ -190,15 +195,23 @@ def add_extra_fields(frontend_node, field_config, function_args):
"""Add extra fields to the frontend node"""
if not function_args:
return
_field_config = field_config.copy()
function_args_names = [arg["name"] for arg in function_args]
# If kwargs is in the function_args and not all field_config keys are in function_args
# then we need to add the extra fields
for extra_field in function_args:
if "name" not in extra_field or extra_field["name"] == "self":
if "name" not in extra_field or extra_field["name"] in [
"self",
"kwargs",
"args",
]:
continue
field_name, field_type, field_value, field_required = get_field_properties(
extra_field
)
config = field_config.get(field_name, {})
config = _field_config.pop(field_name, {})
frontend_node = add_new_custom_field(
frontend_node,
field_name,
@ -207,6 +220,23 @@ def add_extra_fields(frontend_node, field_config, function_args):
field_required,
config,
)
if "kwargs" in function_args_names and not all(
key in function_args_names for key in field_config.keys()
):
for field_name, field_config in _field_config.copy().items():
config = _field_config.get(field_name, {})
config = config.model_dump() if isinstance(config, BaseModel) else config
field_name, field_type, field_value, field_required = get_field_properties(
extra_field=config
)
frontend_node = add_new_custom_field(
frontend_node,
field_name,
field_type,
field_value,
field_required,
config,
)
def get_field_dict(field: Union[TemplateField, dict]):
@ -220,6 +250,7 @@ def run_build_config(
custom_component: CustomComponent,
user_id: Optional[Union[str, UUID]] = None,
update_field=None,
update_field_value=None,
):
"""Build the field configuration for a custom component"""
@ -246,33 +277,48 @@ def run_build_config(
custom_instance = custom_class(user_id=user_id)
build_config: Dict = custom_instance.build_config()
for field_name, field in build_config.items():
for field_name, field in build_config.copy().items():
# Allow user to build TemplateField as well
# as a dict with the same keys as TemplateField
field_dict = get_field_dict(field)
# This has to be done to set refresh if options or value are callable
update_field_dict(field_dict)
if update_field is not None and field_name != update_field:
build_config = update_field_dict(
custom_component_instance=custom_instance,
field_dict=field_dict,
build_config=build_config,
call=False,
)
continue
try:
update_field_dict(field_dict, call=True)
build_config = update_field_dict(
custom_component_instance=custom_instance,
field_dict=field_dict,
build_config=build_config,
update_field=update_field,
update_field_value=update_field_value,
call=True,
)
build_config[field_name] = field_dict
except Exception as exc:
logger.error(f"Error while getting build_config: {str(exc)}")
if isinstance(exc, UpdateBuildConfigError):
message = str(exc)
else:
message = f"Error while getting build_config: {str(exc)}"
raise HTTPException(
status_code=400,
detail={
"error": message,
"traceback": traceback.format_exc(),
},
) from exc
return build_config, custom_instance
except Exception as exc:
logger.error(f"Error while building field config: {str(exc)}")
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid type convertion. Please check your code and try again."
),
"traceback": traceback.format_exc(),
},
) from exc
raise exc
def sanitize_template_config(template_config):
@ -317,13 +363,17 @@ def build_custom_component_template(
custom_component: CustomComponent,
user_id: Optional[Union[str, UUID]] = None,
update_field: Optional[str] = None,
update_field_value: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
"""Build a custom component template for the langchain"""
try:
frontend_node = build_frontend_node(custom_component.template_config)
field_config, custom_instance = run_build_config(
custom_component, user_id=user_id, update_field=update_field
custom_component,
user_id=user_id,
update_field=update_field,
update_field_value=update_field_value,
)
entrypoint_args = custom_component.get_function_entrypoint_args
@ -402,22 +452,32 @@ def build_custom_components(settings_service):
return custom_components_from_file
def update_field_dict(field_dict, call=False):
def update_field_dict(
custom_component_instance: "CustomComponent",
field_dict: Dict,
build_config: Dict,
update_field: Optional[str] = None,
update_field_value: Optional[Any] = None,
call: bool = False,
):
"""Update the field dictionary by calling options() or value() if they are callable"""
if "options" in field_dict and callable(field_dict["options"]):
if "refresh" in field_dict:
if call:
field_dict["options"] = field_dict["options"]()
# Also update the "refresh" key
field_dict["refresh"] = True
if "value" in field_dict and callable(field_dict["value"]):
if call:
field_dict["value"] = field_dict["value"]()
try:
custom_component_instance.update_build_config(
build_config, update_field, update_field_value
)
except Exception as exc:
logger.error(f"Error while running update_build_config: {str(exc)}")
raise UpdateBuildConfigError(
f"Error while running update_build_config: {str(exc)}"
) from exc
field_dict["refresh"] = True
# Let's check if "range_spec" is a RangeSpec object
if "rangeSpec" in field_dict and isinstance(field_dict["rangeSpec"], RangeSpec):
field_dict["rangeSpec"] = field_dict["rangeSpec"].model_dump()
return build_config
def sanitize_field_config(field_config: Dict):

View file

@ -0,0 +1,14 @@
from typing import Optional
from langflow.template.field.base import TemplateField
class DefaultPromptField(TemplateField):
name: str
display_name: Optional[str] = None
field_type: str = "str"
advanced: bool = False
multiline: bool = True
input_types: list[str] = ["Document", "BaseOutputParser", "Text", "Record"]
value: str = "" # Set the value to empty string

View file

@ -6,8 +6,6 @@ from typing import Dict, List, Optional, Union
from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
PROMPT_INPUT_TYPES = ["Document", "BaseOutputParser", "Text", "Record"]
def add_type_ignores():
if not hasattr(ast, "TypeIgnore"):
@ -45,7 +43,9 @@ def validate_code(code):
# Evaluate the function definition
for node in tree.body:
if isinstance(node, ast.FunctionDef):
code_obj = compile(ast.Module(body=[node], type_ignores=[]), "<string>", "exec")
code_obj = compile(
ast.Module(body=[node], type_ignores=[]), "<string>", "exec"
)
try:
exec(code_obj)
except Exception as e:
@ -89,15 +89,23 @@ def execute_function(code, function_name, *args, **kwargs):
exec_globals,
locals(),
)
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
function_code = next(
node for node in module.body if isinstance(node, ast.FunctionDef) and node.name == function_name
node
for node in module.body
if isinstance(node, ast.FunctionDef) and node.name == function_name
)
function_code.parent = None
code_obj = compile(ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec")
code_obj = compile(
ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec"
)
try:
exec(code_obj, exec_globals, locals())
except Exception as exc:
@ -124,15 +132,23 @@ def create_function(code, function_name):
if isinstance(node, ast.Import):
for alias in node.names:
try:
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
function_code = next(
node for node in module.body if isinstance(node, ast.FunctionDef) and node.name == function_name
node
for node in module.body
if isinstance(node, ast.FunctionDef) and node.name == function_name
)
function_code.parent = None
code_obj = compile(ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec")
code_obj = compile(
ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec"
)
with contextlib.suppress(Exception):
exec(code_obj, exec_globals, locals())
exec_globals[function_name] = locals()[function_name]
@ -194,9 +210,13 @@ def prepare_global_scope(code, module):
if isinstance(node, ast.Import):
for alias in node.names:
try:
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
elif isinstance(node, ast.ImportFrom) and node.module is not None:
try:
imported_module = importlib.import_module(node.module)
@ -217,7 +237,11 @@ def extract_class_code(module, class_name):
:param class_name: Name of the class to extract
:return: AST node of the specified class
"""
class_code = next(node for node in module.body if isinstance(node, ast.ClassDef) and node.name == class_name)
class_code = next(
node
for node in module.body
if isinstance(node, ast.ClassDef) and node.name == class_name
)
class_code.parent = None
return class_code
@ -230,7 +254,9 @@ def compile_class_code(class_code):
:param class_code: AST node of the class
:return: Compiled code object of the class
"""
code_obj = compile(ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec")
code_obj = compile(
ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec"
)
return code_obj
@ -274,7 +300,9 @@ def get_default_imports(code_string):
langflow_imports = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
necessary_imports = find_names_in_code(code_string, langflow_imports)
langflow_module = importlib.import_module("langflow.field_typing")
default_imports.update({name: getattr(langflow_module, name) for name in necessary_imports})
default_imports.update(
{name: getattr(langflow_module, name) for name in necessary_imports}
)
return default_imports

View file

@ -98,24 +98,33 @@ export default function ParameterComponent({
return;
}
try {
const res = await postCustomComponentUpdate(code, name);
if (res.status === 200 && data.node?.template) {
setNode(data.id, (oldNode) => {
let newNode = cloneDeep(oldNode);
await postCustomComponentUpdate(
code,
name,
data.node?.template[name]?.value
)
.then((res) => {
if (res.status === 200 && data.node?.template) {
setNode(data.id, (oldNode) => {
let newNode = cloneDeep(oldNode);
newNode.data = {
...newNode.data,
};
newNode.data = {
...newNode.data,
};
newNode.data.node.template[name] = res.data.template[name];
newNode.data.node.template = res.data.template;
return newNode;
return newNode;
});
}
})
.catch((error) => {
console.error("Error occurred while updating the node:", error);
setErrorData({
title: "Error while updating the Component",
list: [error.response.data.detail.error ?? "Unknown error"],
});
}
} catch (err) {
setErrorData(err as { title: string; list?: Array<string> });
}
});
renderTooltips();
if (delayAnimation) {

View file

@ -369,11 +369,13 @@ export async function postCustomComponent(
export async function postCustomComponentUpdate(
code: string,
field: string
field: string,
field_value: string
): Promise<AxiosResponse<APIClassType>> {
return await api.post(`${BASE_URL_API}custom_component/update`, {
code,
field,
field_value,
});
}