diff --git a/pyproject.toml b/pyproject.toml index 7164e4247..bd16677a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,7 +176,7 @@ dev-dependencies = [ "asgi-lifespan>=2.1.0", "pytest-github-actions-annotate-failures>=0.2.0", "pytest-codspeed>=3.0.0", - "forbiddenfruit>=0.1.4", + "blockbuster>=1.1.0,<1.2", ] diff --git a/src/backend/tests/blockbuster.py b/src/backend/tests/blockbuster.py deleted file mode 100644 index c3b3a5f9d..000000000 --- a/src/backend/tests/blockbuster.py +++ /dev/null @@ -1,144 +0,0 @@ -import asyncio -import inspect -import io -import os -import socket -import ssl -import sys -import time -from importlib.abc import FileLoader - -import forbiddenfruit - - -class BlockingError(Exception): ... - - -def _blocking_error(func): - if inspect.isbuiltin(func): - msg = f"Blocking call to {func.__qualname__} ({func.__self__})" - elif inspect.ismethoddescriptor(func): - msg = f"Blocking call to {func}" - else: - msg = f"Blocking call to {func.__module__}.{func.__qualname__}" - return BlockingError(msg) - - -def _wrap_blocking(func): - def wrapper(*args, **kwargs): - try: - asyncio.get_running_loop() - except RuntimeError: - return func(*args, **kwargs) - raise _blocking_error(func) - - return wrapper - - -def _wrap_time_blocking(func): - def wrapper(*args, **kwargs): - try: - asyncio.get_running_loop() - except RuntimeError: - return func(*args, **kwargs) - for frame_info in inspect.stack(): - if frame_info.filename.endswith("pydev/pydevd.py") and frame_info.function == "_do_wait_suspend": - return func(*args, **kwargs) - - raise _blocking_error(func) - - return wrapper - - -def _wrap_os_blocking(func): - def os_op(fd, *args, **kwargs): - try: - asyncio.get_running_loop() - except RuntimeError: - return func(fd, *args, **kwargs) - if os.get_blocking(fd): - raise _blocking_error(func) - return func(fd, *args, **kwargs) - - return os_op - - -def _wrap_socket_blocking(func): - def socket_op(self, *args, **kwargs): - try: - asyncio.get_running_loop() - except RuntimeError: - return func(self, *args, **kwargs) - if self.getblocking(): - raise _blocking_error(func) - return func(self, *args, **kwargs) - - return socket_op - - -def _wrap_file_read_blocking(func): - def file_op(self, *args, **kwargs): - try: - asyncio.get_running_loop() - except RuntimeError: - return func(self, *args, **kwargs) - for frame_info in inspect.stack(): - if isinstance(frame_info.frame.f_locals.get("self"), FileLoader): - return func(self, *args, **kwargs) - if frame_info.filename.endswith("_pytest/assertion/rewrite.py") and frame_info.function in { - "_rewrite_test", - "_read_pyc", - }: - return func(self, *args, **kwargs) - if frame_info.filename.endswith("settings/service.py") and frame_info.function == "initialize": - return func(self, *args, **kwargs) - raise _blocking_error(func) - - return file_op - - -def _wrap_file_write_blocking(func): - def file_op(self, *args, **kwargs): - try: - asyncio.get_running_loop() - except RuntimeError: - return func(self, *args, **kwargs) - for frame_info in inspect.stack(): - if frame_info.filename.endswith("_pytest/assertion/rewrite.py") and frame_info.function == "_write_pyc": - return func(self, *args, **kwargs) - if frame_info.filename.endswith("settings/service.py") and frame_info.function == "initialize": - return func(self, *args, **kwargs) - if self not in {sys.stdout, sys.stderr}: - raise _blocking_error(func) - return func(self, *args, **kwargs) - - return file_op - - -def init(): - time.sleep = _wrap_time_blocking(time.sleep) - - os.read = _wrap_os_blocking(os.read) - os.write = _wrap_os_blocking(os.write) - - socket.socket.send = _wrap_socket_blocking(socket.socket.send) - socket.socket.sendall = _wrap_socket_blocking(socket.socket.sendall) - socket.socket.sendto = _wrap_socket_blocking(socket.socket.sendto) - socket.socket.recv = _wrap_socket_blocking(socket.socket.recv) - socket.socket.recv_into = _wrap_socket_blocking(socket.socket.recv_into) - socket.socket.recvfrom = _wrap_socket_blocking(socket.socket.recvfrom) - socket.socket.recvfrom_into = _wrap_socket_blocking(socket.socket.recvfrom_into) - socket.socket.recvmsg = _wrap_socket_blocking(socket.socket.recvmsg) - socket.socket.recvmsg_into = _wrap_socket_blocking(socket.socket.recvmsg_into) - - ssl.SSLSocket.write = _wrap_socket_blocking(ssl.SSLSocket.write) - ssl.SSLSocket.send = _wrap_socket_blocking(ssl.SSLSocket.send) - ssl.SSLSocket.read = _wrap_socket_blocking(ssl.SSLSocket.read) - ssl.SSLSocket.recv = _wrap_socket_blocking(ssl.SSLSocket.recv) - - forbiddenfruit.curse(io.BufferedReader, "read", _wrap_file_read_blocking(io.BufferedReader.read)) - forbiddenfruit.curse(io.BufferedWriter, "write", _wrap_file_write_blocking(io.BufferedWriter.write)) - forbiddenfruit.curse(io.BufferedRandom, "read", _wrap_blocking(io.BufferedRandom.read)) - forbiddenfruit.curse(io.BufferedRandom, "write", _wrap_file_write_blocking(io.BufferedRandom.write)) - forbiddenfruit.curse(io.TextIOWrapper, "read", _wrap_file_read_blocking(io.TextIOWrapper.read)) - forbiddenfruit.curse(io.TextIOWrapper, "write", _wrap_file_write_blocking(io.TextIOWrapper.write)) diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index 7605c349c..ea71e2db6 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -13,6 +13,7 @@ from uuid import UUID import orjson import pytest from asgi_lifespan import LifespanManager +from blockbuster import blockbuster_ctx from dotenv import load_dotenv from fastapi.testclient import TestClient from httpx import ASGITransport, AsyncClient @@ -34,7 +35,6 @@ from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.pool import StaticPool from typer.testing import CliRunner -from tests import blockbuster from tests.api_keys import get_openai_api_key if TYPE_CHECKING: @@ -42,7 +42,22 @@ if TYPE_CHECKING: load_dotenv() -blockbuster.init() + + +@pytest.fixture(autouse=True) +def blockbuster(): + with blockbuster_ctx() as bb: + for func in [ + "io.BufferedReader.read", + "io.BufferedWriter.write", + "io.TextIOWrapper.read", + "io.TextIOWrapper.write", + ]: + bb.functions[func].can_block_functions.append(("settings/service.py", {"initialize"})) + for func in bb.functions: + if func.startswith("sqlite3."): + bb.functions[func].deactivate() + yield bb def pytest_configure(config): diff --git a/uv.lock b/uv.lock index 984f05f15..cee245848 100644 --- a/uv.lock +++ b/uv.lock @@ -447,6 +447,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/da/43b15f28fe5f9e027b41c539abc5469052e9d48fd75f8ff094ba2a0ae767/billiard-4.2.1-py3-none-any.whl", hash = "sha256:40b59a4ac8806ba2c2369ea98d876bc6108b051c227baffd928c644d15d8f3cb", size = 86766 }, ] +[[package]] +name = "blockbuster" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "forbiddenfruit" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/72/14f688cb48f37758666adfc92f078b26ca4fac696d42903ed47be00d2ad0/blockbuster-1.1.0.tar.gz", hash = "sha256:f812f516eb6b91bc0255d99b7a88890a3a03d68303637c214f228b0c2d0c13ce", size = 9068 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/6e/22d3aabd041915de0e9e695d0744d0b0dc0f70c98301a96210ad3f67cc1c/blockbuster-1.1.0-py3-none-any.whl", hash = "sha256:f8861d1ace11053ebd80b2e98459af69c077e3c3e18a6ddd81830a8f3776d5b2", size = 7901 }, +] + [[package]] name = "boto3" version = "1.34.162" @@ -3614,8 +3626,8 @@ local = [ [package.dev-dependencies] dev = [ { name = "asgi-lifespan" }, + { name = "blockbuster" }, { name = "dictdiffer" }, - { name = "forbiddenfruit" }, { name = "httpx" }, { name = "ipykernel" }, { name = "mypy" }, @@ -3744,8 +3756,8 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "asgi-lifespan", specifier = ">=2.1.0" }, + { name = "blockbuster", specifier = ">=1.1.0,<1.2" }, { name = "dictdiffer", specifier = ">=0.9.0" }, - { name = "forbiddenfruit", specifier = ">=0.1.4" }, { name = "httpx", specifier = ">=0.27.0" }, { name = "ipykernel", specifier = ">=6.29.0" }, { name = "mypy", specifier = ">=1.11.0" },