🔧 fix(test_custom_component.py): fix import statements and formatting issues in test_custom_component.py

 feat(test_custom_component.py): add tests for list_flows, build_config methods in CustomComponent class
🔧 fix(test_custom_component.py): fix formatting issues in test_list_flows_multiple_queries test
 feat(test_custom_component.py): add tests for list_flows, build_config methods in CustomComponent class
 feat(test_custom_component.py): add test for return type of list_flows method in CustomComponent class
 feat(test_custom_component.py): add test for return type of build_config method in CustomComponent class
 feat(test_custom_component.py): add test for presence of 'fields' key in build_config method in CustomComponent class
 feat(test_custom_component.py): add test for type of 'fields' value in build_config method in CustomComponent class
 feat(test_custom_component.py): add test for type of keys in 'fields' value in build_config method in CustomComponent class
 feat(test_custom_component.py): add test for type of values in 'fields' value in build_config method in CustomComponent class
This commit is contained in:
gustavoschaedler 2023-07-26 17:40:53 +01:00
commit d5ee293590

View file

@ -1,9 +1,12 @@
import ast
import pytest
import types
from uuid import uuid4
from unittest.mock import patch, MagicMock
from fastapi import HTTPException
from langflow.database.models.flow import Flow, FlowCreate
from langflow.interface.custom.base import CustomComponent
from langflow.interface.custom.component import (
Component,
@ -57,7 +60,8 @@ def test_code_parser_get_tree():
def test_code_parser_syntax_error():
"""
Test the __get_tree method raises the CodeSyntaxError when given incorrect syntax.
Test the __get_tree method raises the
CodeSyntaxError when given incorrect syntax.
"""
code_syntax_error = "zzz import os"
@ -86,7 +90,8 @@ def test_component_get_code_tree():
def test_component_code_null_error():
"""
Test the get_function method raises the ComponentCodeNullError when the code is empty.
Test the get_function method raises the
ComponentCodeNullError when the code is empty.
"""
component = Component(code="", function_entrypoint_name="")
with pytest.raises(ComponentCodeNullError):
@ -140,7 +145,8 @@ def test_custom_component_get_function():
def test_code_parser_parse_imports_import():
"""
Test the parse_imports method of the CodeParser class with an import statement.
Test the parse_imports method of the CodeParser
class with an import statement.
"""
parser = CodeParser(code_default)
tree = parser._CodeParser__get_tree()
@ -152,7 +158,8 @@ def test_code_parser_parse_imports_import():
def test_code_parser_parse_imports_importfrom():
"""
Test the parse_imports method of the CodeParser class with an import from statement.
Test the parse_imports method of the CodeParser
class with an import from statement.
"""
parser = CodeParser("from os import path")
tree = parser._CodeParser__get_tree()
@ -203,16 +210,18 @@ def test_code_parser_parse_global_vars():
def test_component_get_function_valid():
"""
Test the get_function method of the Component class with valid code and function_entrypoint_name.
Test the get_function method of the Component
class with valid code and function_entrypoint_name.
"""
component = Component(code="def build(): pass", function_entrypoint_name="build")
function = component.get_function()
assert callable(function)
my_function = component.get_function()
assert callable(my_function)
def test_custom_component_get_function_entrypoint_args():
"""
Test the get_function_entrypoint_args property of the CustomComponent class.
Test the get_function_entrypoint_args
property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
@ -226,7 +235,8 @@ def test_custom_component_get_function_entrypoint_args():
def test_custom_component_get_function_entrypoint_return_type():
"""
Test the get_function_entrypoint_return_type property of the CustomComponent class.
Test the get_function_entrypoint_return_type
property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
@ -248,7 +258,8 @@ def test_custom_component_get_main_class_name():
def test_custom_component_get_function_valid():
"""
Test the get_function property of the CustomComponent class with valid code and function_entrypoint_name.
Test the get_function property of the CustomComponent
class with valid code and function_entrypoint_name.
"""
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
@ -281,7 +292,8 @@ def test_code_parser_parse_arg_with_annotation():
def test_code_parser_parse_callable_details_no_args():
"""
Test the parse_callable_details method of the CodeParser class with a function with no arguments.
Test the parse_callable_details method of the
CodeParser class with a function with no arguments.
"""
parser = CodeParser("")
node = ast.FunctionDef(
@ -328,7 +340,8 @@ def test_code_parser_parse_ann_assign():
def test_code_parser_parse_function_def_not_init():
"""
Test the parse_function_def method of the CodeParser class with a function that is not __init__.
Test the parse_function_def method of the
CodeParser class with a function that is not __init__.
"""
parser = CodeParser("")
stmt = ast.FunctionDef(
@ -347,7 +360,8 @@ def test_code_parser_parse_function_def_not_init():
def test_code_parser_parse_function_def_init():
"""
Test the parse_function_def method of the CodeParser class with an __init__ function.
Test the parse_function_def method of the
CodeParser class with an __init__ function.
"""
parser = CodeParser("")
stmt = ast.FunctionDef(
@ -386,7 +400,8 @@ def test_custom_component_class_template_validation_no_code():
def test_custom_component_get_code_tree_syntax_error():
"""
Test the get_code_tree method of the CustomComponent class raises the CodeSyntaxError when given incorrect syntax.
Test the get_code_tree method of the CustomComponent class
raises the CodeSyntaxError when given incorrect syntax.
"""
custom_component = CustomComponent(
code="import os as", function_entrypoint_name="build"
@ -397,7 +412,8 @@ def test_custom_component_get_code_tree_syntax_error():
def test_custom_component_get_function_entrypoint_args_no_args():
"""
Test the get_function_entrypoint_args property of the CustomComponent class with a build method with no arguments.
Test the get_function_entrypoint_args property of
the CustomComponent class with a build method with no arguments.
"""
my_code = """
class MyMainClass(CustomComponent):
@ -426,7 +442,8 @@ class MyClass(CustomComponent):
def test_custom_component_get_main_class_name_no_main_class():
"""
Test the get_main_class_name property of the CustomComponent class when there is no main class.
Test the get_main_class_name property of the
CustomComponent class when there is no main class.
"""
my_code = """
def build():
@ -439,7 +456,8 @@ def build():
def test_custom_component_build_not_implemented():
"""
Test the build method of the CustomComponent class raises the NotImplementedError.
Test the build method of the CustomComponent
class raises the NotImplementedError.
"""
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
@ -469,21 +487,87 @@ def test_build_config_no_code():
assert component.get_function_entrypoint_return_type == ""
def test_list_flows_multiple_queries():
mock_flow_1 = MagicMock()
mock_flow_2 = MagicMock()
@pytest.fixture
def component():
return CustomComponent(
field_config={
"fields": {
"llm": {"type": "str"},
"url": {"type": "str"},
"year": {"type": "int"},
}
}
)
session_getter_module = "langflow.database.base.session_getter"
with patch(session_getter_module) as mock_session_getter:
mock_session = MagicMock()
mock_session.query.return_value.all.side_effect = [[mock_flow_1], [mock_flow_2]]
mock_session_getter.return_value.__enter__.return_value = mock_session
@pytest.fixture(scope="session")
def test_flow(db):
flow_data = {
"nodes": [{"id": "1"}, {"id": "2"}],
"edges": [{"source": "1", "target": "2"}],
}
component = CustomComponent()
result = component.list_flows()
# Create flow
flow = FlowCreate(
id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data
)
# Only the result of the second query is returned
assert len(result) == 1
assert result[0] == mock_flow_2
assert mock_session.query.call_count == 2
# Add to database
db.add(flow)
db.commit()
yield flow
# Clean up
db.delete(flow)
db.commit()
@pytest.fixture(scope="session")
def db(app):
# Setup database for tests
yield app.db
# Teardown
app.db.drop_all()
def test_list_flows_return_type(component):
flows = component.list_flows()
assert isinstance(flows, list)
def test_list_flows_flow_objects(component):
flows = component.list_flows()
assert all(isinstance(flow, Flow) for flow in flows)
def test_build_config_return_type(component):
config = component.build_config()
assert isinstance(config, dict)
def test_build_config_has_fields(component):
config = component.build_config()
assert "fields" in config
def test_build_config_fields_dict(component):
config = component.build_config()
assert isinstance(config["fields"], dict)
def test_build_config_field_keys(component):
config = component.build_config()
assert all(isinstance(key, str) for key in config["fields"])
def test_build_config_field_values_dict(component):
config = component.build_config()
assert all(isinstance(value, dict) for value in config["fields"].values())
def test_build_config_field_value_keys(component):
config = component.build_config()
field_values = config["fields"].values()
assert all("type" in value for value in field_values)