Fix code formatting and add MissingDefault class

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-07 16:24:08 -03:00
commit 17d1841e28
3 changed files with 57 additions and 14 deletions

View file

@ -9,7 +9,11 @@ from fastapi import HTTPException
from loguru import logger
from langflow.interface.custom.eval import eval_custom_component_code
from langflow.interface.custom.schema import CallableCodeDetails, ClassCodeDetails
from langflow.interface.custom.schema import (
CallableCodeDetails,
ClassCodeDetails,
MissingDefault,
)
class CodeSyntaxError(HTTPException):
@ -95,7 +99,9 @@ class CodeParser:
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
if alias.asname:
self.data["imports"].append((node.module, f"{alias.name} as {alias.asname}"))
self.data["imports"].append(
(node.module, f"{alias.name} as {alias.asname}")
)
else:
self.data["imports"].append((node.module, alias.name))
@ -144,7 +150,9 @@ class CodeParser:
return_type = None
if node.returns:
return_type_str = ast.unparse(node.returns)
eval_env = self.construct_eval_env(return_type_str, tuple(self.data["imports"]))
eval_env = self.construct_eval_env(
return_type_str, tuple(self.data["imports"])
)
try:
return_type = eval(return_type_str, eval_env)
@ -185,15 +193,23 @@ class CodeParser:
num_args = len(node.args.args)
num_defaults = len(node.args.defaults)
num_missing_defaults = num_args - num_defaults
missing_defaults = [None] * num_missing_defaults
default_values = [ast.unparse(default).strip("'") if default else None for default in node.args.defaults]
missing_defaults = [MissingDefault()] * num_missing_defaults
default_values = [
ast.unparse(default).strip("'") if default else None
for default in node.args.defaults
]
# Now check all default values to see if there
# are any "None" values in the middle
default_values = [None if value == "None" else value for value in default_values]
default_values = [
None if value == "None" else value for value in default_values
]
defaults = missing_defaults + default_values
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.args, defaults)]
args = [
self.parse_arg(arg, default)
for arg, default in zip(node.args.args, defaults)
]
return args
def parse_varargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
@ -211,11 +227,17 @@ class CodeParser:
"""
Parses the keyword-only arguments of a function or method node.
"""
kw_defaults = [None] * (len(node.args.kwonlyargs) - len(node.args.kw_defaults)) + [
ast.unparse(default) if default else None for default in node.args.kw_defaults
kw_defaults = [None] * (
len(node.args.kwonlyargs) - len(node.args.kw_defaults)
) + [
ast.unparse(default) if default else None
for default in node.args.kw_defaults
]
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.kwonlyargs, kw_defaults)]
args = [
self.parse_arg(arg, default)
for arg, default in zip(node.args.kwonlyargs, kw_defaults)
]
return args
def parse_kwargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
@ -319,7 +341,9 @@ class CodeParser:
Extracts global variables from the code.
"""
global_var = {
"targets": [t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets],
"targets": [
t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets
],
"value": ast.unparse(node.value),
}
self.data["global_vars"].append(global_var)

View file

@ -27,3 +27,12 @@ class CallableCodeDetails(BaseModel):
body: list
return_type: Optional[Any] = None
has_return: bool = False
class MissingDefault:
"""
A class to represent a missing default value.
"""
def __repr__(self):
return "MISSING"

View file

@ -20,6 +20,7 @@ from langflow.interface.custom.directory_reader.utils import (
merge_nested_dicts_with_renaming,
)
from langflow.interface.custom.eval import eval_custom_component_code
from langflow.interface.custom.schema import MissingDefault
from langflow.schema import dotdict
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.custom_components import (
@ -111,7 +112,7 @@ def extract_type_from_optional(field_type):
str: The extracted type, or an empty string if no type was found.
"""
match = re.search(r"\[(.*?)\]$", field_type)
return match[1] if match else None
return match[1] if match else field_type
def get_field_properties(extra_field):
@ -119,7 +120,13 @@ def get_field_properties(extra_field):
field_name = extra_field["name"]
field_type = extra_field.get("type", "str")
field_value = extra_field.get("default", "")
field_required = "optional" not in field_type.lower()
# a required field is a field that does not contain
# optional in field_type
# and a field that does not have a default value
field_required = "optional" not in field_type.lower() and isinstance(
field_value, MissingDefault
)
field_value = field_value if not isinstance(field_value, MissingDefault) else None
if not field_required:
field_type = extract_type_from_optional(field_type)
@ -469,7 +476,10 @@ def update_field_dict(
):
"""Update the field dictionary by calling options() or value() if they are callable"""
if ("real_time_refresh" in field_dict or "refresh_button" in field_dict) and any(
(field_dict["real_time_refresh"], field_dict["refresh_button"])
(
field_dict.get("real_time_refresh", False),
field_dict.get("refresh_button", False),
)
):
if call:
try: