🔥 refactor(custom.py): remove unused code and class 'CustomComponent_old' to improve code cleanliness and maintainability
🔧 fix(test_custom_component.py): fix formatting issues in test_custom_component.py for better readability ✨ feat(test_custom_component.py): add import statements for 'patch' and 'MagicMock' to enable mocking in tests 🔬 test(test_custom_component.py): add test for the 'get_function' method of the Component class with valid code and function_entrypoint_name 🔬 test(test_custom_component.py): add test for the 'parse_assign' method of the CodeParser class 🔬 test(test_custom_component.py): add test for the 'get_code_tree' method of the Component class when given incorrect syntax 🔬 test(test_custom_component.py): add test for the '_class_template_validation' method of the CustomComponent class when the code is None 🔬 test(test_custom_component.py): add test for the 'get_function_entrypoint_args' method of the CustomComponent class 🔬 test(test_custom_component.py): add test for the 'get_function_entrypoint_return_type' method of the CustomComponent class 🔬 test(test_custom_component.py): add test for the 'get_main_class_name' method of the CustomComponent class when there is no main class 🔥 refactor(test_custom_component.py): remove commented out code and unused fixtures to improve code readability and maintainability 🔧 refactor(tests): remove commented out test cases and unused imports ✨ feat(tests): add new test case for list_flows method when there are no flows in the database ✨ feat(tests): add new test case for build_config method when code is not provided ✨ feat(tests): add new test case for list_flows method when there are multiple queries to the database
This commit is contained in:
parent
63ead274c4
commit
a89a9a3267
2 changed files with 28 additions and 473 deletions
|
|
@ -48,29 +48,3 @@ class PythonFunctionTool(Function, Tool):
|
|||
|
||||
class PythonFunction(Function):
|
||||
code: str
|
||||
|
||||
|
||||
class CustomComponent_old(BaseModel):
|
||||
code: str
|
||||
function: Optional[Callable] = None
|
||||
imports: Optional[str] = None
|
||||
|
||||
# Eval code and store the class
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
# Validate the Class code
|
||||
@validator("code")
|
||||
def validate_func(cls, v):
|
||||
try:
|
||||
validate.eval_function(v)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return v
|
||||
|
||||
def get_function(self):
|
||||
"""Get the function"""
|
||||
function_name = validate.extract_function_name(self.code)
|
||||
|
||||
return validate.create_function(self.code, function_name)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import ast
|
||||
import pytest
|
||||
import types
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from fastapi import HTTPException
|
||||
from langflow.interface.custom.base import CustomComponent
|
||||
|
|
@ -447,462 +448,42 @@ def test_custom_component_build_not_implemented():
|
|||
custom_component.build()
|
||||
|
||||
|
||||
# -------------------------------------------------------
|
||||
# @pytest.fixture
|
||||
# def custom_chain():
|
||||
# return '''
|
||||
# from __future__ import annotations
|
||||
# from typing import Any, Dict, List, Optional
|
||||
def test_list_flows_no_flows():
|
||||
session_getter_module = "langflow.database.base.session_getter"
|
||||
|
||||
# from pydantic import Extra
|
||||
with patch(session_getter_module) as mock_session_getter:
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.all.return_value = []
|
||||
mock_session_getter.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# from langchain.schema import BaseLanguageModel, Document
|
||||
# from langchain.callbacks.manager import (
|
||||
# AsyncCallbackManagerForChainRun,
|
||||
# CallbackManagerForChainRun,
|
||||
# )
|
||||
# from langchain.chains.base import Chain
|
||||
# from langchain.prompts import StringPromptTemplate
|
||||
# from langflow.interface.custom.base import CustomComponent
|
||||
component = CustomComponent()
|
||||
result = component.list_flows()
|
||||
|
||||
# class MyCustomChain(Chain):
|
||||
# """
|
||||
# An example of a custom chain.
|
||||
# """
|
||||
assert len(result) == 0
|
||||
|
||||
# from typing import Any, Dict, List, Optional
|
||||
|
||||
# from pydantic import Extra
|
||||
def test_build_config_no_code():
|
||||
component = CustomComponent(code=None)
|
||||
|
||||
# from langchain.schema import BaseLanguageModel, Document
|
||||
# from langchain.callbacks.manager import (
|
||||
# AsyncCallbackManagerForChainRun,
|
||||
# CallbackManagerForChainRun,
|
||||
# )
|
||||
# from langchain.chains.base import Chain
|
||||
# from langchain.prompts import StringPromptTemplate
|
||||
# from langflow.interface.custom.base import CustomComponent
|
||||
assert component.get_function_entrypoint_args == ""
|
||||
assert component.get_function_entrypoint_return_type == ""
|
||||
|
||||
# class MyCustomChain(Chain):
|
||||
# """
|
||||
# An example of a custom chain.
|
||||
# """
|
||||
|
||||
# prompt: StringPromptTemplate
|
||||
# """Prompt object to use."""
|
||||
# llm: BaseLanguageModel
|
||||
# output_key: str = "text" #: :meta private:
|
||||
def test_list_flows_multiple_queries():
|
||||
mock_flow_1 = MagicMock()
|
||||
mock_flow_2 = MagicMock()
|
||||
|
||||
# class Config:
|
||||
# """Configuration for this pydantic object."""
|
||||
session_getter_module = "langflow.database.base.session_getter"
|
||||
|
||||
# extra = Extra.forbid
|
||||
# arbitrary_types_allowed = True
|
||||
with patch(session_getter_module) as mock_session_getter:
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.all.side_effect = [[mock_flow_1], [mock_flow_2]]
|
||||
mock_session_getter.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# @property
|
||||
# def input_keys(self) -> List[str]:
|
||||
# """Will be whatever keys the prompt expects.
|
||||
component = CustomComponent()
|
||||
result = component.list_flows()
|
||||
|
||||
# :meta private:
|
||||
# """
|
||||
# return self.prompt.input_variables
|
||||
|
||||
# @property
|
||||
# def output_keys(self) -> List[str]:
|
||||
# """Will always return text key.
|
||||
|
||||
# :meta private:
|
||||
# """
|
||||
# return [self.output_key]
|
||||
|
||||
# def _call(
|
||||
# self,
|
||||
# inputs: Dict[str, Any],
|
||||
# run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
# ) -> Dict[str, str]:
|
||||
# # Your custom chain logic goes here
|
||||
# # This is just an example that mimics LLMChain
|
||||
# prompt_value = self.prompt.format_prompt(**inputs)
|
||||
|
||||
# # Whenever you call a language model, or another chain, you should pass
|
||||
# # a callback manager to it. This allows the inner run to be tracked by
|
||||
# # any callbacks that are registered on the outer run.
|
||||
# # You can always obtain a callback manager for this by calling
|
||||
# # `run_manager.get_child()` as shown below.
|
||||
# response = self.llm.generate_prompt(
|
||||
# [prompt_value],
|
||||
# callbacks=run_manager.get_child() if run_manager else None,
|
||||
# )
|
||||
|
||||
# # If you want to log something about this run, you can do so by calling
|
||||
# # methods on the `run_manager`, as shown below. This will trigger any
|
||||
# # callbacks that are registered for that event.
|
||||
# if run_manager:
|
||||
# run_manager.on_text("Log something about this run")
|
||||
|
||||
# return {self.output_key: response.generations[0][0].text}
|
||||
|
||||
# async def _acall(
|
||||
# self,
|
||||
# inputs: Dict[str, Any],
|
||||
# run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
# ) -> Dict[str, str]:
|
||||
# # Your custom chain logic goes here
|
||||
# # This is just an example that mimics LLMChain
|
||||
# prompt_value = self.prompt.format_prompt(**inputs)
|
||||
|
||||
# # Whenever you call a language model, or another chain, you should pass
|
||||
# # a callback manager to it. This allows the inner run to be tracked by
|
||||
# # any callbacks that are registered on the outer run.
|
||||
# # You can always obtain a callback manager for this by calling
|
||||
# # `run_manager.get_child()` as shown below.
|
||||
# response = await self.llm.agenerate_prompt(
|
||||
# [prompt_value],
|
||||
# callbacks=run_manager.get_child() if run_manager else None,
|
||||
# )
|
||||
|
||||
# # If you want to log something about this run, you can do so by calling
|
||||
# # methods on the `run_manager`, as shown below. This will trigger any
|
||||
# # callbacks that are registered for that event.
|
||||
# if run_manager:
|
||||
# await run_manager.on_text("Log something about this run")
|
||||
|
||||
# return {self.output_key: response.generations[0][0].text}
|
||||
|
||||
# @property
|
||||
# def _chain_type(self) -> str:
|
||||
# return "my_custom_chain"
|
||||
|
||||
# class CustomChain(CustomComponent):
|
||||
# display_name: str = "Custom Chain"
|
||||
# field_config = {
|
||||
# "prompt": {"field_type": "prompt"},
|
||||
# "llm": {"field_type": "BaseLanguageModel"},
|
||||
# }
|
||||
|
||||
# def build(self, prompt, llm, input: str) -> Document:
|
||||
# chain = MyCustomChain(prompt=prompt, llm=llm)
|
||||
# return chain(input)
|
||||
# '''
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def data_processing():
|
||||
# return """
|
||||
# import pandas as pd
|
||||
# from langchain.schema import Document
|
||||
# from langflow.interface.custom.base import CustomComponent
|
||||
|
||||
# class CSVLoaderComponent(CustomComponent):
|
||||
# display_name: str = "CSV Loader"
|
||||
# field_config = {
|
||||
# "filename": {"field_type": "str", "required": True},
|
||||
# "column_name": {"field_type": "str", "required": True},
|
||||
# }
|
||||
|
||||
# def build(self, filename: str, column_name: str) -> Document:
|
||||
# # Load the CSV file
|
||||
# df = pd.read_csv(filename)
|
||||
|
||||
# # Verify the column exists
|
||||
# if column_name not in df.columns:
|
||||
# raise ValueError(f"Column '{column_name}' not found in the CSV file")
|
||||
|
||||
# # Convert each row of the specified column to a document object
|
||||
# documents = []
|
||||
# for content in df[column_name]:
|
||||
# metadata = {"filename": filename}
|
||||
# documents.append(Document(page_content=str(content), metadata=metadata))
|
||||
|
||||
# return documents
|
||||
# """
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def filter_docs():
|
||||
# return """
|
||||
# from langchain.schema import Document
|
||||
# from langflow.interface.custom.base import CustomComponent
|
||||
# from typing import List
|
||||
|
||||
# class DocumentFilterByLengthComponent(CustomComponent):
|
||||
# display_name: str = "Document Filter By Length"
|
||||
# field_config = {
|
||||
# "documents": {"field_type": "Document", "required": True},
|
||||
# "max_length": {"field_type": "int", "required": True},
|
||||
# }
|
||||
|
||||
# def build(self, documents: List[Document], max_length: int) -> List[Document]:
|
||||
# # Filter the documents by length
|
||||
# filtered_documents = [doc for doc in documents if len(doc.page_content) <= max_length]
|
||||
|
||||
# return filtered_documents
|
||||
# """
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def get_request():
|
||||
# return """
|
||||
# import requests
|
||||
# from typing import Dict, Union
|
||||
# from langchain.schema import Document
|
||||
# from langflow.interface.custom.base import CustomComponent
|
||||
|
||||
# class GetRequestComponent(CustomComponent):
|
||||
# display_name: str = "GET Request"
|
||||
# field_config = {
|
||||
# "url": {"field_type": "str", "required": True},
|
||||
# }
|
||||
|
||||
# def build(self, url: str) -> Document:
|
||||
# # Send a GET request to the URL
|
||||
# response = requests.get(url)
|
||||
|
||||
# # Raise an exception if the request was not successful
|
||||
# if response.status_code != 200:
|
||||
# raise ValueError(f"GET request failed: {response.status_code} status code")
|
||||
|
||||
# # Create a document with the response text and the URL as metadata
|
||||
# document = Document(page_content=response.text, metadata={"url": url})
|
||||
|
||||
# return document
|
||||
# """
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def post_request():
|
||||
# return """
|
||||
# import requests
|
||||
# from typing import Dict, Union
|
||||
# from langchain.schema import Document
|
||||
# from langflow.interface.custom.base import CustomComponent
|
||||
|
||||
# class PostRequestComponent(CustomComponent):
|
||||
# display_name: str = "POST Request"
|
||||
# field_config = {
|
||||
# "url": {"field_type": "str", "required": True},
|
||||
# "data": {"field_type": "dict", "required": True},
|
||||
# }
|
||||
|
||||
# def build(self, url: str, data: Dict[str, Union[str, int]]) -> Document:
|
||||
# # Send a POST request to the URL
|
||||
# response = requests.post(url, data=data)
|
||||
|
||||
# # Raise an exception if the request was not successful
|
||||
# if response.status_code != 200:
|
||||
# raise ValueError(f"POST request failed: {response.status_code} status code")
|
||||
|
||||
# # Create a document with the response text and the URL and data as metadata
|
||||
# document = Document(page_content=response.text, metadata={"url": url, "data": data})
|
||||
|
||||
# return document
|
||||
# """
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def code_default():
|
||||
# return """
|
||||
# from langflow import Prompt
|
||||
# from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
||||
# from langchain.llms.base import BaseLLM
|
||||
# from langchain.chains import LLMChain
|
||||
# from langchain import PromptTemplate
|
||||
# from langchain.schema import Document
|
||||
|
||||
# import requests
|
||||
|
||||
# class YourComponent(CustomComponent):
|
||||
# #display_name: str = "Your Component"
|
||||
# #description: str = "Your description"
|
||||
# #field_config = { "url": { "multiline": True, "required": True } }
|
||||
|
||||
# def build(self, url: str, llm: BaseLLM, template: Prompt) -> Document:
|
||||
# response = requests.get(url)
|
||||
# prompt = PromptTemplate.from_template(template)
|
||||
# chain = LLMChain(llm=llm, prompt=prompt)
|
||||
# result = chain.run(response.text[:300])
|
||||
# return Document(page_content=str(result))
|
||||
# """
|
||||
|
||||
|
||||
# @pytest.fixture(params=[
|
||||
# 'code_default', 'custom_chain', 'data_processing',
|
||||
# 'filter_docs', 'get_request', 'post_request'])
|
||||
# def component_code(
|
||||
# request, code_default, custom_chain, data_processing,
|
||||
# filter_docs, get_request, post_request):
|
||||
# return locals()[request.param]
|
||||
|
||||
|
||||
# def test_empty_code_tree(component_code):
|
||||
# """
|
||||
# Test the situation when the code tree is empty.
|
||||
# """
|
||||
# cc = CustomComponent(code=component_code)
|
||||
# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree:
|
||||
# mocked_get_code_tree.return_value = {}
|
||||
# assert cc.get_function_entrypoint_args == ''
|
||||
# assert cc.get_function_entrypoint_return_type == ''
|
||||
# assert cc.get_main_class_name == ''
|
||||
# assert cc.build_template_config == {}
|
||||
|
||||
|
||||
# def test_class_template_validation(component_code):
|
||||
# """
|
||||
# Test the _class_template_validation method.
|
||||
# """
|
||||
# cc = CustomComponent(code=component_code)
|
||||
# assert cc._class_template_validation(component_code) == True
|
||||
# with pytest.raises(HTTPException):
|
||||
# cc._class_template_validation(None)
|
||||
|
||||
|
||||
# def test_get_code_tree(component_code):
|
||||
# """
|
||||
# Test the get_code_tree method.
|
||||
# """
|
||||
# cc = CustomComponent(code=component_code)
|
||||
# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree:
|
||||
# mocked_get_code_tree.return_value = {'classes': []}
|
||||
# assert cc.get_code_tree(component_code) == {'classes': []}
|
||||
|
||||
|
||||
# def test_get_function_entrypoint_args(component_code):
|
||||
# """
|
||||
# Test the get_function_entrypoint_args method.
|
||||
# """
|
||||
# cc = CustomComponent(code=component_code)
|
||||
# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree:
|
||||
# mocked_get_code_tree.return_value = {'classes': []}
|
||||
# assert cc.get_function_entrypoint_args == ''
|
||||
|
||||
|
||||
# def test_get_function_entrypoint_return_type(component_code):
|
||||
# """
|
||||
# Test the get_function_entrypoint_return_type method.
|
||||
# """
|
||||
# cc = CustomComponent(code=component_code)
|
||||
# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree:
|
||||
# mocked_get_code_tree.return_value = {'classes': []}
|
||||
# assert cc.get_function_entrypoint_return_type == ''
|
||||
|
||||
|
||||
# def test_get_main_class_name(component_code):
|
||||
# """
|
||||
# Test the get_main_class_name method.
|
||||
# """
|
||||
# cc = CustomComponent(code=component_code)
|
||||
# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree:
|
||||
# mocked_get_code_tree.return_value = {'classes': []}
|
||||
# assert cc.get_main_class_name == ''
|
||||
|
||||
|
||||
# def test_build_template_config(component_code):
|
||||
# """
|
||||
# Test the build_template_config method.
|
||||
# """
|
||||
# cc = CustomComponent(code=component_code)
|
||||
# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree:
|
||||
# mocked_get_code_tree.return_value = {
|
||||
# 'classes': [{'name': '', 'attributes': []}]}
|
||||
# assert cc.build_template_config == {}
|
||||
|
||||
|
||||
# def test_get_function(component_code):
|
||||
# """
|
||||
# Test the get_function method.
|
||||
# """
|
||||
# cc = CustomComponent(code=component_code, function_entrypoint_name='build')
|
||||
# assert callable(cc.get_function)
|
||||
|
||||
|
||||
# def test_build(component_code):
|
||||
# """
|
||||
# Test the build method.
|
||||
# """
|
||||
# cc = CustomComponent(code=component_code)
|
||||
# with pytest.raises(NotImplementedError):
|
||||
# cc.build()
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("entrypoint_name", ["build", "non_exist_method"])
|
||||
# def test_set_non_existing_function_entrypoint_name(component_code, entrypoint_name):
|
||||
# """
|
||||
# Test setting a non-existing function entrypoint name.
|
||||
# """
|
||||
# cc = CustomComponent(
|
||||
# code=component_code,
|
||||
# function_entrypoint_name=entrypoint_name
|
||||
# )
|
||||
# with pytest.raises(AttributeError):
|
||||
# cc.get_function
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("base_class", ["CustomComponent", "NonExistingClass"])
|
||||
# def test_set_non_existing_base_class(component_code, base_class):
|
||||
# """
|
||||
# Test setting a non-existing base class.
|
||||
# """
|
||||
# cc = CustomComponent(code=component_code)
|
||||
# cc.code_class_base_inheritance = base_class
|
||||
# with pytest.raises(AttributeError):
|
||||
# cc.get_main_class_name
|
||||
|
||||
|
||||
# def test_class_with_no_methods(component_code):
|
||||
# """
|
||||
# Test a component class with no methods.
|
||||
# """
|
||||
# cc = CustomComponent(code=component_code)
|
||||
# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree:
|
||||
# mocked_get_code_tree.return_value = {
|
||||
# 'classes': [
|
||||
# {
|
||||
# 'name': 'CustomComponent',
|
||||
# 'methods': [],
|
||||
# 'bases': ['CustomComponent']
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# assert cc.get_function_entrypoint_args == ''
|
||||
# assert cc.get_function_entrypoint_return_type == ''
|
||||
|
||||
|
||||
# def test_class_with_no_bases(component_code):
|
||||
# """
|
||||
# Test a component class with no bases.
|
||||
# """
|
||||
# cc = CustomComponent(code=component_code)
|
||||
# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree:
|
||||
# mocked_get_code_tree.return_value = {
|
||||
# 'classes': [
|
||||
# {
|
||||
# 'name': 'CustomComponent',
|
||||
# 'methods': [],
|
||||
# 'bases': []
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# assert cc.get_function_entrypoint_args == ''
|
||||
# assert cc.get_function_entrypoint_return_type == ''
|
||||
|
||||
|
||||
# def test_class_with_no_name(component_code):
|
||||
# """
|
||||
# Test a component class with no name.
|
||||
# """
|
||||
# cc = CustomComponent(code=component_code)
|
||||
# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree:
|
||||
# mocked_get_code_tree.return_value = {'classes': [
|
||||
# {'name': '', 'methods': [], 'bases': ['CustomComponent']}]}
|
||||
# assert cc.get_main_class_name == ''
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("input_code", ["", "not a valid python code"])
|
||||
# def test_invalid_input_code(input_code):
|
||||
# """
|
||||
# Test inputting an invalid Python code.
|
||||
# """
|
||||
# with pytest.raises(SyntaxError):
|
||||
# cc = CustomComponent(code=input_code)
|
||||
# Only the result of the second query is returned
|
||||
assert len(result) == 1
|
||||
assert result[0] == mock_flow_2
|
||||
assert mock_session.query.call_count == 2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue