📝 (tests/conftest.py): Remove duplicate imports and organize imports for better readability
♻️ (tests/test_custom_component.py): Refactor CustomComponent to Component for better naming consistency ♻️ (tests/test_endpoints.py): Refactor test functions to improve readability and maintainability by simplifying assertions and organizing code
This commit is contained in:
parent
433ea80ab6
commit
e4f4401d75
3 changed files with 23 additions and 16 deletions
|
|
@ -12,6 +12,10 @@ import orjson
|
|||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
from sqlmodel.pool import StaticPool
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.initial_setup.setup import STARTER_FOLDER_NAME
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
|
|
@ -21,9 +25,6 @@ from langflow.services.database.models.folder.model import Folder
|
|||
from langflow.services.database.models.user.model import User, UserCreate
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
from sqlmodel.pool import StaticPool
|
||||
from typer.testing import CliRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.service import DatabaseService
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from uuid import uuid4
|
|||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.custom import CustomComponent, Component
|
||||
from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
|
||||
from langflow.custom.custom_component.base_component import BaseComponent, ComponentCodeNullError
|
||||
from langflow.custom.utils import build_custom_component_template
|
||||
|
|
@ -15,7 +15,7 @@ from langflow.services.database.models.flow import Flow, FlowCreate
|
|||
def code_component_with_multiple_outputs():
|
||||
with open("tests/data/component_multiple_outputs.py", "r") as f:
|
||||
code = f.read()
|
||||
return CustomComponent(code=code)
|
||||
return Component(code=code)
|
||||
|
||||
|
||||
code_default = """
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from uuid import UUID, uuid4
|
|||
import pytest
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from langflow.custom.directory_reader.directory_reader import DirectoryReader
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
|
|
@ -445,9 +446,10 @@ def test_successful_run_no_payload(client, starter_project, created_api_key):
|
|||
assert all(["ChatOutput" in _id for _id in ids])
|
||||
display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")]
|
||||
assert all([name in display_names for name in ["Chat Output"]])
|
||||
inner_results = [output.get("results").get("result") for output in outputs_dict.get("outputs")]
|
||||
output_results_has_results = all("results" in output.get("results") for output in outputs_dict.get("outputs"))
|
||||
inner_results = [output.get("results") for output in outputs_dict.get("outputs")]
|
||||
|
||||
assert all([result is not None for result in inner_results]), inner_results
|
||||
assert all([result is not None for result in inner_results]), (outputs_dict, output_results_has_results)
|
||||
|
||||
|
||||
def test_successful_run_with_output_type_text(client, starter_project, created_api_key):
|
||||
|
|
@ -475,9 +477,9 @@ def test_successful_run_with_output_type_text(client, starter_project, created_a
|
|||
assert all(["ChatOutput" in _id for _id in ids]), ids
|
||||
display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")]
|
||||
assert all([name in display_names for name in ["Chat Output"]]), display_names
|
||||
inner_results = [output.get("results").get("result") for output in outputs_dict.get("outputs")]
|
||||
expected_result = ""
|
||||
assert all([expected_result in result for result in inner_results]), inner_results
|
||||
inner_results = [output.get("results") for output in outputs_dict.get("outputs")]
|
||||
expected_keys = ["Record", "Message"]
|
||||
assert all([key in result for result in inner_results for key in expected_keys]), outputs_dict
|
||||
|
||||
|
||||
def test_successful_run_with_output_type_any(client, starter_project, created_api_key):
|
||||
|
|
@ -506,9 +508,9 @@ def test_successful_run_with_output_type_any(client, starter_project, created_ap
|
|||
assert all(["ChatOutput" in _id or "TextOutput" in _id for _id in ids]), ids
|
||||
display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")]
|
||||
assert all([name in display_names for name in ["Chat Output"]]), display_names
|
||||
inner_results = [output.get("results").get("result") for output in outputs_dict.get("outputs")]
|
||||
expected_result = ""
|
||||
assert all([expected_result in result for result in inner_results]), inner_results
|
||||
inner_results = [output.get("results") for output in outputs_dict.get("outputs")]
|
||||
expected_keys = ["Record", "Message"]
|
||||
assert all([key in result for result in inner_results for key in expected_keys]), outputs_dict
|
||||
|
||||
|
||||
def test_successful_run_with_output_type_debug(client, starter_project, created_api_key):
|
||||
|
|
@ -564,7 +566,7 @@ def test_successful_run_with_input_type_text(client, starter_project, created_ap
|
|||
text_input_outputs = [output for output in outputs_dict.get("outputs") if "TextInput" in output.get("component_id")]
|
||||
assert len(text_input_outputs) == 0
|
||||
# Now we check if the input_value is correct
|
||||
assert all([output.get("results").get("result") == "value1" for output in text_input_outputs]), text_input_outputs
|
||||
assert all([output.get("results") == "value1" for output in text_input_outputs]), text_input_outputs
|
||||
|
||||
|
||||
# Now do the same for "chat" input type
|
||||
|
|
@ -595,7 +597,9 @@ def test_successful_run_with_input_type_chat(client, starter_project, created_ap
|
|||
chat_input_outputs = [output for output in outputs_dict.get("outputs") if "ChatInput" in output.get("component_id")]
|
||||
assert len(chat_input_outputs) == 1
|
||||
# Now we check if the input_value is correct
|
||||
assert all([output.get("results").get("result") == "value1" for output in chat_input_outputs]), chat_input_outputs
|
||||
assert all(
|
||||
[output.get("results").get("Message").get("result") == "value1" for output in chat_input_outputs]
|
||||
), chat_input_outputs
|
||||
|
||||
|
||||
def test_successful_run_with_input_type_any(client, starter_project, created_api_key):
|
||||
|
|
@ -629,7 +633,9 @@ def test_successful_run_with_input_type_any(client, starter_project, created_api
|
|||
]
|
||||
assert len(any_input_outputs) == 1
|
||||
# Now we check if the input_value is correct
|
||||
assert all([output.get("results").get("result") == "value1" for output in any_input_outputs]), any_input_outputs
|
||||
assert all(
|
||||
[output.get("results").get("Message").get("result") == "value1" for output in any_input_outputs]
|
||||
), any_input_outputs
|
||||
|
||||
|
||||
def test_run_with_inputs_and_outputs(client, starter_project, created_api_key):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue