ref: Add ALL ruff rules for tests (#4183)

Add ALL ruff rules for tests
This commit is contained in:
Christophe Bornet 2024-10-19 22:41:37 +02:00 committed by GitHub
commit f96f2eaf8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
67 changed files with 421 additions and 361 deletions

View file

@ -216,6 +216,43 @@ directory = "coverage"
exclude = ["src/backend/langflow/alembic/*"]
line-length = 120
[tool.ruff.lint]
pydocstyle.convention = "google"
select = ["ALL"]
ignore = [
"C90", # McCabe complexity
"CPY", # Missing copyright
"COM812", # Messes with the formatter
"ERA", # Eradicate commented-out code
"FIX002", # Line contains TODO
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
"PLR09", # Too many something (arg, statements, etc)
"RUF012", # Pydantic models are currently not well detected. See https://github.com/astral-sh/ruff/issues/13630
"TD002", # Missing author in TODO
"TD003", # Missing issue link in TODO
"TRY301", # A bit too harsh (Abstract `raise` to an inner function)
# Rules that are TODOs
"ANN",
]
# Preview rules that are not yet activated
external = ["RUF027"]
[tool.ruff.lint.per-file-ignores]
"scripts/*" = [
"D1",
"INP",
"T201",
]
"src/backend/tests/*" = [
"D1",
"PLR2004",
"S101",
"SLF001",
]
[tool.mypy]
plugins = ["pydantic.mypy"]
follow_imports = "skip"

View file

@ -12,8 +12,10 @@ PYPI_LANGFLOW_NIGHTLY_URL = "https://pypi.org/pypi/langflow-nightly/json"
PYPI_LANGFLOW_BASE_URL = "https://pypi.org/pypi/langflow-base/json"
PYPI_LANGFLOW_BASE_NIGHTLY_URL = "https://pypi.org/pypi/langflow-base-nightly/json"
ARGUMENT_NUMBER = 2
def get_latest_published_version(build_type: str, is_nightly: bool) -> Version:
def get_latest_published_version(build_type: str, *, is_nightly: bool) -> Version:
import requests
url = ""
@ -25,12 +27,12 @@ def get_latest_published_version(build_type: str, is_nightly: bool) -> Version:
msg = f"Invalid build type: {build_type}"
raise ValueError(msg)
res = requests.get(url)
res = requests.get(url, timeout=10)
try:
version_str = res.json()["info"]["version"]
except Exception as e:
msg = "Got unexpected response from PyPI"
raise RuntimeError(msg, e)
raise RuntimeError(msg) from e
return Version(version_str)
@ -74,9 +76,9 @@ def create_tag(build_type: str):
if __name__ == "__main__":
if len(sys.argv) != 2:
if len(sys.argv) != ARGUMENT_NUMBER:
msg = "Specify base or main"
raise Exception(msg)
raise ValueError(msg)
build_type = sys.argv[1]
tag = create_tag(build_type)

View file

@ -1,3 +1,5 @@
#!/usr/bin/env python
import re
import sys
from pathlib import Path
@ -5,6 +7,7 @@ from pathlib import Path
import packaging.version
BASE_DIR = Path(__file__).parent.parent.parent
ARGUMENT_NUMBER = 2
def update_base_dep(pyproject_path: str, new_version: str) -> None:
@ -18,7 +21,7 @@ def update_base_dep(pyproject_path: str, new_version: str) -> None:
pattern = re.compile(r'langflow-base = \{ path = "\./src/backend/base", develop = true \}')
if not pattern.search(content):
msg = f'langflow-base poetry dependency not found in "{filepath}"'
raise Exception(msg)
raise ValueError(msg)
content = pattern.sub(replacement, content)
filepath.write_text(content, encoding="utf-8")
@ -28,16 +31,13 @@ def verify_pep440(version):
https://github.com/pypa/packaging/blob/16.7/packaging/version.py#L191
"""
try:
return packaging.version.Version(version)
except packaging.version.InvalidVersion:
raise
return packaging.version.Version(version)
def main() -> None:
if len(sys.argv) != 2:
if len(sys.argv) != ARGUMENT_NUMBER:
msg = "New version not specified"
raise Exception(msg)
raise ValueError(msg)
base_version = sys.argv[1]
# Strip "v" prefix from version if present

View file

@ -1,8 +1,11 @@
#!/usr/bin/env python
import re
import sys
from pathlib import Path
BASE_DIR = Path(__file__).parent.parent.parent
ARGUMENT_NUMBER = 3
def update_pyproject_name(pyproject_path: str, new_project_name: str) -> None:
@ -15,7 +18,7 @@ def update_pyproject_name(pyproject_path: str, new_project_name: str) -> None:
if not pattern.search(content):
msg = f'Project name not found in "{filepath}"'
raise Exception(msg)
raise ValueError(msg)
content = pattern.sub(new_project_name, content)
filepath.write_text(content, encoding="utf-8")
@ -39,15 +42,15 @@ def update_uv_dep(pyproject_path: str, new_project_name: str) -> None:
# Updates the dependency name for uv
if not pattern.search(content):
msg = f"{replacement} uv dependency not found in {filepath}"
raise Exception(msg)
raise ValueError(msg)
content = pattern.sub(replacement, content)
filepath.write_text(content, encoding="utf-8")
def main() -> None:
if len(sys.argv) != 3:
if len(sys.argv) != ARGUMENT_NUMBER:
msg = "Must specify project name and build type, e.g. langflow-nightly base"
raise Exception(msg)
raise ValueError(msg)
new_project_name = sys.argv[1]
build_type = sys.argv[2]

View file

@ -1,3 +1,5 @@
#!/usr/bin/env python
import re
import sys
from pathlib import Path
@ -5,6 +7,7 @@ from pathlib import Path
import packaging.version
BASE_DIR = Path(__file__).parent.parent.parent
ARGUMENT_NUMBER = 3
def update_pyproject_version(pyproject_path: str, new_version: str) -> None:
@ -17,7 +20,7 @@ def update_pyproject_version(pyproject_path: str, new_version: str) -> None:
if not pattern.search(content):
msg = f'Project version not found in "{filepath}"'
raise Exception(msg)
raise ValueError(msg)
content = pattern.sub(new_version, content)
@ -29,16 +32,13 @@ def verify_pep440(version):
https://github.com/pypa/packaging/blob/16.7/packaging/version.py#L191
"""
try:
return packaging.version.Version(version)
except packaging.version.InvalidVersion:
raise
return packaging.version.Version(version)
def main() -> None:
if len(sys.argv) != 3:
if len(sys.argv) != ARGUMENT_NUMBER:
msg = "New version not specified"
raise Exception(msg)
raise ValueError(msg)
new_version = sys.argv[1]
# Strip "v" prefix from version if present

View file

@ -1,8 +1,11 @@
#!/usr/bin/env python
import re
import sys
from pathlib import Path
BASE_DIR = Path(__file__).parent.parent.parent
ARGUMENT_NUMBER = 2
def update_uv_dep(base_version: str) -> None:
@ -19,7 +22,7 @@ def update_uv_dep(base_version: str) -> None:
# Check if the pattern is found
if not pattern.search(content):
msg = f"{pattern} UV dependency not found in {pyproject_path}"
raise Exception(msg)
raise ValueError(msg)
# Replace the matched pattern with the new one
content = pattern.sub(replacement, content)
@ -29,9 +32,9 @@ def update_uv_dep(base_version: str) -> None:
def main() -> None:
if len(sys.argv) != 2:
if len(sys.argv) != ARGUMENT_NUMBER:
msg = "specify base version"
raise Exception(msg)
raise ValueError(msg)
base_version = sys.argv[1]
base_version = base_version.lstrip("v")
update_uv_dep(base_version)

View file

@ -12,7 +12,7 @@ from pydantic import BaseModel
from langflow.base.tools.constants import TOOL_OUTPUT_NAME
from langflow.custom.tree_visitor import RequiredInputsVisitor
from langflow.field_typing import Tool # noqa: TCH001 Needed by add_toolkit_output
from langflow.field_typing import Tool # noqa: TCH001 Needed by _add_toolkit_output
from langflow.graph.state.model import create_state_model
from langflow.helpers.custom import format_type
from langflow.schema.artifact import get_artifact_type, post_process_raw

View file

@ -0,0 +1 @@
"""Version package."""

View file

@ -1,8 +1,11 @@
"""Module for package versioning."""
import contextlib
def get_version() -> str:
"""Retrieves the version of the package from a possible list of package names.
This accounts for after package names are updated for -nightly builds.
Returns:
@ -32,7 +35,9 @@ def get_version() -> str:
def is_pre_release(v: str) -> bool:
"""Returns a boolean indicating whether the version is a pre-release version,
"""Returns a boolean indicating whether the version is a pre-release version.
Returns a boolean indicating whether the version is a pre-release version,
as per the definition of a pre-release segment from PEP 440.
"""
return any(label in v for label in ["a", "b", "rc"])

View file

@ -28,7 +28,6 @@ from langflow.services.database.models.vertex_builds.crud import delete_vertex_b
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from loguru import logger
from pytest import LogCaptureFixture
from sqlmodel import Session, SQLModel, create_engine, select
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
@ -102,7 +101,7 @@ def _delete_transactions_and_vertex_builds(session, user: User):
@pytest.fixture
def caplog(caplog: LogCaptureFixture):
def caplog(caplog: pytest.LogCaptureFixture):
handler_id = logger.add(
caplog.handler,
format="{message}",
@ -144,7 +143,7 @@ def load_flows_dir():
@pytest.fixture(name="distributed_env")
def setup_env(monkeypatch):
def _setup_env(monkeypatch):
monkeypatch.setenv("LANGFLOW_CACHE_TYPE", "redis")
monkeypatch.setenv("LANGFLOW_REDIS_HOST", "result_backend")
monkeypatch.setenv("LANGFLOW_REDIS_PORT", "6379")
@ -158,7 +157,11 @@ def setup_env(monkeypatch):
@pytest.fixture(name="distributed_client")
def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
def distributed_client_fixture(
session: Session, # noqa: ARG001
monkeypatch,
distributed_env, # noqa: ARG001
):
# Here we load the .env from ../deploy/.env
from langflow.core import celery_app
@ -273,7 +276,12 @@ def json_memory_chatbot_no_llm():
@pytest.fixture(name="client")
async def client_fixture(session: Session, monkeypatch, request, load_flows_dir):
async def client_fixture(
session: Session, # noqa: ARG001
monkeypatch,
request,
load_flows_dir,
):
# Set the database url to a test database
if "noclient" in request.keywords:
yield
@ -296,9 +304,11 @@ async def client_fixture(session: Session, monkeypatch, request, load_flows_dir)
db_service.database_url = f"sqlite:///{db_path}"
db_service.reload_engine()
# app.dependency_overrides[get_session] = get_session_override
async with LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager:
async with AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client:
yield client
async with (
LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager,
AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client,
):
yield client
# app.dependency_overrides.clear()
monkeypatch.undo()
# clear the temp db
@ -308,7 +318,7 @@ async def client_fixture(session: Session, monkeypatch, request, load_flows_dir)
# create a fixture for session_getter above
@pytest.fixture(name="session_getter")
def session_getter_fixture(client):
def session_getter_fixture(client): # noqa: ARG001
@contextmanager
def blank_session_getter(db_service: "DatabaseService"):
with Session(db_service.engine) as session:
@ -326,7 +336,7 @@ def runner():
async def test_user(client):
user_data = UserCreate(
username="testuser",
password="testpassword",
password="testpassword", # noqa: S106
)
response = await client.post("api/v1/users/", json=user_data.model_dump())
assert response.status_code == 201
@ -337,7 +347,7 @@ async def test_user(client):
@pytest.fixture
def active_user(client):
def active_user(client): # noqa: ARG001
db_manager = get_db_service()
with db_manager.with_session() as session:
user = User(
@ -375,7 +385,11 @@ async def logged_in_headers(client, active_user):
@pytest.fixture
def flow(client, json_flow: str, active_user):
def flow(
client, # noqa: ARG001
json_flow: str,
active_user,
):
from langflow.services.database.models.flow.model import FlowCreate
loaded_json = json.loads(json_flow)

View file

@ -7,7 +7,7 @@ class TestComponent(CustomComponent):
def refresh_values(self):
# This is a function that will be called every time the component is updated
# and should return a list of random strings
return [f"Random {random.randint(1, 100)}" for _ in range(5)]
return [f"Random {random.randint(1, 100)}" for _ in range(5)] # noqa: S311
def build_config(self):
return {"param": {"display_name": "Param", "options": self.refresh_values}}

View file

@ -16,7 +16,7 @@ class MultipleOutputsComponent(Component):
]
def certain_output(self) -> int:
return randint(0, self.number)
return randint(0, self.number) # noqa: S311
def other_output(self) -> int:
return self.certain_output()

View file

@ -8,7 +8,7 @@ class TestComponent(CustomComponent):
def refresh_values(self):
# This is a function that will be called every time the component is updated
# and should return a list of random strings
return [f"Random {random.randint(1, 100)}" for _ in range(5)]
return [f"Random {random.randint(1, 100)}" for _ in range(5)] # noqa: S311
def build_config(self):
return {"param": Input(display_name="Param", options=self.refresh_values)}

View file

@ -1,5 +1,6 @@
import pytest
from langflow.schema.message import Message
from tests.api_keys import get_openai_api_key
from tests.integration.utils import download_flow_from_github, run_json_flow

View file

@ -1,4 +1,5 @@
import pytest
from tests.integration.utils import run_single_component

View file

@ -6,6 +6,7 @@ from langchain_core.documents import Document
from langflow.components.embeddings import OpenAIEmbeddingsComponent
from langflow.components.vectorstores import AstraVectorStoreComponent
from langflow.schema.data import Data
from tests.api_keys import get_astradb_api_endpoint, get_astradb_application_token, get_openai_api_key
from tests.integration.components.mock_components import TextToData
from tests.integration.utils import ComponentInputHandle, run_single_component
@ -27,7 +28,7 @@ ALL_COLLECTIONS = [
@pytest.fixture
def astradb_client(request):
def astradb_client():
client = AstraDB(api_endpoint=get_astradb_api_endpoint(), token=get_astradb_application_token())
yield client
for collection in ALL_COLLECTIONS:

View file

@ -2,6 +2,7 @@ import pytest
from langflow.components.helpers.ParseJSONData import ParseJSONDataComponent
from langflow.components.inputs import ChatInput
from langflow.schema import Data
from tests.integration.components.mock_components import TextToData
from tests.integration.utils import ComponentInputHandle, run_single_component

View file

@ -2,6 +2,7 @@ import pytest
from langflow.components.inputs import ChatInput
from langflow.memory import get_messages
from langflow.schema.message import Message
from tests.integration.utils import run_single_component

View file

@ -1,6 +1,7 @@
import pytest
from langflow.components.inputs import TextInputComponent
from langflow.schema.message import Message
from tests.integration.utils import run_single_component

View file

@ -4,6 +4,7 @@ import pytest
from langflow.components.models.OpenAIModel import OpenAIModelComponent
from langflow.components.output_parsers.OutputParser import OutputParserComponent
from langflow.components.prompts.Prompt import PromptComponent
from tests.integration.utils import ComponentInputHandle, run_single_component

View file

@ -2,6 +2,7 @@ import pytest
from langflow.components.outputs import ChatOutput
from langflow.memory import get_messages
from langflow.schema.message import Message
from tests.integration.utils import run_single_component

View file

@ -1,6 +1,7 @@
import pytest
from langflow.components.outputs import TextOutputComponent
from langflow.schema.message import Message
from tests.integration.utils import run_single_component

View file

@ -1,6 +1,7 @@
import pytest
from langflow.components.prompts import PromptComponent
from langflow.schema.message import Message
from tests.integration.utils import run_single_component

View file

@ -4,18 +4,19 @@ from langflow.components.outputs import ChatOutput
from langflow.components.prompts import PromptComponent
from langflow.graph import Graph
from langflow.schema.message import Message
from tests.integration.utils import run_flow
@pytest.mark.asyncio
async def test_simple_no_llm():
graph = Graph()
input = graph.add_component(ChatInput())
output = graph.add_component(ChatOutput())
flow_input = graph.add_component(ChatInput())
flow_output = graph.add_component(ChatOutput())
component = PromptComponent(template="This is the message: {var1}", var1="")
prompt = graph.add_component(component)
graph.add_component_edge(input, ("message", "var1"), prompt)
graph.add_component_edge(prompt, ("prompt", "input_value"), output)
graph.add_component_edge(flow_input, ("message", "var1"), prompt)
graph.add_component_edge(prompt, ("prompt", "input_value"), flow_output)
outputs = await run_flow(graph, run_input="hello!")
assert isinstance(outputs["message"], Message)
assert outputs["message"].text == "This is the message: hello!"

View file

@ -12,26 +12,26 @@ from langflow.graph import Graph
from langflow.processing.process import run_graph_internal
def check_env_vars(*vars):
def check_env_vars(*env_vars):
"""Check if all specified environment variables are set.
Args:
*vars (str): The environment variables to check.
*env_vars (str): The environment variables to check.
Returns:
bool: True if all environment variables are set, False otherwise.
"""
return all(os.getenv(var) for var in vars)
return all(os.getenv(var) for var in env_vars)
def valid_nvidia_vectorize_region(api_endpoint: str) -> bool:
"""Check if the specified region is valid.
Args:
region (str): The region to check.
api_endpoint: The API endpoint to check.
Returns:
bool: True if the region is contains hosted nvidia models, False otherwise.
True if the region contains hosted nvidia models, False otherwise.
"""
parsed_endpoint = parse_api_endpoint(api_endpoint)
if not parsed_endpoint:
@ -63,12 +63,12 @@ class JSONFlow:
json: dict
def get_components_by_type(self, component_type):
result = []
for node in self.json["data"]["nodes"]:
if node["data"]["type"] == component_type:
result.append(node["id"])
result = [node["id"] for node in self.json["data"]["nodes"] if node["data"]["type"] == component_type]
if not result:
msg = f"Component of type {component_type} not found, available types: {', '.join({node['data']['type'] for node in self.json['data']['nodes']})}"
msg = (
f"Component of type {component_type} not found, "
f"available types: {', '.join({node['data']['type'] for node in self.json['data']['nodes']})}"
)
raise ValueError(msg)
return result
@ -97,7 +97,8 @@ class JSONFlow:
def download_flow_from_github(name: str, version: str) -> JSONFlow:
response = requests.get(
f"https://raw.githubusercontent.com/langflow-ai/langflow/v{version}/src/backend/base/langflow/initial_setup/starter_projects/{name}.json"
f"https://raw.githubusercontent.com/langflow-ai/langflow/v{version}/src/backend/base/langflow/initial_setup/starter_projects/{name}.json",
timeout=10,
)
response.raise_for_status()
as_json = response.json()
@ -151,7 +152,7 @@ async def run_single_component(
raw_inputs[key] = value
if isinstance(value, Component):
msg = "Component inputs must be wrapped in ComponentInputHandle"
raise ValueError(msg)
raise TypeError(msg)
component = clazz(**raw_inputs, _user_id=user_id)
component_id = graph.add_component(component)
if inputs:

View file

@ -60,7 +60,7 @@ class NameTest(FastHttpUser):
@task
def send_name_and_check(self):
name = random.choice(self.names)
name = random.choice(self.names) # noqa: S311
payload1 = {
"inputs": {"text": f"Hello, My name is {name}"},

View file

@ -17,7 +17,10 @@ def test_get_suggestion_message():
# Test case 3: Multiple outdated components
outdated_components = ["component1", "component2", "component3"]
expected_message = "The flow contains 3 outdated components. We recommend updating the following components: component1, component2, component3."
expected_message = (
"The flow contains 3 outdated components. "
"We recommend updating the following components: component1, component2, component3."
)
assert get_suggestion_message(outdated_components) == expected_message

View file

@ -75,7 +75,7 @@ async def test_create_variable__variable_value_cannot_be_empty(client: AsyncClie
@pytest.mark.usefixtures("active_user")
async def test_create_variable__HTTPException(client: AsyncClient, body, logged_in_headers):
async def test_create_variable__httpexception(client: AsyncClient, body, logged_in_headers):
status_code = 418
generic_message = "I'm a teapot"
@ -89,7 +89,7 @@ async def test_create_variable__HTTPException(client: AsyncClient, body, logged_
@pytest.mark.usefixtures("active_user")
async def test_create_variable__Exception(client: AsyncClient, body, logged_in_headers):
async def test_create_variable__exception(client: AsyncClient, body, logged_in_headers):
generic_message = "Generic error message"
with mock.patch("langflow.services.auth.utils.encrypt_api_key") as m:
@ -133,16 +133,10 @@ async def test_read_variables__empty(client: AsyncClient, logged_in_headers):
async def test_read_variables__(client: AsyncClient, logged_in_headers):
generic_message = "Generic error message"
with pytest.raises(Exception) as exc, mock.patch("sqlmodel.Session.exec") as m:
with mock.patch("sqlmodel.Session.exec") as m:
m.side_effect = Exception(generic_message)
response = await client.get("api/v1/variables/", headers=logged_in_headers)
result = response.json()
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert generic_message in result["detail"]
assert generic_message in str(exc.value)
with pytest.raises(Exception, match=generic_message):
await client.get("api/v1/variables/", headers=logged_in_headers)
@pytest.mark.usefixtures("active_user")
@ -165,7 +159,7 @@ async def test_update_variable(client: AsyncClient, body, logged_in_headers):
@pytest.mark.usefixtures("active_user")
async def test_update_variable__Exception(client: AsyncClient, body, logged_in_headers):
async def test_update_variable__exception(client: AsyncClient, body, logged_in_headers):
wrong_id = uuid4()
body["id"] = str(wrong_id)
@ -186,7 +180,7 @@ async def test_delete_variable(client: AsyncClient, body, logged_in_headers):
@pytest.mark.usefixtures("active_user")
async def test_delete_variable__Exception(client: AsyncClient, logged_in_headers):
async def test_delete_variable__exception(client: AsyncClient, logged_in_headers):
wrong_id = uuid4()
response = await client.delete(f"api/v1/variables/{wrong_id}", headers=logged_in_headers)

View file

@ -26,4 +26,5 @@ def test_run_flow_from_json_params():
params = func_spec.args + func_spec.kwonlyargs
assert expected_params.issubset(params), "Not all expected parameters are present in run_flow_from_json"
# TODO: Add tests by loading a flow and running it need to text with fake llm and check if it returns the correct output
# TODO: Add tests by loading a flow and running it need to text with fake llm and check if it returns the
# correct output

View file

@ -12,7 +12,7 @@ from langflow.services.settings.feature_flags import FEATURE_FLAGS
@pytest.fixture
def add_toolkit_output():
def _add_toolkit_output():
FEATURE_FLAGS.add_toolkit_output = True
yield
FEATURE_FLAGS.add_toolkit_output = False
@ -81,7 +81,7 @@ def test_component_tool():
@pytest.mark.api_key_required
@pytest.mark.usefixtures("add_toolkit_output", "client")
@pytest.mark.usefixtures("_add_toolkit_output", "client")
def test_component_tool_with_api_key():
chat_output = ChatOutput()
openai_llm = OpenAIModelComponent()

View file

@ -1,34 +1,60 @@
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.language_models import BaseLanguageModel
from langflow.components.helpers.structured_output import StructuredOutputComponent
from langflow.schema.data import Data
from pydantic import BaseModel
@pytest.fixture
def client():
pass
from typing_extensions import override
class TestStructuredOutputComponent:
# Ensure that the structured output is successfully generated with the correct BaseModel instance returned by the mock function
# Ensure that the structured output is successfully generated with the correct BaseModel instance returned by
# the mock function
def test_successful_structured_output_generation_with_patch_with_config(self):
from unittest.mock import patch
class MockLanguageModel:
def with_structured_output(self, schema):
class MockLanguageModel(BaseLanguageModel):
@override
def with_structured_output(self, *args, **kwargs):
return self
def with_config(self, config):
@override
def with_config(self, *args, **kwargs):
return self
def invoke(self, inputs):
@override
def invoke(self, *args, **kwargs):
return self
def mock_get_chat_result(runnable, input_value, config):
@override
def generate_prompt(self, *args, **kwargs):
raise NotImplementedError
@override
async def agenerate_prompt(self, *args, **kwargs):
raise NotImplementedError
@override
def predict(self, *args, **kwargs):
raise NotImplementedError
@override
def predict_messages(self, *args, **kwargs):
raise NotImplementedError
@override
async def apredict(self, *args, **kwargs):
raise NotImplementedError
@override
async def apredict_messages(self, *args, **kwargs):
raise NotImplementedError
def mock_get_chat_result(runnable, input_value, config): # noqa: ARG001
class MockBaseModel(BaseModel):
def model_dump(self):
@override
def model_dump(self, **kwargs):
return {"field": "value"}
return MockBaseModel()

View file

@ -112,7 +112,7 @@ def test_update_build_config_keep_alive(component):
"langchain_community.chat_models.ChatOllama",
return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"),
)
def test_build_model(_mock_chat_ollama, component):
def test_build_model(_mock_chat_ollama, component): # noqa: PT019
component.base_url = "http://localhost:11434"
component.model_name = "llama3.1"
component.mirostat = "Mirostat 2.0"

View file

@ -1,4 +1,4 @@
from langflow.components.prompts.Prompt import PromptComponent # type: ignore
from langflow.components.prompts.Prompt import PromptComponent
class TestPromptComponent:

View file

@ -109,9 +109,5 @@ def test_validate_text_key_invalid(create_data_component):
create_data_component.text_key = "invalid_key"
# Act & Assert
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValueError, match="Text Key: 'invalid_key' not found in the Data keys: 'key1, key2'"):
create_data_component.validate_text_key()
# Check for the exact error message
expected_error_message = f"Text Key: '{create_data_component.text_key}' not found in the Data keys: '{', '.join(create_data_component.get_data().keys())}'"
assert str(exc_info.value) == expected_error_message

View file

@ -96,10 +96,5 @@ def test_validate_text_key_invalid(update_data_component):
data = Data(data={"key1": "value1", "key2": "value2"}, text_key="key1")
update_data_component.text_key = "invalid_key"
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValueError, match="Text Key: invalid_key not found in the Data keys: key1,key2"):
update_data_component.validate_text_key(data)
expected_error_message = (
f"Text Key: {update_data_component.text_key} not found in the Data keys: {','.join(data.data.keys())}"
)
assert str(exc_info.value) == expected_error_message

View file

@ -11,7 +11,7 @@ from langflow.template.field.base import Output
def test_set_invalid_output():
chatinput = ChatInput()
chatoutput = ChatOutput()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Method build_config is not a valid output of ChatInput"):
chatoutput.set(input_value=chatinput.build_config)

View file

@ -90,9 +90,9 @@ class TestEventManager:
queue = asyncio.Queue()
manager = EventManager(queue)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Event name cannot be empty"):
manager.register_event("", "test_type", mock_callback)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Event name must start with 'on_'"):
manager.register_event("invalid_name", "test_type", mock_callback)
# Sending an event with complex data and verifying successful event transmission

View file

@ -16,28 +16,32 @@ def test_api_exception():
}
# Expected result
with patch(
"langflow.services.database.models.flow.utils.get_outdated_components", return_value=mock_outdated_components
with (
patch(
"langflow.services.database.models.flow.utils.get_outdated_components",
return_value=mock_outdated_components,
),
patch("langflow.api.utils.get_suggestion_message", return_value=mock_suggestion_message),
patch(
"langflow.services.database.models.flow.utils.get_components_versions",
return_value=mock_component_versions,
),
):
with patch("langflow.api.utils.get_suggestion_message", return_value=mock_suggestion_message):
with patch(
"langflow.services.database.models.flow.utils.get_components_versions",
return_value=mock_component_versions,
):
# Create an APIException instance
api_exception = APIException(mock_exception, mock_flow)
# Create an APIException instance
api_exception = APIException(mock_exception, mock_flow)
# Expected body
expected_body = ExceptionBody(
message="Test exception",
suggestion="The flow contains 2 outdated components. We recommend updating the following components: component1, component2.",
)
# Expected body
expected_body = ExceptionBody(
message="Test exception",
suggestion="The flow contains 2 outdated components. "
"We recommend updating the following components: component1, component2.",
)
# Assert the status code
assert api_exception.status_code == 500
# Assert the status code
assert api_exception.status_code == 500
# Assert the detail
assert api_exception.detail == expected_body.model_dump_json()
# Assert the detail
assert api_exception.detail == expected_body.model_dump_json()
def test_api_exception_no_flow():

View file

@ -22,27 +22,27 @@ class TestCreateStateModel:
# Successfully create a model with valid method return type annotations
def test_create_model_with_valid_return_type_annotations(self, chat_input_component):
StateModel = create_state_model(method_one=chat_input_component.message_response)
state_model = create_state_model(method_one=chat_input_component.message_response)
state_instance = StateModel()
state_instance = state_model()
assert state_instance.method_one is UNDEFINED
chat_input_component.set_output_value("message", "test")
assert state_instance.method_one == "test"
def test_create_model_and_assign_values_fails(self, chat_input_component):
StateModel = create_state_model(method_one=chat_input_component.message_response)
state_model = create_state_model(method_one=chat_input_component.message_response)
state_instance = StateModel()
state_instance = state_model()
state_instance.method_one = "test"
assert state_instance.method_one == "test"
def test_create_with_multiple_components(self, chat_input_component, chat_output_component):
NewStateModel = create_state_model(
new_state_model = create_state_model(
model_name="NewStateModel",
first_method=chat_input_component.message_response,
second_method=chat_output_component.message_response,
)
state_instance = NewStateModel()
state_instance = new_state_model()
assert state_instance.first_method is UNDEFINED
assert state_instance.second_method is UNDEFINED
state_instance.first_method = "test"
@ -51,9 +51,9 @@ class TestCreateStateModel:
assert state_instance.second_method == 123
def test_create_with_pydantic_field(self, chat_input_component):
StateModel = create_state_model(method_one=chat_input_component.message_response, my_attribute=Field(None))
state_model = create_state_model(method_one=chat_input_component.message_response, my_attribute=Field(None))
state_instance = StateModel()
state_instance = state_model()
state_instance.method_one = "test"
state_instance.my_attribute = "test"
assert state_instance.method_one == "test"
@ -64,8 +64,8 @@ class TestCreateStateModel:
# Creates a model with fields based on provided keyword arguments
def test_create_model_with_fields_from_kwargs(self):
StateModel = create_state_model(field_one=(str, "default"), field_two=(int, 123))
state_instance = StateModel()
state_model = create_state_model(field_one=(str, "default"), field_two=(int, 123))
state_instance = state_model()
assert state_instance.field_one == "default"
assert state_instance.field_two == 123
@ -81,16 +81,16 @@ class TestCreateStateModel:
# Handles empty keyword arguments gracefully
def test_handle_empty_kwargs_gracefully(self):
StateModel = create_state_model()
state_instance = StateModel()
state_model = create_state_model()
state_instance = state_model()
assert state_instance is not None
# Ensures model name defaults to "State" if not provided
def test_default_model_name_to_state(self):
StateModel = create_state_model()
assert StateModel.__name__ == "State"
OtherNameModel = create_state_model(model_name="OtherName")
assert OtherNameModel.__name__ == "OtherName"
state_model = create_state_model()
assert state_model.__name__ == "State"
other_name_model = create_state_model(model_name="OtherName")
assert other_name_model.__name__ == "OtherName"
# Validates that callable values are properly type-annotated
@ -110,8 +110,7 @@ class TestCreateStateModel:
chat_input = ChatInput(_id="chat_input")
chat_output = ChatOutput(input_value="test", _id="chat_output")
chat_output.set(sender_name=chat_input.message_response)
ChatStateModel = create_state_model(model_name="ChatState", message=chat_output.message_response)
chat_state_model = ChatStateModel()
chat_state_model = create_state_model(model_name="ChatState", message=chat_output.message_response)()
assert chat_state_model.__class__.__name__ == "ChatState"
assert chat_state_model.message is UNDEFINED
@ -121,9 +120,7 @@ class TestCreateStateModel:
# and check that the graph is running
# correctly
ids = ["chat_input", "chat_output"]
results = []
for result in graph.start():
results.append(result)
results = list(graph.start())
assert len(results) == 3
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))

View file

@ -9,7 +9,6 @@ from langflow.components.outputs.TextOutput import TextOutputComponent
from langflow.components.tools.YfinanceTool import YfinanceToolComponent
from langflow.graph.graph.base import Graph
from langflow.graph.graph.constants import Finish
from pytest import LogCaptureFixture
@pytest.mark.asyncio
@ -19,12 +18,12 @@ async def test_graph_not_prepared():
graph = Graph()
graph.add_component(chat_input)
graph.add_component(chat_output)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Graph not prepared"):
await graph.astep()
@pytest.mark.asyncio
async def test_graph(caplog: LogCaptureFixture):
async def test_graph(caplog: pytest.LogCaptureFixture):
chat_input = ChatInput()
chat_output = ChatOutput()
graph = Graph()
@ -83,9 +82,7 @@ async def test_graph_functional_async_start():
# and check that the graph is running
# correctly
ids = ["chat_input", "chat_output"]
results = []
async for result in graph.async_start():
results.append(result)
results = [result async for result in graph.async_start()]
assert len(results) == 3
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
@ -102,9 +99,7 @@ def test_graph_functional_start():
# and check that the graph is running
# correctly
ids = ["chat_input", "chat_output"]
results = []
for result in graph.start():
results.append(result)
results = list(graph.start())
assert len(results) == 3
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
@ -123,9 +118,7 @@ def test_graph_functional_start_end():
# and check that the graph is running
# correctly
ids = ["chat_input", "text_output"]
results = []
for result in graph.start():
results.append(result)
results = list(graph.start())
assert len(results) == len(ids) + 1
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))

View file

@ -103,11 +103,9 @@ def test_cycle_in_graph_max_iterations():
# Run queue should contain chat_input and not router
assert "chat_input" in graph._run_queue
assert "router" not in graph._run_queue
results = []
with pytest.raises(ValueError, match="Max iterations reached"):
for result in graph.start(max_iterations=2, config={"output": {"cache": False}}):
results.append(result)
list(graph.start(max_iterations=2, config={"output": {"cache": False}}))
def test_that_outputs_cache_is_set_to_false_in_cycle():
@ -149,7 +147,10 @@ def test_updated_graph_with_prompts():
# First prompt: Guessing game with hints
prompt_component_1 = PromptComponent(_id="prompt_component_1").set(
template="Try to guess a word. I will give you hints if you get it wrong.\nHint: {hint}\nLast try: {last_try}\nAnswer:",
template="Try to guess a word. I will give you hints if you get it wrong.\n"
"Hint: {hint}\n"
"Last try: {last_try}\n"
"Answer:",
)
# First OpenAI LLM component (Processes the guessing prompt)
@ -168,7 +169,10 @@ def test_updated_graph_with_prompts():
# Second prompt: After the last try, provide a new hint
prompt_component_2 = PromptComponent(_id="prompt_component_2")
prompt_component_2.set(
template="Given the following word and the following last try. Give the guesser a new hint.\nLast try: {last_try}\nWord: {word}\nHint:",
template="Given the following word and the following last try. Give the guesser a new hint.\n"
"Last try: {last_try}\n"
"Word: {word}\n"
"Hint:",
word=chat_input.message_response,
last_try=router.false_response,
)

View file

@ -38,9 +38,9 @@ AI: """
graph = Graph(chat_input, chat_output)
GraphStateModel = create_state_model_from_graph(graph)
assert GraphStateModel.__name__ == "GraphStateModel"
assert list(GraphStateModel.model_computed_fields.keys()) == [
graph_state_model = create_state_model_from_graph(graph)
assert graph_state_model.__name__ == "GraphStateModel"
assert list(graph_state_model.model_computed_fields.keys()) == [
"chat_input",
"chat_output",
"openai",
@ -60,12 +60,9 @@ def test_graph_functional_start_graph_state_update():
# Now iterate through the graph
# and check that the graph is running
# correctly
GraphStateModel = create_state_model_from_graph(graph)
graph_state_model = GraphStateModel()
graph_state_model = create_state_model_from_graph(graph)()
ids = ["chat_input", "chat_output"]
results = []
for result in graph.start():
results.append(result)
results = list(graph.start())
assert len(results) == 3
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
@ -87,12 +84,9 @@ def test_graph_state_model_serialization():
# Now iterate through the graph
# and check that the graph is running
# correctly
GraphStateModel = create_state_model_from_graph(graph)
graph_state_model = GraphStateModel()
graph_state_model = create_state_model_from_graph(graph)()
ids = ["chat_input", "chat_output"]
results = []
for result in graph.start():
results.append(result)
results = list(graph.start())
assert len(results) == 3
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
@ -116,8 +110,7 @@ def test_graph_state_model_json_schema():
graph = Graph(chat_input, chat_output)
graph.prepare()
GraphStateModel = create_state_model_from_graph(graph)
graph_state_model: BaseModel = GraphStateModel()
graph_state_model: BaseModel = create_state_model_from_graph(graph)()
json_schema = graph_state_model.model_json_schema(mode="serialization")
# Test main schema structure

View file

@ -66,7 +66,7 @@ def test_pickle(data):
manager = RunnableVerticesManager.from_dict(data)
binary = pickle.dumps(manager)
result = pickle.loads(binary)
result = pickle.loads(binary) # noqa: S301
assert result.run_map == manager.run_map
assert result.run_predecessors == manager.run_predecessors

View file

@ -119,7 +119,7 @@ def test_sort_up_to_vertex_a(graph):
def test_sort_up_to_vertex_invalid_vertex(graph):
vertex_id = "7"
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Parent node map is required to find the root of a group node"):
utils.sort_up_to_vertex(graph, vertex_id)
@ -432,7 +432,7 @@ class TestFindCycleVertices:
assert sorted(result) == sorted(expected_output)
@pytest.mark.parametrize("_", range(5))
def test_handle_two_inputs_in_cycle(self, _):
def test_handle_two_inputs_in_cycle(self, _): # noqa: PT019
edges = [
("chat_input", "router"),
("chat_input", "concatenate"),

View file

@ -79,8 +79,8 @@ def test_invalid_node_types():
],
"edges": [],
}
with pytest.raises(Exception):
g = Graph()
g = Graph()
with pytest.raises(KeyError):
g.add_nodes_and_edges(graph_data["nodes"], graph_data["edges"])

View file

@ -79,13 +79,13 @@ class TestBuildModelFromSchema:
{"name": "field1", "type": "str", "default": "default_value1"},
{"name": "field2", "type": "unknown_type", "default": "default_value2"},
]
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Invalid type: unknown_type"):
build_model_from_schema(schema)
# Confirms that the function raises a specific exception for invalid input
def test_raises_error_for_invalid_input_different_exception_with_specific_exception(self):
with pytest.raises(ValueError):
schema = [{"name": "field1", "type": "invalid_type", "default": "default_value"}]
schema = [{"name": "field1", "type": "invalid_type", "default": "default_value"}]
with pytest.raises(ValueError, match="Invalid type: invalid_type"):
build_model_from_schema(schema)
# Processes schemas with missing optional keys like description or multiple

View file

@ -43,6 +43,7 @@ AI: """
return graph
@pytest.mark.usefixtures("client")
def test_memory_chatbot(memory_chatbot_graph):
# Now we run step by step
expected_order = deque(["chat_input", "chat_memory", "prompt", "openai", "chat_output"])

View file

@ -34,7 +34,7 @@ def ingestion_graph():
embedding=openai_embeddings.build_embeddings,
ingest_data=text_splitter.split_text,
api_endpoint="https://astra.example.com",
token="token",
token="token", # noqa: S106
)
vector_store.set_on_output(name="vector_store", value="mock_vector_store", cache=True)
vector_store.set_on_output(name="base_retriever", value="mock_retriever", cache=True)
@ -53,7 +53,7 @@ def rag_graph():
rag_vector_store.set(
search_input=chat_input.message_response,
api_endpoint="https://astra.example.com",
token="token",
token="token", # noqa: S106
embedding=openai_embeddings.build_embeddings,
)
# Mock search_documents
@ -110,9 +110,7 @@ def test_vector_store_rag(ingestion_graph, rag_graph):
"openai-embeddings-124",
]
for ids, graph, len_results in [(ingestion_ids, ingestion_graph, 5), (rag_ids, rag_graph, 8)]:
results = []
for result in graph.start():
results.append(result)
results = list(graph.start())
assert len(results) == len_results
vids = [result.vertex.id for result in results if hasattr(result, "vertex")]
@ -217,12 +215,14 @@ def test_vector_store_rag_add(ingestion_graph: Graph, rag_graph: Graph):
rag_graph_copy = copy.deepcopy(rag_graph)
ingestion_graph_copy += rag_graph_copy
assert (
len(ingestion_graph_copy.vertices) == len(ingestion_graph.vertices) + len(rag_graph.vertices)
), f"Vertices mismatch: {len(ingestion_graph_copy.vertices)} != {len(ingestion_graph.vertices)} + {len(rag_graph.vertices)}"
assert len(ingestion_graph_copy.edges) == len(ingestion_graph.edges) + len(
rag_graph.edges
), f"Edges mismatch: {len(ingestion_graph_copy.edges)} != {len(ingestion_graph.edges)} + {len(rag_graph.edges)}"
assert len(ingestion_graph_copy.vertices) == len(ingestion_graph.vertices) + len(rag_graph.vertices), (
f"Vertices mismatch: {len(ingestion_graph_copy.vertices)} "
f"!= {len(ingestion_graph.vertices)} + {len(rag_graph.vertices)}"
)
assert len(ingestion_graph_copy.edges) == len(ingestion_graph.edges) + len(rag_graph.edges), (
f"Edges mismatch: {len(ingestion_graph_copy.edges)} "
f"!= {len(ingestion_graph.edges)} + {len(rag_graph.edges)}"
)
combined_graph_dump = ingestion_graph_copy.dump(
name="Combined Graph", description="Graph for data ingestion and RAG", endpoint_name="combined"

View file

@ -70,7 +70,7 @@ def test_instantiate_input_valid():
def test_instantiate_input_invalid():
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Invalid input type: InvalidInput"):
instantiate_input("InvalidInput", {"name": "invalid_input", "value": "This is a string"})
@ -224,5 +224,5 @@ def test_instantiate_input_comprehensive():
input_instance = instantiate_input(input_type, data)
assert isinstance(input_instance, InputTypesMap[input_type])
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Invalid input type: InvalidInput"):
instantiate_input("InvalidInput", {"name": "invalid_input", "value": "Invalid"})

View file

@ -180,7 +180,7 @@ class TestCreateInputSchema:
input_instance = StrInput(name="test_field", is_list=True)
schema = create_input_schema([input_instance])
field_info = schema.model_fields["test_field"]
assert field_info.annotation == list[str] # type: ignore
assert field_info.annotation == list[str]
# Converting FieldTypes to corresponding Python types
def test_field_types_conversion(self):

View file

@ -33,7 +33,7 @@ class TestColumn:
# Invalid formatter raises ValueError
def test_invalid_formatter_raises_value_error(self):
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="'invalid' is not a valid FormatterType"):
Column(display_name="Invalid Column", name="invalid_column", formatter="invalid")
# Formatter is None when not provided

View file

@ -72,19 +72,16 @@ def test_get_variable(service, session):
assert result == value
def test_get_variable__ValueError(service, session):
def test_get_variable__valueerror(service, session):
user_id = uuid4()
name = "name"
field = ""
with pytest.raises(ValueError) as exc:
with pytest.raises(ValueError, match=f"{name} variable not found."):
service.get_variable(user_id, name, field, session)
assert name in str(exc.value)
assert "variable not found" in str(exc.value)
def test_get_variable__TypeError(service, session):
def test_get_variable__typeerror(service, session):
user_id = uuid4()
name = "name"
value = "value"
@ -142,17 +139,14 @@ def test_update_variable(service, session):
assert isinstance(result.updated_at, datetime)
def test_update_variable__ValueError(service, session):
def test_update_variable__valueerror(service, session):
user_id = uuid4()
name = "name"
value = "value"
with pytest.raises(ValueError) as exc:
with pytest.raises(ValueError, match=f"{name} variable not found."):
service.update_variable(user_id, name, value, session=session)
assert name in str(exc.value)
assert "variable not found" in str(exc.value)
def test_update_variable_fields(service, session):
user_id = uuid4()
@ -192,26 +186,21 @@ def test_delete_variable(service, session):
service.create_variable(user_id, name, value, session=session)
recovered = service.get_variable(user_id, name, field, session=session)
service.delete_variable(user_id, name, session=session)
with pytest.raises(ValueError) as exc:
with pytest.raises(ValueError, match=f"{name} variable not found."):
service.get_variable(user_id, name, field, session)
assert recovered == value
assert name in str(exc.value)
assert "variable not found" in str(exc.value)
def test_delete_variable__ValueError(service, session):
def test_delete_variable__valueerror(service, session):
user_id = uuid4()
name = "name"
with pytest.raises(ValueError) as exc:
with pytest.raises(ValueError, match=f"{name} variable not found."):
service.delete_variable(user_id, name, session=session)
assert name in str(exc.value)
assert "variable not found" in str(exc.value)
def test_delete_varaible_by_id(service, session):
def test_delete_variable_by_id(service, session):
user_id = uuid4()
name = "name"
value = "value"
@ -220,24 +209,19 @@ def test_delete_varaible_by_id(service, session):
saved = service.create_variable(user_id, name, value, session=session)
recovered = service.get_variable(user_id, name, field, session=session)
service.delete_variable_by_id(user_id, saved.id, session=session)
with pytest.raises(ValueError) as exc:
with pytest.raises(ValueError, match=f"{name} variable not found."):
service.get_variable(user_id, name, field, session)
assert recovered == value
assert name in str(exc.value)
assert "variable not found" in str(exc.value)
def test_delete_variable_by_id__ValueError(service, session):
def test_delete_variable_by_id__valueerror(service, session):
user_id = uuid4()
variable_id = uuid4()
with pytest.raises(ValueError) as exc:
with pytest.raises(ValueError, match=f"{variable_id} variable not found."):
service.delete_variable_by_id(user_id, variable_id, session=session)
assert str(variable_id) in str(exc.value)
assert "variable not found" in str(exc.value)
def test_create_variable(service, session):
user_id = uuid4()

View file

@ -4,7 +4,11 @@ from langflow.services.database.models.api_key import ApiKeyCreate
@pytest.fixture
async def api_key(client, logged_in_headers, active_user):
async def api_key(
client,
logged_in_headers,
active_user, # noqa: ARG001
):
api_key = ApiKeyCreate(name="test-api-key")
response = await client.post("api/v1/api_key/", data=api_key.model_dump_json(), headers=logged_in_headers)

View file

@ -51,9 +51,7 @@ def test_code_parser_get_tree():
def test_code_parser_syntax_error():
"""Test the __get_tree method raises the
CodeSyntaxError when given incorrect syntax.
"""
"""Test the __get_tree method raises the CodeSyntaxError when given incorrect syntax."""
code_syntax_error = "zzz import os"
parser = CodeParser(code_syntax_error)
@ -76,9 +74,7 @@ def test_component_get_code_tree():
def test_component_code_null_error():
"""Test the get_function method raises the
ComponentCodeNullError when the code is empty.
"""
"""Test the get_function method raises the ComponentCodeNullError when the code is empty."""
component = BaseComponent(_code="", _function_entrypoint_name="")
with pytest.raises(ComponentCodeNullError):
component.get_function()
@ -108,9 +104,7 @@ def test_custom_component_get_function():
def test_code_parser_parse_imports_import():
"""Test the parse_imports method of the CodeParser
class with an import statement.
"""
"""Test the parse_imports method of the CodeParser class with an import statement."""
parser = CodeParser(code_default)
tree = parser.get_tree()
for node in ast.walk(tree):
@ -120,9 +114,7 @@ def test_code_parser_parse_imports_import():
def test_code_parser_parse_imports_importfrom():
"""Test the parse_imports method of the CodeParser
class with an import from statement.
"""
"""Test the parse_imports method of the CodeParser class with an import from statement."""
parser = CodeParser("from os import path")
tree = parser.get_tree()
for node in ast.walk(tree):
@ -157,9 +149,9 @@ def test_code_parser_parse_classes_raises():
"""Test the parse_classes method of the CodeParser class."""
parser = CodeParser("class Test: pass")
tree = parser.get_tree()
with pytest.raises(TypeError):
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
with pytest.raises(TypeError):
parser.parse_classes(node)
@ -175,18 +167,14 @@ def test_code_parser_parse_global_vars():
def test_component_get_function_valid():
"""Test the get_function method of the Component
class with valid code and function_entrypoint_name.
"""
"""Test the get_function method of the Component class with valid code and function_entrypoint_name."""
component = BaseComponent(_code="def build(): pass", _function_entrypoint_name="build")
my_function = component.get_function()
assert callable(my_function)
def test_custom_component_get_function_entrypoint_args():
"""Test the get_function_entrypoint_args
property of the CustomComponent class.
"""
"""Test the get_function_entrypoint_args property of the CustomComponent class."""
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
args = custom_component.get_function_entrypoint_args
assert len(args) == 3
@ -196,9 +184,7 @@ def test_custom_component_get_function_entrypoint_args():
def test_custom_component_get_function_entrypoint_return_type():
"""Test the get_function_entrypoint_return_type
property of the CustomComponent class.
"""
"""Test the get_function_entrypoint_return_type property of the CustomComponent class."""
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
return_type = custom_component.get_function_entrypoint_return_type
assert return_type == [Document]
@ -212,9 +198,7 @@ def test_custom_component_get_main_class_name():
def test_custom_component_get_function_valid():
"""Test the get_function property of the CustomComponent
class with valid code and function_entrypoint_name.
"""
"""Test the get_function property of the CustomComponent class with valid code and function_entrypoint_name."""
custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build")
my_function = custom_component.get_function
assert callable(my_function)
@ -239,9 +223,7 @@ def test_code_parser_parse_arg_with_annotation():
def test_code_parser_parse_callable_details_no_args():
"""Test the parse_callable_details method of the
CodeParser class with a function with no arguments.
"""
"""Test the parse_callable_details method of the CodeParser class with a function with no arguments."""
parser = CodeParser("")
node = ast.FunctionDef(
name="test",
@ -280,9 +262,7 @@ def test_code_parser_parse_ann_assign():
def test_code_parser_parse_function_def_not_init():
"""Test the parse_function_def method of the
CodeParser class with a function that is not __init__.
"""
"""Test the parse_function_def method of the CodeParser class with a function that is not __init__."""
parser = CodeParser("")
stmt = ast.FunctionDef(
name="test",
@ -297,9 +277,7 @@ def test_code_parser_parse_function_def_not_init():
def test_code_parser_parse_function_def_init():
"""Test the parse_function_def method of the
CodeParser class with an __init__ function.
"""
"""Test the parse_function_def method of the CodeParser class with an __init__ function."""
parser = CodeParser("")
stmt = ast.FunctionDef(
name="__init__",
@ -314,36 +292,28 @@ def test_code_parser_parse_function_def_init():
def test_component_get_code_tree_syntax_error():
"""Test the get_code_tree method of the Component class
raises the CodeSyntaxError when given incorrect syntax.
"""
"""Test the get_code_tree method of the Component class raises the CodeSyntaxError when given incorrect syntax."""
component = BaseComponent(_code="import os as", _function_entrypoint_name="build")
with pytest.raises(CodeSyntaxError):
component.get_code_tree(component._code)
def test_custom_component_class_template_validation_no_code():
"""Test the _class_template_validation method of the CustomComponent class
raises the HTTPException when the code is None.
"""
"""Test CustomComponent._class_template_validation raises the HTTPException when the code is None."""
custom_component = CustomComponent(_code=None, _function_entrypoint_name="build")
with pytest.raises(TypeError):
custom_component.get_function()
def test_custom_component_get_code_tree_syntax_error():
"""Test the get_code_tree method of the CustomComponent class
raises the CodeSyntaxError when given incorrect syntax.
"""
"""Test CustomComponent.get_code_tree raises the CodeSyntaxError when given incorrect syntax."""
custom_component = CustomComponent(_code="import os as", _function_entrypoint_name="build")
with pytest.raises(CodeSyntaxError):
custom_component.get_code_tree(custom_component._code)
def test_custom_component_get_function_entrypoint_args_no_args():
"""Test the get_function_entrypoint_args property of
the CustomComponent class with a build method with no arguments.
"""
"""Test CustomComponent.get_function_entrypoint_args with a build method with no arguments."""
my_code = """
from langflow.custom import CustomComponent
class MyMainClass(CustomComponent):
@ -356,9 +326,7 @@ class MyMainClass(CustomComponent):
def test_custom_component_get_function_entrypoint_return_type_no_return_type():
"""Test the get_function_entrypoint_return_type property of the
CustomComponent class with a build method with no return type.
"""
"""Test CustomComponent.get_function_entrypoint_return_type with a build method with no return type."""
my_code = """
from langflow.custom import CustomComponent
class MyClass(CustomComponent):
@ -371,9 +339,7 @@ class MyClass(CustomComponent):
def test_custom_component_get_main_class_name_no_main_class():
"""Test the get_main_class_name property of the
CustomComponent class when there is no main class.
"""
"""Test the get_main_class_name property of the CustomComponent class when there is no main class."""
my_code = """
def build():
pass"""
@ -384,9 +350,7 @@ def build():
def test_custom_component_build_not_implemented():
"""Test the build method of the CustomComponent
class raises the NotImplementedError.
"""
"""Test the build method of the CustomComponent class raises the NotImplementedError."""
custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build")
with pytest.raises(NotImplementedError):
custom_component.build()

View file

@ -15,7 +15,10 @@ def code_component_with_multiple_outputs():
@pytest.fixture
def component(client, active_user):
def component(
client, # noqa: ARG001
active_user,
):
return CustomComponent(
user_id=active_user.id,
field_config={

View file

@ -1,5 +1,5 @@
import json
from collections import namedtuple
from typing import NamedTuple
from uuid import UUID, uuid4
import orjson
@ -13,7 +13,6 @@ from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
from langflow.services.database.models.folder.model import FolderCreate
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from sqlmodel import Session
@pytest.fixture(scope="module")
@ -179,11 +178,11 @@ async def test_delete_flows_with_transaction_and_build(client: TestClient, logge
assert response.status_code == 201
flow_ids.append(response.json()["id"])
class VertexTuple(NamedTuple):
id: str
# Create a transaction for each flow
for flow_id in flow_ids:
VertexTuple = namedtuple("VertexTuple", ["id"])
await log_transaction(
str(flow_id), source=VertexTuple(id="vid"), target=VertexTuple(id="tid"), status="success"
)
@ -249,10 +248,11 @@ async def test_delete_folder_with_flows_with_transaction_and_build(client: TestC
assert response.status_code == 201
flow_ids.append(response.json()["id"])
class VertexTuple(NamedTuple):
id: str
# Create a transaction for each flow
for flow_id in flow_ids:
VertexTuple = namedtuple("VertexTuple", ["id"])
await log_transaction(
str(flow_id), source=VertexTuple(id="vid"), target=VertexTuple(id="tid"), status="success"
)
@ -400,9 +400,9 @@ async def test_upload_file(client: TestClient, json_flow: str, logged_in_headers
assert response_data[1]["data"] == data
@pytest.mark.usefixtures("session")
async def test_download_file(
client: TestClient,
session: Session,
json_flow,
active_user,
logged_in_headers,
@ -419,14 +419,14 @@ async def test_download_file(
]
)
db_manager = get_db_service()
with session_getter(db_manager) as session:
with session_getter(db_manager) as _session:
saved_flows = []
for flow in flow_list.flows:
flow.user_id = active_user.id
db_flow = Flow.model_validate(flow, from_attributes=True)
session.add(db_flow)
_session.add(db_flow)
saved_flows.append(db_flow)
session.commit()
_session.commit()
# Make request to endpoint inside the session context
flow_ids = [str(db_flow.id) for db_flow in saved_flows] # Convert UUIDs to strings
flow_ids_json = json.dumps(flow_ids)

View file

@ -1,4 +1,4 @@
import time
import asyncio
from uuid import UUID, uuid4
import pytest
@ -27,7 +27,7 @@ async def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1)
)
if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS":
return task_status_response.json()
time.sleep(sleep_time)
await asyncio.sleep(sleep_time)
return None # Return None if task did not complete in time

View file

@ -26,7 +26,13 @@ def mock_storage_service():
@pytest.fixture(name="files_client")
async def files_client_fixture(session: Session, monkeypatch, request, load_flows_dir, mock_storage_service):
async def files_client_fixture(
session: Session, # noqa: ARG001
monkeypatch,
request,
load_flows_dir,
mock_storage_service,
):
# Set the database url to a test database
if "noclient" in request.keywords:
yield
@ -47,9 +53,11 @@ async def files_client_fixture(session: Session, monkeypatch, request, load_flow
app = create_app()
app.dependency_overrides[get_storage_service] = lambda: mock_storage_service
async with LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager:
async with AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client:
yield client
async with (
LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager,
AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client,
):
yield client
# app.dependency_overrides.clear()
monkeypatch.undo()
# clear the temp db

View file

@ -59,7 +59,8 @@ async def test_create_or_update_starter_projects():
assert folder is not None
num_db_projects = len(folder.flows)
# Check that the number of projects in the database is the same as the number of projects returned by load_starter_projects
# Check that the number of projects in the database is the same as the number of projects returned by
# load_starter_projects
assert num_db_projects == num_projects
@ -76,7 +77,8 @@ async def test_create_or_update_starter_projects():
# # Get the number of projects in the database
# num_db_projects = session.exec(select(func.count(Flow.id)).where(Flow.folder == STARTER_FOLDER_NAME)).one()
# # Check that the number of projects in the database is the same as the number of projects returned by load_starter_projects
# # Check that the number of projects in the database is the same as the number of projects returned by
# # load_starter_projects
# assert num_db_projects == num_projects
# # Get all the starter projects
@ -99,7 +101,7 @@ async def test_create_or_update_starter_projects():
# delete_messages(session_id="test")
def find_componeny_by_name(components, name):
def find_component_by_name(components, name):
for children in components.values():
if name in children:
return children[name]
@ -111,17 +113,17 @@ def set_value(component, input_name, value):
component["template"][input_name]["value"] = value
def component_to_node(id, type, component):
return {"id": type + id, "data": {"node": component, "type": type, "id": id}}
def component_to_node(node_id, node_type, component):
return {"id": node_type + node_id, "data": {"node": component, "type": node_type, "id": node_id}}
def add_edge(input, output, from_output, to_input):
def add_edge(source, target, from_output, to_input):
return {
"source": input,
"target": output,
"source": source,
"target": target,
"data": {
"sourceHandle": {"dataType": "ChatInput", "id": input, "name": from_output, "output_types": ["Message"]},
"targetHandle": {"fieldName": to_input, "id": output, "inputTypes": ["Message"], "type": "str"},
"sourceHandle": {"dataType": "ChatInput", "id": source, "name": from_output, "output_types": ["Message"]},
"targetHandle": {"fieldName": to_input, "id": target, "inputTypes": ["Message"], "type": "str"},
},
}
@ -131,8 +133,8 @@ async def test_refresh_starter_projects():
data_path = str(Path(__file__).parent.parent.parent.absolute() / "base" / "langflow" / "components")
components = build_custom_component_list_from_path(data_path)
chat_input = find_componeny_by_name(components, "ChatInput")
chat_output = find_componeny_by_name(components, "ChatOutput")
chat_input = find_component_by_name(components, "ChatInput")
chat_output = find_component_by_name(components, "ChatOutput")
chat_output["template"]["code"]["value"] = "changed !"
del chat_output["template"]["should_store_message"]
graph_data = {

View file

@ -8,13 +8,13 @@ from langflow.services.variable.kubernetes_secrets import KubernetesSecretManage
@pytest.fixture
def mock_kube_config(mocker):
def _mock_kube_config(mocker):
mocker.patch("kubernetes.config.load_kube_config")
mocker.patch("kubernetes.config.load_incluster_config")
@pytest.fixture
def secret_manager(mock_kube_config):
def secret_manager(_mock_kube_config):
return KubernetesSecretManager(namespace="test-namespace")

View file

@ -19,15 +19,15 @@ def created_message():
@pytest.fixture
def created_messages(session):
with session_scope() as session:
def created_messages(session): # noqa: ARG001
with session_scope() as _session:
messages = [
MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
]
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
messagetables = add_messagetables(messagetables, session)
messagetables = add_messagetables(messagetables, _session)
return [MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables]

View file

@ -20,15 +20,15 @@ async def created_message():
@pytest.fixture
def created_messages(session):
with session_scope() as session:
def created_messages(session): # noqa: ARG001
with session_scope() as _session:
messages = [
MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
]
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
return add_messagetables(messagetables, session)
return add_messagetables(messagetables, _session)
@pytest.mark.api_key_required

View file

@ -282,12 +282,16 @@ async def test_load_langchain_object_with_cached_session(basic_graph_data):
# session_service = get_session_service()
# session_id1 = "non-existent-session-id"
# session_id = session_service.build_key(session_id1, basic_graph_data)
# graph1, artifacts1 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id")
# graph1, artifacts1 = await session_service.load_session(
# session_id, data_graph=basic_graph_data, flow_id="flow_id"
# )
# # Clear the cache
# await session_service.clear_session(session_id)
# # Use the new session_id to get the graph again
# graph2, artifacts2 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id")
# graph2, artifacts2 = await session_service.load_session(
# session_id, data_graph=basic_graph_data, flow_id="flow_id"
# )
#
# # Since the cache was cleared, objects should be different
# assert id(graph1) != id(graph2)
@ -297,8 +301,12 @@ async def test_load_langchain_object_with_cached_session(basic_graph_data):
# # Provide a non-existent session_id
# session_service = get_session_service()
# session_id1 = None
# graph1, artifacts1 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id")
# graph1, artifacts1 = await session_service.load_session(
# session_id1, data_graph=basic_graph_data, flow_id="flow_id"
# )
# # Use the new session_id to get the langchain_object again
# graph2, artifacts2 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id")
# graph2, artifacts2 = await session_service.load_session(
# session_id1, data_graph=basic_graph_data, flow_id="flow_id"
# )
#
# assert graph1 == graph2

View file

@ -45,52 +45,52 @@ class TestInput:
assert set(post_process_type(SequenceABC[float])) == {float}
# Union types
assert set(post_process_type(Union[int, str])) == {int, str}
assert set(post_process_type(Union[int, SequenceABC[str]])) == {int, str}
assert set(post_process_type(Union[int, SequenceABC[int]])) == {int}
assert set(post_process_type(Union[int, str])) == {int, str} # noqa: UP007
assert set(post_process_type(Union[int, SequenceABC[str]])) == {int, str} # noqa: UP007
assert set(post_process_type(Union[int, SequenceABC[int]])) == {int} # noqa: UP007
# Nested Union with lists
assert set(post_process_type(Union[list[int], list[str]])) == {int, str}
assert set(post_process_type(Union[int, list[str], list[float]])) == {int, str, float}
assert set(post_process_type(Union[list[int], list[str]])) == {int, str} # noqa: UP007
assert set(post_process_type(Union[int, list[str], list[float]])) == {int, str, float} # noqa: UP007
# Custom data types
assert set(post_process_type(Data)) == {Data}
assert set(post_process_type(list[Data])) == {Data}
# Union with custom types
assert set(post_process_type(Union[Data, str])) == {Data, str}
assert set(post_process_type(Union[Data, int, list[str]])) == {Data, int, str}
assert set(post_process_type(Union[Data, str])) == {Data, str} # noqa: UP007
assert set(post_process_type(Union[Data, int, list[str]])) == {Data, int, str} # noqa: UP007
# Empty lists and edge cases
assert set(post_process_type(list)) == {list}
assert set(post_process_type(Union[int, None])) == {int, NoneType}
assert set(post_process_type(Union[None, list[None]])) == {None, NoneType}
assert set(post_process_type(Union[int, None])) == {int, NoneType} # noqa: UP007
assert set(post_process_type(Union[None, list[None]])) == {None, NoneType} # noqa: UP007
# Handling complex nested structures
assert set(post_process_type(Union[SequenceABC[int | str], list[float]])) == {int, str, float}
assert set(post_process_type(Union[int | list[str] | list[float], str])) == {int, str, float}
assert set(post_process_type(Union[SequenceABC[int | str], list[float]])) == {int, str, float} # noqa: UP007
assert set(post_process_type(Union[int | list[str] | list[float], str])) == {int, str, float} # noqa: UP007
# Non-generic types should return as is
assert set(post_process_type(dict)) == {dict}
assert set(post_process_type(tuple)) == {tuple}
# Union with custom types
assert set(post_process_type(Union[Data, str])) == {Data, str}
assert set(post_process_type(Union[Data, str])) == {Data, str} # noqa: UP007
assert set(post_process_type(Data | str)) == {Data, str}
assert set(post_process_type(Data | int | list[str])) == {Data, int, str}
# More complex combinations with Data
assert set(post_process_type(Data | list[float])) == {Data, float}
assert set(post_process_type(Data | Union[int, str])) == {Data, int, str}
assert set(post_process_type(Data | Union[int, str])) == {Data, int, str} # noqa: UP007
assert set(post_process_type(Data | list[int] | None)) == {Data, int, type(None)}
assert set(post_process_type(Data | Union[float, None])) == {Data, float, type(None)}
assert set(post_process_type(Data | Union[float, None])) == {Data, float, type(None)} # noqa: UP007
# Multiple Data types combined
assert set(post_process_type(Union[Data, str | float])) == {Data, str, float}
assert set(post_process_type(Union[Data | float | str, int])) == {Data, int, float, str}
assert set(post_process_type(Union[Data, str | float])) == {Data, str, float} # noqa: UP007
assert set(post_process_type(Union[Data | float | str, int])) == {Data, int, float, str} # noqa: UP007
# Testing with nested unions and lists
assert set(post_process_type(Union[list[Data], list[int | str]])) == {Data, int, str}
assert set(post_process_type(Union[list[Data], list[int | str]])) == {Data, int, str} # noqa: UP007
assert set(post_process_type(Data | list[float | str])) == {Data, float, str}
def test_input_to_dict(self):
@ -157,7 +157,7 @@ class TestPostProcessType:
assert post_process_type(list[int]) == [int]
def test_union_type(self):
assert set(post_process_type(Union[int, str])) == {int, str}
assert set(post_process_type(Union[int, str])) == {int, str} # noqa: UP007
def test_custom_type(self):
class CustomType:
@ -175,4 +175,4 @@ class TestPostProcessType:
class CustomType:
pass
assert set(post_process_type(Union[CustomType, int])) == {CustomType, int}
assert set(post_process_type(Union[CustomType, int])) == {CustomType, int} # noqa: UP007

View file

@ -112,11 +112,11 @@ def test_teardown_superuser_default_superuser(mock_get_session, mock_get_setting
@patch("langflow.services.deps.get_settings_service")
@patch("langflow.services.deps.get_session")
def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service):
ADMIN_USER_NAME = "admin_user"
admin_user_name = "admin_user"
mock_settings_service = MagicMock()
mock_settings_service.auth_settings.AUTO_LOGIN = False
mock_settings_service.auth_settings.SUPERUSER = ADMIN_USER_NAME
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password"
mock_settings_service.auth_settings.SUPERUSER = admin_user_name
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password" # noqa: S105
mock_get_settings_service.return_value = mock_settings_service
mock_session = MagicMock()

View file

@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timezone
import pytest
from httpx import AsyncClient
@ -11,7 +11,7 @@ from sqlmodel import select
@pytest.fixture
def super_user(client):
def super_user(client): # noqa: ARG001
settings_manager = get_settings_service()
auth_settings = settings_manager.auth_settings
with session_getter(get_db_service()) as session:
@ -23,7 +23,10 @@ def super_user(client):
@pytest.fixture
async def super_user_headers(client: AsyncClient, super_user):
async def super_user_headers(
client: AsyncClient,
super_user, # noqa: ARG001
):
settings_service = get_settings_service()
auth_settings = settings_service.auth_settings
login_data = {
@ -38,14 +41,14 @@ async def super_user_headers(client: AsyncClient, super_user):
@pytest.fixture
def deactivated_user(client):
def deactivated_user(client): # noqa: ARG001
with session_getter(get_db_service()) as session:
user = User(
username="deactivateduser",
password=get_password_hash("testpassword"),
is_active=False,
is_superuser=False,
last_login_at=datetime.now(),
last_login_at=datetime.now(tz=timezone.utc),
)
session.add(user)
session.commit()
@ -55,7 +58,7 @@ def deactivated_user(client):
async def test_user_waiting_for_approval(client):
username = "waitingforapproval"
password = "testpassword"
password = "testpassword" # noqa: S105
# Debug: Check if the user already exists
with session_getter(get_db_service()) as session:
@ -140,7 +143,7 @@ async def test_inactive_user(client: AsyncClient):
username="inactiveuser",
password=get_password_hash("testpassword"),
is_active=False,
last_login_at=datetime(2023, 1, 1, 0, 0, 0),
last_login_at=datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
)
session.add(user)
session.commit()
@ -205,7 +208,7 @@ async def test_patch_user(client: AsyncClient, active_user, logged_in_headers):
async def test_patch_reset_password(client: AsyncClient, active_user, logged_in_headers):
user_id = active_user.id
update_data = UserUpdate(
password="newpassword",
password="newpassword", # noqa: S106
)
response = await client.patch(

View file

@ -5,7 +5,7 @@ import pytest
@pytest.fixture(autouse=True)
def check_openai_api_key_in_environment_variables():
def _check_openai_api_key_in_environment_variables():
pass