Refactor code formatting in validate.py

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-27 16:36:27 -03:00
commit af69ac3333

View file

@ -45,7 +45,9 @@ def validate_code(code):
# Evaluate the function definition
for node in tree.body:
if isinstance(node, ast.FunctionDef):
code_obj = compile(ast.Module(body=[node], type_ignores=[]), "<string>", "exec")
code_obj = compile(
ast.Module(body=[node], type_ignores=[]), "<string>", "exec"
)
try:
exec(code_obj)
except Exception as e:
@ -89,15 +91,23 @@ def execute_function(code, function_name, *args, **kwargs):
exec_globals,
locals(),
)
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
function_code = next(
node for node in module.body if isinstance(node, ast.FunctionDef) and node.name == function_name
node
for node in module.body
if isinstance(node, ast.FunctionDef) and node.name == function_name
)
function_code.parent = None
code_obj = compile(ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec")
code_obj = compile(
ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec"
)
try:
exec(code_obj, exec_globals, locals())
except Exception as exc:
@ -124,15 +134,23 @@ def create_function(code, function_name):
if isinstance(node, ast.Import):
for alias in node.names:
try:
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
function_code = next(
node for node in module.body if isinstance(node, ast.FunctionDef) and node.name == function_name
node
for node in module.body
if isinstance(node, ast.FunctionDef) and node.name == function_name
)
function_code.parent = None
code_obj = compile(ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec")
code_obj = compile(
ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec"
)
with contextlib.suppress(Exception):
exec(code_obj, exec_globals, locals())
exec_globals[function_name] = locals()[function_name]
@ -194,9 +212,13 @@ def prepare_global_scope(code, module):
if isinstance(node, ast.Import):
for alias in node.names:
try:
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
elif isinstance(node, ast.ImportFrom) and node.module is not None:
try:
imported_module = importlib.import_module(node.module)
@ -217,7 +239,11 @@ def extract_class_code(module, class_name):
:param class_name: Name of the class to extract
:return: AST node of the specified class
"""
class_code = next(node for node in module.body if isinstance(node, ast.ClassDef) and node.name == class_name)
class_code = next(
node
for node in module.body
if isinstance(node, ast.ClassDef) and node.name == class_name
)
class_code.parent = None
return class_code
@ -230,7 +256,9 @@ def compile_class_code(class_code):
:param class_code: AST node of the class
:return: Compiled code object of the class
"""
code_obj = compile(ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec")
code_obj = compile(
ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec"
)
return code_obj
@ -274,7 +302,9 @@ def get_default_imports(code_string):
langflow_imports = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
necessary_imports = find_names_in_code(code_string, langflow_imports)
langflow_module = importlib.import_module("langflow.field_typing")
default_imports.update({name: getattr(langflow_module, name) for name in necessary_imports})
default_imports.update(
{name: getattr(langflow_module, name) for name in necessary_imports}
)
return default_imports