diff --git a/tests/test_custom_component.py b/tests/test_custom_component.py index 8963a7ece..2811f68bd 100644 --- a/tests/test_custom_component.py +++ b/tests/test_custom_component.py @@ -1,9 +1,12 @@ import ast import pytest import types +from uuid import uuid4 + from unittest.mock import patch, MagicMock from fastapi import HTTPException +from langflow.database.models.flow import Flow, FlowCreate from langflow.interface.custom.base import CustomComponent from langflow.interface.custom.component import ( Component, @@ -57,7 +60,8 @@ def test_code_parser_get_tree(): def test_code_parser_syntax_error(): """ - Test the __get_tree method raises the CodeSyntaxError when given incorrect syntax. + Test the __get_tree method raises the + CodeSyntaxError when given incorrect syntax. """ code_syntax_error = "zzz import os" @@ -86,7 +90,8 @@ def test_component_get_code_tree(): def test_component_code_null_error(): """ - Test the get_function method raises the ComponentCodeNullError when the code is empty. + Test the get_function method raises the + ComponentCodeNullError when the code is empty. """ component = Component(code="", function_entrypoint_name="") with pytest.raises(ComponentCodeNullError): @@ -140,7 +145,8 @@ def test_custom_component_get_function(): def test_code_parser_parse_imports_import(): """ - Test the parse_imports method of the CodeParser class with an import statement. + Test the parse_imports method of the CodeParser + class with an import statement. """ parser = CodeParser(code_default) tree = parser._CodeParser__get_tree() @@ -152,7 +158,8 @@ def test_code_parser_parse_imports_import(): def test_code_parser_parse_imports_importfrom(): """ - Test the parse_imports method of the CodeParser class with an import from statement. + Test the parse_imports method of the CodeParser + class with an import from statement. """ parser = CodeParser("from os import path") tree = parser._CodeParser__get_tree() @@ -203,16 +210,18 @@ def test_code_parser_parse_global_vars(): def test_component_get_function_valid(): """ - Test the get_function method of the Component class with valid code and function_entrypoint_name. + Test the get_function method of the Component + class with valid code and function_entrypoint_name. """ component = Component(code="def build(): pass", function_entrypoint_name="build") - function = component.get_function() - assert callable(function) + my_function = component.get_function() + assert callable(my_function) def test_custom_component_get_function_entrypoint_args(): """ - Test the get_function_entrypoint_args property of the CustomComponent class. + Test the get_function_entrypoint_args + property of the CustomComponent class. """ custom_component = CustomComponent( code=code_default, function_entrypoint_name="build" @@ -226,7 +235,8 @@ def test_custom_component_get_function_entrypoint_args(): def test_custom_component_get_function_entrypoint_return_type(): """ - Test the get_function_entrypoint_return_type property of the CustomComponent class. + Test the get_function_entrypoint_return_type + property of the CustomComponent class. """ custom_component = CustomComponent( code=code_default, function_entrypoint_name="build" @@ -248,7 +258,8 @@ def test_custom_component_get_main_class_name(): def test_custom_component_get_function_valid(): """ - Test the get_function property of the CustomComponent class with valid code and function_entrypoint_name. + Test the get_function property of the CustomComponent + class with valid code and function_entrypoint_name. """ custom_component = CustomComponent( code="def build(): pass", function_entrypoint_name="build" @@ -281,7 +292,8 @@ def test_code_parser_parse_arg_with_annotation(): def test_code_parser_parse_callable_details_no_args(): """ - Test the parse_callable_details method of the CodeParser class with a function with no arguments. + Test the parse_callable_details method of the + CodeParser class with a function with no arguments. """ parser = CodeParser("") node = ast.FunctionDef( @@ -328,7 +340,8 @@ def test_code_parser_parse_ann_assign(): def test_code_parser_parse_function_def_not_init(): """ - Test the parse_function_def method of the CodeParser class with a function that is not __init__. + Test the parse_function_def method of the + CodeParser class with a function that is not __init__. """ parser = CodeParser("") stmt = ast.FunctionDef( @@ -347,7 +360,8 @@ def test_code_parser_parse_function_def_not_init(): def test_code_parser_parse_function_def_init(): """ - Test the parse_function_def method of the CodeParser class with an __init__ function. + Test the parse_function_def method of the + CodeParser class with an __init__ function. """ parser = CodeParser("") stmt = ast.FunctionDef( @@ -386,7 +400,8 @@ def test_custom_component_class_template_validation_no_code(): def test_custom_component_get_code_tree_syntax_error(): """ - Test the get_code_tree method of the CustomComponent class raises the CodeSyntaxError when given incorrect syntax. + Test the get_code_tree method of the CustomComponent class + raises the CodeSyntaxError when given incorrect syntax. """ custom_component = CustomComponent( code="import os as", function_entrypoint_name="build" @@ -397,7 +412,8 @@ def test_custom_component_get_code_tree_syntax_error(): def test_custom_component_get_function_entrypoint_args_no_args(): """ - Test the get_function_entrypoint_args property of the CustomComponent class with a build method with no arguments. + Test the get_function_entrypoint_args property of + the CustomComponent class with a build method with no arguments. """ my_code = """ class MyMainClass(CustomComponent): @@ -426,7 +442,8 @@ class MyClass(CustomComponent): def test_custom_component_get_main_class_name_no_main_class(): """ - Test the get_main_class_name property of the CustomComponent class when there is no main class. + Test the get_main_class_name property of the + CustomComponent class when there is no main class. """ my_code = """ def build(): @@ -439,7 +456,8 @@ def build(): def test_custom_component_build_not_implemented(): """ - Test the build method of the CustomComponent class raises the NotImplementedError. + Test the build method of the CustomComponent + class raises the NotImplementedError. """ custom_component = CustomComponent( code="def build(): pass", function_entrypoint_name="build" @@ -469,21 +487,87 @@ def test_build_config_no_code(): assert component.get_function_entrypoint_return_type == "" -def test_list_flows_multiple_queries(): - mock_flow_1 = MagicMock() - mock_flow_2 = MagicMock() +@pytest.fixture +def component(): + return CustomComponent( + field_config={ + "fields": { + "llm": {"type": "str"}, + "url": {"type": "str"}, + "year": {"type": "int"}, + } + } + ) - session_getter_module = "langflow.database.base.session_getter" - with patch(session_getter_module) as mock_session_getter: - mock_session = MagicMock() - mock_session.query.return_value.all.side_effect = [[mock_flow_1], [mock_flow_2]] - mock_session_getter.return_value.__enter__.return_value = mock_session +@pytest.fixture(scope="session") +def test_flow(db): + flow_data = { + "nodes": [{"id": "1"}, {"id": "2"}], + "edges": [{"source": "1", "target": "2"}], + } - component = CustomComponent() - result = component.list_flows() + # Create flow + flow = FlowCreate( + id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data + ) - # Only the result of the second query is returned - assert len(result) == 1 - assert result[0] == mock_flow_2 - assert mock_session.query.call_count == 2 + # Add to database + db.add(flow) + db.commit() + + yield flow + + # Clean up + db.delete(flow) + db.commit() + + +@pytest.fixture(scope="session") +def db(app): + # Setup database for tests + yield app.db + + # Teardown + app.db.drop_all() + + +def test_list_flows_return_type(component): + flows = component.list_flows() + assert isinstance(flows, list) + + +def test_list_flows_flow_objects(component): + flows = component.list_flows() + assert all(isinstance(flow, Flow) for flow in flows) + + +def test_build_config_return_type(component): + config = component.build_config() + assert isinstance(config, dict) + + +def test_build_config_has_fields(component): + config = component.build_config() + assert "fields" in config + + +def test_build_config_fields_dict(component): + config = component.build_config() + assert isinstance(config["fields"], dict) + + +def test_build_config_field_keys(component): + config = component.build_config() + assert all(isinstance(key, str) for key in config["fields"]) + + +def test_build_config_field_values_dict(component): + config = component.build_config() + assert all(isinstance(value, dict) for value in config["fields"].values()) + + +def test_build_config_field_value_keys(component): + config = component.build_config() + field_values = config["fields"].values() + assert all("type" in value for value in field_values)