#!/usr/bin/env python3 """ Lightweight helpers for talking to the switch-pico firmware over UART. This module exposes the raw report structure plus a small convenience wrapper so other scripts can do things like "press a button" or "move a stick" without depending on SDL. It mirrors the framing in ``switch-pico.cpp``: Host -> Pico : 0xAA, buttons (LE16), hat, lx, ly, rx, ry Pico -> Host : 0xBB, 0x01, 8 rumble bytes, checksum (sum of first 10 bytes) """ from __future__ import annotations import struct import time import threading from dataclasses import dataclass, field from enum import IntEnum, IntFlag from typing import Iterable, Mapping, Optional, Tuple, Union, List, Dict import serial from serial.tools import list_ports, list_ports_common UART_HEADER = 0xAA RUMBLE_HEADER = 0xBB RUMBLE_TYPE_RUMBLE = 0x01 UART_BAUD = 921600 class SwitchButton(IntFlag): # Mirrors the masks defined in switch_pro_descriptors.h Y = 1 << 0 B = 1 << 1 A = 1 << 2 X = 1 << 3 L = 1 << 4 R = 1 << 5 ZL = 1 << 6 ZR = 1 << 7 MINUS = 1 << 8 PLUS = 1 << 9 LCLICK = 1 << 10 RCLICK = 1 << 11 HOME = 1 << 12 CAPTURE = 1 << 13 class SwitchDpad(IntEnum): UP = 0x00 UP_RIGHT = 0x01 RIGHT = 0x02 DOWN_RIGHT = 0x03 DOWN = 0x04 DOWN_LEFT = 0x05 LEFT = 0x06 UP_LEFT = 0x07 CENTER = 0x08 def _is_usb_serial_path(path: str) -> bool: """Heuristic for USB serial path prefixes.""" lower = path.lower() usb_prefixes = ( "/dev/ttyusb", # Linux USB serial "/dev/ttyacm", # Linux CDC ACM "/dev/cu.usb", # macOS cu/tty USB adapters "/dev/tty.usb", ) if lower.startswith(usb_prefixes): return True # Windows COM ports don't clearly indicate USB; treat as unknown here. return False def _is_usb_serial_port(port: list_ports_common.ListPortInfo) -> bool: """Heuristic: prefer ports with USB VID/PID; fall back to path hints.""" if getattr(port, "vid", None) is not None or getattr(port, "pid", None) is not None: return True path = port.device or "" manufacturer = (getattr(port, "manufacturer", "") or "").upper() if "USB" in manufacturer: return True return _is_usb_serial_path(path) def discover_serial_ports( include_non_usb: bool = False, ignore_descriptions: Optional[List[str]] = None, include_descriptions: Optional[List[str]] = None, include_manufacturers: Optional[List[str]] = None, ) -> List[Dict[str, str]]: """ List serial ports with simple filtering similar to controller_uart_bridge. Args: include_non_usb: Include ports that don't look USB-based (e.g., onboard UARTs). ignore_descriptions: Substrings (case-insensitive) to exclude by description. include_descriptions: If provided, only include ports whose description contains one of these substrings. include_manufacturers: If provided, only include ports whose manufacturer contains one of these substrings. """ ignored = [d.lower() for d in (ignore_descriptions or [])] includes = [d.lower() for d in (include_descriptions or [])] include_mfrs = [m.lower() for m in (include_manufacturers or [])] results: List[Dict[str, str]] = [] for port in list_ports.comports(): path = port.device or "" if not path: continue if not include_non_usb and not _is_usb_serial_port(port): continue desc_lower = (port.description or "").lower() mfr_lower = (port.manufacturer or "").lower() if include_mfrs and not any(keep in mfr_lower for keep in include_mfrs): continue if includes and not any(keep in desc_lower for keep in includes): continue if any(skip in desc_lower for skip in ignored): continue results.append( { "device": path, "description": port.description or "Unknown", "manufacturer": port.manufacturer or "", } ) return results def first_serial_port( include_non_usb: bool = False, ignore_descriptions: Optional[List[str]] = None, include_descriptions: Optional[List[str]] = None, include_manufacturers: Optional[List[str]] = None, ) -> Optional[str]: """Return the first discovered serial port path (or None if none are found).""" ports = discover_serial_ports( include_non_usb, ignore_descriptions, include_descriptions, include_manufacturers, ) if not ports: return None return ports[0]["device"] def clamp_byte(value: Union[int, float]) -> int: """Clamp a numeric value to the 0-255 byte range.""" return max(0, min(255, int(value))) def normalize_stick_value(value: Union[int, float]) -> int: """ Convert a normalized float (-1..1) or raw byte (0..255) to the stick range. Floats are treated as -1.0 = full negative deflection, 0.0 = center, 1.0 = full positive deflection. Integers are assumed to already be in the 0-255 range. """ if isinstance(value, float): value = max(-1.0, min(1.0, value)) value = int(round((value + 1.0) * 255 / 2.0)) return clamp_byte(value) def axis_to_stick(value: int, deadzone: int) -> int: """Convert a signed axis value to 0-255 stick range with deadzone.""" if abs(value) < deadzone: value = 0 scaled = int((value + 32768) * 255 / 65535) return clamp_byte(scaled) def trigger_to_button(value: int, threshold: int) -> bool: """Return True if analog trigger crosses digital threshold.""" return value >= threshold def str_to_dpad(flags: Mapping[str, bool]) -> SwitchDpad: """Translate DPAD button flags into a Switch hat/DPAD value.""" up = flags.get("up", False) down = flags.get("down", False) left = flags.get("left", False) right = flags.get("right", False) if up and right: return SwitchDpad.UP_RIGHT if up and left: return SwitchDpad.UP_LEFT if down and right: return SwitchDpad.DOWN_RIGHT if down and left: return SwitchDpad.DOWN_LEFT if up: return SwitchDpad.UP if down: return SwitchDpad.DOWN if right: return SwitchDpad.RIGHT if left: return SwitchDpad.LEFT return SwitchDpad.CENTER @dataclass class SwitchReport: buttons: int = 0 hat: SwitchDpad = SwitchDpad.CENTER lx: int = 128 ly: int = 128 rx: int = 128 ry: int = 128 def to_bytes(self) -> bytes: """Serialize the report into the UART packet format.""" return struct.pack( " None: """Open a UART connection to the Pico with non-blocking IO.""" self.serial = serial.Serial( port=port, baudrate=baudrate, bytesize=serial.EIGHTBITS, stopbits=serial.STOPBITS_ONE, parity=serial.PARITY_NONE, timeout=0.0, write_timeout=0.0, xonxoff=False, rtscts=False, dsrdtr=False, ) self._buffer = bytearray() def send_report(self, report: SwitchReport) -> None: """Send a controller report to the Pico.""" self.serial.write(report.to_bytes()) def read_rumble_payload(self) -> Optional[bytes]: """ Drain available UART bytes into an internal buffer, then extract one rumble frame. Frame format: 0: 0xBB (RUMBLE_HEADER) 1: type (0x01 for rumble) 2-9: 8-byte rumble payload 10: checksum (sum of first 10 bytes) & 0xFF """ waiting = self.serial.in_waiting if waiting: self._buffer.extend(self.serial.read(waiting)) while True: if not self._buffer: return None start = self._buffer.find(bytes([RUMBLE_HEADER])) if start < 0: self._buffer.clear() return None if len(self._buffer) - start < 11: if start > 0: del self._buffer[:start] return None frame = self._buffer[start:start + 11] checksum = sum(frame[:10]) & 0xFF if frame[1] == RUMBLE_TYPE_RUMBLE and checksum == frame[10]: payload = bytes(frame[2:10]) del self._buffer[:start + 11] return payload del self._buffer[:start + 1] def close(self) -> None: """Close the UART connection.""" self.serial.close() def decode_rumble(payload: bytes) -> Tuple[float, float]: """Return normalized rumble amplitudes (0.0-1.0) for left/right.""" if len(payload) < 8: return 0.0, 0.0 if payload == b"\x00\x01\x40\x40\x00\x01\x40\x40": return 0.0, 0.0 right_raw = ((payload[1] & 0x03) << 8) | payload[0] left_raw = ((payload[5] & 0x03) << 8) | payload[4] if left_raw < 8 and right_raw < 8: return 0.0, 0.0 left = min(max(left_raw / 1023.0, 0.0), 1.0) right = min(max(right_raw / 1023.0, 0.0), 1.0) return left, right @dataclass class SwitchControllerState: """Mutable controller state with helpers for building reports.""" report: SwitchReport = field(default_factory=SwitchReport) def press(self, *buttons_or_hat: Union[SwitchButton, SwitchDpad, int]) -> None: """Press one or more buttons, or set the hat if a SwitchDpad is provided.""" for item in buttons_or_hat: if isinstance(item, SwitchDpad): # If multiple hats are provided, the last one wins. self.report.hat = SwitchDpad(int(item) & 0xFF) else: self.report.buttons |= int(item) def release(self, *buttons_or_hat: Union[SwitchButton, SwitchDpad, int]) -> None: """Release one or more buttons, or center the hat if a SwitchDpad is provided.""" for item in buttons_or_hat: if isinstance(item, SwitchDpad): self.report.hat = SwitchDpad.CENTER else: self.report.buttons &= ~int(item) def set_buttons(self, buttons: Iterable[Union[SwitchButton, int]]) -> None: """Replace the current button bitmask with the provided buttons.""" self.report.buttons = 0 self.press(*buttons) def set_hat(self, hat: Union[SwitchDpad, int]) -> None: """Set the DPAD/hat value directly.""" self.report.hat = SwitchDpad(int(hat) & 0xFF) def move_left_stick(self, x: Union[int, float], y: Union[int, float]) -> None: """Move the left stick using normalized floats (-1..1) or raw bytes (0-255).""" self.report.lx = normalize_stick_value(x) self.report.ly = normalize_stick_value(y) def move_right_stick(self, x: Union[int, float], y: Union[int, float]) -> None: """Move the right stick using normalized floats (-1..1) or raw bytes (0-255).""" self.report.rx = normalize_stick_value(x) self.report.ry = normalize_stick_value(y) def neutral(self) -> None: """Clear all input back to the neutral controller state.""" self.report.buttons = 0 self.report.hat = SwitchDpad.CENTER self.report.lx = 128 self.report.ly = 128 self.report.rx = 128 self.report.ry = 128 class SwitchUARTClient: """ High-level helper to send controller actions to the Pico and poll rumble. Example: with SwitchUARTClient("/dev/cu.usbserial-0001") as client: client.press(SwitchButton.A) time.sleep(0.1) client.release(SwitchButton.A) client.move_left_stick(0.0, -1.0) # push up """ def __init__( self, port: str, baud: int = UART_BAUD, send_interval: float = 1.0 / 500.0, auto_send: bool = True, ) -> None: """ Args: port: Serial port path (e.g., 'COM5' or '/dev/cu.usbserial-0001'). baud: UART baud rate. send_interval: Minimum interval between sends in seconds (defaults to 500 Hz). auto_send: If True, keep sending the current state in a background thread so the Pico continuously sees the latest input (mirrors controller_uart_bridge). """ self.uart = PicoUART(port, baud) self.state = SwitchControllerState() self.send_interval = max(0.0, send_interval) self._last_send = 0.0 self._auto_send = auto_send self._stop_event = threading.Event() self._auto_thread: Optional[threading.Thread] = None if self._auto_send: self._start_auto_send_thread() def send(self) -> None: """Send the current state to the Pico, throttled by send_interval if set.""" now = time.monotonic() if self.send_interval and (now - self._last_send) < self.send_interval: return self.uart.send_report(self.state.report) self._last_send = now def _start_auto_send_thread(self) -> None: """Continuously send the current state so the Pico stays active.""" if self._auto_thread is not None: return sleep_time = self.send_interval if self.send_interval > 0 else 0.002 def loop() -> None: while not self._stop_event.is_set(): self.send() self._stop_event.wait(sleep_time) self._auto_thread = threading.Thread(target=loop, daemon=True) self._auto_thread.start() def press(self, *buttons: SwitchButton | SwitchDpad | int) -> None: """Press buttons or set hat using SwitchButton/SwitchDpad (ints also allowed).""" self.state.press(*buttons) self.send() def release(self, *buttons: SwitchButton | SwitchDpad | int) -> None: """Release buttons or center hat when given a SwitchDpad.""" self.state.release(*buttons) self.send() def set_buttons(self, buttons: Iterable[SwitchButton | int]) -> None: self.state.set_buttons(buttons) self.send() def set_hat(self, hat: SwitchDpad | int) -> None: self.state.set_hat(hat) self.send() def move_left_stick(self, x: Union[int, float], y: Union[int, float]) -> None: self.state.move_left_stick(x, y) self.send() def move_right_stick(self, x: Union[int, float], y: Union[int, float]) -> None: self.state.move_right_stick(x, y) self.send() def press_for(self, duration: float, *buttons: SwitchButton | SwitchDpad | int) -> None: """Press buttons/hat for a duration, then release.""" self.press(*buttons) time.sleep(max(0.0, duration)) self.release(*buttons) def move_left_stick_for( self, x: Union[int, float], y: Union[int, float], duration: float, neutral_after: bool = True ) -> None: """Move left stick for a duration, optionally returning it to neutral afterward.""" self.move_left_stick(x, y) time.sleep(max(0.0, duration)) if neutral_after: self.state.move_left_stick(128, 128) self.send() def move_right_stick_for( self, x: Union[int, float], y: Union[int, float], duration: float, neutral_after: bool = True ) -> None: """Move right stick for a duration, optionally returning it to neutral afterward.""" self.move_right_stick(x, y) time.sleep(max(0.0, duration)) if neutral_after: self.state.move_right_stick(128, 128) self.send() def neutral(self) -> None: self.state.neutral() self.send() def poll_rumble(self) -> Optional[Tuple[float, float]]: """ Poll for the latest rumble payload and return normalized amplitudes. Returns None if no rumble frame was available. """ payload = self.uart.read_rumble_payload() if payload: return decode_rumble(payload) return None def close(self) -> None: if self._auto_thread: self._stop_event.set() self._auto_thread.join(timeout=0.5) self._auto_thread = None self.uart.close() def __enter__(self) -> "SwitchUARTClient": return self def __exit__(self, exc_type, exc, tb) -> None: self.close()