Refactor into lib

This commit is contained in:
Joey Yakimowich-Payne 2025-11-25 06:23:01 -07:00
commit 396d8ba33d
No known key found for this signature in database
GPG key ID: 6BFE655FA5ABD1E1
3 changed files with 359 additions and 184 deletions

View file

@ -66,6 +66,21 @@ Options:
Hotplugging: controllers and UARTs can be plugged/unplugged while running; the bridge will auto reconnect when possible.
### Using the lightweight UART helper (no SDL needed)
For simple scripts or tests you can skip SDL and drive the Pico directly with `switch_pico_uart.py`:
```python
from switch_pico_uart import SwitchUARTClient, SwitchButton, SwitchHat
with SwitchUARTClient("/dev/cu.usbserial-0001") as client:
client.press(SwitchButton.A)
client.release(SwitchButton.A)
client.move_left_stick(0.0, -1.0) # push up
client.set_hat(SwitchHat.TOP_RIGHT)
print(client.poll_rumble()) # returns (left, right) amplitudes 0.0-1.0 or None
```
- `SwitchButton` is an `IntFlag` (bitwise friendly) and `SwitchHat` is an `IntEnum` for the DPAD/hat values.
- The helper only depends on `pyserial`; SDL is not required.
### macOS tips
- Ensure the USBserial adapter shows up (use `/dev/cu.usb*` for TX).
- Some controllers Guide/Home buttons are intercepted by macOS; using XInput/DInput mode or disabling Steams controller handling helps.

View file

