From 59baa407d2d04d2786fd50b82a68f84eec43b50f Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Thu, 6 Apr 2023 14:53:44 -0300 Subject: [PATCH] feat: added agents tests --- tests/test_agents_template.py | 176 ++++++++++++++++++++++++++++++++++ tests/test_chains_template.py | 9 ++ 2 files changed, 185 insertions(+) create mode 100644 tests/test_agents_template.py diff --git a/tests/test_agents_template.py b/tests/test_agents_template.py new file mode 100644 index 000000000..d63365fdc --- /dev/null +++ b/tests/test_agents_template.py @@ -0,0 +1,176 @@ +from fastapi.testclient import TestClient +from langflow.settings import settings + + +# check that all agents are in settings.agents +# are in json_response["agents"] +def test_agents_settings(client: TestClient): + response = client.get("/all") + assert response.status_code == 200 + json_response = response.json() + agents = json_response["agents"] + assert set(agents.keys()) == set(settings.agents) + + +def test_zero_shot_agent(client: TestClient): + response = client.get("/all") + assert response.status_code == 200 + json_response = response.json() + agents = json_response["agents"] + + zero_shot_agent = agents["ZeroShotAgent"] + assert set(zero_shot_agent["base_classes"]) == { + "ZeroShotAgent", + "BaseSingleActionAgent", + "Agent", + "function", + } + template = zero_shot_agent["template"] + + assert template["llm_chain"] == { + "required": True, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "llm_chain", + "type": "LLMChain", + "list": False, + } + assert template["allowed_tools"] == { + "required": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "allowed_tools", + "type": "Tool", + "list": True, + } + + +def test_json_agent(client: TestClient): + response = client.get("/all") + assert response.status_code == 200 + json_response = response.json() + agents = json_response["agents"] + + json_agent = agents["JsonAgent"] + assert json_agent["base_classes"] == ["AgentExecutor"] + template = json_agent["template"] + + assert template["toolkit"] == { + "required": True, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "toolkit", + "type": "BaseToolkit", + "list": False, + } + assert template["llm"] == { + "required": True, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "llm", + "type": "BaseLanguageModel", + "list": False, + } + + +def test_csv_agent(client: TestClient): + response = client.get("/all") + assert response.status_code == 200 + json_response = response.json() + agents = json_response["agents"] + + csv_agent = agents["CSVAgent"] + assert csv_agent["base_classes"] == ["AgentExecutor"] + template = csv_agent["template"] + + assert template["path"] == { + "required": True, + "placeholder": "", + "show": True, + "multiline": False, + "value": "", + "suffixes": [".csv"], + "fileTypes": ["csv"], + "password": False, + "name": "path", + "type": "file", + "list": False, + "content": None, + } + assert template["llm"] == { + "required": True, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "llm", + "type": "BaseLanguageModel", + "list": False, + } + + +def test_initialize_agent(client: TestClient): + response = client.get("/all") + assert response.status_code == 200 + json_response = response.json() + agents = json_response["agents"] + + initialize_agent = agents["initialize_agent"] + assert initialize_agent["base_classes"] == ["AgentExecutor"] + template = initialize_agent["template"] + + assert template["agent"] == { + "required": True, + "placeholder": "", + "show": True, + "multiline": False, + "value": "zero-shot-react-description", + "password": False, + "options": [ + "zero-shot-react-description", + "react-docstore", + "self-ask-with-search", + "conversational-react-description", + ], + "name": "agent", + "type": "str", + "list": True, + } + assert template["memory"] == { + "required": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "memory", + "type": "BaseChatMemory", + "list": False, + } + assert template["tools"] == { + "required": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "tools", + "type": "Tool", + "list": True, + } + assert template["llm"] == { + "required": True, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "llm", + "type": "BaseLanguageModel", + "list": False, + } diff --git a/tests/test_chains_template.py b/tests/test_chains_template.py index 6b692ee7d..cbecfdc15 100644 --- a/tests/test_chains_template.py +++ b/tests/test_chains_template.py @@ -1,4 +1,13 @@ from fastapi.testclient import TestClient +from langflow.settings import settings + + +def test_chains_settings(client: TestClient): + response = client.get("/all") + assert response.status_code == 200 + json_response = response.json() + chains = json_response["chains"] + assert set(chains.keys()) == set(settings.chains) # Test the ConversationChain object