diff --git a/tests/conftest.py b/tests/conftest.py index 1773ebf23..dfb2b56f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -116,254 +116,3 @@ def client_fixture(session: Session): yield TestClient(app) app.dependency_overrides.clear() - - -@pytest.fixture -def custom_chain(): - return ''' -from __future__ import annotations -from typing import Any, Dict, List, Optional - -from pydantic import Extra - -from langchain.schema import BaseLanguageModel, Document -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) -from langchain.chains.base import Chain -from langchain.prompts import StringPromptTemplate -from langflow.interface.custom.base import CustomComponent - -class MyCustomChain(Chain): - """ - An example of a custom chain. - """ - -from typing import Any, Dict, List, Optional - -from pydantic import Extra - -from langchain.schema import BaseLanguageModel, Document -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) -from langchain.chains.base import Chain -from langchain.prompts import StringPromptTemplate -from langflow.interface.custom.base import CustomComponent - -class MyCustomChain(Chain): - """ - An example of a custom chain. - """ - - prompt: StringPromptTemplate - """Prompt object to use.""" - llm: BaseLanguageModel - output_key: str = "text" #: :meta private: - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @property - def input_keys(self) -> List[str]: - """Will be whatever keys the prompt expects. - - :meta private: - """ - return self.prompt.input_variables - - @property - def output_keys(self) -> List[str]: - """Will always return text key. - - :meta private: - """ - return [self.output_key] - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - # Your custom chain logic goes here - # This is just an example that mimics LLMChain - prompt_value = self.prompt.format_prompt(**inputs) - - # Whenever you call a language model, or another chain, you should pass - # a callback manager to it. This allows the inner run to be tracked by - # any callbacks that are registered on the outer run. - # You can always obtain a callback manager for this by calling - # `run_manager.get_child()` as shown below. - response = self.llm.generate_prompt( - [prompt_value], - callbacks=run_manager.get_child() if run_manager else None, - ) - - # If you want to log something about this run, you can do so by calling - # methods on the `run_manager`, as shown below. This will trigger any - # callbacks that are registered for that event. - if run_manager: - run_manager.on_text("Log something about this run") - - return {self.output_key: response.generations[0][0].text} - - async def _acall( - self, - inputs: Dict[str, Any], - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - # Your custom chain logic goes here - # This is just an example that mimics LLMChain - prompt_value = self.prompt.format_prompt(**inputs) - - # Whenever you call a language model, or another chain, you should pass - # a callback manager to it. This allows the inner run to be tracked by - # any callbacks that are registered on the outer run. - # You can always obtain a callback manager for this by calling - # `run_manager.get_child()` as shown below. - response = await self.llm.agenerate_prompt( - [prompt_value], - callbacks=run_manager.get_child() if run_manager else None, - ) - - # If you want to log something about this run, you can do so by calling - # methods on the `run_manager`, as shown below. This will trigger any - # callbacks that are registered for that event. - if run_manager: - await run_manager.on_text("Log something about this run") - - return {self.output_key: response.generations[0][0].text} - - @property - def _chain_type(self) -> str: - return "my_custom_chain" - -class CustomChain(CustomComponent): - display_name: str = "Custom Chain" - field_config = { - "prompt": {"field_type": "prompt"}, - "llm": {"field_type": "BaseLanguageModel"}, - } - - def build(self, prompt, llm, input: str) -> Document: - chain = MyCustomChain(prompt=prompt, llm=llm) - return chain(input) -''' - - -@pytest.fixture -def data_processing(): - return """ -import pandas as pd -from langchain.schema import Document -from langflow.interface.custom.base import CustomComponent - -class CSVLoaderComponent(CustomComponent): - display_name: str = "CSV Loader" - field_config = { - "filename": {"field_type": "str", "required": True}, - "column_name": {"field_type": "str", "required": True}, - } - - def build(self, filename: str, column_name: str) -> Document: - # Load the CSV file - df = pd.read_csv(filename) - - # Verify the column exists - if column_name not in df.columns: - raise ValueError(f"Column '{column_name}' not found in the CSV file") - - # Convert each row of the specified column to a document object - documents = [] - for content in df[column_name]: - metadata = {"filename": filename} - documents.append(Document(page_content=str(content), metadata=metadata)) - - return documents -""" - - -@pytest.fixture -def filter_docs(): - return """ -from langchain.schema import Document -from langflow.interface.custom.base import CustomComponent -from typing import List - -class DocumentFilterByLengthComponent(CustomComponent): - display_name: str = "Document Filter By Length" - field_config = { - "documents": {"field_type": "Document", "required": True}, - "max_length": {"field_type": "int", "required": True}, - } - - def build(self, documents: List[Document], max_length: int) -> List[Document]: - # Filter the documents by length - filtered_documents = [doc for doc in documents if len(doc.page_content) <= max_length] - - return filtered_documents -""" - - -@pytest.fixture -def get_request(): - return """ -import requests -from typing import Dict, Union -from langchain.schema import Document -from langflow.interface.custom.base import CustomComponent - -class GetRequestComponent(CustomComponent): - display_name: str = "GET Request" - field_config = { - "url": {"field_type": "str", "required": True}, - } - - def build(self, url: str) -> Document: - # Send a GET request to the URL - response = requests.get(url) - - # Raise an exception if the request was not successful - if response.status_code != 200: - raise ValueError(f"GET request failed: {response.status_code} status code") - - # Create a document with the response text and the URL as metadata - document = Document(page_content=response.text, metadata={"url": url}) - - return document -""" - - -@pytest.fixture -def post_request(): - return """ -import requests -from typing import Dict, Union -from langchain.schema import Document -from langflow.interface.custom.base import CustomComponent - -class PostRequestComponent(CustomComponent): - display_name: str = "POST Request" - field_config = { - "url": {"field_type": "str", "required": True}, - "data": {"field_type": "dict", "required": True}, - } - - def build(self, url: str, data: Dict[str, Union[str, int]]) -> Document: - # Send a POST request to the URL - response = requests.post(url, data=data) - - # Raise an exception if the request was not successful - if response.status_code != 200: - raise ValueError(f"POST request failed: {response.status_code} status code") - - # Create a document with the response text and the URL and data as metadata - document = Document(page_content=response.text, metadata={"url": url, "data": data}) - - return document -""" diff --git a/tests/test_agents_template.py b/tests/test_agents_template.py index 93f4f8b5b..62c237b5a 100644 --- a/tests/test_agents_template.py +++ b/tests/test_agents_template.py @@ -18,6 +18,7 @@ def test_zero_shot_agent(client: TestClient): assert template["tools"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -32,6 +33,7 @@ def test_zero_shot_agent(client: TestClient): # Additional assertions for other template variables assert template["callback_manager"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -44,6 +46,7 @@ def test_zero_shot_agent(client: TestClient): } assert template["llm"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -56,6 +59,7 @@ def test_zero_shot_agent(client: TestClient): } assert template["output_parser"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -68,6 +72,7 @@ def test_zero_shot_agent(client: TestClient): } assert template["input_variables"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -80,6 +85,7 @@ def test_zero_shot_agent(client: TestClient): } assert template["prefix"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": True, @@ -93,6 +99,7 @@ def test_zero_shot_agent(client: TestClient): } assert template["suffix"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": True, @@ -118,6 +125,7 @@ def test_json_agent(client: TestClient): assert template["toolkit"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -130,6 +138,7 @@ def test_json_agent(client: TestClient): } assert template["llm"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -155,12 +164,12 @@ def test_csv_agent(client: TestClient): assert template["path"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, "value": "", "suffixes": [".csv"], - "fileTypes": ["csv"], "password": False, "name": "path", "type": "file", @@ -171,6 +180,7 @@ def test_csv_agent(client: TestClient): } assert template["llm"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -196,6 +206,7 @@ def test_initialize_agent(client: TestClient): assert template["agent"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -217,6 +228,7 @@ def test_initialize_agent(client: TestClient): } assert template["memory"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -229,6 +241,7 @@ def test_initialize_agent(client: TestClient): } assert template["tools"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -241,6 +254,7 @@ def test_initialize_agent(client: TestClient): } assert template["llm"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, diff --git a/tests/test_chains_template.py b/tests/test_chains_template.py index e183cb0d0..2e2d84b9d 100644 --- a/tests/test_chains_template.py +++ b/tests/test_chains_template.py @@ -29,6 +29,7 @@ def test_conversation_chain(client: TestClient): template = chain["template"] assert template["memory"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -41,6 +42,7 @@ def test_conversation_chain(client: TestClient): } assert template["verbose"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -53,6 +55,7 @@ def test_conversation_chain(client: TestClient): } assert template["llm"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -65,6 +68,7 @@ def test_conversation_chain(client: TestClient): } assert template["input_key"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -78,6 +82,7 @@ def test_conversation_chain(client: TestClient): } assert template["output_key"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -115,6 +120,7 @@ def test_llm_chain(client: TestClient): template = chain["template"] assert template["memory"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -127,6 +133,7 @@ def test_llm_chain(client: TestClient): } assert template["verbose"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -140,6 +147,7 @@ def test_llm_chain(client: TestClient): } assert template["llm"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -152,6 +160,7 @@ def test_llm_chain(client: TestClient): } assert template["output_key"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -182,6 +191,7 @@ def test_llm_checker_chain(client: TestClient): template = chain["template"] assert template["llm"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -215,6 +225,7 @@ def test_llm_math_chain(client: TestClient): template = chain["template"] assert template["memory"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -227,6 +238,7 @@ def test_llm_math_chain(client: TestClient): } assert template["verbose"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -240,6 +252,7 @@ def test_llm_math_chain(client: TestClient): } assert template["llm"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -252,6 +265,7 @@ def test_llm_math_chain(client: TestClient): } assert template["input_key"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -265,6 +279,7 @@ def test_llm_math_chain(client: TestClient): } assert template["output_key"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -306,6 +321,7 @@ def test_series_character_chain(client: TestClient): assert template["llm"] == { "required": True, + "dynamic": False, "display_name": "LLM", "placeholder": "", "show": True, @@ -319,6 +335,7 @@ def test_series_character_chain(client: TestClient): } assert template["character"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -331,6 +348,7 @@ def test_series_character_chain(client: TestClient): } assert template["series"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -372,6 +390,7 @@ def test_mid_journey_prompt_chain(client: TestClient): assert template["llm"] == { "required": True, + "dynamic": False, "display_name": "LLM", "placeholder": "", "show": True, @@ -412,6 +431,7 @@ def test_time_travel_guide_chain(client: TestClient): assert template["llm"] == { "required": True, + "dynamic": False, "placeholder": "", "display_name": "LLM", "show": True, @@ -425,6 +445,7 @@ def test_time_travel_guide_chain(client: TestClient): } assert template["memory"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, diff --git a/tests/test_creators.py b/tests/test_creators.py index 5453b57eb..2098e87cd 100644 --- a/tests/test_creators.py +++ b/tests/test_creators.py @@ -35,6 +35,7 @@ def test_lang_chain_type_creator_to_dict( sample_lang_chain_type_creator: LangChainTypeCreator, ): type_dict = sample_lang_chain_type_creator.to_dict() + assert len(type_dict) == 1 assert "test_type" in type_dict assert "node1" in type_dict["test_type"] diff --git a/tests/test_custom_component.py b/tests/test_custom_component.py index b73a80d69..bb7e00dcf 100644 --- a/tests/test_custom_component.py +++ b/tests/test_custom_component.py @@ -1,180 +1,904 @@ import ast import pytest +import types + from fastapi import HTTPException +from langflow.interface.custom.base import CustomComponent +from langflow.interface.custom.component import ( + Component, + ComponentCodeNullError, + ComponentFunctionEntrypointNameNullError, +) +from langflow.interface.custom.code_parser import CodeParser, CodeSyntaxError + + +code_default = """ +from langflow import Prompt from langflow.interface.custom.custom_component import CustomComponent -from langflow.interface.custom.constants import DEFAULT_CUSTOM_COMPONENT_CODE + +from langchain.llms.base import BaseLLM +from langchain.chains import LLMChain +from langchain import PromptTemplate +from langchain.schema import Document + +import requests + +class YourComponent(CustomComponent): + langflow_display_name: str = "Your Component" + langflow_description: str = "Your description" + langflow_field_config = { "url": { "multiline": True, "required": True } } + + def build(self, url: str, llm: BaseLLM, template: Prompt) -> Document: + response = requests.get(url) + prompt = PromptTemplate.from_template(template) + chain = LLMChain(llm=llm, prompt=prompt) + result = chain.run(response.text[:300]) + return Document(page_content=str(result)) +""" -# Test the __init__ method -def test_init(): - component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE) - assert isinstance(component, CustomComponent) - assert component.code == DEFAULT_CUSTOM_COMPONENT_CODE +def test_code_parser_init(): + """ + Test the initialization of the CodeParser class. + """ + parser = CodeParser(code_default) + assert parser.code == code_default -# Test the _handle_import method -def test_handle_import(): - component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE) - node = ast.parse("import math").body[0] - component._handle_import(node) - assert "math" in component.class_template["imports"] +def test_code_parser_get_tree(): + """ + Test the __get_tree method of the CodeParser class. + """ + parser = CodeParser(code_default) + tree = parser._CodeParser__get_tree() + assert isinstance(tree, ast.AST) -# Test the _handle_class method -def test_handle_class(): - component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE) - node = ast.parse("class Test: pass").body[0] - component._handle_class(node) - assert component.class_template["class"]["name"] == "Test" +def test_code_parser_syntax_error(): + """ + Test the __get_tree method raises the CodeSyntaxError when given incorrect syntax. + """ + code_syntax_error = "zzz import os" + + parser = CodeParser(code_syntax_error) + with pytest.raises(CodeSyntaxError): + parser._CodeParser__get_tree() -# Test the _handle_function method -def test_handle_function(): - component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE) - node = ast.parse("def func(): pass").body[0] - component._handle_function(node) - function_data = {"name": "func", "arguments": [], "return_type": "None"} - assert function_data in component.class_template["functions"] +def test_component_init(): + """ + Test the initialization of the Component class. + """ + component = Component(code=code_default, function_entrypoint_name="build") + assert component.code == code_default + assert component.function_entrypoint_name == "build" -# Test the transform_list method -def test_transform_list(): - component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE) - input_list = ["var1: int", "var2: str", "var3"] - output_list = [["var1", "int"], ["var2", "str"], ["var3", None]] - assert component.transform_list(input_list) == output_list +def test_component_get_code_tree(): + """ + Test the get_code_tree method of the Component class. + """ + component = Component(code=code_default, function_entrypoint_name="build") + tree = component.get_code_tree(component.code) + assert "imports" in tree -# Test the extract_class_info method with valid code -def test_extract_class_info(): - component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE) - class_info = component.extract_class_info() - assert "requests" in class_info["imports"] - assert class_info["class"]["name"] == "YourComponent" - function_data = { - "name": "build", - "arguments": ["self", "url: str", "llm: BaseLLM", "template: Prompt"], - "return_type": "Document", - } - assert function_data in class_info["functions"] +def test_component_code_null_error(): + """ + Test the get_function method raises the ComponentCodeNullError when the code is empty. + """ + component = Component(code="", function_entrypoint_name="") + with pytest.raises(ComponentCodeNullError): + component.get_function() -# Test the extract_class_info method with invalid code -def test_extract_class_info_invalid_code(): - component = CustomComponent(field_config={}, code="invalid code") - with pytest.raises(HTTPException) as e: - component.extract_class_info() - - exception = e.value - assert exception.status_code == 400 - assert exception.detail["error"] == "invalid syntax" +def test_component_function_entrypoint_name_null_error(): + """ + Test the get_function method raises the ComponentFunctionEntrypointNameNullError when the function_entrypoint_name is empty. + """ + component = Component(code=code_default, function_entrypoint_name="") + with pytest.raises(ComponentFunctionEntrypointNameNullError): + component.get_function() -# Test the get_entrypoint_function_args_and_return_type method -def test_get_entrypoint_function_args_and_return_type(): - component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE) - ( - function_args, - return_type, - template_config, - ) = component.get_entrypoint_function_args_and_return_type() - assert function_args == [ - ["self", None], - ["url", "str"], - ["llm", "BaseLLM"], - ["template", "Prompt"], - ] +def test_custom_component_init(): + """ + Test the initialization of the CustomComponent class. + """ + function_entrypoint_name = "build" + + custom_component = CustomComponent( + code=code_default, function_entrypoint_name=function_entrypoint_name + ) + assert custom_component.code == code_default + assert custom_component.function_entrypoint_name == function_entrypoint_name + + +def test_custom_component_build_template_config(): + """ + Test the build_template_config property of the CustomComponent class. + """ + custom_component = CustomComponent( + code=code_default, function_entrypoint_name="build" + ) + config = custom_component.build_template_config + assert isinstance(config, dict) + + +def test_custom_component_get_function(): + """ + Test the get_function property of the CustomComponent class. + """ + custom_component = CustomComponent( + code="def build(): pass", function_entrypoint_name="build" + ) + my_function = custom_component.get_function + assert isinstance(my_function, types.FunctionType) + + +def test_code_parser_parse_imports_import(): + """ + Test the parse_imports method of the CodeParser class with an import statement. + """ + parser = CodeParser(code_default) + tree = parser._CodeParser__get_tree() + for node in ast.walk(tree): + if isinstance(node, ast.Import): + parser.parse_imports(node) + assert "requests" in parser.data["imports"] + + +def test_code_parser_parse_imports_importfrom(): + """ + Test the parse_imports method of the CodeParser class with an import from statement. + """ + parser = CodeParser("from os import path") + tree = parser._CodeParser__get_tree() + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom): + parser.parse_imports(node) + assert ("os", "path") in parser.data["imports"] + + +def test_code_parser_parse_functions(): + """ + Test the parse_functions method of the CodeParser class. + """ + parser = CodeParser("def test(): pass") + tree = parser._CodeParser__get_tree() + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + parser.parse_functions(node) + assert len(parser.data["functions"]) == 1 + assert parser.data["functions"][0]["name"] == "test" + + +def test_code_parser_parse_classes(): + """ + Test the parse_classes method of the CodeParser class. + """ + parser = CodeParser("class Test: pass") + tree = parser._CodeParser__get_tree() + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + parser.parse_classes(node) + assert len(parser.data["classes"]) == 1 + assert parser.data["classes"][0]["name"] == "Test" + + +def test_code_parser_parse_global_vars(): + """ + Test the parse_global_vars method of the CodeParser class. + """ + parser = CodeParser("x = 1") + tree = parser._CodeParser__get_tree() + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + parser.parse_global_vars(node) + assert len(parser.data["global_vars"]) == 1 + assert parser.data["global_vars"][0]["targets"] == ["x"] + + +def test_component_get_function_valid(): + """ + Test the get_function method of the Component class with valid code and function_entrypoint_name. + """ + component = Component(code="def build(): pass", function_entrypoint_name="build") + function = component.get_function() + assert callable(function) + + +def test_custom_component_get_function_entrypoint_args(): + """ + Test the get_function_entrypoint_args property of the CustomComponent class. + """ + custom_component = CustomComponent( + code=code_default, function_entrypoint_name="build" + ) + args = custom_component.get_function_entrypoint_args + assert len(args) == 4 + assert args[0]["name"] == "self" + assert args[1]["name"] == "url" + assert args[2]["name"] == "llm" + + +def test_custom_component_get_function_entrypoint_return_type(): + """ + Test the get_function_entrypoint_return_type property of the CustomComponent class. + """ + custom_component = CustomComponent( + code=code_default, function_entrypoint_name="build" + ) + return_type = custom_component.get_function_entrypoint_return_type assert return_type == "Document" - assert template_config == { - "description": "Your description", - "display_name": "Your Component", - "field_config": {"url": {"multiline": True, "required": True}}, - } -# Test the _build_template_config method -def test__build_template_config(): - attributes = { - "field_config": "'field_config_value'", - "display_name": "'display_name_value'", - "description": "'description_value'", - } - component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE) - template_config = component._build_template_config(attributes) - - assert template_config == { - "field_config": "field_config_value", - "display_name": "display_name_value", - "description": "description_value", - } +def test_custom_component_get_main_class_name(): + """ + Test the get_main_class_name property of the CustomComponent class. + """ + custom_component = CustomComponent( + code=code_default, function_entrypoint_name="build" + ) + class_name = custom_component.get_main_class_name + assert class_name == "YourComponent" -# Test the _class_template_validation method with a valid class template -def test__class_template_validation_valid(): - component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE) - assert component._class_template_validation(code=component.data) is True +def test_custom_component_get_function_valid(): + """ + Test the get_function property of the CustomComponent class with valid code and function_entrypoint_name. + """ + custom_component = CustomComponent( + code="def build(): pass", function_entrypoint_name="build" + ) + my_function = custom_component.get_function + assert callable(my_function) -# Test the _class_template_validation method with an invalid class template -def test__class_template_validation_invalid(): - component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE) - class_template = {} - - with pytest.raises(Exception) as e: - component._class_template_validation(class_template) - - exception = e.value - assert exception.status_code == 400 - assert exception.detail["error"] == "The main class must have a valid name." +def test_code_parser_parse_arg_no_annotation(): + """ + Test the parse_arg method of the CodeParser class without an annotation. + """ + parser = CodeParser("") + arg = ast.arg(arg="x", annotation=None) + result = parser.parse_arg(arg, None) + assert result["name"] == "x" + assert "type" not in result -# Test the build method -def test_build(): - component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE) - with pytest.raises(Exception) as e: - component.build() - - assert e.type == NotImplementedError +def test_code_parser_parse_arg_with_annotation(): + """ + Test the parse_arg method of the CodeParser class with an annotation. + """ + parser = CodeParser("") + arg = ast.arg(arg="x", annotation=ast.Name(id="int", ctx=ast.Load())) + result = parser.parse_arg(arg, None) + assert result["name"] == "x" + assert result["type"] == "int" -# Test the data property -def test_data(): - code = DEFAULT_CUSTOM_COMPONENT_CODE - component = CustomComponent(field_config={}, code=code) - class_info = component.data - assert "requests" in class_info["imports"] - assert class_info["class"]["name"] == "YourComponent" - function_data = { - "name": "build", - "arguments": ["self", "url: str", "llm: BaseLLM", "template: Prompt"], - "return_type": "Document", - } - assert function_data in class_info["functions"] +def test_code_parser_parse_callable_details_no_args(): + """ + Test the parse_callable_details method of the CodeParser class with a function with no arguments. + """ + parser = CodeParser("") + node = ast.FunctionDef( + name="test", + args=ast.arguments( + args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] + ), + body=[], + decorator_list=[], + returns=None, + ) + result = parser.parse_callable_details(node) + assert result["name"] == "test" + assert len(result["args"]) == 0 -# Test the is_check_valid method -def test_is_check_valid(): - component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE) - - assert component.is_check_valid() is True +def test_code_parser_parse_assign(): + """ + Test the parse_assign method of the CodeParser class. + """ + parser = CodeParser("") + stmt = ast.Assign(targets=[ast.Name(id="x", ctx=ast.Store())], value=ast.Num(n=1)) + result = parser.parse_assign(stmt) + assert result["name"] == "x" + assert result["value"] == "1" -# Test the args_and_return_type property -def test_args_and_return_type(): - component = CustomComponent(field_config={}, code=DEFAULT_CUSTOM_COMPONENT_CODE) +def test_code_parser_parse_ann_assign(): + """ + Test the parse_ann_assign method of the CodeParser class. + """ + parser = CodeParser("") + stmt = ast.AnnAssign( + target=ast.Name(id="x", ctx=ast.Store()), + annotation=ast.Name(id="int", ctx=ast.Load()), + value=ast.Num(n=1), + simple=1, + ) + result = parser.parse_ann_assign(stmt) + assert result["name"] == "x" + assert result["value"] == "1" + assert result["annotation"] == "int" - function_args, return_type, template_config = component.args_and_return_type - assert function_args == [ - ["self", None], - ["url", "str"], - ["llm", "BaseLLM"], - ["template", "Prompt"], - ] +def test_code_parser_parse_function_def_not_init(): + """ + Test the parse_function_def method of the CodeParser class with a function that is not __init__. + """ + parser = CodeParser("") + stmt = ast.FunctionDef( + name="test", + args=ast.arguments( + args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] + ), + body=[], + decorator_list=[], + returns=None, + ) + result, is_init = parser.parse_function_def(stmt) + assert result["name"] == "test" + assert not is_init - assert return_type == "Document" - assert template_config == { - "description": "Your description", - "display_name": "Your Component", - "field_config": {"url": {"multiline": True, "required": True}}, - } + +def test_code_parser_parse_function_def_init(): + """ + Test the parse_function_def method of the CodeParser class with an __init__ function. + """ + parser = CodeParser("") + stmt = ast.FunctionDef( + name="__init__", + args=ast.arguments( + args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] + ), + body=[], + decorator_list=[], + returns=None, + ) + result, is_init = parser.parse_function_def(stmt) + assert result["name"] == "__init__" + assert is_init + + +def test_component_get_code_tree_syntax_error(): + """ + Test the get_code_tree method of the Component class raises the CodeSyntaxError when given incorrect syntax. + """ + component = Component(code="import os as", function_entrypoint_name="build") + with pytest.raises(CodeSyntaxError): + component.get_code_tree(component.code) + + +def test_custom_component_class_template_validation_no_code(): + """ + Test the _class_template_validation method of the CustomComponent class raises the HTTPException when the code is None. + """ + custom_component = CustomComponent(code=None, function_entrypoint_name="build") + with pytest.raises(HTTPException): + custom_component._class_template_validation(custom_component.code) + + +def test_custom_component_get_code_tree_syntax_error(): + """ + Test the get_code_tree method of the CustomComponent class raises the CodeSyntaxError when given incorrect syntax. + """ + custom_component = CustomComponent( + code="import os as", function_entrypoint_name="build" + ) + with pytest.raises(CodeSyntaxError): + custom_component.get_code_tree(custom_component.code) + + +def test_custom_component_get_function_entrypoint_args_no_args(): + """ + Test the get_function_entrypoint_args property of the CustomComponent class with a build method with no arguments. + """ + my_code = """ +class MyMainClass(CustomComponent): + def build(): + pass""" + + custom_component = CustomComponent(code=my_code, function_entrypoint_name="build") + args = custom_component.get_function_entrypoint_args + assert len(args) == 0 + + +def test_custom_component_get_function_entrypoint_return_type_no_return_type(): + """ + Test the get_function_entrypoint_return_type property of the CustomComponent class with a build method with no return type. + """ + my_code = """ +class MyClass(CustomComponent): + def build(): + pass""" + + custom_component = CustomComponent(code=my_code, function_entrypoint_name="build") + return_type = custom_component.get_function_entrypoint_return_type + assert return_type is None + + +def test_custom_component_get_main_class_name_no_main_class(): + """ + Test the get_main_class_name property of the CustomComponent class when there is no main class. + """ + my_code = """ +def build(): + pass""" + + custom_component = CustomComponent(code=my_code, function_entrypoint_name="build") + class_name = custom_component.get_main_class_name + assert class_name == "" + + +def test_custom_component_build_not_implemented(): + """ + Test the build method of the CustomComponent class raises the NotImplementedError. + """ + custom_component = CustomComponent( + code="def build(): pass", function_entrypoint_name="build" + ) + with pytest.raises(NotImplementedError): + custom_component.build() + + +# ------------------------------------------------------- +# @pytest.fixture +# def custom_chain(): +# return ''' +# from __future__ import annotations +# from typing import Any, Dict, List, Optional + +# from pydantic import Extra + +# from langchain.schema import BaseLanguageModel, Document +# from langchain.callbacks.manager import ( +# AsyncCallbackManagerForChainRun, +# CallbackManagerForChainRun, +# ) +# from langchain.chains.base import Chain +# from langchain.prompts import StringPromptTemplate +# from langflow.interface.custom.base import CustomComponent + +# class MyCustomChain(Chain): +# """ +# An example of a custom chain. +# """ + +# from typing import Any, Dict, List, Optional + +# from pydantic import Extra + +# from langchain.schema import BaseLanguageModel, Document +# from langchain.callbacks.manager import ( +# AsyncCallbackManagerForChainRun, +# CallbackManagerForChainRun, +# ) +# from langchain.chains.base import Chain +# from langchain.prompts import StringPromptTemplate +# from langflow.interface.custom.base import CustomComponent + +# class MyCustomChain(Chain): +# """ +# An example of a custom chain. +# """ + +# prompt: StringPromptTemplate +# """Prompt object to use.""" +# llm: BaseLanguageModel +# output_key: str = "text" #: :meta private: + +# class Config: +# """Configuration for this pydantic object.""" + +# extra = Extra.forbid +# arbitrary_types_allowed = True + +# @property +# def input_keys(self) -> List[str]: +# """Will be whatever keys the prompt expects. + +# :meta private: +# """ +# return self.prompt.input_variables + +# @property +# def output_keys(self) -> List[str]: +# """Will always return text key. + +# :meta private: +# """ +# return [self.output_key] + +# def _call( +# self, +# inputs: Dict[str, Any], +# run_manager: Optional[CallbackManagerForChainRun] = None, +# ) -> Dict[str, str]: +# # Your custom chain logic goes here +# # This is just an example that mimics LLMChain +# prompt_value = self.prompt.format_prompt(**inputs) + +# # Whenever you call a language model, or another chain, you should pass +# # a callback manager to it. This allows the inner run to be tracked by +# # any callbacks that are registered on the outer run. +# # You can always obtain a callback manager for this by calling +# # `run_manager.get_child()` as shown below. +# response = self.llm.generate_prompt( +# [prompt_value], +# callbacks=run_manager.get_child() if run_manager else None, +# ) + +# # If you want to log something about this run, you can do so by calling +# # methods on the `run_manager`, as shown below. This will trigger any +# # callbacks that are registered for that event. +# if run_manager: +# run_manager.on_text("Log something about this run") + +# return {self.output_key: response.generations[0][0].text} + +# async def _acall( +# self, +# inputs: Dict[str, Any], +# run_manager: Optional[AsyncCallbackManagerForChainRun] = None, +# ) -> Dict[str, str]: +# # Your custom chain logic goes here +# # This is just an example that mimics LLMChain +# prompt_value = self.prompt.format_prompt(**inputs) + +# # Whenever you call a language model, or another chain, you should pass +# # a callback manager to it. This allows the inner run to be tracked by +# # any callbacks that are registered on the outer run. +# # You can always obtain a callback manager for this by calling +# # `run_manager.get_child()` as shown below. +# response = await self.llm.agenerate_prompt( +# [prompt_value], +# callbacks=run_manager.get_child() if run_manager else None, +# ) + +# # If you want to log something about this run, you can do so by calling +# # methods on the `run_manager`, as shown below. This will trigger any +# # callbacks that are registered for that event. +# if run_manager: +# await run_manager.on_text("Log something about this run") + +# return {self.output_key: response.generations[0][0].text} + +# @property +# def _chain_type(self) -> str: +# return "my_custom_chain" + +# class CustomChain(CustomComponent): +# display_name: str = "Custom Chain" +# field_config = { +# "prompt": {"field_type": "prompt"}, +# "llm": {"field_type": "BaseLanguageModel"}, +# } + +# def build(self, prompt, llm, input: str) -> Document: +# chain = MyCustomChain(prompt=prompt, llm=llm) +# return chain(input) +# ''' + + +# @pytest.fixture +# def data_processing(): +# return """ +# import pandas as pd +# from langchain.schema import Document +# from langflow.interface.custom.base import CustomComponent + +# class CSVLoaderComponent(CustomComponent): +# display_name: str = "CSV Loader" +# field_config = { +# "filename": {"field_type": "str", "required": True}, +# "column_name": {"field_type": "str", "required": True}, +# } + +# def build(self, filename: str, column_name: str) -> Document: +# # Load the CSV file +# df = pd.read_csv(filename) + +# # Verify the column exists +# if column_name not in df.columns: +# raise ValueError(f"Column '{column_name}' not found in the CSV file") + +# # Convert each row of the specified column to a document object +# documents = [] +# for content in df[column_name]: +# metadata = {"filename": filename} +# documents.append(Document(page_content=str(content), metadata=metadata)) + +# return documents +# """ + + +# @pytest.fixture +# def filter_docs(): +# return """ +# from langchain.schema import Document +# from langflow.interface.custom.base import CustomComponent +# from typing import List + +# class DocumentFilterByLengthComponent(CustomComponent): +# display_name: str = "Document Filter By Length" +# field_config = { +# "documents": {"field_type": "Document", "required": True}, +# "max_length": {"field_type": "int", "required": True}, +# } + +# def build(self, documents: List[Document], max_length: int) -> List[Document]: +# # Filter the documents by length +# filtered_documents = [doc for doc in documents if len(doc.page_content) <= max_length] + +# return filtered_documents +# """ + + +# @pytest.fixture +# def get_request(): +# return """ +# import requests +# from typing import Dict, Union +# from langchain.schema import Document +# from langflow.interface.custom.base import CustomComponent + +# class GetRequestComponent(CustomComponent): +# display_name: str = "GET Request" +# field_config = { +# "url": {"field_type": "str", "required": True}, +# } + +# def build(self, url: str) -> Document: +# # Send a GET request to the URL +# response = requests.get(url) + +# # Raise an exception if the request was not successful +# if response.status_code != 200: +# raise ValueError(f"GET request failed: {response.status_code} status code") + +# # Create a document with the response text and the URL as metadata +# document = Document(page_content=response.text, metadata={"url": url}) + +# return document +# """ + + +# @pytest.fixture +# def post_request(): +# return """ +# import requests +# from typing import Dict, Union +# from langchain.schema import Document +# from langflow.interface.custom.base import CustomComponent + +# class PostRequestComponent(CustomComponent): +# display_name: str = "POST Request" +# field_config = { +# "url": {"field_type": "str", "required": True}, +# "data": {"field_type": "dict", "required": True}, +# } + +# def build(self, url: str, data: Dict[str, Union[str, int]]) -> Document: +# # Send a POST request to the URL +# response = requests.post(url, data=data) + +# # Raise an exception if the request was not successful +# if response.status_code != 200: +# raise ValueError(f"POST request failed: {response.status_code} status code") + +# # Create a document with the response text and the URL and data as metadata +# document = Document(page_content=response.text, metadata={"url": url, "data": data}) + +# return document +# """ + + +# @pytest.fixture +# def code_default(): +# return """ +# from langflow import Prompt +# from langflow.interface.custom.custom_component import CustomComponent + +# from langchain.llms.base import BaseLLM +# from langchain.chains import LLMChain +# from langchain import PromptTemplate +# from langchain.schema import Document + +# import requests + +# class YourComponent(CustomComponent): +# #display_name: str = "Your Component" +# #description: str = "Your description" +# #field_config = { "url": { "multiline": True, "required": True } } + +# def build(self, url: str, llm: BaseLLM, template: Prompt) -> Document: +# response = requests.get(url) +# prompt = PromptTemplate.from_template(template) +# chain = LLMChain(llm=llm, prompt=prompt) +# result = chain.run(response.text[:300]) +# return Document(page_content=str(result)) +# """ + + +# @pytest.fixture(params=[ +# 'code_default', 'custom_chain', 'data_processing', +# 'filter_docs', 'get_request', 'post_request']) +# def component_code( +# request, code_default, custom_chain, data_processing, +# filter_docs, get_request, post_request): +# return locals()[request.param] + + +# def test_empty_code_tree(component_code): +# """ +# Test the situation when the code tree is empty. +# """ +# cc = CustomComponent(code=component_code) +# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: +# mocked_get_code_tree.return_value = {} +# assert cc.get_function_entrypoint_args == '' +# assert cc.get_function_entrypoint_return_type == '' +# assert cc.get_main_class_name == '' +# assert cc.build_template_config == {} + + +# def test_class_template_validation(component_code): +# """ +# Test the _class_template_validation method. +# """ +# cc = CustomComponent(code=component_code) +# assert cc._class_template_validation(component_code) == True +# with pytest.raises(HTTPException): +# cc._class_template_validation(None) + + +# def test_get_code_tree(component_code): +# """ +# Test the get_code_tree method. +# """ +# cc = CustomComponent(code=component_code) +# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: +# mocked_get_code_tree.return_value = {'classes': []} +# assert cc.get_code_tree(component_code) == {'classes': []} + + +# def test_get_function_entrypoint_args(component_code): +# """ +# Test the get_function_entrypoint_args method. +# """ +# cc = CustomComponent(code=component_code) +# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: +# mocked_get_code_tree.return_value = {'classes': []} +# assert cc.get_function_entrypoint_args == '' + + +# def test_get_function_entrypoint_return_type(component_code): +# """ +# Test the get_function_entrypoint_return_type method. +# """ +# cc = CustomComponent(code=component_code) +# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: +# mocked_get_code_tree.return_value = {'classes': []} +# assert cc.get_function_entrypoint_return_type == '' + + +# def test_get_main_class_name(component_code): +# """ +# Test the get_main_class_name method. +# """ +# cc = CustomComponent(code=component_code) +# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: +# mocked_get_code_tree.return_value = {'classes': []} +# assert cc.get_main_class_name == '' + + +# def test_build_template_config(component_code): +# """ +# Test the build_template_config method. +# """ +# cc = CustomComponent(code=component_code) +# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: +# mocked_get_code_tree.return_value = { +# 'classes': [{'name': '', 'attributes': []}]} +# assert cc.build_template_config == {} + + +# def test_get_function(component_code): +# """ +# Test the get_function method. +# """ +# cc = CustomComponent(code=component_code, function_entrypoint_name='build') +# assert callable(cc.get_function) + + +# def test_build(component_code): +# """ +# Test the build method. +# """ +# cc = CustomComponent(code=component_code) +# with pytest.raises(NotImplementedError): +# cc.build() + + +# @pytest.mark.parametrize("entrypoint_name", ["build", "non_exist_method"]) +# def test_set_non_existing_function_entrypoint_name(component_code, entrypoint_name): +# """ +# Test setting a non-existing function entrypoint name. +# """ +# cc = CustomComponent( +# code=component_code, +# function_entrypoint_name=entrypoint_name +# ) +# with pytest.raises(AttributeError): +# cc.get_function + + +# @pytest.mark.parametrize("base_class", ["CustomComponent", "NonExistingClass"]) +# def test_set_non_existing_base_class(component_code, base_class): +# """ +# Test setting a non-existing base class. +# """ +# cc = CustomComponent(code=component_code) +# cc.code_class_base_inheritance = base_class +# with pytest.raises(AttributeError): +# cc.get_main_class_name + + +# def test_class_with_no_methods(component_code): +# """ +# Test a component class with no methods. +# """ +# cc = CustomComponent(code=component_code) +# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: +# mocked_get_code_tree.return_value = { +# 'classes': [ +# { +# 'name': 'CustomComponent', +# 'methods': [], +# 'bases': ['CustomComponent'] +# } +# ] +# } +# assert cc.get_function_entrypoint_args == '' +# assert cc.get_function_entrypoint_return_type == '' + + +# def test_class_with_no_bases(component_code): +# """ +# Test a component class with no bases. +# """ +# cc = CustomComponent(code=component_code) +# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: +# mocked_get_code_tree.return_value = { +# 'classes': [ +# { +# 'name': 'CustomComponent', +# 'methods': [], +# 'bases': [] +# } +# ] +# } +# assert cc.get_function_entrypoint_args == '' +# assert cc.get_function_entrypoint_return_type == '' + + +# def test_class_with_no_name(component_code): +# """ +# Test a component class with no name. +# """ +# cc = CustomComponent(code=component_code) +# with patch.object(cc, 'get_code_tree') as mocked_get_code_tree: +# mocked_get_code_tree.return_value = {'classes': [ +# {'name': '', 'methods': [], 'bases': ['CustomComponent']}]} +# assert cc.get_main_class_name == '' + + +# @pytest.mark.parametrize("input_code", ["", "not a valid python code"]) +# def test_invalid_input_code(input_code): +# """ +# Test inputting an invalid Python code. +# """ +# with pytest.raises(SyntaxError): +# cc = CustomComponent(code=input_code) diff --git a/tests/test_llms_template.py b/tests/test_llms_template.py index 7679ba9c0..6bb1bc28d 100644 --- a/tests/test_llms_template.py +++ b/tests/test_llms_template.py @@ -113,6 +113,7 @@ def test_openai(client: TestClient): assert template["cache"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -125,6 +126,7 @@ def test_openai(client: TestClient): } assert template["verbose"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -137,6 +139,7 @@ def test_openai(client: TestClient): } assert template["client"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -149,6 +152,7 @@ def test_openai(client: TestClient): } assert template["model_name"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -170,6 +174,7 @@ def test_openai(client: TestClient): # Add more assertions for other properties here assert template["temperature"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -183,6 +188,7 @@ def test_openai(client: TestClient): } assert template["max_tokens"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -196,6 +202,7 @@ def test_openai(client: TestClient): } assert template["top_p"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -209,6 +216,7 @@ def test_openai(client: TestClient): } assert template["frequency_penalty"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -222,6 +230,7 @@ def test_openai(client: TestClient): } assert template["presence_penalty"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -235,6 +244,7 @@ def test_openai(client: TestClient): } assert template["n"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -248,6 +258,7 @@ def test_openai(client: TestClient): } assert template["best_of"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -261,6 +272,7 @@ def test_openai(client: TestClient): } assert template["model_kwargs"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -273,6 +285,7 @@ def test_openai(client: TestClient): } assert template["openai_api_key"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -287,6 +300,7 @@ def test_openai(client: TestClient): } assert template["batch_size"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -300,6 +314,7 @@ def test_openai(client: TestClient): } assert template["request_timeout"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -312,6 +327,7 @@ def test_openai(client: TestClient): } assert template["logit_bias"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -324,6 +340,7 @@ def test_openai(client: TestClient): } assert template["max_retries"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -337,6 +354,7 @@ def test_openai(client: TestClient): } assert template["streaming"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -361,6 +379,7 @@ def test_chat_open_ai(client: TestClient): assert template["verbose"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -374,6 +393,7 @@ def test_chat_open_ai(client: TestClient): } assert template["client"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -386,6 +406,7 @@ def test_chat_open_ai(client: TestClient): } assert template["model_name"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -409,6 +430,7 @@ def test_chat_open_ai(client: TestClient): } assert template["temperature"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -422,6 +444,7 @@ def test_chat_open_ai(client: TestClient): } assert template["model_kwargs"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -434,6 +457,7 @@ def test_chat_open_ai(client: TestClient): } assert template["openai_api_key"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, @@ -448,6 +472,7 @@ def test_chat_open_ai(client: TestClient): } assert template["request_timeout"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -460,6 +485,7 @@ def test_chat_open_ai(client: TestClient): } assert template["max_retries"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -473,6 +499,7 @@ def test_chat_open_ai(client: TestClient): } assert template["streaming"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -486,6 +513,7 @@ def test_chat_open_ai(client: TestClient): } assert template["n"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -500,6 +528,7 @@ def test_chat_open_ai(client: TestClient): assert template["max_tokens"] == { "required": False, + "dynamic": False, "placeholder": "", "show": True, "multiline": False, diff --git a/tests/test_prompts_template.py b/tests/test_prompts_template.py index 5486f3034..afc595a41 100644 --- a/tests/test_prompts_template.py +++ b/tests/test_prompts_template.py @@ -20,6 +20,7 @@ def test_prompt_template(client: TestClient): template = prompt["template"] assert template["input_variables"] == { "required": True, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -30,8 +31,10 @@ def test_prompt_template(client: TestClient): "advanced": False, "info": "", } + assert template["output_parser"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -42,8 +45,10 @@ def test_prompt_template(client: TestClient): "advanced": False, "info": "", } + assert template["partial_variables"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -54,8 +59,10 @@ def test_prompt_template(client: TestClient): "advanced": False, "info": "", } + assert template["template"] == { "required": True, + "dynamic": False, "placeholder": "", "show": True, "multiline": True, @@ -66,8 +73,10 @@ def test_prompt_template(client: TestClient): "advanced": False, "info": "", } + assert template["template_format"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False, @@ -79,8 +88,10 @@ def test_prompt_template(client: TestClient): "advanced": False, "info": "", } + assert template["validate_template"] == { "required": False, + "dynamic": False, "placeholder": "", "show": False, "multiline": False,