langflow/tests/test_custom_component.py
ogabrielluiz 7fb5644a87 refactor: Update langflow custom imports and base classes
This commit updates the langflow custom imports and base classes in the code. It adds the "Component" import and base class to the langflow custom __init__.py file. It also updates the langflow template __init__.py file to include the "Input", "Output", "FrontendNode", and "Template" imports and base classes. Additionally, it modifies the langflow base ChatComponent class to inherit from the Component class. These changes improve the organization and functionality of the langflow custom and template modules.

Note: The commit message has been generated based on the provided code changes and recent commits.
2024-06-02 20:35:59 -03:00

532 lines
16 KiB
Python

import ast
import types
from uuid import uuid4
import pytest
from langchain_core.documents import Document
from langflow.custom import CustomComponent
from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
from langflow.custom.custom_component.base_component import BaseComponent, ComponentCodeNullError
from langflow.custom.utils import build_custom_component_template
from langflow.services.database.models.flow import Flow, FlowCreate
@pytest.fixture
def code_component_with_multiple_outputs():
with open("tests/data/component_multiple_outputs.py", "r") as f:
code = f.read()
return CustomComponent(code=code)
code_default = """
from langflow.field_typing import Prompt
from langflow.custom import CustomComponent
from langflow.field_typing import BaseLanguageModel
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_core.documents import Document
import requests
class YourComponent(CustomComponent):
display_name: str = "Your Component"
description: str = "Your description"
field_config = { "url": { "multiline": True, "required": True } }
def build(self, url: str, llm: BaseLanguageModel, template: Prompt) -> Document:
response = requests.get(url)
prompt = PromptTemplate.from_template(template)
chain = LLMChain(llm=llm, prompt=prompt)
result = chain.run(response.text[:300])
return Document(page_content=str(result))
"""
def test_code_parser_init():
"""
Test the initialization of the CodeParser class.
"""
parser = CodeParser(code_default)
assert parser.code == code_default
def test_code_parser_get_tree():
"""
Test the __get_tree method of the CodeParser class.
"""
parser = CodeParser(code_default)
tree = parser.get_tree()
assert isinstance(tree, ast.AST)
def test_code_parser_syntax_error():
"""
Test the __get_tree method raises the
CodeSyntaxError when given incorrect syntax.
"""
code_syntax_error = "zzz import os"
parser = CodeParser(code_syntax_error)
with pytest.raises(CodeSyntaxError):
parser.get_tree()
def test_component_init():
"""
Test the initialization of the Component class.
"""
component = BaseComponent(code=code_default, function_entrypoint_name="build")
assert component.code == code_default
assert component.function_entrypoint_name == "build"
def test_component_get_code_tree():
"""
Test the get_code_tree method of the Component class.
"""
component = BaseComponent(code=code_default, function_entrypoint_name="build")
tree = component.get_code_tree(component.code)
assert "imports" in tree
def test_component_code_null_error():
"""
Test the get_function method raises the
ComponentCodeNullError when the code is empty.
"""
component = BaseComponent(code="", function_entrypoint_name="")
with pytest.raises(ComponentCodeNullError):
component.get_function()
def test_custom_component_init():
"""
Test the initialization of the CustomComponent class.
"""
function_entrypoint_name = "build"
custom_component = CustomComponent(code=code_default, function_entrypoint_name=function_entrypoint_name)
assert custom_component.code == code_default
assert custom_component.function_entrypoint_name == function_entrypoint_name
def test_custom_component_build_template_config():
"""
Test the build_template_config property of the CustomComponent class.
"""
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
config = custom_component.build_template_config()
assert isinstance(config, dict)
def test_custom_component_get_function():
"""
Test the get_function property of the CustomComponent class.
"""
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
my_function = custom_component.get_function()
assert isinstance(my_function, types.FunctionType)
def test_code_parser_parse_imports_import():
"""
Test the parse_imports method of the CodeParser
class with an import statement.
"""
parser = CodeParser(code_default)
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
parser.parse_imports(node)
assert "requests" in parser.data["imports"]
def test_code_parser_parse_imports_importfrom():
"""
Test the parse_imports method of the CodeParser
class with an import from statement.
"""
parser = CodeParser("from os import path")
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom):
parser.parse_imports(node)
assert ("os", "path") in parser.data["imports"]
def test_code_parser_parse_functions():
"""
Test the parse_functions method of the CodeParser class.
"""
parser = CodeParser("def test(): pass")
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
parser.parse_functions(node)
assert len(parser.data["functions"]) == 1
assert parser.data["functions"][0]["name"] == "test"
def test_code_parser_parse_classes():
"""
Test the parse_classes method of the CodeParser class.
"""
parser = CodeParser("class Test: pass")
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
parser.parse_classes(node)
assert len(parser.data["classes"]) == 1
assert parser.data["classes"][0]["name"] == "Test"
def test_code_parser_parse_global_vars():
"""
Test the parse_global_vars method of the CodeParser class.
"""
parser = CodeParser("x = 1")
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
parser.parse_global_vars(node)
assert len(parser.data["global_vars"]) == 1
assert parser.data["global_vars"][0]["targets"] == ["x"]
def test_component_get_function_valid():
"""
Test the get_function method of the Component
class with valid code and function_entrypoint_name.
"""
component = BaseComponent(code="def build(): pass", function_entrypoint_name="build")
my_function = component.get_function()
assert callable(my_function)
def test_custom_component_get_function_entrypoint_args():
"""
Test the get_function_entrypoint_args
property of the CustomComponent class.
"""
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
args = custom_component.get_function_entrypoint_args
assert len(args) == 4
assert args[0]["name"] == "self"
assert args[1]["name"] == "url"
assert args[2]["name"] == "llm"
def test_custom_component_get_function_entrypoint_return_type():
"""
Test the get_function_entrypoint_return_type
property of the CustomComponent class.
"""
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
return_type = custom_component.get_function_entrypoint_return_type
assert return_type == [Document]
def test_custom_component_get_main_class_name():
"""
Test the get_main_class_name property of the CustomComponent class.
"""
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
class_name = custom_component.get_main_class_name
assert class_name == "YourComponent"
def test_custom_component_get_function_valid():
"""
Test the get_function property of the CustomComponent
class with valid code and function_entrypoint_name.
"""
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
my_function = custom_component.get_function
assert callable(my_function)
def test_code_parser_parse_arg_no_annotation():
"""
Test the parse_arg method of the CodeParser class without an annotation.
"""
parser = CodeParser("")
arg = ast.arg(arg="x", annotation=None)
result = parser.parse_arg(arg, None)
assert result["name"] == "x"
assert "type" not in result
def test_code_parser_parse_arg_with_annotation():
"""
Test the parse_arg method of the CodeParser class with an annotation.
"""
parser = CodeParser("")
arg = ast.arg(arg="x", annotation=ast.Name(id="int", ctx=ast.Load()))
result = parser.parse_arg(arg, None)
assert result["name"] == "x"
assert result["type"] == "int"
def test_code_parser_parse_callable_details_no_args():
"""
Test the parse_callable_details method of the
CodeParser class with a function with no arguments.
"""
parser = CodeParser("")
node = ast.FunctionDef(
name="test",
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[],
decorator_list=[],
returns=None,
)
result = parser.parse_callable_details(node)
assert result["name"] == "test"
assert len(result["args"]) == 0
def test_code_parser_parse_assign():
"""
Test the parse_assign method of the CodeParser class.
"""
parser = CodeParser("")
stmt = ast.Assign(targets=[ast.Name(id="x", ctx=ast.Store())], value=ast.Num(n=1))
result = parser.parse_assign(stmt)
assert result["name"] == "x"
assert result["value"] == "1"
def test_code_parser_parse_ann_assign():
"""
Test the parse_ann_assign method of the CodeParser class.
"""
parser = CodeParser("")
stmt = ast.AnnAssign(
target=ast.Name(id="x", ctx=ast.Store()),
annotation=ast.Name(id="int", ctx=ast.Load()),
value=ast.Num(n=1),
simple=1,
)
result = parser.parse_ann_assign(stmt)
assert result["name"] == "x"
assert result["value"] == "1"
assert result["annotation"] == "int"
def test_code_parser_parse_function_def_not_init():
"""
Test the parse_function_def method of the
CodeParser class with a function that is not __init__.
"""
parser = CodeParser("")
stmt = ast.FunctionDef(
name="test",
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[],
decorator_list=[],
returns=None,
)
result, is_init = parser.parse_function_def(stmt)
assert result["name"] == "test"
assert not is_init
def test_code_parser_parse_function_def_init():
"""
Test the parse_function_def method of the
CodeParser class with an __init__ function.
"""
parser = CodeParser("")
stmt = ast.FunctionDef(
name="__init__",
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[],
decorator_list=[],
returns=None,
)
result, is_init = parser.parse_function_def(stmt)
assert result["name"] == "__init__"
assert is_init
def test_component_get_code_tree_syntax_error():
"""
Test the get_code_tree method of the Component class
raises the CodeSyntaxError when given incorrect syntax.
"""
component = BaseComponent(code="import os as", function_entrypoint_name="build")
with pytest.raises(CodeSyntaxError):
component.get_code_tree(component.code)
def test_custom_component_class_template_validation_no_code():
"""
Test the _class_template_validation method of the CustomComponent class
raises the HTTPException when the code is None.
"""
custom_component = CustomComponent(code=None, function_entrypoint_name="build")
with pytest.raises(TypeError):
custom_component.get_function()
def test_custom_component_get_code_tree_syntax_error():
"""
Test the get_code_tree method of the CustomComponent class
raises the CodeSyntaxError when given incorrect syntax.
"""
custom_component = CustomComponent(code="import os as", function_entrypoint_name="build")
with pytest.raises(CodeSyntaxError):
custom_component.get_code_tree(custom_component.code)
def test_custom_component_get_function_entrypoint_args_no_args():
"""
Test the get_function_entrypoint_args property of
the CustomComponent class with a build method with no arguments.
"""
my_code = """
class MyMainClass(CustomComponent):
def build():
pass"""
custom_component = CustomComponent(code=my_code, function_entrypoint_name="build")
args = custom_component.get_function_entrypoint_args
assert len(args) == 0
def test_custom_component_get_function_entrypoint_return_type_no_return_type():
"""
Test the get_function_entrypoint_return_type property of the
CustomComponent class with a build method with no return type.
"""
my_code = """
class MyClass(CustomComponent):
def build():
pass"""
custom_component = CustomComponent(code=my_code, function_entrypoint_name="build")
return_type = custom_component.get_function_entrypoint_return_type
assert return_type == []
def test_custom_component_get_main_class_name_no_main_class():
"""
Test the get_main_class_name property of the
CustomComponent class when there is no main class.
"""
my_code = """
def build():
pass"""
custom_component = CustomComponent(code=my_code, function_entrypoint_name="build")
class_name = custom_component.get_main_class_name
assert class_name == ""
def test_custom_component_build_not_implemented():
"""
Test the build method of the CustomComponent
class raises the NotImplementedError.
"""
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
with pytest.raises(NotImplementedError):
custom_component.build()
def test_build_config_no_code():
component = CustomComponent(code=None)
assert component.get_function_entrypoint_args == []
assert component.get_function_entrypoint_return_type == []
@pytest.fixture
def component(client, active_user):
return CustomComponent(
user_id=active_user.id,
field_config={
"fields": {
"llm": {"type": "str"},
"url": {"type": "str"},
"year": {"type": "int"},
}
},
)
@pytest.fixture(scope="session")
def test_flow(db):
flow_data = {
"nodes": [{"id": "1"}, {"id": "2"}],
"edges": [{"source": "1", "target": "2"}],
}
# Create flow
flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data)
# Add to database
db.add(flow)
db.commit()
yield flow
# Clean up
db.delete(flow)
db.commit()
@pytest.fixture(scope="session")
def db(app):
# Setup database for tests
yield app.db
# Teardown
app.db.drop_all()
def test_list_flows_return_type(component):
flows = component.list_flows()
assert isinstance(flows, list)
def test_list_flows_flow_objects(component):
flows = component.list_flows()
assert all(isinstance(flow, Flow) for flow in flows)
def test_build_config_return_type(component):
config = component.build_config()
assert isinstance(config, dict)
def test_build_config_has_fields(component):
config = component.build_config()
assert "fields" in config
def test_build_config_fields_dict(component):
config = component.build_config()
assert isinstance(config["fields"], dict)
def test_build_config_field_keys(component):
config = component.build_config()
assert all(isinstance(key, str) for key in config["fields"])
def test_build_config_field_values_dict(component):
config = component.build_config()
assert all(isinstance(value, dict) for value in config["fields"].values())
def test_build_config_field_value_keys(component):
config = component.build_config()
field_values = config["fields"].values()
assert all("type" in value for value in field_values)
def test_custom_component_multiple_outputs(code_component_with_multiple_outputs, active_user):
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
assert frontnd_node_dict["outputs"][0]["types"] == ["Text"]