From a89a9a3267be0274800ed42a07bd537dadba8f55 Mon Sep 17 00:00:00 2001 From: gustavoschaedler Date: Wed, 26 Jul 2023 16:56:21 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A5=20refactor(custom.py):=20remove=20?= =?UTF-8?q?unused=20code=20and=20class=20'CustomComponent=5Fold'=20to=20im?= =?UTF-8?q?prove=20code=20cleanliness=20and=20maintainability?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🔧 fix(test_custom_component.py): fix formatting issues in test_custom_component.py for better readability ✨ feat(test_custom_component.py): add import statements for 'patch' and 'MagicMock' to enable mocking in tests 🔬 test(test_custom_component.py): add test for the 'get_function' method of the Component class with valid code and function_entrypoint_name 🔬 test(test_custom_component.py): add test for the 'parse_assign' method of the CodeParser class 🔬 test(test_custom_component.py): add test for the 'get_code_tree' method of the Component class when given incorrect syntax 🔬 test(test_custom_component.py): add test for the '_class_template_validation' method of the CustomComponent class when the code is None 🔬 test(test_custom_component.py): add test for the 'get_function_entrypoint_args' method of the CustomComponent class 🔬 test(test_custom_component.py): add test for the 'get_function_entrypoint_return_type' method of the CustomComponent class 🔬 test(test_custom_component.py): add test for the 'get_main_class_name' method of the CustomComponent class when there is no main class 🔥 refactor(test_custom_component.py): remove commented out code and unused fixtures to improve code readability and maintainability 🔧 refactor(tests): remove commented out test cases and unused imports ✨ feat(tests): add new test case for list_flows method when there are no flows in the database ✨ feat(tests): add new test case for build_config method when code is not provided ✨ feat(tests): add new test case for list_flows method when there are multiple queries to the database --- .../langflow/interface/tools/custom.py | 26 - tests/test_custom_component.py | 475 ++---------------- 2 files changed, 28 insertions(+), 473 deletions(-) diff --git a/src/backend/langflow/interface/tools/custom.py b/src/backend/langflow/interface/tools/custom.py index a0ed5d378..321298e34 100644 --- a/src/backend/langflow/interface/tools/custom.py +++ b/src/backend/langflow/interface/tools/custom.py @@ -48,29 +48,3 @@ class PythonFunctionTool(Function, Tool): class PythonFunction(Function): code: str - - -class CustomComponent_old(BaseModel): - code: str - function: Optional[Callable] = None - imports: Optional[str] = None - - # Eval code and store the class - def __init__(self, **data): - super().__init__(**data) - - # Validate the Class code - @validator("code") - def validate_func(cls, v): - try: - validate.eval_function(v) - except Exception as e: - raise e - - return v - - def get_function(self): - """Get the function""" - function_name = validate.extract_function_name(self.code) - - return validate.create_function(self.code, function_name) diff --git a/tests/test_custom_component.py b/tests/test_custom_component.py index f4e57d10d..8963a7ece 100644 --- a/tests/test_custom_component.py +++ b/tests/test_custom_component.py @@ -1,6 +1,7 @@ import ast import pytest import types +from unittest.mock import patch, MagicMock from fastapi import HTTPException from langflow.interface.custom.base import CustomComponent @@ -447,462 +448,42 @@ def test_custom_component_build_not_implemented(): custom_component.build() -# ------------------------------------------------------- -# @pytest.fixture -# def custom_chain(): -# return ''' -# from __future__ import annotations -# from typing import Any, Dict, List, Optional +def test_list_flows_no_flows(): + session_getter_module = "langflow.database.base.session_getter" -# from pydantic import Extra + with patch(session_getter_module) as mock_session_getter: + mock_session = MagicMock() + mock_session.query.return_value.all.return_value = [] + mock_session_getter.return_value.__enter__.return_value = mock_session -# from langchain.schema import BaseLanguageModel, Document -# from langchain.callbacks.manager import ( -# AsyncCallbackManagerForChainRun, -# CallbackManagerForChainRun, -# ) -# from langchain.chains.base import Chain -# from langchain.prompts import StringPromptTemplate -# from langflow.interface.custom.base import CustomComponent + component = CustomComponent() + result = component.list_flows() -# class MyCustomChain(Chain): -# """ -# An example of a custom chain. -# """ + assert len(result) == 0 -# from typing import Any, Dict, List, Optional -# from pydantic import Extra +def test_build_config_no_code(): + component = CustomComponent(code=None) -# from langchain.schema import BaseLanguageModel, Document -# from langchain.callbacks.manager import ( -# AsyncCallbackManagerForChainRun, -# CallbackManagerForChainRun, -# ) -# from langchain.chains.base import Chain -# from langchain.prompts import StringPromptTemplate -# from langflow.interface.custom.base import CustomComponent + assert component.get_function_entrypoint_args == "" + assert component.get_function_entrypoint_return_type == "" -# class MyCustomChain(Chain): -# """ -# An example of a custom chain. -# """ -# prompt: StringPromptTemplate -# """Prompt object to use.""" -# llm: BaseLanguageModel -# output_key: str = "text" #: :meta private: +def test_list_flows_multiple_queries(): + mock_flow_1 = MagicMock() + mock_flow_2 = MagicMock() -# class Config: -# """Configuration for this pydantic object.""" + session_getter_module = "langflow.database.base.session_getter" -# extra = Extra.forbid -# arbitrary_types_allowed = True + with patch(session_getter_module) as mock_session_getter: + mock_session = MagicMock() + mock_session.query.return_value.all.side_effect = [[mock_flow_1], [mock_flow_2]] + mock_session_getter.return_value.__enter__.return_value = mock_session -# @property -# def input_keys(self) -> List[str]: -# """Will be whatever keys the prompt expects. + component = CustomComponent() + result = component.list_flows() -# :meta private: -# """ -# return self.prompt.input_variables - -# @property -# def output_keys(self) -> List[str]: -# """Will always return text key. - -# :meta private: -# """ -# return [self.output_key] - -# def _call( -# self, -# inputs: Dict[str, Any], -# run_manager: Optional[CallbackManagerForChainRun] = None, -# ) -> Dict[str, str]: -# # Your custom chain logic goes here -# # This is just an example that mimics LLMChain -# prompt_value = self.prompt.format_prompt(**inputs) - -# # Whenever you call a language model, or another chain, you should pass -# # a callback manager to it. This allows the inner run to be tracked by -# # any callbacks that are registered on the outer run. -# # You can always obtain a callback manager for this by calling -# # `run_manager.get_child()` as shown below. -# response = self.llm.generate_prompt( -# [prompt_value], -# callbacks=run_manager.get_child() if run_manager else None, -# ) - -# # If you want to log something about this run, you can do so by calling -# # methods on the `run_manager`, as shown below. This will trigger any -# # callbacks that are registered for that event. -# if run_manager: -# run_manager.on_text("Log something about this run") - -# return {self.output_key: response.generations[0][0].text} - -# async def _acall( -# self, -# inputs: Dict[str, Any], -# run_manager: Optional[AsyncCallbackManagerForChainRun] = None, -# ) -> Dict[str, str]: -# # Your custom chain logic goes here -# # This is just an example that mimics LLMChain -# prompt_value = self.prompt.format_prompt(**inputs) - -# # Whenever you call a language model, or another chain, you should pass -# # a callback manager to it. This allows the inner run to be tracked by -# # any callbacks that are registered on the outer run. -# # You can always obtain a callback manager for this by calling -# # `run_manager.get_child()` as shown below. -# response = await self.llm.agenerate_prompt( -# [prompt_value], -# callbacks=run_manager.get_child() if run_manager else None, -# ) - -# # If you want to log something about this run, you can do so by calling -# # methods on the `run_manager`, as shown below. This will trigger any -# # callbacks that are registered for that event. -# if run_manager: -# await run_manager.on_text("Log something about this run") - -# return {self.output_key: response.generations[0][0].text} - -# @property -# def _chain_type(self) -> str: -# return "my_custom_chain" - -# class CustomChain(CustomComponent): -# display_name: str = "Custom Chain" -# field_config = { -# "prompt": {"field_type": "prompt"}, -# "llm": {"field_type": "BaseLanguageModel"}, -# } - -# def build(self, prompt, llm, input: str) -> Document: -# chain = MyCustomChain(prompt=prompt, llm=llm) -# return chain(input) -# ''' - - -# @pytest.fixture -# def data_processing(): -# return """ -# import pandas as pd -# from langchain.schema import Document -# from langflow.interface.custom.base import CustomComponent - -# class CSVLoaderComponent(CustomComponent): -# display_name: str = "CSV Loader" -# field_config = { -# "filename": {"field_type": "str", "required": True}, -# "column_name": {"field_type": "str", "required": True}, -# } - -# def build(self, filename: str, column_name: str) -> Document: -# # Load the CSV file -# df = pd.read_csv(filename) - -# # Verify the column exists -# if column_name not in df.columns: -# raise ValueError(f"Column '{column_name}' not found in the CSV file") - -# # Convert each row of the specified column to a document object -# documents = [] -# for content in df[column_name]: -# metadata = {"filename": filename} -# documents.append(Document(page_content=str(content), metadata=metadata)) - -# return documents -# """ - - -# @pytest.fixture -# def filter_docs(): -# return """ -# from langchain.schema import Document -# from langflow.interface.custom.base import CustomComponent -# from typing import List - -# class DocumentFilterByLengthComponent(CustomComponent): -# display_name: str = "Document Filter By Length" -# field_config = { -# "documents": {"field_type": "Document", "required": True}, -# "max_length": {"field_type": "int", "required": True}, -# } - -# def build(self, documents: List[Document], max_length: int) -> List[Document]: -# # Filter the documents by length -# filtered_documents = [doc for doc in documents if len(doc.page_content) <= max_length] - -# return filtered_documents -# """ - - -# @pytest.fixture -# def get_request(): -# return """ -# import requests -# from typing import Dict, Union -# from langchain.schema import Document -# from langflow.interface.custom.base import CustomComponent - -# class GetRequestComponent(CustomComponent): -# display_name: str = "GET Request" -# field_config = { -# "url": {"field_type": "str", "required": True}, -# } - -# def build(self, url: str) -> Document: -# # Send a GET request to the URL -# response = requests.get(url) - -# # Raise an exception if the request was not successful -# if response.status_code != 200: -# raise ValueError(f"GET request failed: {response.status_code} status code") - -# # Create a document with the response text and the URL as metadata -# document = Document(page_content=response.text, metadata={"url": url}) - -# return document -# """ - - -# @pytest.fixture -# def post_request(): -# return """ -# import requests -# from typing import Dict, Union -# from langchain.schema import Document -# from langflow.interface.custom.base import CustomComponent - -# class PostRequestComponent(CustomComponent): -# display_name: str = "POST Request" -# field_config = { -# "url": {"field_type": "str", "required": True}, -# "data": {"field_type": "dict", "required": True}, -# } - -# def build(self, url: str, data: Dict[str, Union[str, int]]) -> Document: -# # Send a POST request to the URL -# response = requests.post(url, data=data) - -# # Raise an exception if the request was not successful -# if response.status_code != 200: -# raise ValueError(f"POST request failed: {response.status_code} status code") - -# # Create a document with the response text and the URL and data as metadata -# document = Document(page_content=response.text, metadata={"url": url, "data": data}) - -# return document -# """ - - -# @pytest.fixture -# def code_default(): -# return """ -# from langflow import Prompt -# from langflow.interface.custom.custom_component import CustomComponent - -# from langchain.llms.base import BaseLLM -# from langchain.chains import LLMChain -# from langchain import PromptTemplate -# from langchain.schema import Document - -# import requests - -# class YourComponent(CustomComponent): -# #display_name: str = "Your Component" -# #description: str = "Your description" -# #field_config = { "url": { "multiline": True, "required": True } } - -# def build(self, url: str, llm: BaseLLM, template: Prompt) -> Document: -# response = requests.get(url) -# prompt = PromptTemplate.from_template(template) -# chain = LLMChain(llm=llm, prompt=prompt) -# result = chain.run(response.text[:300]) -# return Document(page_content=str(result)) -# """ - - -# @pytest.fixture(params=[ -# 'code_default', 'custom_chain', 'data_processing', -# 'filter_docs', 'get_request', 'post_request']) -# def component_code( -# request, code_default, custom_chain, data_processing, -# filter_docs, get_request, post_request): -# return locals()[request.param] - - -# def test_empty_code_tree(component_code): -# """ -# Test the situation when the code tree is empty. -# """ -# cc = CustomComponent(code=component_code) -# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: -# mocked_get_code_tree.return_value = {} -# assert cc.get_function_entrypoint_args == '' -# assert cc.get_function_entrypoint_return_type == '' -# assert cc.get_main_class_name == '' -# assert cc.build_template_config == {} - - -# def test_class_template_validation(component_code): -# """ -# Test the _class_template_validation method. -# """ -# cc = CustomComponent(code=component_code) -# assert cc._class_template_validation(component_code) == True -# with pytest.raises(HTTPException): -# cc._class_template_validation(None) - - -# def test_get_code_tree(component_code): -# """ -# Test the get_code_tree method. -# """ -# cc = CustomComponent(code=component_code) -# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: -# mocked_get_code_tree.return_value = {'classes': []} -# assert cc.get_code_tree(component_code) == {'classes': []} - - -# def test_get_function_entrypoint_args(component_code): -# """ -# Test the get_function_entrypoint_args method. -# """ -# cc = CustomComponent(code=component_code) -# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: -# mocked_get_code_tree.return_value = {'classes': []} -# assert cc.get_function_entrypoint_args == '' - - -# def test_get_function_entrypoint_return_type(component_code): -# """ -# Test the get_function_entrypoint_return_type method. -# """ -# cc = CustomComponent(code=component_code) -# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: -# mocked_get_code_tree.return_value = {'classes': []} -# assert cc.get_function_entrypoint_return_type == '' - - -# def test_get_main_class_name(component_code): -# """ -# Test the get_main_class_name method. -# """ -# cc = CustomComponent(code=component_code) -# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: -# mocked_get_code_tree.return_value = {'classes': []} -# assert cc.get_main_class_name == '' - - -# def test_build_template_config(component_code): -# """ -# Test the build_template_config method. -# """ -# cc = CustomComponent(code=component_code) -# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: -# mocked_get_code_tree.return_value = { -# 'classes': [{'name': '', 'attributes': []}]} -# assert cc.build_template_config == {} - - -# def test_get_function(component_code): -# """ -# Test the get_function method. -# """ -# cc = CustomComponent(code=component_code, function_entrypoint_name='build') -# assert callable(cc.get_function) - - -# def test_build(component_code): -# """ -# Test the build method. -# """ -# cc = CustomComponent(code=component_code) -# with pytest.raises(NotImplementedError): -# cc.build() - - -# @pytest.mark.parametrize("entrypoint_name", ["build", "non_exist_method"]) -# def test_set_non_existing_function_entrypoint_name(component_code, entrypoint_name): -# """ -# Test setting a non-existing function entrypoint name. -# """ -# cc = CustomComponent( -# code=component_code, -# function_entrypoint_name=entrypoint_name -# ) -# with pytest.raises(AttributeError): -# cc.get_function - - -# @pytest.mark.parametrize("base_class", ["CustomComponent", "NonExistingClass"]) -# def test_set_non_existing_base_class(component_code, base_class): -# """ -# Test setting a non-existing base class. -# """ -# cc = CustomComponent(code=component_code) -# cc.code_class_base_inheritance = base_class -# with pytest.raises(AttributeError): -# cc.get_main_class_name - - -# def test_class_with_no_methods(component_code): -# """ -# Test a component class with no methods. -# """ -# cc = CustomComponent(code=component_code) -# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: -# mocked_get_code_tree.return_value = { -# 'classes': [ -# { -# 'name': 'CustomComponent', -# 'methods': [], -# 'bases': ['CustomComponent'] -# } -# ] -# } -# assert cc.get_function_entrypoint_args == '' -# assert cc.get_function_entrypoint_return_type == '' - - -# def test_class_with_no_bases(component_code): -# """ -# Test a component class with no bases. -# """ -# cc = CustomComponent(code=component_code) -# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: -# mocked_get_code_tree.return_value = { -# 'classes': [ -# { -# 'name': 'CustomComponent', -# 'methods': [], -# 'bases': [] -# } -# ] -# } -# assert cc.get_function_entrypoint_args == '' -# assert cc.get_function_entrypoint_return_type == '' - - -# def test_class_with_no_name(component_code): -# """ -# Test a component class with no name. -# """ -# cc = CustomComponent(code=component_code) -# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: -# mocked_get_code_tree.return_value = {'classes': [ -# {'name': '', 'methods': [], 'bases': ['CustomComponent']}]} -# assert cc.get_main_class_name == '' - - -# @pytest.mark.parametrize("input_code", ["", "not a valid python code"]) -# def test_invalid_input_code(input_code): -# """ -# Test inputting an invalid Python code. -# """ -# with pytest.raises(SyntaxError): -# cc = CustomComponent(code=input_code) + # Only the result of the second query is returned + assert len(result) == 1 + assert result[0] == mock_flow_2 + assert mock_session.query.call_count == 2