diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 4df8df96b..31591f735 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -34,6 +34,9 @@ "GitHub.vscode-pull-request-github" ] } + }, + "tasks": { + "test": "make unit_tests" } // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. diff --git a/localhost-19-2.db b/localhost-19-2.db new file mode 100644 index 000000000..7ee7c113a Binary files /dev/null and b/localhost-19-2.db differ diff --git a/src/backend/base/langflow/utils/validate.py b/src/backend/base/langflow/utils/validate.py index 5178cf735..4de5f9884 100644 --- a/src/backend/base/langflow/utils/validate.py +++ b/src/backend/base/langflow/utils/validate.py @@ -227,6 +227,16 @@ def prepare_global_scope(code, module): except ModuleNotFoundError as e: msg = f"Module {node.module} not found. Please install it and try again" raise ModuleNotFoundError(msg) from e + elif isinstance(node, ast.ClassDef): + # Compile and execute the class definition to properly create the class + class_code = compile(ast.Module(body=[node], type_ignores=[]), "", "exec") + exec(class_code, exec_globals) + elif isinstance(node, ast.FunctionDef): + function_code = compile(ast.Module(body=[node], type_ignores=[]), "", "exec") + exec(function_code, exec_globals) + elif isinstance(node, ast.Assign): + assign_code = compile(ast.Module(body=[node], type_ignores=[]), "", "exec") + exec(assign_code, exec_globals) return exec_globals diff --git a/src/backend/tests/unit/test_validate_code.py b/src/backend/tests/unit/test_validate_code.py index 93b567cd9..4470b3efe 100644 --- a/src/backend/tests/unit/test_validate_code.py +++ b/src/backend/tests/unit/test_validate_code.py @@ -2,7 +2,13 @@ from pathlib import Path from unittest import mock import pytest -from langflow.utils.validate import create_function, execute_function, extract_function_name, validate_code +from langflow.utils.validate import ( + create_class, + create_function, + execute_function, + extract_function_name, + validate_code, +) from requests.exceptions import MissingSchema @@ -100,3 +106,67 @@ def my_function(x): """ with mock.patch("requests.get", side_effect=MissingSchema), pytest.raises(MissingSchema): execute_function(code, "my_function", "invalid_url") + + +def test_create_class(): + code = """ +from langflow.custom import CustomComponent + +class ExternalClass: + def __init__(self, value): + self.value = value + +class MyComponent(CustomComponent): + def build(self): + return ExternalClass("test") +""" + class_name = "MyComponent" + created_class = create_class(code, class_name) + instance = created_class() + result = instance.build() + assert result.value == "test" + + +def test_create_class_with_multiple_external_classes(): + code = """ +from langflow.custom import CustomComponent + +class ExternalClass1: + def __init__(self, value): + self.value = value + +class ExternalClass2: + def __init__(self, value): + self.value = value + +class MyComponent(CustomComponent): + def build(self): + return ExternalClass1("test1"), ExternalClass2("test2") +""" + class_name = "MyComponent" + created_class = create_class(code, class_name) + instance = created_class() + result1, result2 = instance.build() + assert result1.value == "test1" + assert result2.value == "test2" + + +def test_create_class_with_external_variables_and_functions(): + code = """ +from langflow.custom import CustomComponent + +external_variable = "external_value" + +def external_function(): + return "external_function_value" + +class MyComponent(CustomComponent): + def build(self): + return external_variable, external_function() +""" + class_name = "MyComponent" + created_class = create_class(code, class_name) + instance = created_class() + result_variable, result_function = instance.build() + assert result_variable == "external_value" + assert result_function == "external_function_value"