🐛 fix(custom.py): remove unused imports and unused code block to improve code cleanliness and performance

 feat(custom.py): add NotImplementedError to the build method to indicate that it needs to be implemented in subclasses
🚧 chore(test_custom_component.py): add test cases for various methods in the CustomComponent class to improve test coverage and ensure code correctness
This commit is contained in:
gustavoschaedler 2023-07-10 19:34:36 +01:00
commit 6122521783
2 changed files with 183 additions and 5 deletions

View file

@ -1,4 +1,5 @@
import ast
import contextlib
import traceback
from typing import Callable, Optional
from fastapi import HTTPException
@ -89,12 +90,9 @@ class CustomComponent(BaseModel):
else:
split_item.append(None)
for i in range(len(split_item)):
try:
with contextlib.suppress(ValueError):
# Try to evaluate the item
split_item[i] = ast.literal_eval(split_item[i])
except ValueError:
# If it fails, just pass
pass
output_list.append(split_item)
@ -198,7 +196,7 @@ class CustomComponent(BaseModel):
return validate.create_function(self.code, self.function_entrypoint_name)
def build(self):
pass
raise NotImplementedError
@property
def data(self):

View file

@ -0,0 +1,180 @@
import ast
import pytest
from fastapi import HTTPException
from langflow.interface.custom.custom import CustomComponent
from langflow.interface.custom.constants import DEFAULT_CUSTOM_COMPONENT_CODE
# Test the __init__ method
def test_init():
component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE)
assert isinstance(component, CustomComponent)
assert component.code == DEFAULT_CUSTOM_COMPONENT_CODE
# Test the _handle_import method
def test_handle_import():
component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE)
node = ast.parse("import math").body[0]
component._handle_import(node)
assert "math" in component.class_template["imports"]
# Test the _handle_class method
def test_handle_class():
component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE)
node = ast.parse("class Test: pass").body[0]
component._handle_class(node)
assert component.class_template["class"]["name"] == "Test"
# Test the _handle_function method
def test_handle_function():
component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE)
node = ast.parse("def func(): pass").body[0]
component._handle_function(node)
function_data = {"name": "func", "arguments": [], "return_type": "None"}
assert function_data in component.class_template["functions"]
# Test the transform_list method
def test_transform_list():
component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE)
input_list = ["var1: int", "var2: str", "var3"]
output_list = [["var1", "int"], ["var2", "str"], ["var3", None]]
assert component.transform_list(input_list) == output_list
# Test the extract_class_info method with valid code
def test_extract_class_info():
component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE)
class_info = component.extract_class_info()
assert "requests" in class_info["imports"]
assert class_info["class"]["name"] == "YourComponent"
function_data = {
"name": "build",
"arguments": ["self", "url: str", "llm: BaseLLM", "template: Prompt"],
"return_type": "Document",
}
assert function_data in class_info["functions"]
# Test the extract_class_info method with invalid code
def test_extract_class_info_invalid_code():
component = CustomComponent(field_config={}, code="invalid code")
with pytest.raises(HTTPException) as e:
component.extract_class_info()
exception = e.value
assert exception.status_code == 400
assert exception.detail["error"] == "invalid syntax"
# Test the get_entrypoint_function_args_and_return_type method
def test_get_entrypoint_function_args_and_return_type():
component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE)
(
function_args,
return_type,
template_config,
) = component.get_entrypoint_function_args_and_return_type()
assert function_args == [
["self", None],
["url", "str"],
["llm", "BaseLLM"],
["template", "Prompt"],
]
assert return_type == "Document"
assert template_config == {
"description": "Your description",
"display_name": "Your Component",
"field_config": {"url": {"multiline": True, "required": True}},
}
# Test the _build_template_config method
def test__build_template_config():
attributes = {
"field_config": "'field_config_value'",
"display_name": "'display_name_value'",
"description": "'description_value'",
}
component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE)
template_config = component._build_template_config(attributes)
assert template_config == {
"field_config": "field_config_value",
"display_name": "display_name_value",
"description": "description_value",
}
# Test the _class_template_validation method with a valid class template
def test__class_template_validation_valid():
component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE)
assert component._class_template_validation(code=component.data) is True
# Test the _class_template_validation method with an invalid class template
def test__class_template_validation_invalid():
component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE)
class_template = {}
with pytest.raises(Exception) as e:
component._class_template_validation(class_template)
exception = e.value
assert exception.status_code == 400
assert exception.detail["error"] == "The main class must have a valid name."
# Test the build method
def test_build():
component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE)
with pytest.raises(Exception) as e:
component.build()
assert e.type == NotImplementedError
# Test the data property
def test_data():
code = DEFAULT_CUSTOM_COMPONENT_CODE
component = CustomComponent(field_config={}, code=code)
class_info = component.data
assert "requests" in class_info["imports"]
assert class_info["class"]["name"] == "YourComponent"
function_data = {
"name": "build",
"arguments": ["self", "url: str", "llm: BaseLLM", "template: Prompt"],
"return_type": "Document",
}
assert function_data in class_info["functions"]
# Test the is_check_valid method
def test_is_check_valid():
component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE)
assert component.is_check_valid() is True
# Test the args_and_return_type property
def test_args_and_return_type():
component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE)
function_args, return_type, template_config = component.args_and_return_type
assert function_args == [
["self", None],
["url", "str"],
["llm", "BaseLLM"],
["template", "Prompt"],
]
assert return_type == "Document"
assert template_config == {
"description": "Your description",
"display_name": "Your Component",
"field_config": {"url": {"multiline": True, "required": True}},
}