@ -16,15 +16,12 @@ Features inspired by ``host/controller_bridge.py``:
from __future__ import annotations
import argparse
import struct
import threading
import time
from dataclasses import dataclass, field
from ctypes import create_string_buffer
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import serial
from serial import SerialException
from serial.tools import list_ports
from serial.tools import list_ports_common
@ -35,46 +32,24 @@ from rich.console import Console
from rich.prompt import Prompt
from rich.table import Table
UART_HEADER = 0xAA
RUMBLE_HEADER = 0xBB
RUMBLE_TYPE_RUMBLE = 0x01
UART_BAUD = 921600
from switch_pico_uart import (
UART_BAUD,
PicoUART,
SwitchButton,
SwitchHat,
SwitchReport,
axis_to_stick,
dpad_to_hat,
decode_rumble,
trigger_to_button,
)
RUMBLE_IDLE_TIMEOUT = 0.25 # seconds without packets before forcing rumble off
RUMBLE_STUCK_TIMEOUT = 0.60 # continuous same-energy rumble will be stopped after this
RUMBLE_MIN_ACTIVE = 0.50 # below this, rumble is treated as off/noise
RUMBLE_SCALE = 0.8
class SwitchButton:
# Mirrors the masks defined in switch_pro_descriptors.h
Y = 1 << 0
B = 1 << 1
A = 1 << 2
X = 1 << 3
L = 1 << 4
R = 1 << 5
ZL = 1 << 6
ZR = 1 << 7
MINUS = 1 << 8
PLUS = 1 << 9
LCLICK = 1 << 10
RCLICK = 1 << 11
HOME = 1 << 12
CAPTURE = 1 << 13
class SwitchHat:
TOP = 0x00
TOP_RIGHT = 0x01
RIGHT = 0x02
BOTTOM_RIGHT = 0x03
BOTTOM = 0x04
BOTTOM_LEFT = 0x05
LEFT = 0x06
TOP_LEFT = 0x07
CENTER = 0x08
def parse_mapping(value: str) -> Tuple[int, str]:
"""Parse 'index:serial_port' CLI mapping argument."""
if ":" not in value:
@ -89,19 +64,6 @@ def parse_mapping(value: str) -> Tuple[int, str]:
return idx, port.strip()
def axis_to_stick(value: int, deadzone: int) -> int:
"""Convert a signed SDL axis value to 0-255 stick range with deadzone."""
if abs(value) < deadzone:
value = 0
scaled = int((value + 32768) * 255 / 65535)
return max(0, min(255, scaled))
def trigger_to_button(value: int, threshold: int) -> bool:
"""Return True if analog trigger crosses digital threshold."""
return value >= threshold
def set_hint(name: str, value: str) -> None:
"""Set an SDL hint safely even if the constant is missing in PySDL2."""
try:
@ -134,32 +96,6 @@ DPAD_BUTTONS = {
}
def dpad_to_hat(flags: Dict[str, bool]) -> int:
"""Translate DPAD button flags into a Switch hat value."""
up = flags["up"]
down = flags["down"]
left = flags["left"]
right = flags["right"]
if up and right:
return SwitchHat.TOP_RIGHT
if up and left:
return SwitchHat.TOP_LEFT
if down and right:
return SwitchHat.BOTTOM_RIGHT
if down and left:
return SwitchHat.BOTTOM_LEFT
if up:
return SwitchHat.TOP
if down:
return SwitchHat.BOTTOM
if right:
return SwitchHat.RIGHT
if left:
return SwitchHat.LEFT
return SwitchHat.CENTER
def is_usb_serial(path: str) -> bool:
"""Heuristic for USB serial path prefixes."""
if path.startswith("/dev/tty.") and not path.startswith("/dev/tty.usb"):
@ -244,110 +180,6 @@ def interactive_pairing(
return mappings
@dataclass
class SwitchReport:
buttons: int = 0
hat: int = SwitchHat.CENTER
lx: int = 128
ly: int = 128
rx: int = 128
ry: int = 128
def to_bytes(self) -> bytes:
"""Serialize the report into the UART packet format."""
return struct.pack(
"<BHBBBBB", UART_HEADER, self.buttons & 0xFFFF, self.hat & 0xFF, self.lx, self.ly, self.rx, self.ry
)
class PicoUART:
def __init__(self, port: str, baudrate: int = UART_BAUD) -> None:
"""Open a UART connection to the Pico with non-blocking IO."""
self.serial = serial.Serial(
port=port,
baudrate=baudrate,
bytesize=serial.EIGHTBITS,
stopbits=serial.STOPBITS_ONE,
parity=serial.PARITY_NONE,
timeout=0.0,
write_timeout=0.0,
xonxoff=False,
rtscts=False,
dsrdtr=False,
)
self._buffer = bytearray()
def send_report(self, report: SwitchReport) -> None:
"""Send a controller report to the Pico."""
# Non-blocking write; no flush to avoid sync stalls.
self.serial.write(report.to_bytes())
def read_rumble_payload(self) -> Optional[bytes]:
"""
Drain all currently available UART bytes into an internal buffer,
then try to extract a single valid rumble frame.
Frame format:
0: 0xBB (RUMBLE_HEADER)
1: type (0x01 for rumble)
2-9: 8-byte rumble payload
10: checksum (sum of first 10 bytes) & 0xFF
"""
# Read whatever is waiting in OS buffer
waiting = self.serial.in_waiting
if waiting:
self._buffer.extend(self.serial.read(waiting))
while True:
if not self._buffer:
return None
start = self._buffer.find(bytes([RUMBLE_HEADER]))
if start < 0:
# No header at all, drop garbage
self._buffer.clear()
return None
# Not enough data for a full frame yet
if len(self._buffer) - start < 11:
if start > 0:
del self._buffer[:start]
return None
frame = self._buffer[start:start + 11]
checksum = sum(frame[:10]) & 0xFF
if frame[1] == RUMBLE_TYPE_RUMBLE and checksum == frame[10]:
payload = bytes(frame[2:10])
del self._buffer[:start + 11]
return payload
# Bad frame, drop this header and resync to the next candidate
del self._buffer[:start + 1]
def close(self) -> None:
"""Close the UART connection."""
self.serial.close()
def decode_rumble(payload: bytes) -> Tuple[float, float]:
"""Return normalized rumble amplitudes (0.0-1.0) for left/right."""
if len(payload) < 8:
return 0.0, 0.0
# Neutral/idle pattern used by Switch: no rumble energy.
if payload == b"\x00\x01\x40\x40\x00\x01\x40\x40":
return 0.0, 0.0
# Rumble amp is 10 bits across bytes 0/1 and 4/5.
# Switch format is right rumble first, then left rumble (4 bytes each).
right_raw = ((payload[1] & 0x03) << 8) | payload[0]
left_raw = ((payload[5] & 0x03) << 8) | payload[4]
if left_raw < 8 and right_raw < 8:
return 0.0, 0.0
left = min(max(left_raw / 1023.0, 0.0), 1.0)
right = min(max(right_raw / 1023.0, 0.0), 1.0)
return left, right
def apply_rumble(controller: sdl2.SDL_GameController, payload: bytes) -> float:
"""Apply rumble payload to SDL controller and return max normalized energy."""
left_norm, right_norm = decode_rumble(payload)
@ -515,7 +347,7 @@ def build_arg_parser() -> argparse.ArgumentParser:
return parser
def poll_controller_buttons(ctx: ControllerContext, button_map: Dict[int, int]) -> None:
def poll_controller_buttons(ctx: ControllerContext, button_map: Dict[int, SwitchButton]) -> None:
"""Update button/hat state based on current SDL controller readings."""
changed = False
for sdl_button, switch_bit in button_map.items():
@ -547,8 +379,8 @@ class BridgeConfig:
interval: float
deadzone_raw: int
trigger_threshold: int
button_map_default: Dict[int, int]
button_map_swapped: Dict[int, int]
button_map_default: Dict[int, SwitchButton]
button_map_swapped: Dict[int, SwitchButton]
swap_abxy_indices: set[int]
swap_abxy_ids: set[str]
@ -565,7 +397,7 @@ class PairingState:
include_port_desc: List[str] = field(default_factory=list)
def load_button_maps(console: Console, args: argparse.Namespace) -> Tuple[Dict[int, int], Dict[int, int], set[int]]:
def load_button_maps(console: Console, args: argparse.Namespace) -> Tuple[Dict[int, SwitchButton], Dict[int, SwitchButton], set[int]]:
"""Load SDL controller mappings and return button map variants."""
default_mapping = Path(__file__).parent / "controller_db" / "gamecontrollerdb.txt"
mappings_to_load: List[str] = []

328
switch_pico_uart.py Normal file
View file

@ -0,0 +1,328 @@
#!/usr/bin/env python3
"""
Lightweight helpers for talking to the switch-pico firmware over UART.
This module exposes the raw report structure plus a small convenience wrapper
so other scripts can do things like "press a button" or "move a stick" without
depending on SDL. It mirrors the framing in ``switch-pico.cpp``:
Host -> Pico : 0xAA, buttons (LE16), hat, lx, ly, rx, ry
Pico -> Host : 0xBB, 0x01, 8 rumble bytes, checksum (sum of first 10 bytes)
"""
from __future__ import annotations
import struct
import time
from dataclasses import dataclass, field
from enum import IntEnum, IntFlag
from typing import Iterable, Mapping, Optional, Tuple, Union
import serial
UART_HEADER = 0xAA
RUMBLE_HEADER = 0xBB
RUMBLE_TYPE_RUMBLE = 0x01
UART_BAUD = 921600
class SwitchButton(IntFlag):
# Mirrors the masks defined in switch_pro_descriptors.h
Y = 1 << 0
B = 1 << 1
A = 1 << 2
X = 1 << 3
L = 1 << 4
R = 1 << 5
ZL = 1 << 6
ZR = 1 << 7
MINUS = 1 << 8
PLUS = 1 << 9
LCLICK = 1 << 10
RCLICK = 1 << 11
HOME = 1 << 12
CAPTURE = 1 << 13
class SwitchHat(IntEnum):
TOP = 0x00
TOP_RIGHT = 0x01
RIGHT = 0x02
BOTTOM_RIGHT = 0x03
BOTTOM = 0x04
BOTTOM_LEFT = 0x05
LEFT = 0x06
TOP_LEFT = 0x07
CENTER = 0x08
def clamp_byte(value: Union[int, float]) -> int:
"""Clamp a numeric value to the 0-255 byte range."""
return max(0, min(255, int(value)))
def normalize_stick_value(value: Union[int, float]) -> int:
"""
Convert a normalized float (-1..1) or raw byte (0..255) to the stick range.
Floats are treated as -1.0 = full negative deflection, 0.0 = center,
1.0 = full positive deflection. Integers are assumed to already be in the
0-255 range.
"""
if isinstance(value, float):
value = max(-1.0, min(1.0, value))
value = int(round((value + 1.0) * 255 / 2.0))
return clamp_byte(value)
def axis_to_stick(value: int, deadzone: int) -> int:
"""Convert a signed axis value to 0-255 stick range with deadzone."""
if abs(value) < deadzone:
value = 0
scaled = int((value + 32768) * 255 / 65535)
return clamp_byte(scaled)
def trigger_to_button(value: int, threshold: int) -> bool:
"""Return True if analog trigger crosses digital threshold."""
return value >= threshold
def dpad_to_hat(flags: Mapping[str, bool]) -> SwitchHat:
"""Translate DPAD button flags into a Switch hat value."""
up = flags.get("up", False)
down = flags.get("down", False)
left = flags.get("left", False)
right = flags.get("right", False)
if up and right:
return SwitchHat.TOP_RIGHT
if up and left:
return SwitchHat.TOP_LEFT
if down and right:
return SwitchHat.BOTTOM_RIGHT
if down and left:
return SwitchHat.BOTTOM_LEFT
if up:
return SwitchHat.TOP
if down:
return SwitchHat.BOTTOM
if right:
return SwitchHat.RIGHT
if left:
return SwitchHat.LEFT
return SwitchHat.CENTER
@dataclass
class SwitchReport:
buttons: int = 0
hat: SwitchHat = SwitchHat.CENTER
lx: int = 128
ly: int = 128
rx: int = 128
ry: int = 128
def to_bytes(self) -> bytes:
"""Serialize the report into the UART packet format."""
return struct.pack(
"<BHBBBBB", UART_HEADER, self.buttons & 0xFFFF, self.hat & 0xFF, self.lx, self.ly, self.rx, self.ry
)
class PicoUART:
def __init__(self, port: str, baudrate: int = UART_BAUD) -> None:
"""Open a UART connection to the Pico with non-blocking IO."""
self.serial = serial.Serial(
port=port,
baudrate=baudrate,
bytesize=serial.EIGHTBITS,
stopbits=serial.STOPBITS_ONE,
parity=serial.PARITY_NONE,
timeout=0.0,
write_timeout=0.0,
xonxoff=False,
rtscts=False,
dsrdtr=False,
)
self._buffer = bytearray()
def send_report(self, report: SwitchReport) -> None:
"""Send a controller report to the Pico."""
self.serial.write(report.to_bytes())
def read_rumble_payload(self) -> Optional[bytes]:
"""
Drain available UART bytes into an internal buffer, then extract one rumble frame.
Frame format:
0: 0xBB (RUMBLE_HEADER)
1: type (0x01 for rumble)
2-9: 8-byte rumble payload
10: checksum (sum of first 10 bytes) & 0xFF
"""
waiting = self.serial.in_waiting
if waiting:
self._buffer.extend(self.serial.read(waiting))
while True:
if not self._buffer:
return None
start = self._buffer.find(bytes([RUMBLE_HEADER]))
if start < 0:
self._buffer.clear()
return None
if len(self._buffer) - start < 11:
if start > 0:
del self._buffer[:start]
return None
frame = self._buffer[start:start + 11]
checksum = sum(frame[:10]) & 0xFF
if frame[1] == RUMBLE_TYPE_RUMBLE and checksum == frame[10]:
payload = bytes(frame[2:10])
del self._buffer[:start + 11]
return payload
del self._buffer[:start + 1]
def close(self) -> None:
"""Close the UART connection."""
self.serial.close()
def decode_rumble(payload: bytes) -> Tuple[float, float]:
"""Return normalized rumble amplitudes (0.0-1.0) for left/right."""
if len(payload) < 8:
return 0.0, 0.0
if payload == b"\x00\x01\x40\x40\x00\x01\x40\x40":
return 0.0, 0.0
right_raw = ((payload[1] & 0x03) << 8) | payload[0]
left_raw = ((payload[5] & 0x03) << 8) | payload[4]
if left_raw < 8 and right_raw < 8:
return 0.0, 0.0
left = min(max(left_raw / 1023.0, 0.0), 1.0)
right = min(max(right_raw / 1023.0, 0.0), 1.0)
return left, right
@dataclass
class SwitchControllerState:
"""Mutable controller state with helpers for building reports."""
report: SwitchReport = field(default_factory=SwitchReport)
def press(self, *buttons: Union[SwitchButton, int]) -> None:
"""Set one or more buttons as pressed."""
for button in buttons:
self.report.buttons |= int(button)
def release(self, *buttons: Union[SwitchButton, int]) -> None:
"""Release one or more buttons."""
for button in buttons:
self.report.buttons &= ~int(button)
def set_buttons(self, buttons: Iterable[Union[SwitchButton, int]]) -> None:
"""Replace the current button bitmask with the provided buttons."""
self.report.buttons = 0
self.press(*buttons)
def set_hat(self, hat: Union[SwitchHat, int]) -> None:
"""Set the DPAD/hat value directly."""
self.report.hat = int(hat) & 0xFF
def move_left_stick(self, x: Union[int, float], y: Union[int, float]) -> None:
"""Move the left stick using normalized floats (-1..1) or raw bytes (0-255)."""
self.report.lx = normalize_stick_value(x)
self.report.ly = normalize_stick_value(y)
def move_right_stick(self, x: Union[int, float], y: Union[int, float]) -> None:
"""Move the right stick using normalized floats (-1..1) or raw bytes (0-255)."""
self.report.rx = normalize_stick_value(x)
self.report.ry = normalize_stick_value(y)
def neutral(self) -> None:
"""Clear all input back to the neutral controller state."""
self.report.buttons = 0
self.report.hat = SwitchHat.CENTER
self.report.lx = 128
self.report.ly = 128
self.report.rx = 128
self.report.ry = 128
class SwitchUARTClient:
"""
High-level helper to send controller actions to the Pico and poll rumble.
Example:
with SwitchUARTClient("/dev/cu.usbserial-0001") as client:
client.press(SwitchButton.A)
time.sleep(0.1)
client.release(SwitchButton.A)
client.move_left_stick(0.0, -1.0) # push up
"""
def __init__(self, port: str, baud: int = UART_BAUD, send_interval: float = 0.0) -> None:
self.uart = PicoUART(port, baud)
self.state = SwitchControllerState()
self.send_interval = max(0.0, send_interval)
self._last_send = 0.0
def send(self) -> None:
"""Send the current state to the Pico, throttled by send_interval if set."""
now = time.monotonic()
if self.send_interval and (now - self._last_send) < self.send_interval:
return
self.uart.send_report(self.state.report)
self._last_send = now
def press(self, *buttons: int) -> None:
self.state.press(*buttons)
self.send()
def release(self, *buttons: int) -> None:
self.state.release(*buttons)
self.send()
def set_buttons(self, buttons: Iterable[int]) -> None:
self.state.set_buttons(buttons)
self.send()
def set_hat(self, hat: int) -> None:
self.state.set_hat(hat)
self.send()
def move_left_stick(self, x: Union[int, float], y: Union[int, float]) -> None:
self.state.move_left_stick(x, y)
self.send()
def move_right_stick(self, x: Union[int, float], y: Union[int, float]) -> None:
self.state.move_right_stick(x, y)
self.send()
def neutral(self) -> None:
self.state.neutral()
self.send()
def poll_rumble(self) -> Optional[Tuple[float, float]]:
"""
Poll for the latest rumble payload and return normalized amplitudes.
Returns None if no rumble frame was available.
"""
payload = self.uart.read_rumble_payload()
if payload:
return decode_rumble(payload)
return None
def close(self) -> None:
self.uart.close()
def __enter__(self) -> "SwitchUARTClient":
return self
def __exit__(self, exc_type, exc, tb) -> None:
self.close()