From e4f4401d751db8e6083a617a9afe4b4f5540cdf7 Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Tue, 4 Jun 2024 09:25:27 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9D=20(tests/conftest.py):=20Remove=20?= =?UTF-8?q?duplicate=20imports=20and=20organize=20imports=20for=20better?= =?UTF-8?q?=20readability=20=E2=99=BB=EF=B8=8F=20(tests/test=5Fcustom=5Fco?= =?UTF-8?q?mponent.py):=20Refactor=20CustomComponent=20to=20Component=20fo?= =?UTF-8?q?r=20better=20naming=20consistency=20=E2=99=BB=EF=B8=8F=20(tests?= =?UTF-8?q?/test=5Fendpoints.py):=20Refactor=20test=20functions=20to=20imp?= =?UTF-8?q?rove=20readability=20and=20maintainability=20by=20simplifying?= =?UTF-8?q?=20assertions=20and=20organizing=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/conftest.py | 7 ++++--- tests/test_custom_component.py | 4 ++-- tests/test_endpoints.py | 28 +++++++++++++++++----------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6e12c56f2..076babde6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,10 @@ import orjson import pytest from fastapi.testclient import TestClient from httpx import AsyncClient +from sqlmodel import Session, SQLModel, create_engine, select +from sqlmodel.pool import StaticPool +from typer.testing import CliRunner + from langflow.graph.graph.base import Graph from langflow.initial_setup.setup import STARTER_FOLDER_NAME from langflow.services.auth.utils import get_password_hash @@ -21,9 +25,6 @@ from langflow.services.database.models.folder.model import Folder from langflow.services.database.models.user.model import User, UserCreate from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service -from sqlmodel import Session, SQLModel, create_engine, select -from sqlmodel.pool import StaticPool -from typer.testing import CliRunner if TYPE_CHECKING: from langflow.services.database.service import DatabaseService diff --git a/tests/test_custom_component.py b/tests/test_custom_component.py index 929a4712e..cf8964c86 100644 --- a/tests/test_custom_component.py +++ b/tests/test_custom_component.py @@ -4,7 +4,7 @@ from uuid import uuid4 import pytest from langchain_core.documents import Document -from langflow.custom import CustomComponent +from langflow.custom import CustomComponent, Component from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError from langflow.custom.custom_component.base_component import BaseComponent, ComponentCodeNullError from langflow.custom.utils import build_custom_component_template @@ -15,7 +15,7 @@ from langflow.services.database.models.flow import Flow, FlowCreate def code_component_with_multiple_outputs(): with open("tests/data/component_multiple_outputs.py", "r") as f: code = f.read() - return CustomComponent(code=code) + return Component(code=code) code_default = """ diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index d23cfd06e..1557a7510 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -4,6 +4,7 @@ from uuid import UUID, uuid4 import pytest from fastapi import status from fastapi.testclient import TestClient + from langflow.custom.directory_reader.directory_reader import DirectoryReader from langflow.services.deps import get_settings_service @@ -445,9 +446,10 @@ def test_successful_run_no_payload(client, starter_project, created_api_key): assert all(["ChatOutput" in _id for _id in ids]) display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")] assert all([name in display_names for name in ["Chat Output"]]) - inner_results = [output.get("results").get("result") for output in outputs_dict.get("outputs")] + output_results_has_results = all("results" in output.get("results") for output in outputs_dict.get("outputs")) + inner_results = [output.get("results") for output in outputs_dict.get("outputs")] - assert all([result is not None for result in inner_results]), inner_results + assert all([result is not None for result in inner_results]), (outputs_dict, output_results_has_results) def test_successful_run_with_output_type_text(client, starter_project, created_api_key): @@ -475,9 +477,9 @@ def test_successful_run_with_output_type_text(client, starter_project, created_a assert all(["ChatOutput" in _id for _id in ids]), ids display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")] assert all([name in display_names for name in ["Chat Output"]]), display_names - inner_results = [output.get("results").get("result") for output in outputs_dict.get("outputs")] - expected_result = "" - assert all([expected_result in result for result in inner_results]), inner_results + inner_results = [output.get("results") for output in outputs_dict.get("outputs")] + expected_keys = ["Record", "Message"] + assert all([key in result for result in inner_results for key in expected_keys]), outputs_dict def test_successful_run_with_output_type_any(client, starter_project, created_api_key): @@ -506,9 +508,9 @@ def test_successful_run_with_output_type_any(client, starter_project, created_ap assert all(["ChatOutput" in _id or "TextOutput" in _id for _id in ids]), ids display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")] assert all([name in display_names for name in ["Chat Output"]]), display_names - inner_results = [output.get("results").get("result") for output in outputs_dict.get("outputs")] - expected_result = "" - assert all([expected_result in result for result in inner_results]), inner_results + inner_results = [output.get("results") for output in outputs_dict.get("outputs")] + expected_keys = ["Record", "Message"] + assert all([key in result for result in inner_results for key in expected_keys]), outputs_dict def test_successful_run_with_output_type_debug(client, starter_project, created_api_key): @@ -564,7 +566,7 @@ def test_successful_run_with_input_type_text(client, starter_project, created_ap text_input_outputs = [output for output in outputs_dict.get("outputs") if "TextInput" in output.get("component_id")] assert len(text_input_outputs) == 0 # Now we check if the input_value is correct - assert all([output.get("results").get("result") == "value1" for output in text_input_outputs]), text_input_outputs + assert all([output.get("results") == "value1" for output in text_input_outputs]), text_input_outputs # Now do the same for "chat" input type @@ -595,7 +597,9 @@ def test_successful_run_with_input_type_chat(client, starter_project, created_ap chat_input_outputs = [output for output in outputs_dict.get("outputs") if "ChatInput" in output.get("component_id")] assert len(chat_input_outputs) == 1 # Now we check if the input_value is correct - assert all([output.get("results").get("result") == "value1" for output in chat_input_outputs]), chat_input_outputs + assert all( + [output.get("results").get("Message").get("result") == "value1" for output in chat_input_outputs] + ), chat_input_outputs def test_successful_run_with_input_type_any(client, starter_project, created_api_key): @@ -629,7 +633,9 @@ def test_successful_run_with_input_type_any(client, starter_project, created_api ] assert len(any_input_outputs) == 1 # Now we check if the input_value is correct - assert all([output.get("results").get("result") == "value1" for output in any_input_outputs]), any_input_outputs + assert all( + [output.get("results").get("Message").get("result") == "value1" for output in any_input_outputs] + ), any_input_outputs def test_run_with_inputs_and_outputs(client, starter_project, created_api_key):