langflow/tests/test_custom_component.py
gustavoschaedler a89a9a3267 🔥 refactor(custom.py): remove unused code and class 'CustomComponent_old' to improve code cleanliness and maintainability
🔧 fix(test_custom_component.py): fix formatting issues in test_custom_component.py for better readability
 feat(test_custom_component.py): add import statements for 'patch' and 'MagicMock' to enable mocking in tests
🔬 test(test_custom_component.py): add test for the 'get_function' method of the Component class with valid code and function_entrypoint_name
🔬 test(test_custom_component.py): add test for the 'parse_assign' method of the CodeParser class
🔬 test(test_custom_component.py): add test for the 'get_code_tree' method of the Component class when given incorrect syntax
🔬 test(test_custom_component.py): add test for the '_class_template_validation' method of the CustomComponent class when the code is None
🔬 test(test_custom_component.py): add test for the 'get_function_entrypoint_args' method of the CustomComponent class
🔬 test(test_custom_component.py): add test for the 'get_function_entrypoint_return_type' method of the CustomComponent class
🔬 test(test_custom_component.py): add test for the 'get_main_class_name' method of the CustomComponent class when there is no main class

🔥 refactor(test_custom_component.py): remove commented out code and unused fixtures to improve code readability and maintainability

🔧 refactor(tests): remove commented out test cases and unused imports
 feat(tests): add new test case for list_flows method when there are no flows in the database
 feat(tests): add new test case for build_config method when code is not provided
 feat(tests): add new test case for list_flows method when there are multiple queries to the database
2023-07-26 16:56:21 +01:00

489 lines
15 KiB
Python

