Compare commits

...
Sign in to create a new pull request.

3 commits

7 changed files with 774 additions and 98 deletions

5
.gitignore vendored
View file

@ -7,4 +7,9 @@ debug
*.egg-info
hid-nintendo.c
__pycache__
Nintendo_Switch_Reverse_Engineering
nxbt
hid-nintendo.c
*.png
switch
!.vscode/*

49
gyro_support_plan.md Normal file
View file

@ -0,0 +1,49 @@
# Gyro Support Plan
## Checklist
- [x] Review current UART input/rumble plumbing (`switch-pico.cpp` and `controller_uart_bridge.py`) to confirm the existing 0xAA input frame (buttons/hat/sticks) and 0xBB rumble return path.
- [x] Map the current firmware IMU hooks (`switch_pro_driver.cpp`/`switch_pro_descriptors.h` `imuData` field, `is_imu_enabled` flag, report construction cadence) and note where IMU payloads should be injected.
- [x] Pull expected IMU packet format, sample cadence (3 samples/report), and unit scaling from references: `Nintendo_Switch_Reverse_Engineering/imu_sensor_notes.md`, `hid-nintendo.c`, and NXBT (`nxbt/controller/protocol.py`).
- [x] Examine GP2040-CEs motion implementation (e.g., `GP2040-CE/src/addons/imu` and Switch report handling) for framing, calibration defaults, and rate control patterns to reuse.
- [x] Decide on UART motion framing: header/length/checksum scheme, sample packing (likely 3 samples of accel+gyro int16 LE), endian/order, and compatibility with existing 8-byte frames (avoid breaking current host builds).
- [ ] Define Switch-facing IMU payload layout inside `SwitchProReport` (axis order, sign conventions, zero/neutral sample) and ensure it matches the reverse-engineered descriptors.
- [ ] Add firmware-side data structures/buffers for incoming IMU samples (triple-buffer if mirroring Joy-Con 3-sample bursts) and default zeroing when IMU is disabled/missing.
- [ ] Extend UART parser in `switch-pico.cpp` to accept the new motion frame(s), validate checksum, and stash samples atomically alongside button state.
- [ ] Gate IMU injection on the hosts `TOGGLE_IMU` feature flag (`is_imu_enabled`) and ensure reports carry motion data only when enabled; default to zeros otherwise.
- [ ] Apply calibration/scaling constants: choose defaults from references (e.g., 0.06103 dps/LSB gyro, accel per imu_sensor_notes) and document where to adjust for sensor-specific offsets.
- [ ] Update host bridge to enable SDL sensor support (`SDL_GameControllerSetSensorEnabled`, `SDL_CONTROLLERAXIS` vs sensor events) and capture gyro (and accel if needed) at the required rate.
- [ ] Buffer and pack host IMU readings into the agreed UART motion frame, including timestamping/rate smoothing so the firmware sees stable 200 Hz-ish samples (3 per 5 ms report).
- [ ] Keep backward compatibility: allow running without IMU-capable controllers (send zero motion) and keep rumble unchanged.
- [ ] Add logging/metrics: lightweight counters for dropped/late IMU frames and a debug toggle to inspect raw samples.
- [ ] Test matrix: host-only sensor capture sanity check; loopback UART frame validator; firmware USB capture with Switch (or nxbt PC host) verifying IMU report contents and that `TOGGLE_IMU` on/off behaves; regression check that buttons/rumble remain stable.
## Findings to date
- Firmware: `SwitchProReport.imuData[36]` exists but is always zero; `is_imu_enabled` is set only via `TOGGLE_IMU` feature report; `switch_pro_task` always sends `switch_report`, so IMU injection should happen before the memcmp/send path.
- IMU payload layout (standard 0x30/31/32/33): bytes 13-24 are accel_x, accel_y, accel_z, gyro_x, gyro_y, gyro_z (all Int16 LE); bytes 25-48 repeat two more samples (3 samples total, ~5 ms apart). Matches `imuData[36]` size and `hid-nintendo.c` parsing (`imu_raw_bytes` split into 3 `joycon_imu_data` structs).
- Scaling from references: accel ≈ 0.000244 G/LSB (±8000 mG), gyro ≈ 0.06103 dps/LSB (±2000 dps) or 0.070 dps/LSB with STs +15% headroom; Switch internally also uses rev/s conversions. Typical packet cadence ~15 ms with 3 samples (≈200 Hz sampling).
- NXBT reference: only injects IMU when `imu_enabled` flag is true; drops a 36-byte sample block at offset 14 (0-based) in the report. Example data is static; good for offset confirmation.
- GP2040-CE reference: Switch Pro driver mirrors this project—`imuData` zeroed, no motion handling yet. No reusable IMU framing, but report/keepalive cadence matches.
## UART IMU framing decision (breaking change OK)
- New host→Pico frame (versioned) replaces the old 8-byte 0xAA packet:
- Byte0: `0xAA` header
- Byte1: `0x02` version
- Byte2: `payload_len` (44 for the layout below)
- Byte3-4: buttons LE (same masks as before)
- Byte5: hat
- Byte6-9: sticks `lx, ly, rx, ry` (0-255 as before)
- Byte10: `imu_sample_count` (host should send 3; firmware may accept 0-3)
- Byte11-46: IMU samples, 3 blocks of 12 bytes each:
- For sample i: `accel_x, accel_y, accel_z, gyro_x, gyro_y, gyro_z` (all int16 LE, Pro Controller axis/sign convention)
- Byte47: checksum = (sum of bytes 0 through the last payload byte) & 0xFF
- Host behavior: always populate 3 samples per packet (~200 Hz, 5 ms spacing) with reference/default scaling; send zeros and `imu_sample_count=0` if IMU unavailable/disabled. Buttons/sticks unchanged.
- Firmware behavior: parse the new frame, validate `payload_len` and checksum, then atomically store button/stick plus up to `imu_sample_count` samples. If `is_imu_enabled` is true, copy samples into `switch_report.imuData` in the 3× sample layout; otherwise zero the IMU block.
- Axis orientation: match Pro Controller orientation (no additional flips beyond standard report ordering).
## Open Questions
- All answered:
- Include both accelerometer and gyro (full IMU).
- Reference/default scaling is acceptable (no per-device calibration).
- Mirror 3-sample bursts (~200 Hz, 5 ms spacing).
- Use Pro Controller axis orientation/sign.
- Breaking UART framing change is acceptable (use the versioned packet above).

View file

@ -21,7 +21,7 @@ import sys
import time
import urllib.request
from dataclasses import dataclass, field
from ctypes import create_string_buffer
from ctypes import create_string_buffer, c_float
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -34,6 +34,12 @@ from rich.table import Table
from .switch_pico_uart import (
UART_BAUD,
MS2_PER_G,
RAD_TO_DEG,
SENSOR_ACCEL,
SENSOR_GYRO,
IMU_SAMPLES_PER_REPORT,
IMUSample,
PicoUART,
SwitchButton,
SwitchDpad,
@ -49,9 +55,14 @@ RUMBLE_IDLE_TIMEOUT = 0.25 # seconds without packets before forcing rumble off
RUMBLE_STUCK_TIMEOUT = 0.60 # continuous same-energy rumble will be stopped after this
RUMBLE_MIN_ACTIVE = 0.40 # below this, rumble is treated as off/noise
RUMBLE_SCALE = 1.0
CONTROLLER_DB_URL_DEFAULT = (
"https://raw.githubusercontent.com/mdqinc/SDL_GameControllerDB/refs/heads/master/gamecontrollerdb.txt"
)
CONTROLLER_DB_URL_DEFAULT = "https://raw.githubusercontent.com/mdqinc/SDL_GameControllerDB/refs/heads/master/gamecontrollerdb.txt"
SDL_TRUE = getattr(sdl2, "SDL_TRUE", 1)
SDL_CONTROLLERSENSORUPDATE = getattr(sdl2, "SDL_CONTROLLERSENSORUPDATE", 0x658)
GYRO_BIAS_SAMPLES = 200
GYRO_DEADZONE_COUNTS = 0
# Keep a small window of IMU samples to pack into the next report
IMU_BUFFER_SIZE = 32
def parse_mapping(value: str) -> Tuple[int, str]:
@ -62,7 +73,9 @@ def parse_mapping(value: str) -> Tuple[int, str]:
try:
idx = int(idx_str, 10)
except ValueError as exc:
raise argparse.ArgumentTypeError(f"Invalid controller index '{idx_str}'") from exc
raise argparse.ArgumentTypeError(
f"Invalid controller index '{idx_str}'"
) from exc
if not port:
raise argparse.ArgumentTypeError("Serial port cannot be empty")
return idx, port.strip()
@ -84,9 +97,13 @@ def download_controller_db(console: Console, destination: Path, url: str) -> boo
try:
destination.write_bytes(data)
except Exception as exc:
console.print(f"[red]Failed to write controller database to {destination}: {exc}[/red]")
console.print(
f"[red]Failed to write controller database to {destination}: {exc}[/red]"
)
return False
console.print(f"[green]Updated controller database ({len(data)} bytes) at {destination}[/green]")
console.print(
f"[green]Updated controller database ({len(data)} bytes) at {destination}[/green]"
)
return True
@ -98,7 +115,9 @@ def parse_hotkey(value: str) -> str:
if not value:
return ""
if len(value) != 1:
raise argparse.ArgumentTypeError("Hotkeys must be a single character (or empty to disable).")
raise argparse.ArgumentTypeError(
"Hotkeys must be a single character (or empty to disable)."
)
return value
@ -151,7 +170,9 @@ def interactive_pairing(
mappings: List[Tuple[int, str]] = []
for controller_idx in controller_info:
if not available:
console.print("[bold red]No more UART devices available for pairing.[/bold red]")
console.print(
"[bold red]No more UART devices available for pairing.[/bold red]"
)
break
table = Table(
@ -174,7 +195,9 @@ def interactive_pairing(
idx = int(selection)
port = available.pop(idx)
mappings.append((controller_idx, port["device"]))
console.print(f"[bold green]Paired controller {controller_idx} with {port['device']}[/bold green]")
console.print(
f"[bold green]Paired controller {controller_idx} with {port['device']}[/bold green]"
)
return mappings
@ -188,7 +211,7 @@ def apply_rumble(controller: sdl2.SDL_GameController, payload: bytes) -> float:
return 0.0
# Attenuate to feel closer to a real controller; cap at ~25% strength.
scale = RUMBLE_SCALE
low = int(min(1.0, left_norm * scale) * 0xFFFF) # SDL: low_frequency_rumble
low = int(min(1.0, left_norm * scale) * 0xFFFF) # SDL: low_frequency_rumble
high = int(min(1.0, right_norm * scale) * 0xFFFF) # SDL: high_frequency_rumble
duration = 10
sdl2.SDL_GameControllerRumble(controller, low, high, duration)
@ -204,9 +227,18 @@ class ControllerContext:
port: Optional[str]
uart: Optional[PicoUART]
report: SwitchReport = field(default_factory=SwitchReport)
dpad: Dict[str, bool] = field(default_factory=lambda: {"up": False, "down": False, "left": False, "right": False})
dpad: Dict[str, bool] = field(
default_factory=lambda: {
"up": False,
"down": False,
"left": False,
"right": False,
}
)
button_state: Dict[int, bool] = field(default_factory=dict)
last_trigger_state: Dict[str, bool] = field(default_factory=lambda: {"left": False, "right": False})
last_trigger_state: Dict[str, bool] = field(
default_factory=lambda: {"left": False, "right": False}
)
last_send: float = 0.0
last_reopen_attempt: float = 0.0
last_rumble: float = 0.0
@ -214,6 +246,16 @@ class ControllerContext:
last_rumble_energy: float = 0.0
rumble_active: bool = False
axis_offsets: Dict[int, int] = field(default_factory=dict)
sensors_supported: bool = False
sensors_enabled: bool = False
imu_samples: List[IMUSample] = field(default_factory=list)
last_accel: Tuple[float, float, float] = (0.0, 0.0, 0.0)
gyro_bias_x: float = 0.0
gyro_bias_y: float = 0.0
gyro_bias_z: float = 0.0
gyro_bias_samples: int = 0
gyro_bias_locked: bool = False
last_debug_imu_print: float = 0.0
def capture_stick_offsets(controller: sdl2.SDL_GameController) -> Dict[int, int]:
@ -226,7 +268,9 @@ def capture_stick_offsets(controller: sdl2.SDL_GameController) -> Dict[int, int]
def format_axis_offsets(offsets: Dict[int, int]) -> str:
"""Return a human-friendly summary of per-axis offsets (for logging)."""
return ", ".join(f"{label}={offsets.get(axis, 0):+d}" for axis, label in STICK_AXIS_LABELS)
return ", ".join(
f"{label}={offsets.get(axis, 0):+d}" for axis, label in STICK_AXIS_LABELS
)
def calibrate_axis_value(value: int, axis: int, ctx: ControllerContext) -> int:
@ -242,7 +286,9 @@ def calibrate_axis_value(value: int, axis: int, ctx: ControllerContext) -> int:
class HotkeyMonitor:
"""Platform-aware helper that watches for configured hotkeys without blocking the main loop."""
def __init__(self, console: Console, key_messages: Optional[Dict[str, str]] = None) -> None:
def __init__(
self, console: Console, key_messages: Optional[Dict[str, str]] = None
) -> None:
self.console = console
self._platform = os.name
self._fd: Optional[int] = None
@ -271,7 +317,9 @@ class HotkeyMonitor:
try:
import msvcrt # type: ignore
except ImportError:
self.console.print("[yellow]Hotkeys disabled: msvcrt unavailable.[/yellow]")
self.console.print(
"[yellow]Hotkeys disabled: msvcrt unavailable.[/yellow]"
)
return False
self._msvcrt = msvcrt
self._active = True
@ -296,7 +344,11 @@ class HotkeyMonitor:
def suspend(self) -> None:
if not self._active:
return
if self._platform != "nt" and self._fd is not None and self._orig_termios is not None:
if (
self._platform != "nt"
and self._fd is not None
and self._orig_termios is not None
):
import termios
termios.tcsetattr(self._fd, termios.TCSADRAIN, self._orig_termios)
@ -316,7 +368,11 @@ class HotkeyMonitor:
self._active = True
def stop(self) -> None:
if self._platform != "nt" and self._fd is not None and self._orig_termios is not None:
if (
self._platform != "nt"
and self._fd is not None
and self._orig_termios is not None
):
import termios
termios.tcsetattr(self._fd, termios.TCSADRAIN, self._orig_termios)
@ -339,7 +395,9 @@ class HotkeyMonitor:
def _print_instructions(self) -> None:
if not self._keys:
return
instructions = " | ".join(f"'{key.upper()}' to {message}" for key, message in self._keys.items())
instructions = " | ".join(
f"'{key.upper()}' to {message}" for key, message in self._keys.items()
)
self.console.print(f"[magenta]Hotkeys active: {instructions}[/magenta]")
def _read_key(self) -> Optional[str]:
@ -361,7 +419,11 @@ class HotkeyMonitor:
return ch
def zero_context_sticks(ctx: ControllerContext, console: Optional[Console] = None, reason: str = "Zeroed stick centers") -> None:
def zero_context_sticks(
ctx: ControllerContext,
console: Optional[Console] = None,
reason: str = "Zeroed stick centers",
) -> None:
"""Capture and store the current stick positions for a controller."""
offsets = capture_stick_offsets(ctx.controller)
ctx.axis_offsets = offsets
@ -371,7 +433,9 @@ def zero_context_sticks(ctx: ControllerContext, console: Optional[Console] = Non
)
def zero_all_context_sticks(contexts: Dict[int, ControllerContext], console: Console) -> None:
def zero_all_context_sticks(
contexts: Dict[int, ControllerContext], console: Console
) -> None:
"""Zero every connected controller's sticks."""
if not contexts:
console.print("[yellow]No controllers available to zero right now.[/yellow]")
@ -390,10 +454,14 @@ def controller_display_name(ctx: ControllerContext) -> str:
return str(name)
def toggle_abxy_for_context(ctx: ControllerContext, config: BridgeConfig, console: Console) -> None:
def toggle_abxy_for_context(
ctx: ControllerContext, config: BridgeConfig, console: Console
) -> None:
"""Toggle the ABXY layout for a single controller."""
if config.swap_abxy_global:
console.print("[yellow]Global --swap-abxy is enabled; disable it to use per-controller toggles.[/yellow]")
console.print(
"[yellow]Global --swap-abxy is enabled; disable it to use per-controller toggles.[/yellow]"
)
return
swapped = ctx.stable_id in config.swap_abxy_ids
action = "standard" if swapped else "swapped"
@ -414,9 +482,13 @@ def prompt_swap_abxy_controller(
) -> None:
"""Prompt the user to choose a controller whose ABXY layout should be toggled."""
if not contexts:
console.print("[yellow]No controllers connected to toggle ABXY layout.[/yellow]")
console.print(
"[yellow]No controllers connected to toggle ABXY layout.[/yellow]"
)
return
controllers = sorted(contexts.values(), key=lambda ctx: (ctx.controller_index, ctx.instance_id))
controllers = sorted(
contexts.values(), key=lambda ctx: (ctx.controller_index, ctx.instance_id)
)
table = Table(title="Toggle ABXY layout for a controller")
table.add_column("Choice", justify="center")
table.add_column("SDL Index", justify="center")
@ -461,7 +533,9 @@ def open_controller(index: int) -> Tuple[sdl2.SDL_GameController, int, str]:
"""Open an SDL GameController by index and return it with instance ID and GUID string."""
controller = sdl2.SDL_GameControllerOpen(index)
if not controller:
raise RuntimeError(f"Failed to open controller {index}: {sdl2.SDL_GetError().decode()}")
raise RuntimeError(
f"Failed to open controller {index}: {sdl2.SDL_GetError().decode()}"
)
joystick = sdl2.SDL_GameControllerGetJoystick(controller)
instance_id = sdl2.SDL_JoystickInstanceID(joystick)
guid_str = guid_string_from_joystick(joystick)
@ -503,7 +577,9 @@ def open_uart_or_warn(port: str, baud: int, console: Console) -> Optional[PicoUA
def build_arg_parser() -> argparse.ArgumentParser:
"""Construct the CLI argument parser for the bridge."""
parser = argparse.ArgumentParser(description="Bridge SDL2 controllers to switch-pico UART (with rumble)")
parser = argparse.ArgumentParser(
description="Bridge SDL2 controllers to switch-pico UART (with rumble)"
)
parser.add_argument(
"--map",
action="append",
@ -516,13 +592,21 @@ def build_arg_parser() -> argparse.ArgumentParser:
nargs="+",
help="Serial ports to auto-pair with controllers in ascending index order.",
)
parser.add_argument("--interactive", action="store_true", help="Launch an interactive pairing UI using Rich.")
parser.add_argument("--all-ports", action="store_true", help="Include non-USB serial ports when listing devices.")
parser.add_argument(
"--interactive",
action="store_true",
help="Launch an interactive pairing UI using Rich.",
)
parser.add_argument(
"--all-ports",
action="store_true",
help="Include non-USB serial ports when listing devices.",
)
parser.add_argument(
"--frequency",
type=float,
default=500.0,
help="Report send frequency per controller (Hz, default 500)",
default=66.7,
help="Report send frequency per controller (Hz, default ~66.7 => ~15ms)",
)
parser.add_argument(
"--deadzone",
@ -624,10 +708,28 @@ def build_arg_parser() -> argparse.ArgumentParser:
default=[],
help="Path to an SDL2 controller mapping database (e.g. controllerdb.txt). Repeatable.",
)
parser.add_argument(
"--debug-imu",
action="store_true",
help="Print raw IMU readings (float and converted int16) for debugging.",
)
parser.add_argument(
"--gyro-scale",
type=float,
default=1.0,
help="Scale factor for gyro sensitivity (default 1.0). Reduce to < 1.0 if camera moves too fast.",
)
parser.add_argument(
"--no-gyro-bias",
action="store_true",
help="Disable automatic gyro bias calibration at startup.",
)
return parser
def poll_controller_buttons(ctx: ControllerContext, button_map: Dict[int, SwitchButton]) -> None:
def poll_controller_buttons(
ctx: ControllerContext, button_map: Dict[int, SwitchButton]
) -> None:
"""Update button/hat state based on current SDL controller readings."""
changed = False
for sdl_button, switch_bit in button_map.items():
@ -667,6 +769,9 @@ class BridgeConfig:
swap_abxy_indices: set[int]
swap_abxy_ids: set[str]
swap_abxy_global: bool
debug_imu: bool
gyro_scale: float
no_gyro_bias: bool
@dataclass
@ -682,7 +787,153 @@ class PairingState:
include_port_mfr: List[str] = field(default_factory=list)
def load_button_maps(console: Console, args: argparse.Namespace) -> Tuple[Dict[int, SwitchButton], Dict[int, SwitchButton], set[int]]:
def clamp_int16(value: float) -> int:
return int(max(-32768, min(32767, round(value))))
def read_sensor_triplet(
controller: sdl2.SDL_GameController, sensor_type: int
) -> Optional[Tuple[float, float, float]]:
if not hasattr(sdl2, "SDL_GameControllerGetSensorData"):
return None
data = (c_float * 3)()
result = sdl2.SDL_GameControllerGetSensorData(controller, sensor_type, data, 3)
if result != 0:
return None
return float(data[0]), float(data[1]), float(data[2])
def convert_accel_to_raw(accel_ms2: float) -> int:
g_units = accel_ms2 / MS2_PER_G
return clamp_int16(g_units * 4000.0)
def convert_gyro_to_raw(gyro_rad: float, scale: float = 1.0) -> int:
# SDL reports gyroscope data in radians/second; convert to dps then to Switch counts.
dps = gyro_rad * RAD_TO_DEG
# 0.070 dps/LSB is approx 14.28 LSB/dps.
counts = clamp_int16((dps / 0.061) * scale)
if abs(counts) < GYRO_DEADZONE_COUNTS:
return 0
return counts
def initialize_controller_sensors(ctx: ControllerContext, console: Console) -> None:
if not all(
hasattr(sdl2, name)
for name in (
"SDL_GameControllerHasSensor",
"SDL_GameControllerSetSensorEnabled",
"SDL_GameControllerGetSensorData",
)
):
return
if SENSOR_ACCEL is None or SENSOR_GYRO is None:
return
accel_supported = bool(
sdl2.SDL_GameControllerHasSensor(ctx.controller, SENSOR_ACCEL)
)
gyro_supported = bool(sdl2.SDL_GameControllerHasSensor(ctx.controller, SENSOR_GYRO))
ctx.sensors_supported = accel_supported and gyro_supported
if not ctx.sensors_supported:
console.print(
f"[yellow]Controller {ctx.controller_index} has no accelerometer/gyro sensors[/yellow]"
)
return
accel_enabled = (
sdl2.SDL_GameControllerSetSensorEnabled(ctx.controller, SENSOR_ACCEL, SDL_TRUE)
== 0
)
gyro_enabled = (
sdl2.SDL_GameControllerSetSensorEnabled(ctx.controller, SENSOR_GYRO, SDL_TRUE)
== 0
)
ctx.sensors_enabled = accel_enabled and gyro_enabled
if not ctx.sensors_enabled:
console.print(
f"[yellow]Controller {ctx.controller_index} failed to enable sensors[/yellow]"
)
def handle_sensor_update(event: sdl2.SDL_Event, contexts: Dict[int, ControllerContext], config: BridgeConfig) -> None:
ctx = contexts.get(event.csensor.which)
if not ctx:
return
sensor_type = event.csensor.sensor
data = event.csensor.data
if sensor_type == SENSOR_ACCEL:
ctx.last_accel = (data[0], data[1], data[2])
return
if sensor_type != SENSOR_GYRO:
return
# Process Gyro update (and combine with last accel)
gyro = (data[0], data[1], data[2])
if not ctx.gyro_bias_locked and ctx.gyro_bias_samples < GYRO_BIAS_SAMPLES:
ctx.gyro_bias_x += gyro[0]
ctx.gyro_bias_y += gyro[1]
ctx.gyro_bias_z += gyro[2]
ctx.gyro_bias_samples += 1
if ctx.gyro_bias_samples >= GYRO_BIAS_SAMPLES:
ctx.gyro_bias_x /= ctx.gyro_bias_samples
ctx.gyro_bias_y /= ctx.gyro_bias_samples
ctx.gyro_bias_z /= ctx.gyro_bias_samples
ctx.gyro_bias_locked = True
bias_x = ctx.gyro_bias_x if ctx.gyro_bias_locked else 0.0
bias_y = ctx.gyro_bias_y if ctx.gyro_bias_locked else 0.0
bias_z = ctx.gyro_bias_z if ctx.gyro_bias_locked else 0.0
# Use last known accel
accel = ctx.last_accel
# Map SDL sensor axes to Pro Controller axes: gravity should land on Z.
accel_raw_x = convert_accel_to_raw(accel[0])
accel_raw_y = convert_accel_to_raw(accel[1])
accel_raw_z = convert_accel_to_raw(accel[2])
gyro_raw_x = convert_gyro_to_raw(gyro[0] - bias_x, config.gyro_scale)
gyro_raw_y = convert_gyro_to_raw(gyro[1] - bias_y, config.gyro_scale)
gyro_raw_z = convert_gyro_to_raw(gyro[2] - bias_z, config.gyro_scale)
# Map SDL axes to Pro axes to match the native Pro USB output:
# Pro accel: ax = SDL_, ay = SDL_Z, az = SDL_Y (gravity).
# Pro gyro: gx = SDL_X, gy = SDL_Z, gz = SDL_Y.
sample = IMUSample(
accel_x=-accel_raw_z,
accel_y=-accel_raw_x,
accel_z=accel_raw_y,
gyro_x=convert_gyro_to_raw(-(gyro[2] - bias_z), config.gyro_scale),
gyro_y=convert_gyro_to_raw(-(gyro[0] - bias_x), config.gyro_scale),
gyro_z=convert_gyro_to_raw(gyro[1] - bias_y, config.gyro_scale),
)
ctx.imu_samples.append(sample)
if len(ctx.imu_samples) > IMU_BUFFER_SIZE:
ctx.imu_samples = ctx.imu_samples[-IMU_BUFFER_SIZE:]
if config.debug_imu:
now = time.monotonic()
if now - ctx.last_debug_imu_print > 0.2:
ctx.last_debug_imu_print = now
print(
f"[IMU dbg idx={ctx.controller_index}] "
f"accel_ms2=({accel[0]:.3f},{accel[1]:.3f},{accel[2]:.3f}) "
f"gyro_dps=({gyro[0]:.3f},{gyro[1]:.3f},{gyro[2]:.3f}) "
f"bias_dps=({bias_x:.3f},{bias_y:.3f},{bias_z:.3f}) "
f"raw=({sample.accel_x},{sample.accel_y},{sample.accel_z};"
f"{sample.gyro_x},{sample.gyro_y},{sample.gyro_z})"
)
def load_button_maps(
console: Console, args: argparse.Namespace
) -> Tuple[Dict[int, SwitchButton], Dict[int, SwitchButton], set[int]]:
"""Load SDL controller mappings and return button map variants."""
default_mapping = Path(__file__).parent / "controller_db" / "gamecontrollerdb.txt"
if args.update_controller_db or not default_mapping.exists():
@ -697,14 +948,20 @@ def load_button_maps(console: Console, args: argparse.Namespace) -> Tuple[Dict[i
button_map_swapped[sdl2.SDL_CONTROLLER_BUTTON_B] = SwitchButton.A
button_map_swapped[sdl2.SDL_CONTROLLER_BUTTON_X] = SwitchButton.Y
button_map_swapped[sdl2.SDL_CONTROLLER_BUTTON_Y] = SwitchButton.X
swap_abxy_indices = {idx for idx in args.swap_abxy_index if idx is not None and idx >= 0}
swap_abxy_indices = {
idx for idx in args.swap_abxy_index if idx is not None and idx >= 0
}
for mapping_path in mappings_to_load:
try:
loaded = sdl2.SDL_GameControllerAddMappingsFromFile(mapping_path.encode())
console.print(f"[green]Loaded {loaded} SDL mapping(s) from {mapping_path}[/green]")
console.print(
f"[green]Loaded {loaded} SDL mapping(s) from {mapping_path}[/green]"
)
except Exception as exc:
console.print(f"[red]Failed to load SDL mapping {mapping_path}: {exc}[/red]")
return button_map_default, button_map_swapped, swap_abxy_indices
console.print(
f"[red]Failed to load SDL mapping {mapping_path}: {exc}[/red]"
)
geturn button_map_default, button_map_swapped, swap_abxy_indices
def build_bridge_config(console: Console, args: argparse.Namespace) -> BridgeConfig:
@ -712,7 +969,9 @@ def build_bridge_config(console: Console, args: argparse.Namespace) -> BridgeCon
interval = 1.0 / max(args.frequency, 1.0)
deadzone_raw = int(max(0.0, min(args.deadzone, 1.0)) * 32767)
trigger_threshold = int(max(0.0, min(args.trigger_threshold, 1.0)) * 32767)
button_map_default, button_map_swapped, swap_abxy_indices = load_button_maps(console, args)
button_map_default, button_map_swapped, swap_abxy_indices = load_button_maps(
console, args
)
swap_abxy_guids = {g.lower() for g in args.swap_abxy_guid}
return BridgeConfig(
interval=interval,
@ -726,6 +985,9 @@ def build_bridge_config(console: Console, args: argparse.Namespace) -> BridgeCon
swap_abxy_indices=swap_abxy_indices,
swap_abxy_ids=set(swap_abxy_guids), # filled later once stable IDs are known
swap_abxy_global=bool(args.swap_abxy),
debug_imu=bool(args.debug_imu),
gyro_scale=args.gyro_scale,
no_gyro_bias=args.no_gyro_bias,
)
@ -736,7 +998,14 @@ def initialize_sdl(parser: argparse.ArgumentParser) -> None:
set_hint("SDL_JOYSTICK_HIDAPI_SWITCH", "1")
# Use controller button labels so Nintendo layouts (ABXY) map correctly on Linux.
set_hint("SDL_GAMECONTROLLER_USE_BUTTON_LABELS", "1")
if sdl2.SDL_Init(sdl2.SDL_INIT_GAMECONTROLLER | sdl2.SDL_INIT_JOYSTICK | sdl2.SDL_INIT_EVERYTHING) != 0:
if (
sdl2.SDL_Init(
sdl2.SDL_INIT_GAMECONTROLLER
| sdl2.SDL_INIT_JOYSTICK
| sdl2.SDL_INIT_EVERYTHING
)
!= 0
):
parser.error(f"SDL init failed: {sdl2.SDL_GetError().decode(errors='ignore')}")
@ -754,8 +1023,12 @@ def detect_controllers(
if sdl2.SDL_IsGameController(index):
name = sdl2.SDL_GameControllerNameForIndex(index)
name_str = name.decode() if isinstance(name, bytes) else str(name)
if include_controller_name and all(substr not in name_str.lower() for substr in include_controller_name):
console.print(f"[yellow]Skipping controller {index} ({name_str}) due to name filter[/yellow]")
if include_controller_name and all(
substr not in name_str.lower() for substr in include_controller_name
):
console.print(
f"[yellow]Skipping controller {index} ({name_str}) due to name filter[/yellow]"
)
continue
console.print(f"[cyan]Detected controller {index}: {name_str}[/cyan]")
controller_indices.append(index)
@ -763,14 +1036,22 @@ def detect_controllers(
else:
name = sdl2.SDL_JoystickNameForIndex(index)
name_str = name.decode() if isinstance(name, bytes) else str(name)
if include_controller_name and all(substr not in name_str.lower() for substr in include_controller_name):
console.print(f"[yellow]Skipping joystick {index} ({name_str}) due to name filter[/yellow]")
if include_controller_name and all(
substr not in name_str.lower() for substr in include_controller_name
):
console.print(
f"[yellow]Skipping joystick {index} ({name_str}) due to name filter[/yellow]"
)
continue
console.print(f"[yellow]Found joystick {index} (not a GameController): {name_str}[/yellow]")
console.print(
f"[yellow]Found joystick {index} (not a GameController): {name_str}[/yellow]"
)
return controller_indices, controller_names
def list_controllers_with_guids(console: Console, parser: argparse.ArgumentParser) -> None:
def list_controllers_with_guids(
console: Console, parser: argparse.ArgumentParser
) -> None:
"""Print detected controllers with their GUID strings and exit."""
count = sdl2.SDL_NumJoysticks()
if count < 0:
@ -785,10 +1066,16 @@ def list_controllers_with_guids(console: Console, parser: argparse.ArgumentParse
table.add_column("GUID")
for idx in range(count):
is_gc = sdl2.SDL_IsGameController(idx)
name = sdl2.SDL_GameControllerNameForIndex(idx) if is_gc else sdl2.SDL_JoystickNameForIndex(idx)
name = (
sdl2.SDL_GameControllerNameForIndex(idx)
if is_gc
else sdl2.SDL_JoystickNameForIndex(idx)
)
name_str = name.decode() if isinstance(name, bytes) else str(name)
guid_str = guid_string_for_device_index(idx)
table.add_row(str(idx), "GameController" if is_gc else "Joystick", name_str, guid_str)
table.add_row(
str(idx), "GameController" if is_gc else "Joystick", name_str, guid_str
)
console.print(table)
@ -827,7 +1114,9 @@ def prepare_pairing_state(
elif auto_pairing_enabled:
if args.ports:
available_ports.extend(list(args.ports))
console.print(f"[green]Prepared {len(available_ports)} specified UART port(s) for auto-pairing.[/green]")
console.print(
f"[green]Prepared {len(available_ports)} specified UART port(s) for auto-pairing.[/green]"
)
else:
# Passive mode: grab whatever UARTs exist now, and keep looking later.
discovered = discover_serial_ports(
@ -842,7 +1131,9 @@ def prepare_pairing_state(
for info in discovered:
console.print(f" {info['device']} ({info['description']})")
else:
console.print("[yellow]No UART devices detected yet; waiting for hotplug...[/yellow]")
console.print(
"[yellow]No UART devices detected yet; waiting for hotplug...[/yellow]"
)
mapping_by_index = {index: port for index, port in mappings}
return PairingState(
@ -857,7 +1148,9 @@ def prepare_pairing_state(
)
def assign_port_for_index(pairing: PairingState, idx: int, console: Console) -> Optional[str]:
def assign_port_for_index(
pairing: PairingState, idx: int, console: Console
) -> Optional[str]:
"""Return the UART assigned to a controller index, auto-pairing if allowed."""
if idx in pairing.mapping_by_index:
return pairing.mapping_by_index[idx]
@ -879,12 +1172,21 @@ def ports_in_use(pairing: PairingState, contexts: Dict[int, ControllerContext])
return used
def handle_removed_port(path: str, pairing: PairingState, contexts: Dict[int, ControllerContext], console: Console) -> None:
def handle_removed_port(
path: str,
pairing: PairingState,
contexts: Dict[int, ControllerContext],
console: Console,
) -> None:
"""Clear mappings/contexts for a UART path that disappeared."""
if path in pairing.available_ports:
pairing.available_ports.remove(path)
console.print(f"[yellow]UART {path} removed; dropping from available pool[/yellow]")
indices_to_clear = [idx for idx, mapped in pairing.mapping_by_index.items() if mapped == path]
console.print(
f"[yellow]UART {path} removed; dropping from available pool[/yellow]"
)
indices_to_clear = [
idx for idx, mapped in pairing.mapping_by_index.items() if mapped == path
]
for idx in indices_to_clear:
pairing.mapping_by_index.pop(idx, None)
pairing.auto_assigned_indices.discard(idx)
@ -902,10 +1204,14 @@ def handle_removed_port(path: str, pairing: PairingState, contexts: Dict[int, Co
ctx.rumble_active = False
ctx.last_rumble_energy = 0.0
ctx.last_reopen_attempt = time.monotonic()
console.print(f"[yellow]UART {path} removed; controller {ctx.controller_index} waiting for reassignment[/yellow]")
console.print(
f"[yellow]UART {path} removed; controller {ctx.controller_index} waiting for reassignment[/yellow]"
)
def discover_new_ports(pairing: PairingState, contexts: Dict[int, ControllerContext], console: Console) -> None:
def discover_new_ports(
pairing: PairingState, contexts: Dict[int, ControllerContext], console: Console
) -> None:
"""Scan for new serial ports and add unused ones to the available pool."""
if not pairing.auto_discover_ports:
return
@ -929,7 +1235,9 @@ def discover_new_ports(pairing: PairingState, contexts: Dict[int, ControllerCont
if path in in_use or path in pairing.available_ports:
continue
pairing.available_ports.append(path)
console.print(f"[green]Discovered UART {path} ({info['description']}); available for pairing.[/green]")
console.print(
f"[green]Discovered UART {path} ({info['description']}); available for pairing.[/green]"
)
def pair_waiting_contexts(
@ -977,7 +1285,9 @@ def open_initial_contexts(
if index >= sdl2.SDL_NumJoysticks() or not sdl2.SDL_IsGameController(index):
name = sdl2.SDL_JoystickNameForIndex(index)
name_str = name.decode() if isinstance(name, bytes) else str(name)
console.print(f"[yellow]Index {index} is not a GameController ({name_str}). Trying raw open failed.[/yellow]")
console.print(
f"[yellow]Index {index} is not a GameController ({name_str}). Trying raw open failed.[/yellow]"
)
continue
port = assign_port_for_index(pairing, index, console)
if port is None and not pairing.auto_pairing_enabled:
@ -993,9 +1303,13 @@ def open_initial_contexts(
uart = open_uart_or_warn(port, args.baud, console) if port else None
if uart:
uarts.append(uart)
console.print(f"[green]Controller {index} (id {stable_id}, inst {instance_id}) paired to {port}[/green]")
console.print(
f"[green]Controller {index} (id {stable_id}, inst {instance_id}) paired to {port}[/green]"
)
elif port:
console.print(f"[yellow]Controller {index} (id {stable_id}, inst {instance_id}) waiting for UART {port}[/yellow]")
console.print(
f"[yellow]Controller {index} (id {stable_id}, inst {instance_id}) waiting for UART {port}[/yellow]"
)
else:
console.print(
f"[yellow]Controller {index} (id {stable_id}, inst {instance_id}) connected; waiting for an available UART[/yellow]"
@ -1010,11 +1324,14 @@ def open_initial_contexts(
)
if config.zero_sticks:
zero_context_sticks(ctx, console)
initialize_controller_sensors(ctx, console)
contexts[instance_id] = ctx
return contexts, uarts
def handle_axis_motion(event: sdl2.SDL_Event, contexts: Dict[int, ControllerContext], config: BridgeConfig) -> None:
def handle_axis_motion(
event: sdl2.SDL_Event, contexts: Dict[int, ControllerContext], config: BridgeConfig
) -> None:
"""Process axis motion event into stick/trigger state."""
ctx = contexts.get(event.caxis.which)
if not ctx:
@ -1095,7 +1412,9 @@ def handle_device_added(
if idx >= sdl2.SDL_NumJoysticks() or not sdl2.SDL_IsGameController(idx):
name = sdl2.SDL_JoystickNameForIndex(idx)
name_str = name.decode() if isinstance(name, bytes) else str(name)
console.print(f"[yellow]Index {idx} is not a GameController ({name_str}). Trying raw open failed.[/yellow]")
console.print(
f"[yellow]Index {idx} is not a GameController ({name_str}). Trying raw open failed.[/yellow]"
)
return
try:
controller, instance_id, guid = open_controller(idx)
@ -1109,9 +1428,13 @@ def handle_device_added(
uart = open_uart_or_warn(port, args.baud, console) if port else None
if uart:
uarts.append(uart)
console.print(f"[green]Controller {idx} (id {stable_id}, inst {instance_id}) paired to {port}[/green]")
console.print(
f"[green]Controller {idx} (id {stable_id}, inst {instance_id}) paired to {port}[/green]"
)
elif port:
console.print(f"[yellow]Controller {idx} (id {stable_id}, inst {instance_id}) waiting for UART {port}[/yellow]")
console.print(
f"[yellow]Controller {idx} (id {stable_id}, inst {instance_id}) waiting for UART {port}[/yellow]"
)
else:
console.print(
f"[yellow]Controller {idx} (id {stable_id}, inst {instance_id}) connected; waiting for an available UART[/yellow]"
@ -1126,6 +1449,7 @@ def handle_device_added(
)
if config.zero_sticks:
zero_context_sticks(ctx, console)
initialize_controller_sensors(ctx, console)
contexts[instance_id] = ctx
@ -1140,7 +1464,9 @@ def handle_device_removed(
ctx = contexts.pop(instance_id, None)
if not ctx:
return
console.print(f"[yellow]Controller {instance_id} (id {ctx.stable_id}) removed[/yellow]")
console.print(
f"[yellow]Controller {instance_id} (id {ctx.stable_id}) removed[/yellow]"
)
if ctx.controller_index in pairing.auto_assigned_indices:
# Return auto-paired UART back to the pool so a future device can use it.
freed = pairing.mapping_by_index.pop(ctx.controller_index, None)
@ -1167,18 +1493,28 @@ def service_contexts(
else config.button_map_default
)
poll_controller_buttons(ctx, current_button_map)
# collect_imu_sample(ctx, config) <-- Removed, using event loop
# Reconnect UART if needed.
if ctx.port and ctx.uart is None and (now - ctx.last_reopen_attempt) > 1.0:
ctx.last_reopen_attempt = now
uart = open_uart_or_warn(ctx.port, args.baud, console)
if uart:
uarts.append(uart)
console.print(f"[green]Reconnected UART {ctx.port} for controller {ctx.controller_index}[/green]")
console.print(
f"[green]Reconnected UART {ctx.port} for controller {ctx.controller_index}[/green]"
)
ctx.uart = uart
if ctx.uart is None:
continue
try:
if now - ctx.last_send >= config.interval:
# Consume up to 3 samples from the head of the queue (FIFO)
count = min(len(ctx.imu_samples), IMU_SAMPLES_PER_REPORT)
if count > 0:
ctx.report.imu_samples = ctx.imu_samples[:count]
ctx.imu_samples = ctx.imu_samples[count:]
else:
ctx.report.imu_samples = []
ctx.uart.send_report(ctx.report)
ctx.last_send = now
@ -1201,7 +1537,10 @@ def service_contexts(
sdl2.SDL_GameControllerRumble(ctx.controller, 0, 0, 0)
ctx.rumble_active = False
ctx.last_rumble_energy = 0.0
elif ctx.rumble_active and (now - ctx.last_rumble_change) > RUMBLE_STUCK_TIMEOUT:
elif (
ctx.rumble_active
and (now - ctx.last_rumble_change) > RUMBLE_STUCK_TIMEOUT
):
sdl2.SDL_GameControllerRumble(ctx.controller, 0, 0, 0)
ctx.rumble_active = False
ctx.last_rumble_energy = 0.0
@ -1242,12 +1581,21 @@ def run_bridge_loop(
break
if event.type == sdl2.SDL_CONTROLLERAXISMOTION:
handle_axis_motion(event, contexts, config)
elif event.type in (sdl2.SDL_CONTROLLERBUTTONDOWN, sdl2.SDL_CONTROLLERBUTTONUP):
elif event.type in (
sdl2.SDL_CONTROLLERBUTTONDOWN,
sdl2.SDL_CONTROLLERBUTTONUP,
):
handle_button_event(event, config, contexts)
elif event.type in (sdl2.SDL_CONTROLLERBUTTONDOWN, sdl2.SDL_CONTROLLERBUTTONUP):
handle_button_event(event, args, config, contexts)
elif event.type == SDL_CONTROLLERSENSORUPDATE:
handle_sensor_update(event, contexts, config)
elif event.type == sdl2.SDL_CONTROLLERDEVICEADDED:
handle_device_added(event, args, pairing, contexts, uarts, console, config)
handle_device_added(
event, args, pairing, contexts, uarts, console, config
)
elif event.type == sdl2.SDL_CONTROLLERDEVICEREMOVED:
handle_device_removed(event, pairing, contexts, console)
gandle_device_removed(event, pairing, contexts, console)
now = time.monotonic()
if now - last_port_scan > port_scan_interval:
@ -1291,7 +1639,9 @@ def main() -> None:
list_controllers_with_guids(console, parser)
return
controller_indices, controller_names = detect_controllers(console, args, parser)
pairing = prepare_pairing_state(args, console, parser, controller_indices, controller_names)
pairing = prepare_pairing_state(
args, console, parser, controller_indices, controller_names
)
hotkey_messages: Dict[str, str] = {}
if config.zero_hotkey:
hotkey_messages[config.zero_hotkey] = "re-zero controller sticks"
@ -1301,14 +1651,20 @@ def main() -> None:
hotkey_messages[config.swap_hotkey] + "; toggle ABXY layout"
)
else:
hotkey_messages[config.swap_hotkey] = "toggle ABXY layout for a controller"
hotkey_messages[config.swap_hotkey] = (
"toggle ABXY layout for a controller"
)
if hotkey_messages:
candidate = HotkeyMonitor(console, hotkey_messages)
if candidate.start():
hotkey_monitor = candidate
contexts, uarts = open_initial_contexts(args, pairing, controller_indices, console, config)
contexts, uarts = open_initial_contexts(
args, pairing, controller_indices, console, config
)
if not contexts:
console.print("[yellow]No controllers opened; waiting for hotplug events...[/yellow]")
console.print(
"[yellow]No controllers opened; waiting for hotplug events...[/yellow]"
)
run_bridge_loop(args, console, config, pairing, contexts, uarts, hotkey_monitor)
finally:
if hotkey_monitor:

View file

@ -1,4 +1,4 @@
#include <stdio.h>
ginclude <stdio.h>
#include <string.h>
#include "bsp/board.h"
#include "hardware/uart.h"
@ -61,11 +61,13 @@ static void on_rumble_from_switch(const uint8_t rumble[8]) {
}
// Consume UART bytes and forward complete frames to the Switch Pro driver.
static void poll_uart_frames() {
static uint8_t buffer[8];
static bool poll_uart_frames() {
static uint8_t buffer[64];
static uint8_t index = 0;
static uint8_t expected_len = 0;
static absolute_time_t last_byte_time = {0};
static bool has_last_byte = false;
bool new_data = false;
while (uart_is_readable(UART_ID)) {
uint8_t byte = uart_getc(UART_ID);
@ -73,6 +75,7 @@ static void poll_uart_frames() {
uint64_t now = to_ms_since_boot(get_absolute_time());
if (has_last_byte && (now - to_ms_since_boot(last_byte_time)) > 20) {
index = 0; // stale data, restart frame
expected_len = 0;
}
last_byte_time = get_absolute_time();
has_last_byte = true;
@ -84,9 +87,19 @@ static void poll_uart_frames() {
}
buffer[index++] = byte;
if (index >= sizeof(buffer)) {
if (index == 3) {
// We just stored payload_len; compute expected frame length (payload + header/version/len/checksum).
expected_len = static_cast<uint8_t>(buffer[2] + 4);
if (expected_len > sizeof(buffer) || expected_len < 8) {
index = 0;
expected_len = 0;
continue;
}
}
if (expected_len && index >= expected_len) {
SwitchInputState parsed{};
if (switch_pro_apply_uart_packet(buffer, sizeof(buffer), &parsed)) {
if (switch_pro_apply_uart_packet(buffer, expected_len, &parsed)) {
g_user_state = parsed;
LOG_PRINTF("[UART] packet buttons=0x%04x hat=%u lx=%u ly=%u rx=%u ry=%u\n",
(parsed.button_a ? SWITCH_PRO_MASK_A : 0) |
@ -110,8 +123,11 @@ static void poll_uart_frames() {
parsed.lx >> 8, parsed.ly >> 8, parsed.rx >> 8, parsed.ry >> 8);
}
index = 0;
expected_len = 0;
new_data = true;
}
}
return new_data;
}
static void log_usb_state() {
@ -146,9 +162,13 @@ int main() {
while (true) {
tud_task(); // USB device tasks
poll_uart_frames(); // Pull controller state from UART1
bool new_data = poll_uart_frames(); // Pull controller state from UART1
SwitchInputState state = g_user_state;
switch_pro_set_input(state);
bool should_update = new_data;
if (should_update) {
switch_pro_set_input(state);
}
switch_pro_task(); // Push state to the Switch host
log_usb_state();
}

View file

@ -14,8 +14,8 @@
#define LOG_PRINTF(...) ((void)0)
#endif
// force a report to be sent every X ms
#define SWITCH_PRO_KEEPALIVE_TIMER 5
// force a report to be sent every X ms (roughly matches Pro Controller cadence)
#define SWITCH_PRO_KEEPALIVE_TIMER 15
static SwitchInputState g_input_state{
false, false, false, false,
@ -190,9 +190,42 @@ static SwitchInputState make_neutral_state() {
s.ly = SWITCH_PRO_JOYSTICK_MID;
s.rx = SWITCH_PRO_JOYSTICK_MID;
s.ry = SWITCH_PRO_JOYSTICK_MID;
s.imu_sample_count = 0;
return s;
}
static void fill_imu_report_data(const SwitchInputState& state) {
// Include IMU data when the host provided samples; otherwise zero.
if (state.imu_sample_count == 0) {
memset(switch_report.imuData, 0x00, sizeof(switch_report.imuData));
return;
}
uint8_t sample_count = state.imu_sample_count > 3 ? 3 : state.imu_sample_count;
uint8_t* dst = switch_report.imuData;
// Map host IMU axes (host already sends Pro-oriented X,Z,Y) into report layout:
// Report order per sample: accel_x, accel_y, accel_z, gyro_x, gyro_y, gyro_z.
for (uint8_t i = 0; i < 3; ++i) {
SwitchImuSample sample{};
if (i < sample_count) {
sample = state.imu_samples[i];
}
dst[0] = static_cast<uint8_t>(sample.accel_x & 0xFF);
dst[1] = static_cast<uint8_t>((sample.accel_x >> 8) & 0xFF);
dst[2] = static_cast<uint8_t>(sample.accel_y & 0xFF);
dst[3] = static_cast<uint8_t>((sample.accel_y >> 8) & 0xFF);
dst[4] = static_cast<uint8_t>(sample.accel_z & 0xFF);
dst[5] = static_cast<uint8_t>((sample.accel_z >> 8) & 0xFF);
dst[6] = static_cast<uint8_t>(sample.gyro_x & 0xFF);
dst[7] = static_cast<uint8_t>((sample.gyro_x >> 8) & 0xFF);
dst[8] = static_cast<uint8_t>(sample.gyro_y & 0xFF);
dst[9] = static_cast<uint8_t>((sample.gyro_y >> 8) & 0xFF);
dst[10] = static_cast<uint8_t>(sample.gyro_z & 0xFF);
dst[11] = static_cast<uint8_t>((sample.gyro_z >> 8) & 0xFF);
dst += 12;
}
}
static void send_identify() {
memset(report_buffer, 0x00, sizeof(report_buffer));
report_buffer[0] = REPORT_USB_INPUT_81;
@ -485,6 +518,7 @@ static void update_switch_report_from_state() {
switch_report.inputs.rightStick.setX(std::min(std::max(scaleRightStickX,rightMinX), rightMaxX));
switch_report.inputs.rightStick.setY(-std::min(std::max(scaleRightStickY,rightMinY), rightMaxY));
fill_imu_report_data(g_input_state);
switch_report.rumbleReport = 0x09;
}
@ -493,6 +527,7 @@ void switch_pro_init() {
last_report_counter = 0;
handshake_counter = 0;
is_ready = false;
is_imu_enabled = true; // default on to allow IMU during host bring-up/debug
is_initialized = false;
is_report_queued = false;
report_sent = false;
@ -594,14 +629,15 @@ void switch_pro_task() {
switch_report.timestamp = last_report_counter;
void * inputReport = &switch_report;
uint16_t report_size = sizeof(switch_report);
if (memcmp(last_report, inputReport, report_size) != 0) {
if (tud_hid_ready() && send_report(0, inputReport, report_size) == true ) {
memcpy(last_report, inputReport, report_size);
report_sent = true;
}
last_report_timer = now;
if (tud_hid_ready() && send_report(0, inputReport, report_size) == true ) {
memcpy(last_report, inputReport, report_size);
report_sent = true;
// Clear IMU samples so they aren't repeated in the next report
// if no new data arrives.
g_input_state.imu_sample_count = 0;
}
last_report_timer = now;
}
} else {
if (!is_initialized) {
@ -617,18 +653,51 @@ void switch_pro_task() {
}
bool switch_pro_apply_uart_packet(const uint8_t* packet, uint8_t length, SwitchInputState* out_state) {
// Packet format: 0xAA, buttons(2 LE), hat, lx, ly, rx, ry
// Packet v2 format:
// 0:0xAA header
// 1:version (0x02)
// 2:payload_len (bytes 3..3+len-1)
// 3-4: buttons LE
// 5: hat
// 6-9: lx, ly, rx, ry (0-255)
// 10: imu_sample_count (0-3)
// 11-46: up to 3 samples of accel/gyro (int16 LE each axis)
// 47: checksum (sum of bytes 0..46) & 0xFF
if (length < 8 || packet[0] != 0xAA) {
return false;
}
if (packet[1] != 0x02) {
return false;
}
uint8_t payload_len = packet[2];
uint16_t expected_len = static_cast<uint16_t>(payload_len) + 4; // header+version+len+checksum
if (length < expected_len) {
return false;
}
uint16_t checksum_end = static_cast<uint16_t>(3 + payload_len - 1); // last payload byte
uint16_t checksum_index = static_cast<uint16_t>(3 + payload_len);
if (checksum_index >= length) {
return false;
}
uint16_t sum = 0;
for (uint16_t i = 0; i <= checksum_end; ++i) {
sum = static_cast<uint16_t>(sum + packet[i]);
}
if ((sum & 0xFF) != packet[checksum_index]) {
return false;
}
SwitchProOutReport out{};
out.buttons = static_cast<uint16_t>(packet[1]) | (static_cast<uint16_t>(packet[2]) << 8);
out.hat = packet[3];
out.lx = packet[4];
out.ly = packet[5];
out.rx = packet[6];
out.ry = packet[7];
out.buttons = static_cast<uint16_t>(packet[3]) | (static_cast<uint16_t>(packet[4]) << 8);
out.hat = packet[5];
out.lx = packet[6];
out.ly = packet[7];
out.rx = packet[8];
out.ry = packet[9];
auto expand_axis = [](uint8_t v) -> uint16_t {
return static_cast<uint16_t>(v) << 8 | v;
@ -636,6 +705,23 @@ bool switch_pro_apply_uart_packet(const uint8_t* packet, uint8_t length, SwitchI
SwitchInputState state = make_neutral_state();
state.imu_sample_count = std::min<uint8_t>(packet[10], 3);
auto read_int16 = [](const uint8_t* src) -> int16_t {
return static_cast<int16_t>(static_cast<uint16_t>(src[0]) | (static_cast<uint16_t>(src[1]) << 8));
};
const uint8_t* imu_base = &packet[11];
for (uint8_t i = 0; i < state.imu_sample_count; ++i) {
const uint8_t* sample_ptr = imu_base + (i * 12);
state.imu_samples[i].accel_x = read_int16(sample_ptr + 0);
state.imu_samples[i].accel_y = read_int16(sample_ptr + 2);
state.imu_samples[i].accel_z = read_int16(sample_ptr + 4);
state.imu_samples[i].gyro_x = read_int16(sample_ptr + 6);
state.imu_samples[i].gyro_y = read_int16(sample_ptr + 8);
state.imu_samples[i].gyro_z = read_int16(sample_ptr + 10);
}
switch (out.hat) {
case SWITCH_PRO_HAT_UP: state.dpad_up = true; break;
case SWITCH_PRO_HAT_UPRIGHT: state.dpad_up = true; state.dpad_right = true; break;
@ -769,9 +855,9 @@ bool tud_control_request_cb(uint8_t rhport, tusb_control_request_t const * reque
void tud_mount_cb(void) {
LOG_PRINTF("[USB] mount_cb\n");
last_host_activity_ms = to_ms_since_boot(get_absolute_time());
forced_ready = false;
is_ready = false;
is_initialized = false;
forced_ready = true;
is_ready = true;
is_initialized = true;
}
void tud_umount_cb(void) {

View file

@ -10,6 +10,15 @@
#include <stdint.h>
#include "switch_pro_descriptors.h"
typedef struct {
int16_t accel_x;
int16_t accel_y;
int16_t accel_z;
int16_t gyro_x;
int16_t gyro_y;
int16_t gyro_z;
} SwitchImuSample;
typedef struct {
bool dpad_up;
bool dpad_down;
@ -35,6 +44,8 @@ typedef struct {
uint16_t ly;
uint16_t rx;
uint16_t ry;
uint8_t imu_sample_count;
SwitchImuSample imu_samples[3];
} SwitchInputState;
// Initialize USB state and calibration before entering the main loop.

149
tools/read_pro_imu.py Normal file
View file

@ -0,0 +1,149 @@
#!/usr/bin/env python3
"""
Read raw IMU samples from a Nintendo Switch Pro Controller (or Pico spoof) over USB.
Uses the `hidapi` (pyhidapi) package. Press Ctrl+C to exit.
"""
import argparse
import struct
import sys
from typing import List, Tuple
import hid # from pyhidapi
DEFAULT_VENDOR_ID = 0x057E
DEFAULT_PRODUCT_ID = 0x2009 # Switch Pro Controller (USB)
def list_devices(filter_vid=None, filter_pid=None):
devices = hid.enumerate()
for d in devices:
if filter_vid and d["vendor_id"] != filter_vid:
continue
if filter_pid and d["product_id"] != filter_pid:
continue
print(
f"VID=0x{d['vendor_id']:04X} PID=0x{d['product_id']:04X} "
f"path={d.get('path')} "
f"serial={d.get('serial_number')} "
f"manufacturer={d.get('manufacturer_string')} "
f"product={d.get('product_string')} "
f"interface={d.get('interface_number')}"
)
return devices
def find_device(vendor_id: int, product_id: int):
for dev in hid.enumerate():
if dev["vendor_id"] == vendor_id and dev["product_id"] == product_id:
return dev
return None
def main():
parser = argparse.ArgumentParser(description="Read raw 0x30 reports (IMU) from a Switch Pro Controller / Pico.")
parser.add_argument("--vid", type=lambda x: int(x, 0), default=DEFAULT_VENDOR_ID, help="Vendor ID (default 0x057E)")
parser.add_argument("--pid", type=lambda x: int(x, 0), default=DEFAULT_PRODUCT_ID, help="Product ID (default 0x2009)")
parser.add_argument("--path", help="Explicit HID path to open (overrides VID/PID).")
parser.add_argument("--count", type=int, default=0, help="Stop after this many 0x30 reports (0 = infinite).")
parser.add_argument("--timeout", type=int, default=3000, help="Read timeout ms (default 3000).")
parser.add_argument("--list", action="store_true", help="List detected HID devices and exit.")
parser.add_argument("--plot", action="store_true", help="Plot accel/gyro traces after capture (requires matplotlib).")
parser.add_argument("--save-prefix", help="If set, save accel/gyro plots as '<prefix>_accel.png' and '<prefix>_gyro.png'.")
args = parser.parse_args()
if args.list:
list_devices()
return
if args.path:
dev_info = {"path": bytes(args.path, encoding="utf-8"), "vendor_id": args.vid, "product_id": args.pid}
else:
dev_info = find_device(args.vid, args.pid)
if not dev_info:
print(
f"No HID device found for VID=0x{args.vid:04X} PID=0x{args.pid:04X}. "
"Use --list to inspect devices or --path to target a specific one.",
file=sys.stderr,
)
sys.exit(1)
device = hid.device()
device.open_path(dev_info["path"])
device.set_nonblocking(False)
print(
f"Reading raw 0x30 reports from device (VID=0x{args.vid:04X} PID=0x{args.pid:04X})... "
"Ctrl+C to stop."
)
accel_series: List[Tuple[int, int, int]] = []
gyro_series: List[Tuple[int, int, int]] = []
try:
read_count = 0
while args.count == 0 or read_count < args.count:
data = device.read(64, timeout_ms=args.timeout)
if not data:
print(f"(timeout after {args.timeout} ms, no data)")
continue
if data[0] != 0x30:
print(f"(non-0x30 report id=0x{data[0]:02X}, len={len(data)})")
continue
samples = []
offset = 13 # accel_x starts at byte 13
for _ in range(3):
ax, ay, az, gx, gy, gz = struct.unpack_from("<hhhhhh", bytes(data), offset)
samples.append((ax, ay, az, gx, gy, gz))
offset += 12
print(samples)
accel_series.extend((s[0], s[1], s[2]) for s in samples)
gyro_series.extend((s[3], s[4], s[5]) for s in samples)
read_count += 1
except KeyboardInterrupt:
pass
finally:
device.close()
if args.plot:
try:
import matplotlib.pyplot as plt
except Exception as exc: # pragma: no cover - optional dependency
print(f"Unable to plot (matplotlib not available): {exc}", file=sys.stderr)
return
if accel_series and gyro_series:
# Each sample is a tuple of three axes; plot per axis vs sample index.
accel_x = [s[0] for s in accel_series]
accel_y = [s[1] for s in accel_series]
accel_z = [s[2] for s in accel_series]
gyro_x = [s[0] for s in gyro_series]
gyro_y = [s[1] for s in gyro_series]
gyro_z = [s[2] for s in gyro_series]
fig1, ax1 = plt.subplots()
ax1.plot(accel_x, label="ax")
ax1.plot(accel_y, label="ay")
ax1.plot(accel_z, label="az")
ax1.set_title("Accel (counts)")
ax1.set_xlabel("Sample")
ax1.set_ylabel("Counts")
ax1.legend()
fig2, ax2 = plt.subplots()
ax2.plot(gyro_x, label="gx")
ax2.plot(gyro_y, label="gy")
ax2.plot(gyro_z, label="gz")
ax2.set_title("Gyro (counts)")
ax2.set_xlabel("Sample")
ax2.set_ylabel("Counts")
ax2.legend()
if args.save_prefix:
fig1.savefig(f"{args.save_prefix}_accel.png", dpi=150, bbox_inches="tight")
fig2.savefig(f"{args.save_prefix}_gyro.png", dpi=150, bbox_inches="tight")
print(f"Saved plots to {args.save_prefix}_accel.png and {args.save_prefix}_gyro.png")
plt.show()
if __name__ == "__main__":
main()