diff --git a/tests/test_custom_component.py b/tests/test_custom_component.py index c60247668..b07753b8d 100644 --- a/tests/test_custom_component.py +++ b/tests/test_custom_component.py @@ -1,18 +1,17 @@ import ast -import pytest import types from uuid import uuid4 - +import pytest from fastapi import HTTPException -from langflow.services.database.models.flow import Flow, FlowCreate +from langflow.field_typing.constants import Data from langflow.interface.custom.base import CustomComponent +from langflow.interface.custom.code_parser import CodeParser, CodeSyntaxError from langflow.interface.custom.component import ( Component, ComponentCodeNullError, ) -from langflow.interface.custom.code_parser import CodeParser, CodeSyntaxError - +from langflow.services.database.models.flow import Flow, FlowCreate code_default = """ from langflow import Prompt @@ -229,9 +228,11 @@ def test_custom_component_get_function_entrypoint_return_type(): Test the get_function_entrypoint_return_type property of the CustomComponent class. """ + from langchain.schema import Document + custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") return_type = custom_component.get_function_entrypoint_return_type - assert return_type == ["Document"] + assert return_type == [Document] def test_custom_component_get_main_class_name(): @@ -414,7 +415,7 @@ class MyClass(CustomComponent): custom_component = CustomComponent(code=my_code, function_entrypoint_name="build") return_type = custom_component.get_function_entrypoint_return_type - assert return_type == [] + assert return_type == [Data] def test_custom_component_get_main_class_name_no_main_class():