import ast
import pytest
import types
from unittest.mock import patch, MagicMock
from fastapi import HTTPException
from langflow.interface.custom.base import CustomComponent
from langflow.interface.custom.component import (
Component,
ComponentCodeNullError,
ComponentFunctionEntrypointNameNullError,
)
from langflow.interface.custom.code_parser import CodeParser, CodeSyntaxError
code_default = """
from langflow import Prompt
from langflow.interface.custom.custom_component import CustomComponent
from langchain.llms.base import BaseLLM
from langchain.chains import LLMChain
from langchain import PromptTemplate
from langchain.schema import Document
import requests
class YourComponent(CustomComponent):
display_name: str = "Your Component"
description: str = "Your description"
field_config = { "url": { "multiline": True, "required": True } }
def build(self, url: str, llm: BaseLLM, template: Prompt) -> Document:
response = requests.get(url)
prompt = PromptTemplate.from_template(template)
chain = LLMChain(llm=llm, prompt=prompt)
result = chain.run(response.text[:300])
return Document(page_content=str(result))
"""
def test_code_parser_init():
"""
Test the initialization of the CodeParser class.
"""
parser = CodeParser(code_default)
assert parser.code == code_default
def test_code_parser_get_tree():
"""
Test the __get_tree method of the CodeParser class.
"""
parser = CodeParser(code_default)
tree = parser._CodeParser__get_tree()
assert isinstance(tree, ast.AST)
def test_code_parser_syntax_error():
"""
Test the __get_tree method raises the CodeSyntaxError when given incorrect syntax.
"""
code_syntax_error = "zzz import os"
parser = CodeParser(code_syntax_error)
with pytest.raises(CodeSyntaxError):
parser._CodeParser__get_tree()
def test_component_init():
"""
Test the initialization of the Component class.
"""
component = Component(code=code_default, function_entrypoint_name="build")
assert component.code == code_default
assert component.function_entrypoint_name == "build"
def test_component_get_code_tree():
"""
Test the get_code_tree method of the Component class.
"""
component = Component(code=code_default, function_entrypoint_name="build")
tree = component.get_code_tree(component.code)
assert "imports" in tree
def test_component_code_null_error():
"""
Test the get_function method raises the ComponentCodeNullError when the code is empty.
"""
component = Component(code="", function_entrypoint_name="")
with pytest.raises(ComponentCodeNullError):
component.get_function()
def test_component_function_entrypoint_name_null_error():
"""
Test the get_function method raises the ComponentFunctionEntrypointNameNullError
when the function_entrypoint_name is empty.
"""
component = Component(code=code_default, function_entrypoint_name="")
with pytest.raises(ComponentFunctionEntrypointNameNullError):
component.get_function()
def test_custom_component_init():
"""
Test the initialization of the CustomComponent class.
"""
function_entrypoint_name = "build"
custom_component = CustomComponent(
code=code_default, function_entrypoint_name=function_entrypoint_name
)
assert custom_component.code == code_default
assert custom_component.function_entrypoint_name == function_entrypoint_name
def test_custom_component_build_template_config():
"""
Test the build_template_config property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
config = custom_component.build_template_config
assert isinstance(config, dict)
def test_custom_component_get_function():
"""
Test the get_function property of the CustomComponent class.
"""
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
)
my_function = custom_component.get_function
assert isinstance(my_function, types.FunctionType)
def test_code_parser_parse_imports_import():
"""
Test the parse_imports method of the CodeParser class with an import statement.
"""
parser = CodeParser(code_default)
tree = parser._CodeParser__get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
parser.parse_imports(node)
assert "requests" in parser.data["imports"]
def test_code_parser_parse_imports_importfrom():
"""
Test the parse_imports method of the CodeParser class with an import from statement.
"""
parser = CodeParser("from os import path")
tree = parser._CodeParser__get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom):
parser.parse_imports(node)
assert ("os", "path") in parser.data["imports"]
def test_code_parser_parse_functions():
"""
Test the parse_functions method of the CodeParser class.
"""
parser = CodeParser("def test(): pass")
tree = parser._CodeParser__get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
parser.parse_functions(node)
assert len(parser.data["functions"]) == 1
assert parser.data["functions"][0]["name"] == "test"
def test_code_parser_parse_classes():
"""
Test the parse_classes method of the CodeParser class.
"""
parser = CodeParser("class Test: pass")
tree = parser._CodeParser__get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
parser.parse_classes(node)
assert len(parser.data["classes"]) == 1
assert parser.data["classes"][0]["name"] == "Test"
def test_code_parser_parse_global_vars():
"""
Test the parse_global_vars method of the CodeParser class.
"""
parser = CodeParser("x = 1")
tree = parser._CodeParser__get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
parser.parse_global_vars(node)
assert len(parser.data["global_vars"]) == 1
assert parser.data["global_vars"][0]["targets"] == ["x"]
def test_component_get_function_valid():
"""
Test the get_function method of the Component class with valid code and function_entrypoint_name.
"""
component = Component(code="def build(): pass", function_entrypoint_name="build")
function = component.get_function()
assert callable(function)
def test_custom_component_get_function_entrypoint_args():
"""
Test the get_function_entrypoint_args property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
args = custom_component.get_function_entrypoint_args
assert len(args) == 4
assert args[0]["name"] == "self"
assert args[1]["name"] == "url"
assert args[2]["name"] == "llm"
def test_custom_component_get_function_entrypoint_return_type():
"""
Test the get_function_entrypoint_return_type property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
return_type = custom_component.get_function_entrypoint_return_type
assert return_type == "Document"
def test_custom_component_get_main_class_name():
"""
Test the get_main_class_name property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
class_name = custom_component.get_main_class_name
assert class_name == "YourComponent"
def test_custom_component_get_function_valid():
"""
Test the get_function property of the CustomComponent class with valid code and function_entrypoint_name.
"""
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
)
my_function = custom_component.get_function
assert callable(my_function)
def test_code_parser_parse_arg_no_annotation():
"""
Test the parse_arg method of the CodeParser class without an annotation.
"""
parser = CodeParser("")
arg = ast.arg(arg="x", annotation=None)
result = parser.parse_arg(arg, None)
assert result["name"] == "x"
assert "type" not in result
def test_code_parser_parse_arg_with_annotation():
"""
Test the parse_arg method of the CodeParser class with an annotation.
"""
parser = CodeParser("")
arg = ast.arg(arg="x", annotation=ast.Name(id="int", ctx=ast.Load()))
result = parser.parse_arg(arg, None)
assert result["name"] == "x"
assert result["type"] == "int"
def test_code_parser_parse_callable_details_no_args():
"""
Test the parse_callable_details method of the CodeParser class with a function with no arguments.
"""
parser = CodeParser("")
node = ast.FunctionDef(
name="test",
args=ast.arguments(
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
),
body=[],
decorator_list=[],
returns=None,
)
result = parser.parse_callable_details(node)
assert result["name"] == "test"
assert len(result["args"]) == 0
def test_code_parser_parse_assign():
"""
Test the parse_assign method of the CodeParser class.
"""
parser = CodeParser("")
stmt = ast.Assign(targets=[ast.Name(id="x", ctx=ast.Store())], value=ast.Num(n=1))
result = parser.parse_assign(stmt)
assert result["name"] == "x"
assert result["value"] == "1"
def test_code_parser_parse_ann_assign():
"""
Test the parse_ann_assign method of the CodeParser class.
"""
parser = CodeParser("")
stmt = ast.AnnAssign(
target=ast.Name(id="x", ctx=ast.Store()),
annotation=ast.Name(id="int", ctx=ast.Load()),
value=ast.Num(n=1),
simple=1,
)
result = parser.parse_ann_assign(stmt)
assert result["name"] == "x"
assert result["value"] == "1"
assert result["annotation"] == "int"
def test_code_parser_parse_function_def_not_init():
"""
Test the parse_function_def method of the CodeParser class with a function that is not __init__.
"""
parser = CodeParser("")
stmt = ast.FunctionDef(
name="test",
args=ast.arguments(
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
),
body=[],
decorator_list=[],
returns=None,
)
result, is_init = parser.parse_function_def(stmt)
assert result["name"] == "test"
assert not is_init
def test_code_parser_parse_function_def_init():
"""
Test the parse_function_def method of the CodeParser class with an __init__ function.
"""
parser = CodeParser("")
stmt = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
),
body=[],
decorator_list=[],
returns=None,
)
result, is_init = parser.parse_function_def(stmt)
assert result["name"] == "__init__"
assert is_init
def test_component_get_code_tree_syntax_error():
"""
Test the get_code_tree method of the Component class
raises the CodeSyntaxError when given incorrect syntax.
"""
component = Component(code="import os as", function_entrypoint_name="build")
with pytest.raises(CodeSyntaxError):
component.get_code_tree(component.code)
def test_custom_component_class_template_validation_no_code():
"""
Test the _class_template_validation method of the CustomComponent class
raises the HTTPException when the code is None.
"""
custom_component = CustomComponent(code=None, function_entrypoint_name="build")
with pytest.raises(HTTPException):
custom_component._class_template_validation(custom_component.code)
def test_custom_component_get_code_tree_syntax_error():
"""
Test the get_code_tree method of the CustomComponent class raises the CodeSyntaxError when given incorrect syntax.
"""
custom_component = CustomComponent(
code="import os as", function_entrypoint_name="build"
)
with pytest.raises(CodeSyntaxError):
custom_component.get_code_tree(custom_component.code)
def test_custom_component_get_function_entrypoint_args_no_args():
"""
Test the get_function_entrypoint_args property of the CustomComponent class with a build method with no arguments.
"""
my_code = """
class MyMainClass(CustomComponent):
def build():
pass"""
custom_component = CustomComponent(code=my_code, function_entrypoint_name="build")
args = custom_component.get_function_entrypoint_args
assert len(args) == 0
def test_custom_component_get_function_entrypoint_return_type_no_return_type():
"""
Test the get_function_entrypoint_return_type property of the
CustomComponent class with a build method with no return type.
"""
my_code = """
class MyClass(CustomComponent):
def build():
pass"""
custom_component = CustomComponent(code=my_code, function_entrypoint_name="build")
return_type = custom_component.get_function_entrypoint_return_type
assert return_type is None
def test_custom_component_get_main_class_name_no_main_class():
"""
Test the get_main_class_name property of the CustomComponent class when there is no main class.
"""
my_code = """
def build():
pass"""
custom_component = CustomComponent(code=my_code, function_entrypoint_name="build")
class_name = custom_component.get_main_class_name
assert class_name == ""
def test_custom_component_build_not_implemented():
"""
Test the build method of the CustomComponent class raises the NotImplementedError.
"""
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
)
with pytest.raises(NotImplementedError):
custom_component.build()
def test_list_flows_no_flows():
session_getter_module = "langflow.database.base.session_getter"
with patch(session_getter_module) as mock_session_getter:
mock_session = MagicMock()
mock_session.query.return_value.all.return_value = []
mock_session_getter.return_value.__enter__.return_value = mock_session
component = CustomComponent()
result = component.list_flows()
assert len(result) == 0
def test_build_config_no_code():
component = CustomComponent(code=None)
assert component.get_function_entrypoint_args == ""
assert component.get_function_entrypoint_return_type == ""
def test_list_flows_multiple_queries():
mock_flow_1 = MagicMock()
mock_flow_2 = MagicMock()
session_getter_module = "langflow.database.base.session_getter"
with patch(session_getter_module) as mock_session_getter:
mock_session = MagicMock()
mock_session.query.return_value.all.side_effect = [[mock_flow_1], [mock_flow_2]]
mock_session_getter.return_value.__enter__.return_value = mock_session
component = CustomComponent()
result = component.list_flows()
# Only the result of the second query is returned
assert len(result) == 1
assert result[0] == mock_flow_2
assert mock_session.query.call_count == 2