switch-pico/switch_pico_uart.py
jojomawswan aef37123fd Refactor
2025-12-01 13:29:04 -07:00

470 lines
16 KiB
Python

#!/usr/bin/env python3
"""
Lightweight helpers for talking to the switch-pico firmware over UART.
This module exposes the raw report structure plus a small convenience wrapper
so other scripts can do things like "press a button" or "move a stick" without
depending on SDL. It mirrors the framing in ``switch-pico.cpp``:
Host -> Pico : 0xAA, buttons (LE16), hat, lx, ly, rx, ry
Pico -> Host : 0xBB, 0x01, 8 rumble bytes, checksum (sum of first 10 bytes)
"""
from __future__ import annotations
import struct
import time
import threading
from dataclasses import dataclass, field
from enum import IntEnum, IntFlag
from typing import Iterable, Mapping, Optional, Tuple, Union, List, Dict
import serial
from serial.tools import list_ports, list_ports_common
UART_HEADER = 0xAA
RUMBLE_HEADER = 0xBB
RUMBLE_TYPE_RUMBLE = 0x01
UART_BAUD = 921600
class SwitchButton(IntFlag):
# Mirrors the masks defined in switch_pro_descriptors.h
Y = 1 << 0
B = 1 << 1
A = 1 << 2
X = 1 << 3
L = 1 << 4
R = 1 << 5
ZL = 1 << 6
ZR = 1 << 7
MINUS = 1 << 8
PLUS = 1 << 9
LCLICK = 1 << 10
RCLICK = 1 << 11
HOME = 1 << 12
CAPTURE = 1 << 13
class SwitchDpad(IntEnum):
UP = 0x00
UP_RIGHT = 0x01
RIGHT = 0x02
DOWN_RIGHT = 0x03
DOWN = 0x04
DOWN_LEFT = 0x05
LEFT = 0x06
UP_LEFT = 0x07
CENTER = 0x08
def _is_usb_serial_path(path: str) -> bool:
"""Heuristic for USB serial path prefixes."""
lower = path.lower()
usb_prefixes = (
"/dev/ttyusb", # Linux USB serial
"/dev/ttyacm", # Linux CDC ACM
"/dev/cu.usb", # macOS cu/tty USB adapters
"/dev/tty.usb",
)
if lower.startswith(usb_prefixes):
return True
# Windows COM ports don't clearly indicate USB; treat as unknown here.
return False
def _is_usb_serial_port(port: list_ports_common.ListPortInfo) -> bool:
"""Heuristic: prefer ports with USB VID/PID; fall back to path hints."""
if getattr(port, "vid", None) is not None or getattr(port, "pid", None) is not None:
return True
path = port.device or ""
manufacturer = (getattr(port, "manufacturer", "") or "").upper()
if "USB" in manufacturer:
return True
return _is_usb_serial_path(path)
def discover_serial_ports(
include_non_usb: bool = False,
ignore_descriptions: Optional[List[str]] = None,
include_descriptions: Optional[List[str]] = None,
) -> List[Dict[str, str]]:
"""
List serial ports with simple filtering similar to controller_uart_bridge.
Args:
include_non_usb: Include ports that don't look USB-based (e.g., onboard UARTs).
ignore_descriptions: Substrings (case-insensitive) to exclude by description.
include_descriptions: If provided, only include ports whose description contains one of these substrings.
"""
ignored = [d.lower() for d in (ignore_descriptions or [])]
includes = [d.lower() for d in (include_descriptions or [])]
results: List[Dict[str, str]] = []
for port in list_ports.comports():
path = port.device or ""
if not path:
continue
if not include_non_usb and not _is_usb_serial_port(port):
continue
desc_lower = (port.description or "").lower()
if includes and not any(keep in desc_lower for keep in includes):
continue
if any(skip in desc_lower for skip in ignored):
continue
results.append({"device": path, "description": port.description or "Unknown"})
return results
def first_serial_port(
include_non_usb: bool = False,
ignore_descriptions: Optional[List[str]] = None,
include_descriptions: Optional[List[str]] = None,
) -> Optional[str]:
"""Return the first discovered serial port path (or None if none are found)."""
ports = discover_serial_ports(include_non_usb, ignore_descriptions, include_descriptions)
if not ports:
return None
return ports[0]["device"]
def clamp_byte(value: Union[int, float]) -> int:
"""Clamp a numeric value to the 0-255 byte range."""
return max(0, min(255, int(value)))
def normalize_stick_value(value: Union[int, float]) -> int:
"""
Convert a normalized float (-1..1) or raw byte (0..255) to the stick range.
Floats are treated as -1.0 = full negative deflection, 0.0 = center,
1.0 = full positive deflection. Integers are assumed to already be in the
0-255 range.
"""
if isinstance(value, float):
value = max(-1.0, min(1.0, value))
value = int(round((value + 1.0) * 255 / 2.0))
return clamp_byte(value)
def axis_to_stick(value: int, deadzone: int) -> int:
"""Convert a signed axis value to 0-255 stick range with deadzone."""
if abs(value) < deadzone:
value = 0
scaled = int((value + 32768) * 255 / 65535)
return clamp_byte(scaled)
def trigger_to_button(value: int, threshold: int) -> bool:
"""Return True if analog trigger crosses digital threshold."""
return value >= threshold
def str_to_dpad(flags: Mapping[str, bool]) -> SwitchDpad:
"""Translate DPAD button flags into a Switch hat/DPAD value."""
up = flags.get("up", False)
down = flags.get("down", False)
left = flags.get("left", False)
right = flags.get("right", False)
if up and right:
return SwitchDpad.UP_RIGHT
if up and left:
return SwitchDpad.UP_LEFT
if down and right:
return SwitchDpad.DOWN_RIGHT
if down and left:
return SwitchDpad.DOWN_LEFT
if up:
return SwitchDpad.UP
if down:
return SwitchDpad.DOWN
if right:
return SwitchDpad.RIGHT
if left:
return SwitchDpad.LEFT
return SwitchDpad.CENTER
@dataclass
class SwitchReport:
buttons: int = 0
hat: SwitchDpad = SwitchDpad.CENTER
lx: int = 128
ly: int = 128
rx: int = 128
ry: int = 128
def to_bytes(self) -> bytes:
"""Serialize the report into the UART packet format."""
return struct.pack(
"<BHBBBBB", UART_HEADER, self.buttons & 0xFFFF, self.hat & 0xFF, self.lx, self.ly, self.rx, self.ry
)
class PicoUART:
def __init__(self, port: str, baudrate: int = UART_BAUD) -> None:
"""Open a UART connection to the Pico with non-blocking IO."""
self.serial = serial.Serial(
port=port,
baudrate=baudrate,
bytesize=serial.EIGHTBITS,
stopbits=serial.STOPBITS_ONE,
parity=serial.PARITY_NONE,
timeout=0.0,
write_timeout=0.0,
xonxoff=False,
rtscts=False,
dsrdtr=False,
)
self._buffer = bytearray()
def send_report(self, report: SwitchReport) -> None:
"""Send a controller report to the Pico."""
self.serial.write(report.to_bytes())
def read_rumble_payload(self) -> Optional[bytes]:
"""
Drain available UART bytes into an internal buffer, then extract one rumble frame.
Frame format:
0: 0xBB (RUMBLE_HEADER)
1: type (0x01 for rumble)
2-9: 8-byte rumble payload
10: checksum (sum of first 10 bytes) & 0xFF
"""
waiting = self.serial.in_waiting
if waiting:
self._buffer.extend(self.serial.read(waiting))
while True:
if not self._buffer:
return None
start = self._buffer.find(bytes([RUMBLE_HEADER]))
if start < 0:
self._buffer.clear()
return None
if len(self._buffer) - start < 11:
if start > 0:
del self._buffer[:start]
return None
frame = self._buffer[start:start + 11]
checksum = sum(frame[:10]) & 0xFF
if frame[1] == RUMBLE_TYPE_RUMBLE and checksum == frame[10]:
payload = bytes(frame[2:10])
del self._buffer[:start + 11]
return payload
del self._buffer[:start + 1]
def close(self) -> None:
"""Close the UART connection."""
self.serial.close()
def decode_rumble(payload: bytes) -> Tuple[float, float]:
"""Return normalized rumble amplitudes (0.0-1.0) for left/right."""
if len(payload) < 8:
return 0.0, 0.0
if payload == b"\x00\x01\x40\x40\x00\x01\x40\x40":
return 0.0, 0.0
right_raw = ((payload[1] & 0x03) << 8) | payload[0]
left_raw = ((payload[5] & 0x03) << 8) | payload[4]
if left_raw < 8 and right_raw < 8:
return 0.0, 0.0
left = min(max(left_raw / 1023.0, 0.0), 1.0)
right = min(max(right_raw / 1023.0, 0.0), 1.0)
return left, right
@dataclass
class SwitchControllerState:
"""Mutable controller state with helpers for building reports."""
report: SwitchReport = field(default_factory=SwitchReport)
def press(self, *buttons_or_hat: Union[SwitchButton, SwitchDpad, int]) -> None:
"""Press one or more buttons, or set the hat if a SwitchDpad is provided."""
for item in buttons_or_hat:
if isinstance(item, SwitchDpad):
# If multiple hats are provided, the last one wins.
self.report.hat = SwitchDpad(int(item) & 0xFF)
else:
self.report.buttons |= int(item)
def release(self, *buttons_or_hat: Union[SwitchButton, SwitchDpad, int]) -> None:
"""Release one or more buttons, or center the hat if a SwitchDpad is provided."""
for item in buttons_or_hat:
if isinstance(item, SwitchDpad):
self.report.hat = SwitchDpad.CENTER
else:
self.report.buttons &= ~int(item)
def set_buttons(self, buttons: Iterable[Union[SwitchButton, int]]) -> None:
"""Replace the current button bitmask with the provided buttons."""
self.report.buttons = 0
self.press(*buttons)
def set_hat(self, hat: Union[SwitchDpad, int]) -> None:
"""Set the DPAD/hat value directly."""
self.report.hat = SwitchDpad(int(hat) & 0xFF)
def move_left_stick(self, x: Union[int, float], y: Union[int, float]) -> None:
"""Move the left stick using normalized floats (-1..1) or raw bytes (0-255)."""
self.report.lx = normalize_stick_value(x)
self.report.ly = normalize_stick_value(y)
def move_right_stick(self, x: Union[int, float], y: Union[int, float]) -> None:
"""Move the right stick using normalized floats (-1..1) or raw bytes (0-255)."""
self.report.rx = normalize_stick_value(x)
self.report.ry = normalize_stick_value(y)
def neutral(self) -> None:
"""Clear all input back to the neutral controller state."""
self.report.buttons = 0
self.report.hat = SwitchDpad.CENTER
self.report.lx = 128
self.report.ly = 128
self.report.rx = 128
self.report.ry = 128
class SwitchUARTClient:
"""
High-level helper to send controller actions to the Pico and poll rumble.
Example:
with SwitchUARTClient("/dev/cu.usbserial-0001") as client:
client.press(SwitchButton.A)
time.sleep(0.1)
client.release(SwitchButton.A)
client.move_left_stick(0.0, -1.0) # push up
"""
def __init__(
self,
port: str,
baud: int = UART_BAUD,
send_interval: float = 1.0 / 500.0,
auto_send: bool = True,
) -> None:
"""
Args:
port: Serial port path (e.g., 'COM5' or '/dev/cu.usbserial-0001').
baud: UART baud rate.
send_interval: Minimum interval between sends in seconds (defaults to 500 Hz).
auto_send: If True, keep sending the current state in a background thread so the
Pico continuously sees the latest input (mirrors controller_uart_bridge).
"""
self.uart = PicoUART(port, baud)
self.state = SwitchControllerState()
self.send_interval = max(0.0, send_interval)
self._last_send = 0.0
self._auto_send = auto_send
self._stop_event = threading.Event()
self._auto_thread: Optional[threading.Thread] = None
if self._auto_send:
self._start_auto_send_thread()
def send(self) -> None:
"""Send the current state to the Pico, throttled by send_interval if set."""
now = time.monotonic()
if self.send_interval and (now - self._last_send) < self.send_interval:
return
self.uart.send_report(self.state.report)
self._last_send = now
def _start_auto_send_thread(self) -> None:
"""Continuously send the current state so the Pico stays active."""
if self._auto_thread is not None:
return
sleep_time = self.send_interval if self.send_interval > 0 else 0.002
def loop() -> None:
while not self._stop_event.is_set():
self.send()
self._stop_event.wait(sleep_time)
self._auto_thread = threading.Thread(target=loop, daemon=True)
self._auto_thread.start()
def press(self, *buttons: SwitchButton | SwitchDpad | int) -> None:
"""Press buttons or set hat using SwitchButton/SwitchDpad (ints also allowed)."""
self.state.press(*buttons)
self.send()
def release(self, *buttons: SwitchButton | SwitchDpad | int) -> None:
"""Release buttons or center hat when given a SwitchDpad."""
self.state.release(*buttons)
self.send()
def set_buttons(self, buttons: Iterable[SwitchButton | int]) -> None:
self.state.set_buttons(buttons)
self.send()
def set_hat(self, hat: SwitchDpad | int) -> None:
self.state.set_hat(hat)
self.send()
def move_left_stick(self, x: Union[int, float], y: Union[int, float]) -> None:
self.state.move_left_stick(x, y)
self.send()
def move_right_stick(self, x: Union[int, float], y: Union[int, float]) -> None:
self.state.move_right_stick(x, y)
self.send()
def press_for(self, duration: float, *buttons: SwitchButton | SwitchDpad | int) -> None:
"""Press buttons/hat for a duration, then release."""
self.press(*buttons)
time.sleep(max(0.0, duration))
self.release(*buttons)
def move_left_stick_for(
self, x: Union[int, float], y: Union[int, float], duration: float, neutral_after: bool = True
) -> None:
"""Move left stick for a duration, optionally returning it to neutral afterward."""
self.move_left_stick(x, y)
time.sleep(max(0.0, duration))
if neutral_after:
self.state.move_left_stick(128, 128)
self.send()
def move_right_stick_for(
self, x: Union[int, float], y: Union[int, float], duration: float, neutral_after: bool = True
) -> None:
"""Move right stick for a duration, optionally returning it to neutral afterward."""
self.move_right_stick(x, y)
time.sleep(max(0.0, duration))
if neutral_after:
self.state.move_right_stick(128, 128)
self.send()
def neutral(self) -> None:
self.state.neutral()
self.send()
def poll_rumble(self) -> Optional[Tuple[float, float]]:
"""
Poll for the latest rumble payload and return normalized amplitudes.
Returns None if no rumble frame was available.
"""
payload = self.uart.read_rumble_payload()
if payload:
return decode_rumble(payload)
return None
def close(self) -> None:
if self._auto_thread:
self._stop_event.set()
self._auto_thread.join(timeout=0.5)
self._auto_thread = None
self.uart.close()
def __enter__(self) -> "SwitchUARTClient":
return self
def __exit__(self, exc_type, exc, tb) -> None:
self.close()