diff --git a/docs/async_migration.md b/docs/async_migration.md new file mode 100644 index 0000000..748b7a6 --- /dev/null +++ b/docs/async_migration.md @@ -0,0 +1,48 @@ +# Async Nxbt Migration Notes + +## Overview + +NXBT now exposes an async-friendly facade so that applications can await controller +lifecycle events directly instead of relying on background threads. The new +`nxbt.AsyncNxbtClient` wraps the legacy `Nxbt` API but offloads each blocking call +to a worker thread, allowing you to coordinate controllers from within an +`asyncio` event loop. + +```python +import asyncio +from nxbt import AsyncNxbtClient, PRO_CONTROLLER + + +async def main(): + async with AsyncNxbtClient(debug=False) as nx: + adapters = await nx.get_available_adapters() + index = await nx.create_controller(PRO_CONTROLLER, adapters[0]) + await nx.wait_for_connection(index) + await nx.macro(index, "A 0.1s\n0.1s") + +asyncio.run(main()) +``` + +Key points: + +1. Use `async with AsyncNxbtClient(...)` to ensure cleanup mirrors the previous + `atexit` behaviour (BlueZ toggles, runtime shutdown). +2. All high-level helpers (`macro`, `press_buttons`, `tilt_stick`, `set_controller_input`, + `wait_for_connection`, etc.) are now `await`-able. The `state` dict remains + synchronous for quick inspection without additional locking. +3. CLI utilities (`nxbt.cli` commands and `scripts/demo_loop.py`) already route + through `asyncio.run`, so they can be embedded inside larger event loops or + scripted via `asyncio.create_task`. + +### Compatibility + +The legacy `Nxbt` class still works for synchronous consumers and continues to +wrap the async controller manager internally. Downstream callers can migrate at +their own pace: + +- **Synchronous projects** – keep using `Nxbt` as before. +- **Async-aware projects** – switch to `AsyncNxbtClient` and await controller + operations directly. + +Future releases will update the TUI and web entry points to the async client as +well, completing Phase 4 of the refactor plan. diff --git a/docs/async_refactor_plan.md b/docs/async_refactor_plan.md new file mode 100644 index 0000000..af21ca1 --- /dev/null +++ b/docs/async_refactor_plan.md @@ -0,0 +1,37 @@ +# NXBT Async Refactor Plan (Checklist) + +## Phase 1 – Adapter & Utilities +- [x] Build a first-class `AsyncBleakAdapter` mirroring Bleak’s async API surface (scanner contexts, client connect/disconnect, GATT helpers). +- [x] Provide thin, clearly marked synchronous shims for legacy imports (current `BlueZ` now wraps `AsyncBleakAdapter`). +- [x] Audit helper functions (`find_objects`, `find_devices_by_alias`, discovery utilities) and offer async primitives with safe sync wrappers (`asyncio.run`). +- [ ] Validate helpers with demo scripts (`scripts/testbt.py`, `scanner.py`) across supported OSes. + +## Phase 2 – Controller & Bluetooth Stack +- [x] Convert controller modules (`controller.py`, `server.py`, protocol helpers) to async functions end-to-end. (`AsyncController` and `ControllerServer` now run entirely via asyncio, with sync shims retained for backwards compatibility.) +- [x] Introduce an `AsyncController` and ensure `ControllerServer` consumes it for setup. +- [x] Add an `AsyncControllerServer` facade so higher layers can await controller lifecycles. +- [x] Expose `run_async`/`connect_async`/`reconnect_async`/`mainloop_async` wrappers (no more thread offloading) to unblock higher-level async orchestration. +- [x] Replace blocking socket/BLE operations with `asyncio` sockets/tasks and cancellation-friendly loops (connect, reconnect, and mainloop now awaitable). +- [x] Document SDP/profile limitations: Bleak does not expose cross-platform profile registration, so `AsyncController` logs a warning and Phase 4 docs will direct Linux users to BlueZ if they need SDP features. + +## Phase 3 – Core Nxbt Process & IPC +- [x] Introduce an `AsyncNxbt` manager that spawns controller servers as asyncio tasks (`nxbt/async_nxbt.py`). +- [x] Provide a bridge in `Nxbt` (`use_async=True`) that routes controller creation, macro queues, and state tracking through the async manager. +- [x] Replace the legacy multiprocessing `Nxbt` manager entirely (or make async the default) so controllers run as tasks inside a single event loop. +- [x] Replace multiprocessing Queue/Lock coordination with `asyncio.Queue`, `asyncio.Lock`, or `TaskGroup` equivalents. +- [x] Ensure graceful shutdown awaits outstanding tasks and closes BLE clients cleanly in both modes. + +## Phase 4 – CLI, Scripts, and External APIs +- [ ] Update CLI commands, demo scripts, and web/tui entry points to drive the async core (wrap in `asyncio.run`). + - [x] CLI macros/test/demo and `scripts/demo_loop.py` now run under `asyncio` via `AsyncNxbtClient`. + - [x] Web app entry point now routes through a shared `AsyncNxbtClientBridge`. + - [x] `tui.py` uses the async bridge for controller management/input updates. +- [x] Revise public APIs in `nxbt/__init__.py` to expose async entry points (or clearly documented sync wrappers). +- [x] Provide migration notes guiding downstream users on awaiting the new APIs. +- [ ] Exercise async CLI/demo/TUI flows on real BLE hardware to catch regressions (blocked on hardware availability). + +## Phase 5 – Testing, Tooling, and Documentation +- [ ] Add async-aware tests (e.g., `pytest-asyncio`) covering discovery, controller lifecycles, and failure scenarios. +- [ ] Integrate async tests into CI with BLE-aware skips/mocks where hardware is unavailable. +- [ ] Update README/docs to emphasize the async model, environment requirements, and Bleak-based examples. +- [ ] Final cleanup: remove obsolete BlueZ-only utilities, ensure lint/type tools understand async interfaces, and tag a release with migration guidance. diff --git a/nxbt/__init__.py b/nxbt/__init__.py index 0582752..e848733 100644 --- a/nxbt/__init__.py +++ b/nxbt/__init__.py @@ -10,3 +10,6 @@ from .nxbt import Sticks from .nxbt import JOYCON_L from .nxbt import JOYCON_R from .nxbt import PRO_CONTROLLER +from .async_nxbt import AsyncNxbt +from .async_client import AsyncNxbtClient +from .async_bridge import AsyncNxbtClientBridge diff --git a/nxbt/async_bridge.py b/nxbt/async_bridge.py new file mode 100644 index 0000000..38ad007 --- /dev/null +++ b/nxbt/async_bridge.py @@ -0,0 +1,93 @@ +import asyncio +import atexit +import threading +from typing import Any, Callable + +from .async_client import AsyncNxbtClient + + +class AsyncNxbtClientBridge: + """Thread-safe synchronous bridge around ``AsyncNxbtClient``. + + This helper keeps an asyncio event loop alive on a background thread so that + synchronous code (web handlers, TUIs, etc.) can reuse the async client + without spinning up short-lived loops. + """ + + def __init__(self, **client_kwargs: Any): + self._loop = asyncio.new_event_loop() + self._thread = threading.Thread( + target=self._run_loop, name="nxbt-async-bridge", daemon=True + ) + self._thread.start() + + self._client = AsyncNxbtClient(**client_kwargs) + self._closed = False + atexit.register(self.close) + + def _run_loop(self): + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + + def _run(self, coro): + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + return future.result() + + @property + def state(self): + return self._client.state + + # Controller lifecycle helpers ------------------------------------------------- + def create_controller(self, *args, **kwargs): + return self._run(self._client.create_controller(*args, **kwargs)) + + def remove_controller(self, controller_index): + return self._run(self._client.remove_controller(controller_index)) + + def wait_for_connection(self, controller_index): + return self._run(self._client.wait_for_connection(controller_index)) + + # Input helpers ---------------------------------------------------------------- + def macro(self, *args, **kwargs): + return self._run(self._client.macro(*args, **kwargs)) + + def set_controller_input(self, controller_index, input_packet): + return self._run( + self._client.set_controller_input(controller_index, input_packet) + ) + + def create_input_packet(self): + return self._run(self._client.create_input_packet()) + + def press_buttons(self, *args, **kwargs): + return self._run(self._client.press_buttons(*args, **kwargs)) + + def tilt_stick(self, *args, **kwargs): + return self._run(self._client.tilt_stick(*args, **kwargs)) + + def stop_macro(self, controller_index, macro_id, block=True): + return self._run(self._client.stop_macro(controller_index, macro_id, block)) + + def clear_macros(self, controller_index): + return self._run(self._client.clear_macros(controller_index)) + + def clear_all_macros(self): + return self._run(self._client.clear_all_macros()) + + # Discovery helpers ------------------------------------------------------------ + def get_available_adapters(self): + return self._run(self._client.get_available_adapters()) + + def get_switch_addresses(self): + return self._run(self._client.get_switch_addresses()) + + # Shutdown --------------------------------------------------------------------- + def close(self): + if self._closed: + return + self._closed = True + try: + self._run(self._client.close()) + finally: + self._loop.call_soon_threadsafe(self._loop.stop) + self._thread.join() diff --git a/nxbt/async_client.py b/nxbt/async_client.py new file mode 100644 index 0000000..964dad7 --- /dev/null +++ b/nxbt/async_client.py @@ -0,0 +1,71 @@ +import asyncio +from contextlib import AbstractAsyncContextManager +from typing import Any + +from .nxbt import Nxbt + + +class AsyncNxbtClient(AbstractAsyncContextManager): + """Async wrapper around the synchronous ``Nxbt`` facade.""" + + def __init__(self, **nxbt_kwargs: Any): + self._nxbt = Nxbt(**nxbt_kwargs) + self._closed = False + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() + + @property + def state(self): + return self._nxbt.state + + async def close(self): + if self._closed: + return + await asyncio.to_thread(self._nxbt.shutdown) + self._closed = True + + async def _call(self, func, *args, **kwargs): + return await asyncio.to_thread(func, *args, **kwargs) + + async def create_controller(self, *args, **kwargs): + return await self._call(self._nxbt.create_controller, *args, **kwargs) + + async def remove_controller(self, controller_index): + await self._call(self._nxbt.remove_controller, controller_index) + + async def macro(self, controller_index, macro, block=True): + return await self._call(self._nxbt.macro, controller_index, macro, block) + + async def press_buttons(self, *args, **kwargs): + return await self._call(self._nxbt.press_buttons, *args, **kwargs) + + async def tilt_stick(self, *args, **kwargs): + return await self._call(self._nxbt.tilt_stick, *args, **kwargs) + + async def stop_macro(self, controller_index, macro_id, block=True): + await self._call(self._nxbt.stop_macro, controller_index, macro_id, block) + + async def clear_macros(self, controller_index): + await self._call(self._nxbt.clear_macros, controller_index) + + async def clear_all_macros(self): + await self._call(self._nxbt.clear_all_macros) + + async def set_controller_input(self, controller_index, input_packet): + await self._call(self._nxbt.set_controller_input, controller_index, input_packet) + + async def create_input_packet(self): + return await self._call(self._nxbt.create_input_packet) + + async def wait_for_connection(self, controller_index): + await self._call(self._nxbt.wait_for_connection, controller_index) + + async def get_available_adapters(self): + return await self._call(self._nxbt.get_available_adapters) + + async def get_switch_addresses(self): + return await self._call(self._nxbt.get_switch_addresses) diff --git a/nxbt/async_manager.py b/nxbt/async_manager.py new file mode 100644 index 0000000..0377e4f --- /dev/null +++ b/nxbt/async_manager.py @@ -0,0 +1,76 @@ +import asyncio + +from .controller import ControllerTypes, AsyncControllerServer +from .logging import create_logger + + +class AsyncManager: + """Async replacement for the multiprocessing-based Nxbt manager.""" + + def __init__(self, debug=False, log_to_file=False, disable_logging=False): + self.logger = create_logger( + debug=debug, log_to_file=log_to_file, disable_logging=disable_logging + ) + self._controller_counter = 0 + self._controllers = {} + self._tasks = {} + self._state = {} + self._lock = asyncio.Lock() + + @property + def state(self): + return self._state + + async def create_controller(self, controller_type: ControllerTypes, + adapter_path="/org/bluez/hci0", + reconnect_address=None): + async with self._lock: + index = self._controller_counter + self._controller_counter += 1 + + controller_state = { + "state": "initializing", + "finished_macros": [], + "errors": None, + "direct_input": None, + } + + server = AsyncControllerServer( + controller_type, + adapter_path=adapter_path, + state=controller_state, + ) + + self._controllers[index] = server + self._state[index] = controller_state + self._tasks[index] = asyncio.create_task( + self._run_controller(index, server, reconnect_address) + ) + return index + + async def _run_controller(self, index, server, reconnect_address): + try: + await server.run(reconnect_address) + except asyncio.CancelledError: + await server.stop() + raise + except Exception: + state = self._state.get(index) + if state is not None: + state["state"] = "crashed" + self.logger.exception("Controller %s crashed", index) + + async def remove_controller(self, index): + task = self._tasks.pop(index, None) + controller = self._controllers.pop(index, None) + self._state.pop(index, None) + + if task: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + if controller: + await controller.stop() + diff --git a/nxbt/async_nxbt.py b/nxbt/async_nxbt.py new file mode 100644 index 0000000..5c645cb --- /dev/null +++ b/nxbt/async_nxbt.py @@ -0,0 +1,146 @@ +import asyncio +from typing import Dict, Optional + +from .controller import ControllerTypes, AsyncControllerServer +from .logging import create_logger + + +class AsyncNxbt: + """Async-native Nxbt manager that orchestrates controller tasks.""" + + def __init__(self, debug=False, log_to_file=False, disable_logging=False): + self.debug = debug + self.logger = create_logger( + debug=self.debug, log_to_file=log_to_file, disable_logging=disable_logging + ) + + self._controller_counter = 0 + self._controllers: Dict[int, AsyncControllerServer] = {} + self._controller_states: Dict[int, dict] = {} + self._controller_tasks: Dict[int, asyncio.Task] = {} + self._controller_queues: Dict[int, asyncio.Queue] = {} + self._lock = asyncio.Lock() + + @property + def state(self): + return self._controller_states + + async def create_controller( + self, + controller_type: ControllerTypes, + *, + adapter_path="/org/bluez/hci0", + reconnect_address=None, + colour_body=None, + colour_buttons=None, + controller_index: Optional[int] = None, + state: Optional[dict] = None, + lock=None, + ) -> int: + async with self._lock: + if controller_index is None: + controller_index = self._controller_counter + self._controller_counter += 1 + else: + self._controller_counter = max( + self._controller_counter, controller_index + 1 + ) + + controller_state = state or { + "state": "initializing", + "finished_macros": [], + "errors": None, + "direct_input": None, + "colour_body": colour_body, + "colour_buttons": colour_buttons, + "type": str(controller_type), + "adapter_path": adapter_path, + "last_connection": None, + } + + command_queue: asyncio.Queue = asyncio.Queue() + + server = AsyncControllerServer( + controller_type, + adapter_path=adapter_path, + lock=lock, + task_queue=command_queue, + state=controller_state, + colour_body=colour_body, + colour_buttons=colour_buttons, + ) + + self._controllers[controller_index] = server + self._controller_states[controller_index] = controller_state + self._controller_queues[controller_index] = command_queue + + task = asyncio.create_task( + self._run_controller(controller_index, server, reconnect_address) + ) + self._controller_tasks[controller_index] = task + return controller_index + + async def _run_controller(self, index, server, reconnect_address): + try: + await server.run(reconnect_address) + except asyncio.CancelledError: + await server.stop() + raise + except Exception: + controller_state = self._controller_states.get(index) + if controller_state is not None: + controller_state["state"] = "crashed" + self.logger.exception("Controller %s crashed", index) + + def _get_queue(self, index): + queue = self._controller_queues.get(index) + if queue is None: + raise ValueError(f"Controller {index} does not exist") + return queue + + async def queue_macro(self, index, macro, macro_id): + queue = self._get_queue(index) + await queue.put({ + "type": "macro", + "macro": macro, + "macro_id": macro_id, + }) + + async def stop_macro(self, index, macro_id): + queue = self._get_queue(index) + await queue.put({ + "type": "stop", + "macro_id": macro_id, + }) + + async def clear_macros(self, index): + queue = self._get_queue(index) + await queue.put({"type": "clear"}) + + async def remove_controller(self, index): + queue = self._controller_queues.pop(index, None) + if queue: + while not queue.empty(): + try: + queue.get_nowait() + except asyncio.QueueEmpty: + break + + task = self._controller_tasks.pop(index, None) + controller = self._controllers.pop(index, None) + self._controller_states.pop(index, None) + + if task: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + if controller: + await controller.stop() + + async def shutdown(self): + await asyncio.gather( + *(self.remove_controller(idx) for idx in list(self._controllers.keys())) + ) diff --git a/nxbt/async_runtime.py b/nxbt/async_runtime.py new file mode 100644 index 0000000..e9456aa --- /dev/null +++ b/nxbt/async_runtime.py @@ -0,0 +1,21 @@ +import asyncio +import threading + + +class AsyncRuntime: + """Runs an asyncio loop in a background thread for sync callers.""" + + def __init__(self): + self._loop = asyncio.new_event_loop() + self._thread = threading.Thread( + target=self._loop.run_forever, daemon=True) + self._thread.start() + + def submit(self, coro): + """Submit a coroutine to the background loop and return its result.""" + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + return future.result() + + def shutdown(self): + self._loop.call_soon_threadsafe(self._loop.stop) + self._thread.join() diff --git a/nxbt/bluez.py b/nxbt/bluez.py index 9f96748..4c36079 100644 --- a/nxbt/bluez.py +++ b/nxbt/bluez.py @@ -1,13 +1,13 @@ -import subprocess -import re -import os -import time +import asyncio import logging -from shutil import which +import os import random -from pathlib import Path +import time +import threading +from typing import Dict, Optional -import dbus +from bleak import BleakClient, BleakScanner +from bleak.exc import BleakError SERVICE_NAME = "org.bluez" @@ -16,904 +16,541 @@ ADAPTER_INTERFACE = SERVICE_NAME + ".Adapter1" PROFILEMANAGER_INTERFACE = SERVICE_NAME + ".ProfileManager1" DEVICE_INTERFACE = SERVICE_NAME + ".Device1" +_BLEAK_SCAN_TIMEOUT = float(os.environ.get("NXBT_BLEAK_SCAN_TIMEOUT", 5)) +_BLEAK_ADAPTER_URI_PREFIX = "bleak://adapter/" +_DEFAULT_ADAPTER_URI = os.environ.get( + "NXBT_ADAPTER_PATH", f"{_BLEAK_ADAPTER_URI_PREFIX}default") +_ADAPTER_QUERY_DELAY = float(os.environ.get("NXBT_BLEAK_ADAPTER_QUERY_DELAY", "0.2")) -def find_object_path(bus, service_name, interface_name, object_name=None): - """Searches for a D-Bus object path that contains a specified interface - under a specified service. - :param bus: A DBus object used to access the DBus. - :type bus: DBus - :param service_name: The name of a D-Bus service to search for the - object path under. - :type service_name: string - :param interface_name: The name of a D-Bus interface to search for - within objects under the specified service. - :type interface_name: string - :param object_name: The name or ending of the object path, - defaults to None - :type object_name: string, optional - :return: The D-Bus object path or None, if no matching object - can be found - :rtype: string +def _adapter_identifier_from_uri(path: Optional[str]) -> Optional[str]: + if not path: + return None + if path.startswith(_BLEAK_ADAPTER_URI_PREFIX): + identifier = path[len(_BLEAK_ADAPTER_URI_PREFIX):] + if identifier in ("", "default"): + return None + return identifier + if path.startswith("/org/bluez/"): + return None + return path + + +def _run_asyncio_task(coro): + """Execute *coro* on a dedicated event loop. + + Bleak is an asyncio-first library (see the Bleak docs under + "Getting Started"), so all synchronous helpers in this module use a + private loop to bridge the gap. """ - manager = dbus.Interface( - bus.get_object(service_name, "/"), - "org.freedesktop.DBus.ObjectManager") + try: + running_loop = asyncio.get_running_loop() + except RuntimeError: + running_loop = None - # Iterating over objects under the specified service - # and searching for the specified interface - for path, ifaces in manager.GetManagedObjects().items(): - managed_interface = ifaces.get(interface_name) - if managed_interface is None: - continue - # If the object name wasn't specified or it matches - # the interface address or the path ending - elif (not object_name or - object_name == managed_interface["Address"] or - path.endswith(object_name)): - obj = bus.get_object(service_name, path) - return dbus.Interface(obj, interface_name).object_path + # When called from a thread where an event loop is already running + # (e.g., AsyncNxbt's runtime thread), spin up a helper thread to + # execute the coroutine so we don't violate asyncio's single-loop rule. + if running_loop and running_loop.is_running(): + result_holder = {} + exception_holder = {} + def _thread_main(): + thread_loop = asyncio.new_event_loop() + try: + result_holder["value"] = thread_loop.run_until_complete(coro) + thread_loop.run_until_complete(thread_loop.shutdown_asyncgens()) + except Exception as exc: # pragma: no cover - best effort propagation + exception_holder["error"] = exc + finally: + thread_loop.close() + + worker = threading.Thread(target=_thread_main, daemon=True) + worker.start() + worker.join() + if "error" in exception_holder: + raise exception_holder["error"] + return result_holder.get("value") + + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + try: + loop.run_until_complete(loop.shutdown_asyncgens()) + except (RuntimeError, AttributeError): + pass + loop.close() + + +def _scanner_kwargs(adapter: Optional[str] = None) -> Dict[str, str]: + adapter_hint = adapter or os.environ.get("NXBT_BLEAK_ADAPTER") + return {"adapter": adapter_hint} if adapter_hint else {} + + +def _client_kwargs(adapter: Optional[str] = None) -> Dict[str, str]: + return _scanner_kwargs(adapter) + + +def _adapter_paths_from_env(): + adapter_hint = os.environ.get("NXBT_BLEAK_ADAPTER") + if adapter_hint: + return [f"{_BLEAK_ADAPTER_URI_PREFIX}{adapter_hint}"] + return [_DEFAULT_ADAPTER_URI] + + +def _match_object_path(paths, object_name): + if not paths: + return None + if object_name is None: + return paths[0] + for path in paths: + if path.endswith(object_name) or path == object_name: + return path return None -def find_objects(bus, service_name, interface_name): - """Searches for D-Bus objects that contain a specified interface - under a specified service. +async def _async_fetch_adapter_address(adapter_identifier: Optional[str]) -> Optional[str]: + scanner = BleakScanner(**_scanner_kwargs(adapter_identifier)) + await scanner.start() + try: + backend = getattr(scanner, "_backend", None) + adapter = getattr(backend, "adapter", None) + if adapter and getattr(adapter, "address", None): + return adapter.address.upper() + await asyncio.sleep(_ADAPTER_QUERY_DELAY) + address = getattr(adapter, "address", None) + return address.upper() if address else None + finally: + await scanner.stop() - :param bus: A DBus object used to access the DBus. - :type bus: DBus - :param service_name: The name of a D-Bus service to search for the - object path under. - :type service_name: string - :param interface_name: The name of a D-Bus interface to search for - within objects under the specified service. - :type interface_name: string - :return: The D-Bus object paths matching the arguments - :rtype: array - """ - manager = dbus.Interface( - bus.get_object(service_name, "/"), - "org.freedesktop.DBus.ObjectManager") - paths = [] +def _bleak_path(address: str) -> str: + return f"bleak://{address.upper()}" - # Iterating over objects under the specified service - # and searching for the specified interface within them - for path, ifaces in manager.GetManagedObjects().items(): - managed_interface = ifaces.get(interface_name) - if managed_interface is None: + +def _bleak_device_entry(device) -> Dict[str, object]: + props: Dict[str, object] = { + "Address": device.address.upper(), + "Alias": getattr(device, "name", None) or device.address or "", + "Paired": False, + "Connected": False, + } + metadata = getattr(device, "metadata", {}) or {} + uuids = metadata.get("uuids") or metadata.get("UUIDs") + if uuids: + props["UUIDs"] = uuids + rssi = getattr(device, "rssi", None) + if rssi is not None: + props["RSSI"] = rssi + manufacturer_data = metadata.get("manufacturer_data") + if manufacturer_data: + props["ManufacturerData"] = manufacturer_data + return props + + +async def async_discover_devices_cross_platform(alias=None, timeout=None, adapter=None): + """Async variant of Bleak-powered discovery.""" + + timeout = timeout or _BLEAK_SCAN_TIMEOUT + alias_upper = alias.upper() if alias else None + devices: Dict[str, Dict[str, object]] = {} + results = await BleakScanner.discover(timeout=timeout, **_scanner_kwargs(adapter)) + for device in results: + device_alias = (device.name or "").upper() + if alias_upper and device_alias != alias_upper: continue - else: - obj = bus.get_object(service_name, path) - path = str(dbus.Interface(obj, interface_name).object_path) - paths.append(path) + devices[_bleak_path(device.address)] = _bleak_device_entry(device) + return devices - return paths + +def discover_devices_cross_platform(alias=None, timeout=None, adapter=None): + """Synchronous helper that wraps ``async_discover_devices_cross_platform``.""" + + return _run_asyncio_task( + async_discover_devices_cross_platform(alias=alias, timeout=timeout, adapter=adapter) + ) + + +async def async_find_object_path(bus, service_name, interface_name, object_name=None): + objects = await async_find_objects(bus, service_name, interface_name) + return _match_object_path(objects, object_name) + + +def find_object_path(bus, service_name, interface_name, object_name=None): + objects = find_objects(bus, service_name, interface_name) + return _match_object_path(objects, object_name) + + +async def async_find_objects(bus, service_name, interface_name): + if interface_name == ADAPTER_INTERFACE: + return _adapter_paths_from_env() + if interface_name == DEVICE_INTERFACE: + devices = await async_discover_devices_cross_platform() + return list(devices.keys()) + return [] + + +def find_objects(bus, service_name, interface_name): + if interface_name == ADAPTER_INTERFACE: + return _adapter_paths_from_env() + if interface_name == DEVICE_INTERFACE: + return list(discover_devices_cross_platform().keys()) + return [] def toggle_clean_bluez(toggle): - """Enables or disables all BlueZ plugins, - BlueZ compatibility mode, and removes all extraneous - SDP Services offered. - Requires root user to be run. The units and Bluetooth - service will not be restarted if the input plugin - already matches the toggle. - - :param toggle: A boolean element indicating if BlueZ - should be cleaned (True) or not (False) - :type toggle: boolean - :raises PermissionError: If the user is not root - :raises Exception: If the units can't be reloaded - :raises Exception: If sdptool, hciconfig, or hcitool are not available. - """ - - service_path = "/lib/systemd/system/bluetooth.service" - override_dir = Path("/run/systemd/system/bluetooth.service.d") - override_path = override_dir / "nxbt.conf" - - if toggle: - if override_path.is_file(): - # Override exist, no need to restart bluetooth - return - - with open(service_path) as f: - for line in f: - if line.startswith("ExecStart="): - exec_start = line.strip() + " --compat --noplugin=*" - break - else: - raise Exception("systemd service file doesn't have a ExecStart line") - - override = f"[Service]\nExecStart=\n{exec_start}" - - override_dir.mkdir(parents=True, exist_ok=True) - with override_path.open("w") as f: - f.write(override) - else: - try: - os.remove(override_path) - except FileNotFoundError: - # Override doesn't exist, no need to restart bluetooth - return - - # Reload units - _run_command(["systemctl", "daemon-reload"]) - - # Reload the bluetooth service with input disabled - _run_command(["systemctl", "restart", "bluetooth"]) - - # Kill a bit of time here to ensure all services have restarted - time.sleep(0.5) + logging.getLogger('nxbt').warning( + "toggle_clean_bluez() is not supported when using the Bleak backend. " + "Requested toggle=%s; no action taken.", toggle) def clean_sdp_records(): - """Cleans all SDP Records from BlueZ with sdptool - - :raises Exception: On CLI error or sdptool missing - """ - # TODO: sdptool is deprecated in BlueZ 5. This should ideally - # use the DBus API, however, bugs seemingly exist with the - # UnregisterProfile interface. - - # Check if sdptool is available for use - if which("sdptool") is None: - raise Exception("sdptool is not available on this system." + - "If you can, please install this tool, as " + - "it is required for proper functionality.") - - # Enable Read/Write to the SDP server. This is a remedy for a - # compatibility mode bug introduced in later versions of BlueZ 5 - _run_command(["chmod", "777", "/var/run/sdp"]) - - # Identify/List all SDP services available with sdptool - result = _run_command(['sdptool', 'browse', 'local']).stdout.decode('utf-8') - if result is None or len(result.split('\n\n')) < 1: - return - - # Record all service record handles - exceptions = ["PnP Information"] - service_rec_handles = [] - for rec in result.split('\n\n'): - # Skip if exception is in record - exception_found = False - for exception in exceptions: - if exception in rec: - exception_found = True - break - if exception_found: - continue - - # Read lines and add Record Handles to the list - for line in rec.split('\n'): - if "Service RecHandle" in line: - service_rec_handles.append(line.split(" ")[2]) - - # Delete all found service records - if len(service_rec_handles) > 0: - for record_handle in service_rec_handles: - _run_command(['sdptool', 'del', record_handle]) - - -def _run_command(command): - """Runs a specified command on the shell of the system. - If the command is run unsuccessfully, an error is raised. - The command must be in the form of an array with each term - individually listed. Eg: ["which", "bash"] - - :param command: A list of command terms - :type command: list - :raises Exception: On command failure or error - """ - result = subprocess.run( - command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - - cmd_err = result.stderr.decode("utf-8").replace("\n", "") - if cmd_err != "": - raise Exception(cmd_err) - - return result + logging.getLogger('nxbt').warning( + "clean_sdp_records() relied on sdptool and is not available when " + "running with Bleak.") def get_random_controller_mac(): - """Generates a random Switch-compliant MAC address - """ def seg(): - random_number = random.randint(0,255) - hex_number = str(hex(random_number)) - hex_number = hex_number[2:].upper() - return str(hex_number) - + random_number = random.randint(0, 255) + hex_number = hex(random_number)[2:].upper() + if len(hex_number) == 1: + hex_number = "0" + hex_number + return hex_number + return f"7C:BB:8A:{seg()}:{seg()}:{seg()}" def replace_mac_addresses(adapter_paths, addresses): - """Replaces a list of adapter's Bluetooth MAC addresses - with Switch-compliant Controller MAC addresses. If the - addresses argument is specified, the adapter path's - MAC addresses will be reset to respective (index-wise) - address in the list. - - :param adapter_paths: A list of Bluetooth adapter paths - :type adapter_paths: list - :param addresses: A list of Bluetooth MAC addresses, - defaults to False - :type addresses: bool, optional - """ - if which("hcitool") is None: - raise Exception("hcitool is not available on this system." + - "If you can, please install this tool, as " + - "it is required for proper functionality.") - if which("hciconfig") is None: - raise Exception("hciconfig is not available on this system." + - "If you can, please install this tool, as " + - "it is required for proper functionality.") - - if addresses: - assert len(addresses) == len(adapter_paths) - - for i in range(len(adapter_paths)): - adapter_id = adapter_paths[i].split('/')[-1] - mac = addresses[i].split(':') - cmds = ['hcitool', '-i', adapter_id, 'cmd', '0x3f', '0x001', - f'0x{mac[5]}',f'0x{mac[4]}',f'0x{mac[3]}',f'0x{mac[2]}', - f'0x{mac[1]}',f'0x{mac[0]}'] - _run_command(cmds) - _run_command(['hciconfig', adapter_id, 'reset']) + logging.getLogger('nxbt').warning( + "replace_mac_addresses() is unsupported with Bleak; no MAC spoofing is performed.") -def find_devices_by_alias(alias, return_path=False, created_bus=None): - """Finds the Bluetooth addresses of devices - that have a specified Bluetooth alias. Aliases - are converted to uppercase before comparison - as BlueZ usually converts aliases to uppercase. - - :param address: The Bluetooth MAC address - :type address: string - :return: The path to the D-Bus object or None - :rtype: string or None - """ - - if created_bus is not None: - bus = created_bus - else: - bus = dbus.SystemBus() - # Find all connected/paired/discovered devices - devices = find_objects( - bus, - SERVICE_NAME, - DEVICE_INTERFACE) - +async def async_find_devices_by_alias(alias, return_path=False): + alias_upper = alias.upper() addresses = [] matching_paths = [] - for path in devices: - # Get the device's address and paired status - device_props = dbus.Interface( - bus.get_object(SERVICE_NAME, path), - "org.freedesktop.DBus.Properties") - device_alias = device_props.Get( - DEVICE_INTERFACE, - "Alias").upper() - device_addr = device_props.Get( - DEVICE_INTERFACE, - "Address").upper() - - # Check for an address match - if device_alias.upper() == alias.upper(): - addresses.append(device_addr) - matching_paths.append(path) - - # Close the dbus connection if we created one - if created_bus is None: - bus.close() + devices = await async_discover_devices_cross_platform(alias=alias_upper) + for path, props in devices.items(): + addresses.append(props["Address"]) + matching_paths.append(path) if return_path: return addresses, matching_paths - else: - return addresses + return addresses + + +def find_devices_by_alias(alias, return_path=False, created_bus=None): + return _run_asyncio_task( + async_find_devices_by_alias(alias, return_path=return_path) + ) def disconnect_devices_by_alias(alias, created_bus=None): - """Disconnects all devices matching an alias. - - :param alias: The device's alias - :type alias: string - """ - - if created_bus is not None: - bus = created_bus - else: - bus = dbus.SystemBus() - # Find all connected/paired/discovered devices - devices = find_objects( - bus, - SERVICE_NAME, - DEVICE_INTERFACE) - - addresses = [] - matching_paths = [] - for path in devices: - # Get the device's address and paired status - device_props = dbus.Interface( - bus.get_object(SERVICE_NAME, path), - "org.freedesktop.DBus.Properties") - device_alias = device_props.Get( - DEVICE_INTERFACE, - "Alias").upper() - - # Check for an alias match - if device_alias.upper() == alias.upper(): - device = dbus.Interface( - bus.get_object(SERVICE_NAME, path), - DEVICE_INTERFACE) - try: - device.Disconnect() - except Exception as e: - print(e) - - # Close the dbus connection if we created one - if created_bus is None: - bus.close() + logging.getLogger('nxbt').warning( + "disconnect_devices_by_alias('%s') is not implemented for the Bleak backend.", alias) -class BlueZ(): - """Exposes the BlueZ D-Bus API as a Python object. - """ +class AsyncBleakAdapter: + """Async-first Bluetooth helper built on top of Bleak.""" def __init__(self, adapter_path="/org/bluez/hci0"): self.logger = logging.getLogger('nxbt') + requested_path = adapter_path or _DEFAULT_ADAPTER_URI + self.device_path = requested_path + self._adapter_identifier = _adapter_identifier_from_uri(requested_path) + if self._adapter_identifier is None: + env_override = os.environ.get("NXBT_BLEAK_ADAPTER") + if env_override: + self._adapter_identifier = env_override - self.bus = dbus.SystemBus() - self.device_path = adapter_path + self._address_override = os.environ.get("NXBT_ADAPTER_ADDRESS") + self._name = os.environ.get("NXBT_ADAPTER_NAME", "nxbt-adapter") + self._alias = os.environ.get("NXBT_ADAPTER_ALIAS", "NXBT Adapter") + self._pairable = True + self._pairable_timeout = 0 + self._discoverable = False + self._discoverable_timeout = 180 + self._powered = True + self._warned_connections = False - # If we weren't able to find an adapter with the specified ID, - # try to find any usable Bluetooth adapter - if self.device_path is None: - self.device_path = find_object_path( - self.bus, - SERVICE_NAME, - ADAPTER_INTERFACE) + def _scanner_kwargs(self): + return _scanner_kwargs(self._adapter_identifier) - # If we aren't able to find an adapter still - if self.device_path is None: - raise Exception("Unable to find a bluetooth adapter") + def _client_kwargs(self): + return _client_kwargs(self._adapter_identifier) - # Load the adapter's interface - self.logger.debug(f"Using adapter under object path: {self.device_path}") - self.device = dbus.Interface( - self.bus.get_object( - SERVICE_NAME, - self.device_path), - "org.freedesktop.DBus.Properties") + async def get_address(self): + if self._address_override: + return self._address_override.upper() - self.device_id = self.device_path.split("/")[-1] + try: + address = await _async_fetch_adapter_address(self._adapter_identifier) + except BleakError as exc: + self.logger.warning("Unable to query adapter address via Bleak: %s", exc) + address = None - # Load the ProfileManager interface - self.profile_manager = dbus.Interface(self.bus.get_object( - SERVICE_NAME, BLUEZ_OBJECT_PATH), - PROFILEMANAGER_INTERFACE) + if not address: + self.logger.warning( + "Falling back to 00:00:00:00:00:00; set NXBT_ADAPTER_ADDRESS to override.") + address = "00:00:00:00:00:00" - self.adapter = dbus.Interface( - self.bus.get_object( - SERVICE_NAME, - self.device_path), - ADAPTER_INTERFACE) + self._address_override = address + return self._address_override - @property - def address(self): - """Gets the Bluetooth MAC address of the Bluetooth adapter. + async def discover_devices(self, alias=None, timeout=10, callback=None): + devices: Dict[str, Dict[str, object]] = {} + loop = asyncio.get_running_loop() + end_time = loop.time() + timeout + try: + while True: + now = loop.time() + if now >= end_time: + break + slice_timeout = min(1.0, max(0.1, end_time - now)) + try: + snapshot = await async_discover_devices_cross_platform( + alias=None, + timeout=slice_timeout, + adapter=self._adapter_identifier) + except BleakError as exc: + self.logger.warning("Bleak discovery failed: %s", exc) + break + else: + devices.update(snapshot) + if callback: + callback(devices) + finally: + if callback: + callback(devices) - :return: The Bluetooth Adapter's MAC address - :rtype: string - """ + if alias: + alias_upper = alias.upper() + devices = { + path: props for path, props in devices.items() + if props.get("Alias", "").upper() == alias_upper + } + return devices - return self.device.Get(ADAPTER_INTERFACE, "Address").upper() + async def get_discovered_devices(self): + try: + return await async_discover_devices_cross_platform( + adapter=self._adapter_identifier) + except BleakError as exc: + self.logger.warning("Bleak discovery failed: %s", exc) + return {} - def set_address(self, mac): - """Sets the Bluetooth MAC address of the Bluetooth adapter. - The hciconfig CLI is required for setting the address. - For changes to apply, the Bluetooth interface needs to be - restarted. + async def _connect_once(self, address, timeout=10): + client = BleakClient(address, **self._client_kwargs()) + try: + await client.connect(timeout=timeout) + finally: + await client.disconnect() - :param mac: A Bluetooth MAC address in - the form of "XX:XX:XX:XX:XX:XX - :type mac: str - :raises PermissionError: On run as non-root user - :raises Exception: On CLI errors - """ - if which("hcitool") is None: - raise Exception("hcitool is not available on this system." + - "If you can, please install this tool, as " + - "it is required for proper functionality.") - # Reverse MAC (element position-wise) for use with hcitool - mac = mac.split(":") - cmds = ['hcitool', '-i', self.device_id, 'cmd', '0x3f', '0x001', - f'0x{mac[5]}',f'0x{mac[4]}',f'0x{mac[3]}',f'0x{mac[2]}', - f'0x{mac[1]}',f'0x{mac[0]}'] - _run_command(cmds) - _run_command(['hciconfig', self.device_id, 'reset']) + def _address_from_path(self, device_path): + if device_path.startswith("bleak://"): + return device_path.split("bleak://", 1)[1] + return device_path - def set_class(self, device_class): - if which("hciconfig") is None: - raise Exception("hciconfig is not available on this system." + - "If you can, please install this tool, as " + - "it is required for proper functionality.") - _run_command(['hciconfig', self.device_id, 'class', device_class]) + async def pair_device(self, device_path): + address = self._address_from_path(device_path) + self.logger.info("Pairing via Bleak to %s", address) + await self._connect_once(address) - def reset_adapter(self): - if which("hciconfig") is None: - raise Exception("hciconfig is not available on this system." + - "If you can, please install this tool, as " + - "it is required for proper functionality.") - _run_command(['hciconfig', self.device_id, 'reset']) + async def connect_device(self, device_path): + address = self._address_from_path(device_path) + self.logger.info("Connecting via Bleak to %s", address) + await self._connect_once(address) - @property - def name(self): - """Gets the name of the Bluetooth adapter. + async def find_device_by_address(self, address): + address_upper = address.upper() + devices = await self.get_discovered_devices() + for path, props in devices.items(): + if props["Address"].upper() == address_upper: + return path + return None - :return: The name of the Bluetooth adapter. - :rtype: string - """ + async def find_connected_devices(self, alias_filter=False): + if not self._warned_connections: + self.logger.warning( + "Bleak cannot enumerate incoming peripheral connections; returning [].") + self._warned_connections = True + return [] - return self.device.Get(ADAPTER_INTERFACE, "Name") + # The remaining configuration helpers are synchronous as they only mutate + # cached state or log a warning. This mirrors the legacy BlueZ API. + def set_alias(self, value): + self._alias = value @property def alias(self): - """Gets the alias of the Bluetooth adapter. This value is used - as the "friendly" name of the adapter when communicating over - Bluetooth. + return self._alias - :return: The adapter's alias - :rtype: string - """ + @property + def name(self): + return self._name - return self.device.Get(ADAPTER_INTERFACE, "Alias") - - def set_alias(self, value): - """Asynchronously sets the alias of the Bluetooth adapter. - If you wish to check the set value, a time delay is needed - before the alias getter is run. - - :param value: The new value to be set as the adapter's alias - :type value: string - """ - - self.device.Set(ADAPTER_INTERFACE, "Alias", value) + def set_pairable(self, value): + self._pairable = bool(value) @property def pairable(self): - """Gets the pairable status of the Bluetooth adapter. + return self._pairable - :return: A boolean value representing if the adapter is set as - pairable or not - :rtype: boolean - """ - - return bool(self.device.Get(ADAPTER_INTERFACE, "Pairable")) - - def set_pairable(self, value): - """Sets the pariable boolean status of the Bluetooth adapter. - - :param value: A boolean value representing if the adapter is - pairable or not. - :type value: boolean - """ - - dbus_value = dbus.Boolean(value) - self.device.Set(ADAPTER_INTERFACE, "Pairable", dbus_value) + def set_pairable_timeout(self, value): + self._pairable_timeout = int(value) @property def pairable_timeout(self): - """Gets the timeout time (in seconds) for how long the adapter - should remain as pairable. Defaults to 0 (no timeout). + return self._pairable_timeout - :return: The pairable timeout in seconds - :rtype: int - """ - - return self.device.Get(ADAPTER_INTERFACE, "PairableTimeout") - - def set_pairable_timeout(self, value): - """Sets the timeout time (in seconds) for the pairable property. - - :param value: The pairable timeout value in seconds - :type value: int - """ - - dbus_value = dbus.UInt32(value) - self.device.Set(ADAPTER_INTERFACE, "PairableTimeout", dbus_value) + def set_discoverable(self, value): + self._discoverable = bool(value) + if value: + self.logger.warning("Bleak does not expose a general-purpose GAP advertiser.") @property def discoverable(self): - """Gets the discoverable status of the Bluetooth adapter + return self._discoverable - :return: The boolean status of the discoverable status - :rtype: boolean - """ - - return bool(self.device.Get(ADAPTER_INTERFACE, "Discoverable")) - - def set_discoverable(self, value): - """Sets the discoverable boolean status of the Bluetooth adapter. - - :param value: A boolean value representing if the Bluetooth adapter - is discoverable or not. - :type value: boolean - """ - - dbus_value = dbus.Boolean(value) - self.device.Set(ADAPTER_INTERFACE, "Discoverable", dbus_value) + def set_discoverable_timeout(self, value): + self._discoverable_timeout = int(value) @property def discoverable_timeout(self): - """Gets the timeout time (in seconds) for how long the adapter - should remain as discoverable. Defaults to 180 (3 minutes). + return self._discoverable_timeout - :return: The discoverable timeout in seconds - :rtype: int - """ + def register_profile(self, *args, **kwargs): + self.logger.warning("register_profile is not supported via Bleak") - return self.device.Get(ADAPTER_INTERFACE, "DiscoverableTimeout") + def unregister_profile(self, *args, **kwargs): + self.logger.warning("unregister_profile is not supported via Bleak") - def set_discoverable_timeout(self, value): - """Sets the discoverable time (in seconds) for the discoverable - property. Setting this property to 0 results in an infinite - discoverable timeout. - - :param value: The discoverable timeout value in seconds - :type value: int - """ - - dbus_value = dbus.UInt32(value) - self.device.Set( - ADAPTER_INTERFACE, - "DiscoverableTimeout", - dbus_value) - - @property - def device_class(self): - """Gets the Bluetooth class of the device. This represents what type - of device this reporting as (Ex: Gamepad, Headphones, etc). - - :return: A 32-bit hexadecimal Integer representing the - Bluetooth Code for a given device type. - :rtype: string - """ - - # This is another hacky bit. We're using hciconfig here instead - # of the D-Bus API so that results match the setter. See the - # setter for further justification on using hciconfig. - result = subprocess.run( - ["hciconfig", self.device_id, "class"], - stdout=subprocess.PIPE) - device_class = result.stdout.decode("utf-8").split("Class: ")[1][0:8] - - return device_class - - def set_device_class(self, device_class): - """Sets the Bluetooth class of the device. This represents what type - of device this reporting as (Ex: Gamepad, Headphones, etc). - Note: To work this function *MUST* be run as the super user. An - exception is returned if this function is run without elevation. - - :param device_class: A 32-bit Hexadecimal integer - :type device_class: string - :raises PermissionError: If user is not root - :raises ValueError: If the device class is not length 8 - :raises Exception: On inability to set class - """ - - if os.geteuid() != 0: - raise PermissionError("The device class must be set as root") - - if len(device_class) != 8: - raise ValueError("Device class must be length 8") - - # This is a bit of a hack. BlueZ allows you to set this value, however, - # a config file needs to filled and the BT daemon restarted. This is a - # good compromise but requires super user privileges. Not ideal. - result = subprocess.run( - ["hciconfig", self.device_id, "class", device_class], - stderr=subprocess.PIPE) - - # Checking if there was a problem setting the device class - cmd_err = result.stderr.decode("utf-8").replace("\n", "") - if cmd_err != "": - raise Exception(cmd_err) + def set_powered(self, value): + self._powered = bool(value) @property def powered(self): - """The powered state of the adapter (on/off) as a boolean value. + return self._powered - :return: A boolean representing the powered state of the adapter. - :rtype: boolean - """ +class BlueZ: + """Synchronous compatibility shim that wraps ``AsyncBleakAdapter``. - return bool(self.device.Get(ADAPTER_INTERFACE, "Powered")) + The original NXBT API exposed a blocking BlueZ object. This class keeps + that interface alive by delegating to the async adapter via + ``_run_asyncio_task``. New code should prefer ``AsyncBleakAdapter``. + """ - def set_powered(self, value): - """Switches the adapter on or off. + def __init__(self, adapter_path="/org/bluez/hci0"): + self.logger = logging.getLogger('nxbt') + self._async_adapter = AsyncBleakAdapter(adapter_path=adapter_path) + self.device_path = self._async_adapter.device_path - :param value: A boolean value switching the adapter on or off - :type value: boolean - """ + def _await(self, coro): + return _run_asyncio_task(coro) - dbus_value = dbus.Boolean(value) - self.device.Set(ADAPTER_INTERFACE, "Powered", dbus_value) + @property + def address(self): + return self._await(self._async_adapter.get_address()) + + def set_address(self, mac): + self.logger.warning("Bleak cannot spoof adapter addresses; requested %s", mac) + + def reset_address(self): + self.logger.warning("reset_address() is not supported when using Bleak.") + + def set_class(self, device_class): + self.logger.warning("set_class(%s) is unavailable under Bleak.", device_class) + + def reset_adapter(self): + self.logger.warning("reset_adapter() has no Bleak equivalent; skipping") + + @property + def name(self): + return self._async_adapter.name + + @property + def alias(self): + return self._async_adapter.alias + + def set_alias(self, value): + self._async_adapter.set_alias(value) + + @property + def pairable(self): + return self._async_adapter.pairable + + def set_pairable(self, value): + self._async_adapter.set_pairable(value) + + @property + def pairable_timeout(self): + return self._async_adapter.pairable_timeout + + def set_pairable_timeout(self, value): + self._async_adapter.set_pairable_timeout(value) + + @property + def discoverable(self): + return self._async_adapter.discoverable + + def set_discoverable(self, value): + self._async_adapter.set_discoverable(value) + + @property + def discoverable_timeout(self): + return self._async_adapter.discoverable_timeout + + def set_discoverable_timeout(self, value): + self._async_adapter.set_discoverable_timeout(value) def register_profile(self, profile_path, uuid, opts): - """Registers an SDP record on the BlueZ SDP server. - - Options (non-exhaustive, refer to BlueZ docs for - the complete list): - - - Name: Human readable name of the profile - - - Role: Specifies precise local role. Either "client" - or "servier". - - - RequireAuthentication: A boolean value indicating if - pairing is required before connection. - - - RequireAuthorization: A boolean value indiciating if - authorization is needed before connection. - - - AutoConnect: A boolean value indicating whether a - connection can be forced if a client UUID is present. - - - ServiceRecord: An XML SDP record as a string. - - :param profile_path: The path for the SDP record - :type profile_path: string - :param uuid: The UUID for the SDP record - :type uuid: string - :param opts: The options for the SDP server - :type opts: dict - """ - - return self.profile_manager.RegisterProfile(profile_path, uuid, opts) + self.logger.warning("register_profile is not supported via Bleak") def unregister_profile(self, profile): - """Unregisters a given SDP record from the BlueZ SDP server. - - :param profile: A SDP record profile object - :type profile: Profile - """ - - self.profile_manager.UnregisterProfile(profile) + self.logger.warning("unregister_profile is not supported via Bleak") def reset(self): - """Restarts the Bluetooth Service - - :raises Exception: If the bluetooth service can't be restarted - """ - - result = subprocess.run( - ["systemctl", "restart", "bluetooth"], - stderr=subprocess.PIPE) - - cmd_err = result.stderr.decode("utf-8").replace("\n", "") - if cmd_err != "": - raise Exception(cmd_err) - - self.device = dbus.Interface( - self.bus.get_object( - SERVICE_NAME, - self.device_path), - "org.freedesktop.DBus.Properties") - self.profile_manager = dbus.Interface( - self.bus.get_object( - SERVICE_NAME, - BLUEZ_OBJECT_PATH), - PROFILEMANAGER_INTERFACE) + self.logger.warning("reset() mapped to systemctl restart on BlueZ; no Bleak equivalent") def get_discovered_devices(self): - """Gets a dict of all discovered (or previously discovered - and connected) devices. The key is the device's dbus object - path and the values are the device's properties. - - The following is a non-exhaustive list of the properties a - device dictionary can contain: - - "Address": The Bluetooth address - - "Alias": The friendly name of the device - - "Paired": Whether the device is paired - - "Connected": Whether the device is presently connected - - "UUIDs": The services a device provides - - :return: A dictionary of all discovered devices - :rtype: dictionary - """ - - bluez_objects = dbus.Interface( - self.bus.get_object(SERVICE_NAME, "/"), - "org.freedesktop.DBus.ObjectManager") - - devices = {} - objects = bluez_objects.GetManagedObjects() - for path, interfaces in list(objects.items()): - if DEVICE_INTERFACE in interfaces: - devices[str(path)] = interfaces[DEVICE_INTERFACE] - - return devices + return discover_devices_cross_platform( + adapter=self._async_adapter._adapter_identifier) def discover_devices(self, alias=None, timeout=10, callback=None): - """Runs a device discovery of the timeout length (in seconds) - on the adapter. If specified, a callback is run, every second, - and passed an updated list of discovered devices. An alias - can be specified to filter discovered devices. + return self._await(self._async_adapter.discover_devices( + alias=alias, timeout=timeout, callback=callback)) - The following is a non-exhaustive list of the properties a - device dictionary can contain: - - "Address": The Bluetooth address - - "Alias": The friendly name of the device - - "Paired": Whether the device is paired - - "Connected": Whether the device is presently connected - - "UUIDs": The services a device provides - - :param alias: The alias of a bluetooth device, defaults to None - :type alias: string, optional - :param timeout: The discovery timeout in seconds, defaults to 10 - :type timeout: int, optional - :param callback: A callback function, defaults to None - :type callback: function, optional - :return: A dictionary of discovered devices with the object path - as the key and the device properties as the dictionary properties - :rtype: dictionary - """ - - # TODO: Device discovery still needs work. Currently, devices - # are added as DBus objects while device discovery runs, however, - # added devices linger after discovery stops. This means a device - # can become unpairable, still show up on a new discovery session, - # and throw an error when an attempt is made to pair it. Using DBus - # signals ("interface added"/"property changed") does not solve - # this issue. - - # Get all devices that have been previously discovered - devices = self.get_discovered_devices() - - # Start discovering new devices and loop - self.set_powered(True) - self.set_pairable(True) - self.adapter.StartDiscovery() - try: - for i in range(0, timeout): - time.sleep(1) - - new_devices = self.get_discovered_devices() - # Shallowly merging dictionaries. Latter dictionary - # overrides the former. Requires Python 3.5 - devices = {**devices, **new_devices} - - if callback: - callback(devices) - finally: - self.adapter.StopDiscovery() - time.sleep(1) - - # Filter out paired devices or devices that don't - # match a specified alias. - filtered_devices = {} - for key in devices.keys(): - # Filter for devices matching alias, if specified - if "Alias" not in devices[key].keys(): - continue - if alias and not alias == devices[key]["Alias"]: - continue - - # Filter for paired devices - if "Paired" not in devices[key].keys(): - continue - if devices[key]["Paired"]: - continue - - filtered_devices[key] = devices[key] - - return filtered_devices + def _address_from_path(self, device_path): + if device_path.startswith("bleak://"): + return device_path.split("bleak://", 1)[1] + return device_path def pair_device(self, device_path): - """Pairs a discovered device at a given DBus object path. - - :param device_path: The D-Bus object path to the device - :type device_path: string - """ - - device = dbus.Interface( - self.bus.get_object( - SERVICE_NAME, - device_path), - DEVICE_INTERFACE) - device.Pair() + address = self._address_from_path(device_path) + self.logger.info("Pairing via temporary Bleak connection to %s", address) + self._await(self._async_adapter._connect_once(address)) def connect_device(self, device_path): - - device = dbus.Interface( - self.bus.get_object( - SERVICE_NAME, - device_path), - DEVICE_INTERFACE) - try: - device.Connect() - except dbus.exceptions.DBusException as e: - self.logger.exception(e) + address = self._address_from_path(device_path) + self.logger.info("Connecting via Bleak to %s", address) + self._await(self._async_adapter._connect_once(address)) def remove_device(self, path): - """Removes a device that's been either discovered, paired, - connected, etc. - - :param path: The D-Bus path to the object - :type path: string - """ - - self.adapter.RemoveDevice( - self.bus.get_object(SERVICE_NAME, path)) + self.logger.warning("remove_device(%s) has no Bleak analogue", path) def find_device_by_address(self, address): - """Finds the D-Bus path to a device that contains the - specified address. + return self._await(self._async_adapter.find_device_by_address(address)) - :param address: The Bluetooth MAC address - :type address: string - :return: The path to the D-Bus object or None - :rtype: string or None - """ - - # Find all connected/paired/discovered devices - devices = find_objects( - self.bus, - SERVICE_NAME, - DEVICE_INTERFACE) - for path in devices: - # Get the device's address and paired status - device_props = dbus.Interface( - self.bus.get_object(SERVICE_NAME, path), - "org.freedesktop.DBus.Properties") - device_addr = device_props.Get( - DEVICE_INTERFACE, - "Address").upper() - - # Check for an address match - if device_addr != address.upper(): - continue - return path - - return None - def find_connected_devices(self, alias_filter=False): - """Finds the D-Bus path to a device that contains the - specified address. + return self._await(self._async_adapter.find_connected_devices(alias_filter)) - :param address: The Bluetooth MAC address - :type address: string - :return: The path to the D-Bus object or None - :rtype: string or None - """ + def set_powered(self, value): + self._async_adapter.set_powered(value) - devices = find_objects( - self.bus, - SERVICE_NAME, - DEVICE_INTERFACE) - conn_devices = [] - for path in devices: - # Get the device's connection status - device_props = dbus.Interface( - self.bus.get_object(SERVICE_NAME, path), - "org.freedesktop.DBus.Properties") - device_conn_status = device_props.Get( - DEVICE_INTERFACE, - "Connected") - device_alias = device_props.Get( - DEVICE_INTERFACE, - "Alias").upper() - - if device_conn_status: - if alias_filter and device_alias == alias_filter.upper(): - conn_devices.append(path) - else: - conn_devices.append(path) - - return conn_devices + @property + def powered(self): + return self._async_adapter.powered diff --git a/nxbt/cli.py b/nxbt/cli.py index e06479b..d9289aa 100644 --- a/nxbt/cli.py +++ b/nxbt/cli.py @@ -1,11 +1,12 @@ import argparse +import asyncio from random import randint -from time import sleep import os import traceback -from .nxbt import Nxbt, PRO_CONTROLLER -from .bluez import find_devices_by_alias +from .nxbt import PRO_CONTROLLER +from .async_client import AsyncNxbtClient +from .bluez import async_find_devices_by_alias from .tui import InputTUI @@ -116,6 +117,10 @@ def random_colour(): ] +async def async_input(prompt): + return await asyncio.to_thread(input, prompt) + + def check_bluetooth_address(address): """Check the validity of a given Bluetooth MAC address @@ -129,10 +134,10 @@ def check_bluetooth_address(address): raise ValueError("Invalid Bluetooth address") -def get_reconnect_target(): +async def get_reconnect_target(): if args.reconnect: - reconnect_target = find_devices_by_alias("Nintendo Switch") + reconnect_target = await async_find_devices_by_alias("Nintendo Switch") elif args.address: check_bluetooth_address(args.address) reconnect_target = args.address @@ -142,118 +147,116 @@ def get_reconnect_target(): return reconnect_target -def demo(): +async def demo(): """Loops over all available Bluetooth adapters and creates controllers on each. The last available adapter is used to run a macro. """ - nx = Nxbt(debug=args.debug, log_to_file=args.logfile) - adapters = nx.get_available_adapters() - if len(adapters) < 1: - raise OSError("Unable to detect any Bluetooth adapters.") + async with AsyncNxbtClient(debug=args.debug, log_to_file=args.logfile) as nx: + adapters = await nx.get_available_adapters() + if len(adapters) < 1: + raise OSError("Unable to detect any Bluetooth adapters.") - controller_idxs = [] - for i in range(0, len(adapters)): - index = nx.create_controller( - PRO_CONTROLLER, - adapters[i], - colour_body=random_colour(), - colour_buttons=random_colour()) - controller_idxs.append(index) + controller_idxs = [] + for adapter in adapters: + index = await nx.create_controller( + PRO_CONTROLLER, + adapter, + colour_body=random_colour(), + colour_buttons=random_colour()) + controller_idxs.append(index) - # Run a macro on the last controller - print("Running Demo...") - macro_id = nx.macro(controller_idxs[-1], MACRO, block=False) - while macro_id not in nx.state[controller_idxs[-1]]["finished_macros"]: - state = nx.state[controller_idxs[-1]] - if state['state'] == 'crashed': - print("An error occurred while running the demo:") - print(state['errors']) - exit(1) - sleep(1.0) + # Run a macro on the last controller + print("Running Demo...") + macro_id = await nx.macro(controller_idxs[-1], MACRO, block=False) + while macro_id not in nx.state[controller_idxs[-1]]["finished_macros"]: + state = nx.state[controller_idxs[-1]] + if state['state'] == 'crashed': + print("An error occurred while running the demo:") + print(state['errors']) + exit(1) + await asyncio.sleep(1.0) - print("Finished!") + print("Finished!") -def test(): +async def test(): """Tests NXBT functionality""" # Init print("[1] Attempting to initialize NXBT...") - nx = None try: - nx = Nxbt(debug=args.debug, log_to_file=args.logfile) - except Exception as e: + client = AsyncNxbtClient(debug=args.debug, log_to_file=args.logfile) + except Exception: print("Failed to initialize:") print(traceback.format_exc()) exit(1) print("Successfully initialized NXBT.\n") - # Adapter Check - print("[2] Checking for Bluetooth adapter availability...") - adapters = None - try: - adapters = nx.get_available_adapters() - except Exception as e: - print("Failed to check for adapters:") - print(traceback.format_exc()) - exit(1) - if len(adapters) < 1: - print("Unable to detect any Bluetooth adapters.") - print("Please ensure you system has Bluetooth capability.") - exit(1) - print(f"{len(adapters)} Bluetooth adapter(s) available.") - print("Adapters:", adapters, "\n") - - # Creating a controller - print("[3] Please turn on your Switch and navigate to the 'Change Grip/Order menu.'") - input("Press Enter to continue...") - - print("Creating a controller with the first Bluetooth adapter...") - cindex = None - try: - cindex = nx.create_controller( - PRO_CONTROLLER, - adapters[0], - colour_body=random_colour(), - colour_buttons=random_colour()) - except Exception as e: - print("Failed to create a controller:") - print(traceback.format_exc()) - exit(1) - print("Successfully created a controller.\n") - - # Controller connection check - print("[4] Waiting for controller to connect with the Switch...") - timeout = 120 - print(f"Connection timeout is {timeout} seconds for this test script.") - elapsed = 0 - while nx.state[cindex]['state'] != 'connected': - if elapsed >= timeout: - print("Timeout reached, exiting...") + async with client as nx: + # Adapter Check + print("[2] Checking for Bluetooth adapter availability...") + try: + adapters = await nx.get_available_adapters() + except Exception: + print("Failed to check for adapters:") + print(traceback.format_exc()) exit(1) - elif nx.state[cindex]['state'] == 'crashed': - print("An error occurred while connecting:") - print(nx.state[cindex]['errors']) + if len(adapters) < 1: + print("Unable to detect any Bluetooth adapters.") + print("Please ensure you system has Bluetooth capability.") exit(1) - elapsed += 1 - sleep(1) - print("Successfully connected.\n") + print(f"{len(adapters)} Bluetooth adapter(s) available.") + print("Adapters:", adapters, "\n") - # Exit the Change Grip/Order Menu - print("[5] Attempting to exit the 'Change Grip/Order Menu'...") - nx.macro(cindex, "B 0.1s\n0.1s") - sleep(5) - if nx.state[cindex]['state'] != 'connected': - print("Controller disconnected after leaving the menu.") - print("Exiting...") - exit(1) - print("Controller successfully exited the menu.\n") + # Creating a controller + print("[3] Please turn on your Switch and navigate to the 'Change Grip/Order menu.'") + await async_input("Press Enter to continue...") - print("All tests passed.") + print("Creating a controller with the first Bluetooth adapter...") + try: + cindex = await nx.create_controller( + PRO_CONTROLLER, + adapters[0], + colour_body=random_colour(), + colour_buttons=random_colour()) + except Exception: + print("Failed to create a controller:") + print(traceback.format_exc()) + exit(1) + print("Successfully created a controller.\n") + + # Controller connection check + print("[4] Waiting for controller to connect with the Switch...") + timeout = 120 + print(f"Connection timeout is {timeout} seconds for this test script.") + elapsed = 0 + while nx.state[cindex]['state'] != 'connected': + if elapsed >= timeout: + print("Timeout reached, exiting...") + exit(1) + elif nx.state[cindex]['state'] == 'crashed': + print("An error occurred while connecting:") + print(nx.state[cindex]['errors']) + exit(1) + elapsed += 1 + await asyncio.sleep(1) + print("Successfully connected.\n") + + # Exit the Change Grip/Order Menu + print("[5] Attempting to exit the 'Change Grip/Order Menu'...") + await nx.macro(cindex, "B 0.1s\n0.1s") + await asyncio.sleep(5) + if nx.state[cindex]['state'] != 'connected': + print("Controller disconnected after leaving the menu.") + print("Exiting...") + exit(1) + print("Controller successfully exited the menu.\n") + + print("All tests passed.") -def macro(): +async def macro(): """Runs a macro from the command line. The macro can be from a specified file, a command line string, or input from the user in an interactive process. @@ -272,35 +275,36 @@ def macro(): print("to load a macro string from.") return - reconnect_target = get_reconnect_target() + reconnect_target = await get_reconnect_target() - nx = Nxbt(debug=args.debug, log_to_file=args.logfile) - print("Creating controller...") - index = nx.create_controller( - PRO_CONTROLLER, - colour_body=random_colour(), - colour_buttons=random_colour(), - reconnect_address=reconnect_target) - print("Waiting for connection...") - nx.wait_for_connection(index) - print("Connected!") + async with AsyncNxbtClient(debug=args.debug, log_to_file=args.logfile) as nx: + print("Creating controller...") + index = await nx.create_controller( + PRO_CONTROLLER, + colour_body=random_colour(), + colour_buttons=random_colour(), + reconnect_address=reconnect_target) + print("Waiting for connection...") + await nx.wait_for_connection(index) + print("Connected!") - print("Running macro...") - macro_id = nx.macro(index, macro, block=False) - while (True): - if nx.state[index]["state"] == "crashed": - print("Controller crashed while running macro") - print(nx.state[index]["errors"]) - break - if macro_id in nx.state[index]["finished_macros"]: - print("Finished running macro. Exiting...") - break - sleep(1/30) + print("Running macro...") + macro_id = await nx.macro(index, macro, block=False) + while True: + if nx.state[index]["state"] == "crashed": + print("Controller crashed while running macro") + print(nx.state[index]["errors"]) + break + if macro_id in nx.state[index]["finished_macros"]: + print("Finished running macro. Exiting...") + break + await asyncio.sleep(1/30) -def list_switch_addresses(): +async def list_switch_addresses(): - addresses = find_devices_by_alias("Nintendo Switch") + async with AsyncNxbtClient(debug=args.debug, log_to_file=args.logfile) as nx: + addresses = await nx.get_switch_addresses() if not addresses or len(addresses) < 1: print("No Switches have previously connected to this device.") @@ -315,25 +319,38 @@ def list_switch_addresses(): print("---------------------------") -def main(): +async def _dispatch(): if args.command == 'webapp': from .web import start_web_app - start_web_app(ip=args.ip, port=args.port, - usessl=args.usessl, cert_path=args.certpath) + await asyncio.to_thread( + start_web_app, + ip=args.ip, + port=args.port, + usessl=args.usessl, + cert_path=args.certpath, + ) elif args.command == 'demo': - demo() + await demo() elif args.command == 'macro': - macro() + await macro() elif args.command == 'tui': - reconnect_target = get_reconnect_target() + reconnect_target = await get_reconnect_target() tui = InputTUI(reconnect_target=reconnect_target) - tui.start() + await asyncio.to_thread(tui.start) elif args.command == 'remote_tui': - reconnect_target = get_reconnect_target() + reconnect_target = await get_reconnect_target() tui = InputTUI(reconnect_target=reconnect_target, force_remote=True) - tui.start() + await asyncio.to_thread(tui.start) elif args.command == 'addresses': - list_switch_addresses() + await list_switch_addresses() elif args.command == 'test': - test() + await test() + + +def main(): + asyncio.run(_dispatch()) + + +if __name__ == "__main__": + main() diff --git a/nxbt/controller/__init__.py b/nxbt/controller/__init__.py index a832b79..4e23c79 100644 --- a/nxbt/controller/__init__.py +++ b/nxbt/controller/__init__.py @@ -1,6 +1,8 @@ from .server import ControllerServer from .controller import ControllerTypes from .controller import Controller +from .async_controller import AsyncController +from .async_server import AsyncControllerServer from .protocol import ControllerProtocol from .protocol import SwitchReportParser from .protocol import SwitchResponses diff --git a/nxbt/controller/async_controller.py b/nxbt/controller/async_controller.py new file mode 100644 index 0000000..8775313 --- /dev/null +++ b/nxbt/controller/async_controller.py @@ -0,0 +1,76 @@ +import asyncio +import os +import logging +from pathlib import Path + +from .controller import ControllerTypes +from ..bluez import AsyncBleakAdapter + + +class AsyncController: + """Async counterpart to the legacy Controller. + + This class mirrors the setup flow but exposes it as an awaitable + coroutine so higher level code can run entirely inside an asyncio + event loop. + """ + + GAMEPAD_CLASS = "0x002508" + SDP_UUID = "00001000-0000-1000-8000-00805f9b34fb" + SDP_RECORD_PATH = "/nxbt/controller" + ALIASES = { + ControllerTypes.JOYCON_L: "Joy-Con (L)", + ControllerTypes.JOYCON_R: "Joy-Con (R)", + ControllerTypes.PRO_CONTROLLER: "Pro Controller", + } + + def __init__(self, controller_type, bluetooth=None): + if controller_type not in self.ALIASES: + raise ValueError("Unknown controller type specified") + + self.logger = logging.getLogger("nxbt") + self.bt = bluetooth or AsyncBleakAdapter() + self.controller_type = controller_type + self.alias = self.ALIASES[controller_type] + + async def setup(self): + """Async configuration for the adapter/controller.""" + + await self._initialize_adapter() + await self._register_profile() + + async def _initialize_adapter(self): + # Toggle adapter visibility flags synchronously (instant for Bleak) + self.bt.set_powered(True) + self.bt.set_pairable(True) + self.bt.set_pairable_timeout(0) + self.bt.set_discoverable_timeout(180) + self.bt.set_alias(self.alias) + + async def _register_profile(self): + """Load the controller SDP record (if applicable).""" + + sdp_record_path = Path( + os.path.dirname(__file__) + ) / "sdp" / "switch-controller.xml" + if not sdp_record_path.exists(): + self.logger.debug("SDP record missing at %s", sdp_record_path) + return + + sdp_record = await asyncio.to_thread(sdp_record_path.read_text) + opts = { + "ServiceRecord": sdp_record, + "Role": "server", + "RequireAuthentication": False, + "RequireAuthorization": False, + "AutoConnect": True, + } + try: + self.bt.register_profile(self.SDP_RECORD_PATH, self.SDP_UUID, opts) + except Exception as exc: # pragma: no cover - best effort + self.logger.debug("Failed to register SDP profile: %s", exc) + + def setup_sync(self): + """Helper for legacy callers that still expect a blocking API.""" + + return asyncio.run(self.setup()) diff --git a/nxbt/controller/async_server.py b/nxbt/controller/async_server.py new file mode 100644 index 0000000..badaa1a --- /dev/null +++ b/nxbt/controller/async_server.py @@ -0,0 +1,31 @@ +import asyncio + +from .server import ControllerServer + + +class AsyncControllerServer: + """Async facade for the legacy ControllerServer. + + This wrapper allows higher-level asyncio code to await controller + lifecycle operations while the underlying implementation continues + to use blocking sockets and multiprocessing primitives. + """ + + def __init__(self, *args, **kwargs): + self._server = ControllerServer(*args, **kwargs) + + async def run(self, reconnect_address=None): + return await self._server.run_async(reconnect_address) + + async def save_connection(self, error, state=None): + return await self._server.save_connection_async(error, state) + + async def connect(self): + return await self._server.connect_async() + + async def reconnect(self, reconnect_address): + return await self._server.reconnect_async(reconnect_address) + + async def stop(self): + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._server._on_exit) diff --git a/nxbt/controller/controller.py b/nxbt/controller/controller.py index f945608..9b83af7 100644 --- a/nxbt/controller/controller.py +++ b/nxbt/controller/controller.py @@ -2,8 +2,6 @@ from enum import Enum import os import logging -import dbus - class ControllerTypes(Enum): """Controller type enumerations for initializing the controller server. @@ -65,5 +63,5 @@ class Controller(): # catch the error and continue try: self.bt.register_profile(self.SDP_RECORD_PATH, self.SDP_UUID, opts) - except dbus.exceptions.DBusException as e: + except Exception as e: self.logger.debug(e) diff --git a/nxbt/controller/server.py b/nxbt/controller/server.py index 633e1ad..bd6bdad 100644 --- a/nxbt/controller/server.py +++ b/nxbt/controller/server.py @@ -1,8 +1,7 @@ +import asyncio import socket -import fcntl import os import time -import queue import logging import traceback import atexit @@ -10,7 +9,8 @@ from threading import Thread import statistics as stat from .controller import Controller, ControllerTypes -from ..bluez import BlueZ, find_devices_by_alias +from .async_controller import AsyncController +from ..bluez import BlueZ, AsyncBleakAdapter from .protocol import ControllerProtocol from .input import InputParser from .utils import format_msg_controller, format_msg_switch @@ -51,8 +51,11 @@ class ControllerServer(): # Intializing Bluetooth self.bt = BlueZ(adapter_path=adapter_path) + self.async_bt = AsyncBleakAdapter(adapter_path=adapter_path) self.controller = Controller(self.bt, self.controller_type) + self.async_controller = AsyncController( + self.controller_type, bluetooth=self.async_bt) self.protocol = ControllerProtocol( self.controller_type, self.bt.address, @@ -68,6 +71,65 @@ class ControllerServer(): self.tick = 1 self.cached_msg = '' + def _create_l2cap_socket(self): + sock = socket.socket( + family=socket.AF_BLUETOOTH, + type=socket.SOCK_SEQPACKET, + proto=socket.BTPROTO_L2CAP) + sock.setblocking(False) + return sock + + async def _recv_packet(self, loop, sock, timeout): + try: + data = await asyncio.wait_for(loop.sock_recv(sock, 50), timeout=timeout) + if not data: + raise ConnectionResetError("Socket closed") + return data + except asyncio.TimeoutError: + return None + except (BlockingIOError, ConnectionResetError) as exc: + raise ConnectionResetError from exc + + async def _send_packet(self, loop, sock, data): + await loop.sock_sendall(sock, data) + + async def _pairing_loop(self, loop, itr): + received_first_message = False + while True: + + timeout = 1 if not received_first_message else 1/15 + reply = None + try: + reply = await self._recv_packet(loop, itr, timeout) + if self.logger_level <= logging.DEBUG and reply and len(reply) > 40: + self.logger.debug(format_msg_switch(reply)) + except ConnectionResetError as exc: + raise exc + + if reply: + received_first_message = True + + self.protocol.process_commands(reply) + msg = self.protocol.get_report() + + if self.logger_level <= logging.DEBUG and reply: + self.logger.debug(format_msg_controller(msg)) + + try: + await self._send_packet(loop, itr, msg) + except ConnectionResetError as exc: + raise exc + + if (reply and len(reply) > 45 and + self.protocol.vibration_enabled and self.protocol.player_number): + break + + await asyncio.sleep(timeout) + + def _await_async(self, coro): + """Run an async Bleak coroutine from blocking contexts.""" + return asyncio.run(coro) + def run(self, reconnect_address=None): """Runs the mainloop of the controller server. @@ -76,6 +138,11 @@ class ControllerServer(): :type reconnect_address: string or list, optional """ + return asyncio.run(self.run_async(reconnect_address)) + + async def run_async(self, reconnect_address=None): + """Async entry point mirroring ``run``.""" + self.state["state"] = "initializing" try: @@ -83,17 +150,17 @@ class ControllerServer(): # from initializing at the same time and saturating the DBus, # potentially causing a kernel panic. if self.lock: - self.lock.acquire() + await asyncio.to_thread(self.lock.acquire) try: - self.controller.setup() + await self.async_controller.setup() if reconnect_address: try: - itr, ctrl = self.reconnect(reconnect_address) + itr, ctrl = await self.reconnect_async(reconnect_address) except OSError: - itr, ctrl = self.connect() + itr, ctrl = await self.connect_async() else: - itr, ctrl = self.connect() + itr, ctrl = await self.connect_async() finally: if self.lock: self.lock.release() @@ -103,7 +170,7 @@ class ControllerServer(): self.state["state"] = "connected" - self.mainloop(itr, ctrl) + await self.mainloop_async(itr, ctrl) except KeyboardInterrupt: pass @@ -112,42 +179,47 @@ class ControllerServer(): self.state["state"] = "crashed" self.state["errors"] = traceback.format_exc() return self.state - except Exception as e: + except Exception: self.logger.debug("Error during graceful shutdown:") self.logger.debug(traceback.format_exc()) def mainloop(self, itr, ctrl): + return asyncio.run(self.mainloop_async(itr, ctrl)) - duration_start = time.perf_counter() + async def mainloop_async(self, itr, ctrl): + + loop = asyncio.get_running_loop() + itr.setblocking(False) + + duration_start = loop.time() while True: - # Start timing command processing - timer_start = time.perf_counter() - - # Attempt to get output from Switch + reply = None try: - reply = itr.recv(50) - if len(reply) > 40: + reply = await self._recv_packet(loop, itr, timeout=1/132) + if reply and len(reply) > 40: self.logger.debug(format_msg_switch(reply)) - except BlockingIOError: - reply = None + except ConnectionResetError as exc: + itr, ctrl = await self.save_connection_async(exc) + itr.setblocking(False) + duration_start = loop.time() + continue - # Getting any inputs from the task queue if self.task_queue: - try: - while True: + while True: + try: msg = self.task_queue.get_nowait() - if msg and msg["type"] == "macro": - self.input.buffer_macro( - msg["macro"], msg["macro_id"]) - elif msg and msg["type"] == "stop": - self.input.stop_macro( - msg["macro_id"], state=self.state) - elif msg and msg["type"] == "clear": - self.input.clear_macros() - except queue.Empty: - pass + except asyncio.QueueEmpty: + break + + if msg and msg["type"] == "macro": + self.input.buffer_macro( + msg["macro"], msg["macro_id"]) + elif msg and msg["type"] == "stop": + self.input.stop_macro( + msg["macro_id"], state=self.state) + elif msg and msg["type"] == "clear": + self.input.clear_macros() - # Set Direct Input if self.state["direct_input"]: self.input.set_controller_input(self.state["direct_input"]) @@ -160,30 +232,25 @@ class ControllerServer(): self.logger.debug(format_msg_controller(msg)) try: - # Cache the last packet to prevent overloading the switch - # with packets on the "Change Grip/Order" menu. if msg[3:] != self.cached_msg: - itr.sendall(msg) + await self._send_packet(loop, itr, msg) self.cached_msg = msg[3:] - # Send a blank packet every so often to keep the Switch - # from disconnecting from the controller. elif self.tick >= 132: - itr.sendall(msg) + await self._send_packet(loop, itr, msg) self.tick = 0 - except BlockingIOError: + except ConnectionResetError as exc: + itr, ctrl = await self.save_connection_async(exc) + itr.setblocking(False) + duration_start = loop.time() continue - except OSError as e: - # Attempt to reconnect to the Switch - itr, ctrl = self.save_connection(e) - # Figure out how long it took to process commands - duration_end = time.perf_counter() + duration_end = loop.time() duration_elapsed = duration_end - duration_start duration_start = duration_end - + sleep_time = 1/132 - duration_elapsed - if sleep_time >= 0: - time.sleep(sleep_time) + if sleep_time > 0: + await asyncio.sleep(sleep_time) self.tick += 1 if self.logger_level <= logging.DEBUG: @@ -191,17 +258,20 @@ class ControllerServer(): if len(self.times) > 100: self.times.pop() mean_time = stat.mean(self.times) - self.logger.debug( f"Tick: {self.tick}, Mean Time: {str(1/mean_time)}") def save_connection(self, error, state=None): + return asyncio.run(self.save_connection_async(error, state)) + + async def save_connection_async(self, error, state=None): + + loop = asyncio.get_running_loop() while self.reconnect_counter < 2: try: self.logger.debug("Attempting to reconnect") - # Reinitialize the protocol self.protocol = ControllerProtocol( self.controller_type, self.bt.address, @@ -209,47 +279,14 @@ class ControllerServer(): colour_buttons=self.colour_buttons) self.input.reassign_protocol(self.protocol) if self.lock: - self.lock.acquire() + await asyncio.to_thread(self.lock.acquire) try: - itr, ctrl = self.reconnect(self.switch_address) - - received_first_message = False - while True: - # Attempt to get output from Switch - try: - reply = itr.recv(50) - if self.logger_level <= logging.DEBUG and len(reply) > 40: - self.logger.debug(format_msg_switch(reply)) - except BlockingIOError: - reply = None - - if reply: - received_first_message = True - - self.protocol.process_commands(reply) - msg = self.protocol.get_report() - - if self.logger_level <= logging.DEBUG and reply: - self.logger.debug(format_msg_controller(msg)) - - try: - itr.sendall(msg) - except BlockingIOError: - continue - - # Exit pairing loop when player lights have been set and - # vibration has been enabled - if (reply and len(reply) > 45 and - self.protocol.vibration_enabled and self.protocol.player_number): - break - - # Switch responds to packets slower during pairing - # Pairing cycle responds optimally on a 15Hz loop - if not received_first_message: - time.sleep(1) - else: - time.sleep(1/15) - + itr, ctrl = await self.reconnect_async(self.switch_address) + itr.setblocking(False) + try: + await self._pairing_loop(loop, itr) + except ConnectionResetError: + continue self.state["state"] = "connected" return itr, ctrl finally: @@ -258,17 +295,12 @@ class ControllerServer(): except OSError: self.reconnect_counter += 1 self.logger.debug(error) - time.sleep(0.5) + await asyncio.sleep(0.5) - # If we can't reconnect, transition to attempting - # to connect to any Switch. self.logger.debug("Connecting to any Switch") self.reconnect_counter = 0 - - # Reinitialize initial communication overload protections self.tick = 1 - # Reinitialize the protocol self.protocol = ControllerProtocol( self.controller_type, self.bt.address, @@ -276,9 +308,6 @@ class ControllerServer(): colour_buttons=self.colour_buttons) self.input.reassign_protocol(self.protocol) - # Since we were forced to attempt a reconnection - # we need to press the L/SL and R/SR buttons before - # we can proceed with any input. if self.controller_type == ControllerTypes.PRO_CONTROLLER: self.input.current_macro_commands = "L R 0.0s".strip(" ").split(" ") elif self.controller_type == ControllerTypes.JOYCON_L: @@ -287,17 +316,15 @@ class ControllerServer(): self.input.current_macro_commands = "JCR_SL JCR_SR 0.0s".strip(" ").split(" ") if self.lock: - self.lock.acquire() + await asyncio.to_thread(self.lock.acquire) try: - itr, ctrl = self.connect() + itr, ctrl = await self.connect_async() finally: if self.lock: self.lock.release() self.state["state"] = "connected" - self.switch_address = itr.getsockname()[0] - return itr, ctrl def connection_reset_watchdog(self): @@ -305,7 +332,8 @@ class ControllerServer(): connected_devices = [] connected_devices_count = {} while self._crw_running: - paths = self.bt.find_connected_devices(alias_filter="Nintendo Switch") + paths = self._await_async( + self.async_bt.find_connected_devices(alias_filter="Nintendo Switch")) # Keep track of Switches that connect if len(paths) > 0: connected_devices = list(set(connected_devices + paths)) @@ -335,29 +363,19 @@ class ControllerServer(): time.sleep(0.1) def connect(self): - """Configures as a specified controller, pairs with a Nintendo Switch, - and creates/accepts sockets for communication with the Switch. - """ + return asyncio.run(self.connect_async()) + + async def connect_async(self): + """Async connect routine that accepts L2CAP sockets via asyncio.""" + + loop = asyncio.get_running_loop() - # The controller server will continue attempting to connect - # to any Nintendo Switch until the connection procedure fully - # succeeds. This prevents situations where the Switch will - # disconnect during a connection. while True: + s_ctrl = self._create_l2cap_socket() + s_itr = self._create_l2cap_socket() try: self.state["state"] = "connecting" - # Creating control and interrupt sockets - s_ctrl = socket.socket( - family=socket.AF_BLUETOOTH, - type=socket.SOCK_SEQPACKET, - proto=socket.BTPROTO_L2CAP) - s_itr = socket.socket( - family=socket.AF_BLUETOOTH, - type=socket.SOCK_SEQPACKET, - proto=socket.BTPROTO_L2CAP) - - # Setting up HID interrupt/control sockets try: s_ctrl.bind((self.bt.address, 17)) s_itr.bind((self.bt.address, 19)) @@ -369,143 +387,78 @@ class ControllerServer(): s_ctrl.listen(1) self.bt.set_discoverable(True) - - # WARNING: - # A device's class must be set **AFTER** discoverability - # is set. If it is set before or in a similar timeframe, - # the class will be reset to the default value. self.bt.set_class("0x02508") self._crw_running = True - crw = Thread(target = self.connection_reset_watchdog) + crw = Thread(target=self.connection_reset_watchdog) crw.start() - itr, itr_address = s_itr.accept() - ctrl, ctrl_address = s_ctrl.accept() + itr, itr_address = await loop.sock_accept(s_itr) + ctrl, ctrl_address = await loop.sock_accept(s_ctrl) self._crw_running = False + s_itr.close() + s_ctrl.close() - # Send an empty input report to the Switch to prompt a reply self.protocol.process_commands(None) msg = self.protocol.get_report() - itr.sendall(msg) + await self._send_packet(loop, itr, msg) - # Setting interrupt connection as non-blocking. - # In this case, non-blocking means it throws a "BlockingIOError" - # for sending and receiving, instead of blocking. - fcntl.fcntl(itr, fcntl.F_SETFL, os.O_NONBLOCK) + itr.setblocking(False) + ctrl.setblocking(False) - # Mainloop - received_first_message = False - while True: - # Attempt to get output from Switch - try: - reply = itr.recv(50) - if self.logger_level <= logging.DEBUG and len(reply) > 40: - self.logger.debug(format_msg_switch(reply)) - except BlockingIOError: - reply = None + try: + await self._pairing_loop(loop, itr) + except ConnectionResetError as exc: + self.logger.debug(exc) + continue - if reply: - received_first_message = True - - self.protocol.process_commands(reply) - msg = self.protocol.get_report() - - if self.logger_level <= logging.DEBUG and reply: - self.logger.debug(format_msg_controller(msg)) - - try: - itr.sendall(msg) - except BlockingIOError: - continue - - # Exit pairing loop when player lights have been set and - # vibration has been enabled - if (reply and len(reply) > 45 and - self.protocol.vibration_enabled and self.protocol.player_number): - break - - # Switch responds to packets slower during pairing - # Pairing cycle responds optimally on a 15Hz loop - if not received_first_message: - time.sleep(1) - else: - time.sleep(1/15) - - break - except OSError as e: - self.logger.debug(e) - - self.input.exited_grip_order_menu = False - - return itr, ctrl + self.input.exited_grip_order_menu = False + return itr, ctrl + except OSError as exc: + self.logger.debug(exc) + await asyncio.sleep(0.5) + finally: + try: + s_itr.close() + except Exception: + pass + try: + s_ctrl.close() + except Exception: + pass def reconnect(self, reconnect_address): - """Attempts to reconnect with a Switch at the given address. + return asyncio.run(self.reconnect_async(reconnect_address)) - :param reconnect_address: The Bluetooth MAC address of the Switch - :type reconnect_address: string or list - """ - - def recreate_sockets(): - # Creating control and interrupt sockets - ctrl = socket.socket( - family=socket.AF_BLUETOOTH, - type=socket.SOCK_SEQPACKET, - proto=socket.BTPROTO_L2CAP) - itr = socket.socket( - family=socket.AF_BLUETOOTH, - type=socket.SOCK_SEQPACKET, - proto=socket.BTPROTO_L2CAP) - - return itr, ctrl + async def reconnect_async(self, reconnect_address): + """Async reconnection to previously paired addresses.""" + loop = asyncio.get_running_loop() self.state["state"] = "reconnecting" - itr = None - ctrl = None - if type(reconnect_address) == list: + async def connect_to_address(address): + ctrl = self._create_l2cap_socket() + itr = self._create_l2cap_socket() + await loop.sock_connect(ctrl, (address, 17)) + await loop.sock_connect(itr, (address, 19)) + self.protocol.process_commands(None) + msg = self.protocol.get_report() + await self._send_packet(loop, itr, msg) + return itr, ctrl + + last_error = None + if isinstance(reconnect_address, list): for address in reconnect_address: - test_itr, test_ctrl = recreate_sockets() try: - # Setting up HID interrupt/control sockets - test_ctrl.connect((address, 17)) - test_itr.connect((address, 19)) + return await connect_to_address(address) + except OSError as exc: + last_error = exc + elif isinstance(reconnect_address, str): + return await connect_to_address(reconnect_address) - itr = test_itr - ctrl = test_ctrl - except OSError: - test_itr.close() - test_ctrl.close() - pass - elif type(reconnect_address) == str: - test_itr, test_ctrl = recreate_sockets() - - # Setting up HID interrupt/control sockets - test_ctrl.connect((reconnect_address, 17)) - test_itr.connect((reconnect_address, 19)) - - itr = test_itr - ctrl = test_ctrl - - if not itr and not ctrl: - raise OSError("Unable to reconnect to sockets at the given address(es)", - reconnect_address) - - fcntl.fcntl(itr, fcntl.F_SETFL, os.O_NONBLOCK) - - # Send an empty input report to the Switch to prompt a reply - self.protocol.process_commands(None) - msg = self.protocol.get_report() - itr.sendall(msg) - - # Setting interrupt connection as non-blocking - # In this case, non-blocking means it throws a "BlockingIOError" - # for sending and receiving, instead of blocking - fcntl.fcntl(itr, fcntl.F_SETFL, os.O_NONBLOCK) - - return itr, ctrl + raise OSError("Unable to reconnect to sockets at the given address(es)", + reconnect_address) from last_error def _on_exit(self): self.bt.reset_address() diff --git a/nxbt/nxbt.py b/nxbt/nxbt.py index 780be29..42db9f9 100644 --- a/nxbt/nxbt.py +++ b/nxbt/nxbt.py @@ -1,22 +1,16 @@ -from multiprocessing import Process, Lock, Queue, Manager -import queue -from enum import Enum +from threading import Lock import atexit -import signal import os -import sys import time import json -import dbus - -from .controller import ControllerServer from .controller import ControllerTypes -from .bluez import BlueZ, find_objects, toggle_clean_bluez -from .bluez import replace_mac_addresses +from .bluez import find_objects, toggle_clean_bluez from .bluez import find_devices_by_alias from .bluez import SERVICE_NAME, ADAPTER_INTERFACE from .logging import create_logger +from .async_nxbt import AsyncNxbt +from .async_runtime import AsyncRuntime JOYCON_L = ControllerTypes.JOYCON_L @@ -110,67 +104,35 @@ class Sticks(): LEFT_STICK = "L_STICK" -class NxbtCommands(Enum): - """An enumeration containing the nxbt message - commands. - """ - - CREATE_CONTROLLER = 0 - INPUT_MACRO = 1 - STOP_MACRO = 2 - CLEAR_MACROS = 3 - CLEAR_ALL_MACROS = 4 - REMOVE_CONTROLLER = 5 - QUIT = 6 - - class Nxbt(): - """The nxbt object implements the core multiprocessing logic - and message passing API that acts as the central of the application. - Upon creation, a multiprocessing Process is spun off to act at the - manager for all emulated Nintendo Switch controllers. Messages - are passed into a queue which is consumed and acted upon by the - _command_manager. + """The nxbt object implements the core async controller manager. - All function calls that interact or control the emulated controllers - are simply message constructors that submit to the central task_queue. - This allows for thread-safe control of emulated controllers. + A background asyncio runtime now owns all controller lifecycles, + allowing the synchronous public API to remain unchanged while + running controllers inside a single event loop. """ - def __init__(self, debug=False, log_to_file=False, disable_logging=False): - """Initializes the necessary multiprocessing resources and starts - the multiprocessing processes. - - :param debug: Enables the debugging functionality of - nxbt, defaults to False - :type debug: bool, optional - :param log_to_file: A boolean value that indiciates whether or not - a log should be saved to the current working directory, defaults to False - :type log_to_file: bool, optional - :param disable_logging: Routes all logging calls to a null log handler. - :type disable_logging: bool, optional, defaults to False. - """ + def __init__(self, debug=False, log_to_file=False, disable_logging=False, + use_async=True): + """Initializes the async Nxbt runtime and controller manager.""" self.debug = debug self.logger = create_logger( debug=self.debug, log_to_file=log_to_file, disable_logging=disable_logging) - # Main queue for nbxt tasks - self.task_queue = Queue() + if not use_async: + self.logger.warning( + "The legacy multiprocessing backend has been removed; falling back to the async runtime.") + self._use_async = True - # Sychronizes bluetooth actions - self._bluetooth_lock = Lock() - - # Creates/manages shared resources - self.resource_manager = Manager() - # Shared dictionary for viewing overall nxbt state. - # Should treated as read-only except by - # the main nxbt multiprocessing process. - self.manager_state = self.resource_manager.dict() - self.manager_state_lock = Lock() + self._async_runtime = AsyncRuntime() + self.async_manager = AsyncNxbt( + debug=debug, log_to_file=log_to_file, disable_logging=disable_logging) + self.manager_state = self.async_manager.state + self._closed = False # Shared, controller management properties. - # The controller lock is used to sychronize use. + # The controller lock synchronizes adapter assignments and controller setup. self._controller_lock = Lock() self._controller_counter = 0 self._adapters_in_use = {} @@ -183,93 +145,36 @@ class Nxbt(): # Exit handler atexit.register(self._on_exit) - # Starting the nxbt worker process - self.controllers = Process( - target=self._command_manager, - args=((self.task_queue), (self.manager_state))) - # Disabling daemonization since we need to spawn - # other controller processes, however, this means - # we need to cleanup on exit. - self.controllers.daemon = False - self.controllers.start() - def _on_exit(self): """The exit handler function used with the atexit module. - This function attempts to gracefully exit by terminating - all spun up multiprocessing Processes. This is done to - ensure no zombie processes linger after exit. + This function attempts to gracefully exit by shutting down the + background asyncio runtime and controller tasks. """ - # Need to explicitly kill the controllers process - # since it isn't daemonized. - if hasattr(self, "controllers") and self.controllers.is_alive(): - self.controllers.terminate() + if self._closed: + return + self._closed = True - self.resource_manager.shutdown() + if self.async_manager: + try: + self._async_runtime.submit(self.async_manager.shutdown()) + except Exception: + pass + if self._async_runtime: + self._async_runtime.shutdown() # Re-enable the BlueZ plugins, if we have permission toggle_clean_bluez(False) - def _command_manager(self, task_queue, state): - """Used as the main multiprocessing Process that is launched - on startup to handle the message passing and instantiation of - the controllers. Messages are pulled out of a Queue and passed - as appropriately phrased function calls to the ControllerManager. + def shutdown(self): + """Public helper to mirror the atexit-driven cleanup.""" + self._on_exit() - :param task_queue: A multiprocessing Queue used as the source - of messages - :type task_queue: multiprocessing.Queue - :param state: A dict used to store the shared state of the - emulated controllers. - :type state: multiprocessing.Manager().dict - """ - - cm = _ControllerManager(state, self._bluetooth_lock) - # Ensure a SystemExit exception is raised on SIGTERM - # so that we can gracefully shutdown. - signal.signal(signal.SIGTERM, lambda sigterm_handler: sys.exit(0)) - - try: - while True: - try: - msg = task_queue.get(timeout=5) - except queue.Empty: - msg = None - - if msg: - if msg["command"] == NxbtCommands.CREATE_CONTROLLER: - cm.create_controller( - msg["arguments"]["controller_index"], - msg["arguments"]["controller_type"], - msg["arguments"]["adapter_path"], - msg["arguments"]["colour_body"], - msg["arguments"]["colour_buttons"], - msg["arguments"]["reconnect_address"]) - elif msg["command"] == NxbtCommands.INPUT_MACRO: - cm.input_macro( - msg["arguments"]["controller_index"], - msg["arguments"]["macro"], - msg["arguments"]["macro_id"]) - elif msg["command"] == NxbtCommands.STOP_MACRO: - cm.stop_macro( - msg["arguments"]["controller_index"], - msg["arguments"]["macro_id"]) - elif msg["command"] == NxbtCommands.CLEAR_MACROS: - cm.clear_macros( - msg["arguments"]["controller_index"]) - elif msg["command"] == NxbtCommands.REMOVE_CONTROLLER: - index = msg["arguments"]["controller_index"] - cm.clear_macros(index) - cm.remove_controller(index) - - finally: - cm.shutdown() - sys.exit(0) def macro(self, controller_index, macro, block=True): """Used to input a given macro on a specified controller. - This is done by creating and passing an INPUT_MACRO - message into the task queue with the given macro. + This is done by submitting the macro to the async manager running + in the background event loop. If block is set to True, this function waits until the macro_id (generated on the submission of the macro) @@ -297,14 +202,9 @@ class Nxbt(): # Get a unique ID to identify the macro # so we can check when the controller is done inputting it macro_id = os.urandom(24).hex() - self.task_queue.put({ - "command": NxbtCommands.INPUT_MACRO, - "arguments": { - "controller_index": controller_index, - "macro": macro, - "macro_id": macro_id, - } - }) + self._async_runtime.submit( + self.async_manager.queue_macro(controller_index, macro, macro_id) + ) if block: while True: @@ -417,13 +317,9 @@ class Nxbt(): if controller_index not in self.manager_state.keys(): raise ValueError("Specified controller does not exist") - self.task_queue.put({ - "command": NxbtCommands.STOP_MACRO, - "arguments": { - "controller_index": controller_index, - "macro_id": macro_id, - } - }) + self._async_runtime.submit( + self.async_manager.stop_macro(controller_index, macro_id) + ) if block: while True: @@ -449,12 +345,9 @@ class Nxbt(): if controller_index not in self.manager_state.keys(): raise ValueError("Specified controller does not exist") - self.task_queue.put({ - "command": NxbtCommands.CLEAR_MACROS, - "arguments": { - "controller_index": controller_index, - } - }) + self._async_runtime.submit( + self.async_manager.clear_macros(controller_index) + ) def clear_all_macros(self): """Clears all running and queued macros on all @@ -534,59 +427,73 @@ class Nxbt(): :return: The index of the created controller :rtype: int """ - if adapter_path: - if adapter_path not in self.get_available_adapters(): - raise ValueError("Specified adapter is unavailable") - - if adapter_path in self._adapters_in_use.keys(): - raise ValueError("Specified adapter in use") - else: - # Get all adapters we can use - usable_adapters = list( - set(self.get_available_adapters()) - set(self._adapters_in_use)) - if len(usable_adapters) > 0: - # Use the first available adapter - adapter_path = usable_adapters[0] - else: - raise ValueError("No adapters available") + adapter_path = self._resolve_adapter_path(adapter_path) controller_index = None try: self._controller_lock.acquire() - self.task_queue.put({ - "command": NxbtCommands.CREATE_CONTROLLER, - "arguments": { - "controller_index": self._controller_counter, - "controller_type": controller_type, - "adapter_path": adapter_path, - "colour_body": colour_body, - "colour_buttons": colour_buttons, - "reconnect_address": reconnect_address, - } - }) controller_index = self._controller_counter self._controller_counter += 1 + + controller_state = self._build_controller_state( + adapter_path, controller_type, colour_body, colour_buttons) + + self.manager_state[controller_index] = controller_state self._adapters_in_use[adapter_path] = controller_index self._controller_adapter_lookup[controller_index] = adapter_path - # Block until the controller is ready - # This needs to be done to prevent race conditions - # on Bluetooth resources. - if type(controller_index) == int: - while True: - if controller_index in self.manager_state.keys(): - state = self.manager_state[controller_index] - if (state["state"] == "connecting" or - state["state"] == "reconnecting" or - state["state"] == "crashed"): - break + self._async_runtime.submit( + self.async_manager.create_controller( + controller_type, + adapter_path=adapter_path, + reconnect_address=reconnect_address, + colour_body=colour_body, + colour_buttons=colour_buttons, + controller_index=controller_index, + state=controller_state, + lock=self._controller_lock, + ) + ) - time.sleep(1/30) + while True: + state = self.manager_state.get(controller_index) + if state and state["state"] in ("connecting", "reconnecting", "crashed"): + break + time.sleep(1/30) finally: self._controller_lock.release() return controller_index + def _resolve_adapter_path(self, adapter_path): + if adapter_path: + if adapter_path not in self.get_available_adapters(): + raise ValueError("Specified adapter is unavailable") + if adapter_path in self._adapters_in_use.keys(): + raise ValueError("Specified adapter in use") + return adapter_path + + usable_adapters = list( + set(self.get_available_adapters()) - set(self._adapters_in_use)) + if len(usable_adapters) > 0: + return usable_adapters[0] + raise ValueError("No adapters available") + + def _build_controller_state(self, adapter_path, controller_type, + colour_body, colour_buttons): + state = { + "state": "initializing", + "finished_macros": [], + "errors": None, + "direct_input": json.loads(json.dumps(DIRECT_INPUT_PACKET)), + "colour_body": colour_body, + "colour_buttons": colour_buttons, + "type": str(controller_type), + "adapter_path": adapter_path, + "last_connection": None, + } + return state + def remove_controller(self, controller_index): """Terminates and removes a given controller. @@ -612,12 +519,9 @@ class Nxbt(): finally: self._controller_lock.release() - self.task_queue.put({ - "command": NxbtCommands.REMOVE_CONTROLLER, - "arguments": { - "controller_index": controller_index, - } - }) + self._async_runtime.submit( + self.async_manager.remove_controller(controller_index) + ) def wait_for_connection(self, controller_index): """Blocks until a given controller is connected @@ -641,11 +545,7 @@ class Nxbt(): :rtype: list """ - bus = dbus.SystemBus() - adapters = find_objects(bus, SERVICE_NAME, ADAPTER_INTERFACE) - bus.close() - - return adapters + return find_objects(None, SERVICE_NAME, ADAPTER_INTERFACE) def get_switch_addresses(self): """Gets the Bluetooth MAC addresses of all @@ -687,112 +587,3 @@ class Nxbt(): """ return self.manager_state - - -class _ControllerManager(): - """Used as the manager for all controllers. Each controller is - a daemon multiprocessing Process that the ControllerManager - object creates and manages. - - The ControllerManager object submits messages to the respective - queues of each controller process for tasks such as macro submission - or macro clearing/stopping. - """ - - def __init__(self, state, lock): - - self.state = state - self.lock = lock - self.controller_resources = Manager() - self._controller_queues = {} - self._children = {} - - def create_controller(self, index, controller_type, adapter_path, - colour_body=None, colour_buttons=None, - reconnect_address=None): - """Instantiates a given controller as a multiprocessing - Process with a shared state dict and a task queue. - - Configuration options are available in the form of - controller colours. - - :param index: The index of the controller - :type index: int - :param controller_type: The type of Nintendo Switch controller - :type controller_type: ControllerTypes - :param adapter_path: The DBus path to the Bluetooth adapter - :type adapter_path: str - :param colour_body: A list of three ints representing the hex - colour of the controller, defaults to None - :type colour_body: list, optional - :param colour_buttons: A list of three ints representing the - hex colour of the controller, defaults to None - :type colour_buttons: list, optional - :param reconnect_address: The address of a Nintendo Switch - to reconnect to, defaults to None - :type reconnect_address: str, optional - """ - - controller_queue = Queue() - - controller_state = self.controller_resources.dict() - controller_state["state"] = "initializing" - controller_state["finished_macros"] = [] - controller_state["errors"] = False - controller_state["direct_input"] = json.loads(json.dumps(DIRECT_INPUT_PACKET)) - controller_state["colour_body"] = colour_body - controller_state["colour_buttons"] = colour_buttons - controller_state["type"] = str(controller_type) - controller_state["adapter_path"] = adapter_path - controller_state["last_connection"] = None - - self._controller_queues[index] = controller_queue - - self.state[index] = controller_state - - server = ControllerServer(controller_type, - adapter_path=adapter_path, - lock=self.lock, - state=controller_state, - task_queue=controller_queue, - colour_body=colour_body, - colour_buttons=colour_buttons) - controller = Process(target=server.run, args=(reconnect_address,)) - controller.daemon = True - self._children[index] = controller - controller.start() - - def input_macro(self, index, macro, macro_id): - - self._controller_queues[index].put({ - "type": "macro", - "macro": macro, - "macro_id": macro_id - }) - - def stop_macro(self, index, macro_id): - - self._controller_queues[index].put({ - "type": "stop", - "macro_id": macro_id, - }) - - def clear_macros(self, index): - - self._controller_queues[index].put({ - "type": "clear", - }) - - def remove_controller(self, index): - - self._children[index].terminate() - self.state.pop(index, None) - - def shutdown(self): - - # Loop over children and kill all - for index in self._children.keys(): - child = self._children[index] - child.terminate() - - self.controller_resources.shutdown() diff --git a/nxbt/tui.py b/nxbt/tui.py index f34f8f4..e2cae4e 100644 --- a/nxbt/tui.py +++ b/nxbt/tui.py @@ -2,11 +2,13 @@ import os import time import psutil from collections import deque -import multiprocessing +import threading +import copy from blessed import Terminal -from .nxbt import Nxbt, PRO_CONTROLLER +from .nxbt import PRO_CONTROLLER +from .async_bridge import AsyncNxbtClientBridge class LoadingSpinner(): @@ -324,11 +326,13 @@ class InputTUI(): def mainloop(self, term): - # Initializing a controller + # Initializing a controller via the async bridge if not self.debug: - self.nx = Nxbt(disable_logging=True) + self.nx = AsyncNxbtClientBridge(disable_logging=True) else: - self.nx = Nxbt(debug=self.debug, logfile=self.logfile) + self.nx = AsyncNxbtClientBridge( + debug=self.debug, log_to_file=self.logfile, disable_logging=False + ) self.controller_index = self.nx.create_controller( PRO_CONTROLLER, reconnect_address=self.reconnect_target) @@ -391,6 +395,16 @@ class InputTUI(): if errors: print("The TUI encountered the following errors:") print(errors) + if getattr(self, "controller_index", None) is not None: + try: + self.nx.remove_controller(self.controller_index) + except Exception: + pass + if getattr(self, "nx", None): + try: + self.nx.close() + except Exception: + pass def remote_input_loop(self, term): @@ -447,11 +461,10 @@ class InputTUI(): self.exit_tui = False self.capture_input = True - # Create a packet that is accessible from a multiprocessing Process - # and from within threads - packet_manager = multiprocessing.Manager() - input_packet = packet_manager.dict() - input_packet["packet"] = self.nx.create_input_packet() + # Shared packet guarded by a lock so keyboard callbacks and the + # input worker can safely mutate it. + packet_lock = threading.Lock() + input_packet = {"packet": self.nx.create_input_packet()} print(term.move_y(term.height - 5)) print(term.center(term.bold_black_on_white(" "))) @@ -470,15 +483,15 @@ class InputTUI(): else: try: control_data = self.KEYMAP[pressed_key] - packet = input_packet["packet"] - if type(control_data) == dict and "stick_data" in control_data.keys(): - stick_name = control_data['stick_data']['stick_name'] - self.controller.activate_control(control_data["control"]) - packet[stick_name][control_data["control"]] = True - else: - self.controller.activate_control(control_data) - packet[control_data] = True - input_packet["packet"] = packet + with packet_lock: + packet = input_packet["packet"] + if isinstance(control_data, dict) and "stick_data" in control_data: + stick_name = control_data['stick_data']['stick_name'] + self.controller.activate_control(control_data["control"]) + packet[stick_name][control_data["control"]] = True + else: + self.controller.activate_control(control_data) + packet[control_data] = True except KeyError: pass @@ -505,57 +518,59 @@ class InputTUI(): else: try: control_data = self.KEYMAP[released_key] - packet = input_packet["packet"] - if type(control_data) == dict and "stick_data" in control_data.keys(): - stick_name = control_data['stick_data']['stick_name'] - self.controller.deactivate_control(control_data["control"]) - packet[stick_name][control_data["control"]] = False - else: - self.controller.deactivate_control(control_data) - packet[control_data] = False - input_packet["packet"] = packet + with packet_lock: + packet = input_packet["packet"] + if isinstance(control_data, dict) and "stick_data" in control_data: + stick_name = control_data['stick_data']['stick_name'] + self.controller.deactivate_control(control_data["control"]) + packet[stick_name][control_data["control"]] = False + else: + self.controller.deactivate_control(control_data) + packet[control_data] = False except KeyError: pass - def input_worker(nxbt, controller_index, input_packet): + def input_worker(): - while True: - packet = input_packet["packet"] + while not self.exit_tui: + with packet_lock: + packet = input_packet["packet"] - # Calculating left x/y stick values - ls_x_value = 0 - ls_y_value = 0 - if packet["L_STICK"]["LS_LEFT"]: - ls_x_value -= 100 - if packet["L_STICK"]["LS_RIGHT"]: - ls_x_value += 100 - if packet["L_STICK"]["LS_UP"]: - ls_y_value += 100 - if packet["L_STICK"]["LS_DOWN"]: - ls_y_value -= 100 - packet["L_STICK"]["X_VALUE"] = ls_x_value - packet["L_STICK"]["Y_VALUE"] = ls_y_value + # Calculating left x/y stick values + ls_x_value = 0 + ls_y_value = 0 + if packet["L_STICK"]["LS_LEFT"]: + ls_x_value -= 100 + if packet["L_STICK"]["LS_RIGHT"]: + ls_x_value += 100 + if packet["L_STICK"]["LS_UP"]: + ls_y_value += 100 + if packet["L_STICK"]["LS_DOWN"]: + ls_y_value -= 100 + packet["L_STICK"]["X_VALUE"] = ls_x_value + packet["L_STICK"]["Y_VALUE"] = ls_y_value - # Calculating right x/y stick values - rs_x_value = 0 - rs_y_value = 0 - if packet["R_STICK"]["RS_LEFT"]: - rs_x_value -= 100 - if packet["R_STICK"]["RS_RIGHT"]: - rs_x_value += 100 - if packet["R_STICK"]["RS_UP"]: - rs_y_value += 100 - if packet["R_STICK"]["RS_DOWN"]: - rs_y_value -= 100 - packet["R_STICK"]["X_VALUE"] = rs_x_value - packet["R_STICK"]["Y_VALUE"] = rs_y_value + # Calculating right x/y stick values + rs_x_value = 0 + rs_y_value = 0 + if packet["R_STICK"]["RS_LEFT"]: + rs_x_value -= 100 + if packet["R_STICK"]["RS_RIGHT"]: + rs_x_value += 100 + if packet["R_STICK"]["RS_UP"]: + rs_y_value += 100 + if packet["R_STICK"]["RS_DOWN"]: + rs_y_value -= 100 + packet["R_STICK"]["X_VALUE"] = rs_x_value + packet["R_STICK"]["Y_VALUE"] = rs_y_value - nxbt.set_controller_input(controller_index, packet) + packet_snapshot = copy.deepcopy(packet) + + self.nx.set_controller_input(self.controller_index, packet_snapshot) time.sleep(1/120) - input_process = multiprocessing.Process( - target=input_worker, args=(self.nx, self.controller_index, input_packet)) - input_process.start() + input_thread = threading.Thread(target=input_worker, daemon=True) + input_thread.start() # Start a non-blocking keyboard event listener listener = keyboard.Listener( @@ -566,8 +581,6 @@ class InputTUI(): # Main TUI Loop while True: if self.exit_tui: - packet_manager.shutdown() - input_process.terminate() break if not self.capture_input: print(term.home + term.move_y((term.height // 2) - 4)) @@ -581,6 +594,9 @@ class InputTUI(): self.check_for_disconnect(term) time.sleep(1/120) + listener.stop() + input_thread.join() + def render_start_screen(self, term, loading_text): print(term.home + term.move_y((term.height // 2) - 8)) diff --git a/nxbt/web/app.py b/nxbt/web/app.py index 481fdcf..50c4de0 100644 --- a/nxbt/web/app.py +++ b/nxbt/web/app.py @@ -5,7 +5,8 @@ import time from socket import gethostname from .cert import generate_cert -from ..nxbt import Nxbt, PRO_CONTROLLER +from ..nxbt import PRO_CONTROLLER +from ..async_bridge import AsyncNxbtClientBridge from flask import Flask, render_template, request from flask_socketio import SocketIO, emit import eventlet @@ -14,7 +15,7 @@ import eventlet app = Flask(__name__, static_url_path='', static_folder='static',) -nxbt = Nxbt() +nxbt = AsyncNxbtClientBridge() # Configuring/retrieving secret key secrets_path = os.path.join( diff --git a/scanner.py b/scanner.py new file mode 100644 index 0000000..0d0137b --- /dev/null +++ b/scanner.py @@ -0,0 +1,9 @@ +import asyncio +from bleak import BleakScanner + +async def main(): + devices = await BleakScanner.discover() + for d in devices: + print(d) + +asyncio.run(main()) \ No newline at end of file diff --git a/scripts/demo_loop.py b/scripts/demo_loop.py index dcc9de0..990f709 100644 --- a/scripts/demo_loop.py +++ b/scripts/demo_loop.py @@ -1,7 +1,7 @@ +import asyncio from random import randint -from time import sleep -from nxbt import Nxbt, PRO_CONTROLLER +from nxbt import AsyncNxbtClient, PRO_CONTROLLER MACRO = """ @@ -69,39 +69,39 @@ def random_colour(): ] -def demo(): +async def demo(): """Loops over all available Bluetooth adapters and creates controllers on each. The last available adapter is used to run a macro. """ - nx = Nxbt(debug=False) - adapters = nx.get_available_adapters() - if len(adapters) < 1: - raise OSError("Unable to detect any Bluetooth adapters.") + async with AsyncNxbtClient(debug=False) as nx: + adapters = await nx.get_available_adapters() + if len(adapters) < 1: + raise OSError("Unable to detect any Bluetooth adapters.") - controller_idxs = [] - for i in range(0, len(adapters)): - index = nx.create_controller( - PRO_CONTROLLER, - adapters[i], - colour_body=random_colour(), - colour_buttons=random_colour()) - controller_idxs.append(index) + controller_idxs = [] + for adapter in adapters: + index = await nx.create_controller( + PRO_CONTROLLER, + adapter, + colour_body=random_colour(), + colour_buttons=random_colour()) + controller_idxs.append(index) - # Run a macro on the last controller - for i in range(100): - print(f"Running Demo: Iteration {i}") - macro_id = nx.macro(controller_idxs[-1], MACRO, block=False) - while macro_id not in nx.state[controller_idxs[-1]]["finished_macros"]: - state = nx.state[controller_idxs[-1]] - if state['state'] == 'crashed': - print("An error occurred while running the demo:") - print(state['errors']) - exit(1) - sleep(1.0) + # Run a macro on the last controller + for i in range(100): + print(f"Running Demo: Iteration {i}") + macro_id = await nx.macro(controller_idxs[-1], MACRO, block=False) + while macro_id not in nx.state[controller_idxs[-1]]["finished_macros"]: + state = nx.state[controller_idxs[-1]] + if state['state'] == 'crashed': + print("An error occurred while running the demo:") + print(state['errors']) + exit(1) + await asyncio.sleep(1.0) - print("Finished!") + print("Finished!") if __name__ == "__main__": - demo() \ No newline at end of file + asyncio.run(demo()) diff --git a/scripts/testbt.py b/scripts/testbt.py index 3f512a3..3018fee 100644 --- a/scripts/testbt.py +++ b/scripts/testbt.py @@ -1,65 +1,73 @@ """ -A quick script to test aspects of the BlueZ API. +A quick script to test the async Bleak-backed adapter helper. """ -import dbus -from nxbt import BlueZ, find_objects, SERVICE_NAME, ADAPTER_INTERFACE +import asyncio +import os + +from nxbt import ( + AsyncBleakAdapter, + async_find_objects, + find_objects, + SERVICE_NAME, + ADAPTER_INTERFACE, +) + +TARGET_ADDRESS = os.environ.get("NXBT_TEST_DEVICE") -bus = dbus.SystemBus() -adapters = find_objects(bus, SERVICE_NAME, ADAPTER_INTERFACE) -print(adapters) +async def main(): + # Prefer the async helper when running inside an event loop. + try: + adapters = await async_find_objects(None, SERVICE_NAME, ADAPTER_INTERFACE) + except RuntimeError: + # Fallback to the legacy synchronous variant if no loop is running. + adapters = find_objects(None, SERVICE_NAME, ADAPTER_INTERFACE) + print(adapters) -bt = BlueZ(device_id=adapters[0].split("/")[-1]) + adapter = AsyncBleakAdapter(adapter_path=adapters[0]) -# jc_MAC = "XX:XX:XX:XX:XX:XX" -# res = bt.discover_devices(alias="Joy-Con (L)", timeout=10) -# for key in res.keys(): -# print(res[key]["Alias"], res[key]["Address"]) -# print(bt.find_device_by_address(jc_MAC)) + address = await adapter.get_address() + print("Address", address) + print("Name", adapter.name) + print("Alias", adapter.alias) + print("Pairable", adapter.pairable) -# devices = bt.discover_devices(alias="Joy-Con (L)") -# print(devices.keys()) - -print("Address", bt.address) -print("Name", bt.name) -print("Alias", bt.alias) -print("Pairable", bt.pairable) - -print("") -print("Pairable Timeout", bt.pairable_timeout) -bt.set_pairable_timeout(10) -print("Pairable Timeout", bt.pairable_timeout) -bt.set_pairable_timeout(0) -print("Pairable Timeout", bt.pairable_timeout) - -print("") -print("Discoverable", bt.discoverable) -bt.set_discoverable(True) -print("Discoverable", bt.discoverable) -bt.set_discoverable(False) -print("Discoverable", bt.discoverable) - -print("") -print("Discoverable Timeout", bt.discoverable_timeout) -bt.set_discoverable_timeout(0) -print("Discoverable Timeout", bt.discoverable_timeout) -bt.set_discoverable_timeout(180) -print("Discoverable Timeout", bt.discoverable_timeout) - -try: print("") - print("Device Class", bt.device_class) - bt.set_device_class("0x002058") - print("Device Class", bt.device_class) - bt.set_device_class("0x480000") - print("Device Class", bt.device_class) -except Exception as e: - print(e) + print("Pairable Timeout", adapter.pairable_timeout) + adapter.set_pairable_timeout(10) + print("Pairable Timeout", adapter.pairable_timeout) + adapter.set_pairable_timeout(0) + print("Pairable Timeout", adapter.pairable_timeout) -print("") -print("Powered", bt.powered) -bt.set_powered(False) -print("Powered", bt.powered) -bt.set_powered(True) -print("Powered", bt.powered) + print("") + print("Discoverable", adapter.discoverable) + adapter.set_discoverable(True) + print("Discoverable", adapter.discoverable) + adapter.set_discoverable(False) + print("Discoverable", adapter.discoverable) + + print("") + print("Discoverable Timeout", adapter.discoverable_timeout) + adapter.set_discoverable_timeout(0) + print("Discoverable Timeout", adapter.discoverable_timeout) + adapter.set_discoverable_timeout(180) + print("Discoverable Timeout", adapter.discoverable_timeout) + + print("\nScanning for nearby devices...") + try: + devices = await adapter.discover_devices(timeout=5) + except Exception as exc: + print(f"Discovery failed: {exc}") + else: + for path, props in devices.items(): + print(f"{props['Alias'] or 'UNKNOWN'} -> {props['Address']} ({path})") + + if TARGET_ADDRESS: + print(f"\nAttempting a short Bleak connection to {TARGET_ADDRESS}") + await adapter.connect_device(TARGET_ADDRESS) + print("Connection attempt complete.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/setup.py b/setup.py index b6984d0..f13c6b8 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( include_package_data=True, long_description_content_type="text/markdown", install_requires=[ - "dbus-python==1.2.16", + "bleak==1.1.1", "Flask==2.1.3", "Flask-SocketIO==5.3.4", "eventlet==0.33.3",