From 4640b9fe548f4b1e790145111384558c6c14ffb5 Mon Sep 17 00:00:00 2001 From: Joey Yakimowich-Payne Date: Mon, 24 Nov 2025 12:37:54 -0700 Subject: [PATCH 1/3] Initial gyro, swinging wildly --- .../controller_uart_bridge.py | 396 +++++++++++++++--- switch-pico.cpp | 19 +- switch_pro_driver.cpp | 96 ++++- switch_pro_driver.h | 11 + 4 files changed, 443 insertions(+), 79 deletions(-) diff --git a/src/switch_pico_bridge/controller_uart_bridge.py b/src/switch_pico_bridge/controller_uart_bridge.py index e223ed0..bcd4f67 100644 --- a/src/switch_pico_bridge/controller_uart_bridge.py +++ b/src/switch_pico_bridge/controller_uart_bridge.py @@ -21,7 +21,7 @@ import sys import time import urllib.request from dataclasses import dataclass, field -from ctypes import create_string_buffer +from ctypes import create_string_buffer, c_float from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -34,6 +34,12 @@ from rich.table import Table from .switch_pico_uart import ( UART_BAUD, + MS2_PER_G, + RAD_TO_DEG, + SENSOR_ACCEL, + SENSOR_GYRO, + IMU_SAMPLES_PER_REPORT, + IMUSample, PicoUART, SwitchButton, SwitchDpad, @@ -49,9 +55,9 @@ RUMBLE_IDLE_TIMEOUT = 0.25 # seconds without packets before forcing rumble off RUMBLE_STUCK_TIMEOUT = 0.60 # continuous same-energy rumble will be stopped after this RUMBLE_MIN_ACTIVE = 0.40 # below this, rumble is treated as off/noise RUMBLE_SCALE = 1.0 -CONTROLLER_DB_URL_DEFAULT = ( - "https://raw.githubusercontent.com/mdqinc/SDL_GameControllerDB/refs/heads/master/gamecontrollerdb.txt" -) +CONTROLLER_DB_URL_DEFAULT = "https://raw.githubusercontent.com/mdqinc/SDL_GameControllerDB/refs/heads/master/gamecontrollerdb.txt" + +SDL_TRUE = getattr(sdl2, "SDL_TRUE", 1) def parse_mapping(value: str) -> Tuple[int, str]: @@ -62,7 +68,9 @@ def parse_mapping(value: str) -> Tuple[int, str]: try: idx = int(idx_str, 10) except ValueError as exc: - raise argparse.ArgumentTypeError(f"Invalid controller index '{idx_str}'") from exc + raise argparse.ArgumentTypeError( + f"Invalid controller index '{idx_str}'" + ) from exc if not port: raise argparse.ArgumentTypeError("Serial port cannot be empty") return idx, port.strip() @@ -84,9 +92,13 @@ def download_controller_db(console: Console, destination: Path, url: str) -> boo try: destination.write_bytes(data) except Exception as exc: - console.print(f"[red]Failed to write controller database to {destination}: {exc}[/red]") + console.print( + f"[red]Failed to write controller database to {destination}: {exc}[/red]" + ) return False - console.print(f"[green]Updated controller database ({len(data)} bytes) at {destination}[/green]") + console.print( + f"[green]Updated controller database ({len(data)} bytes) at {destination}[/green]" + ) return True @@ -98,7 +110,9 @@ def parse_hotkey(value: str) -> str: if not value: return "" if len(value) != 1: - raise argparse.ArgumentTypeError("Hotkeys must be a single character (or empty to disable).") + raise argparse.ArgumentTypeError( + "Hotkeys must be a single character (or empty to disable)." + ) return value @@ -151,7 +165,9 @@ def interactive_pairing( mappings: List[Tuple[int, str]] = [] for controller_idx in controller_info: if not available: - console.print("[bold red]No more UART devices available for pairing.[/bold red]") + console.print( + "[bold red]No more UART devices available for pairing.[/bold red]" + ) break table = Table( @@ -174,7 +190,9 @@ def interactive_pairing( idx = int(selection) port = available.pop(idx) mappings.append((controller_idx, port["device"])) - console.print(f"[bold green]Paired controller {controller_idx} with {port['device']}[/bold green]") + console.print( + f"[bold green]Paired controller {controller_idx} with {port['device']}[/bold green]" + ) return mappings @@ -188,7 +206,7 @@ def apply_rumble(controller: sdl2.SDL_GameController, payload: bytes) -> float: return 0.0 # Attenuate to feel closer to a real controller; cap at ~25% strength. scale = RUMBLE_SCALE - low = int(min(1.0, left_norm * scale) * 0xFFFF) # SDL: low_frequency_rumble + low = int(min(1.0, left_norm * scale) * 0xFFFF) # SDL: low_frequency_rumble high = int(min(1.0, right_norm * scale) * 0xFFFF) # SDL: high_frequency_rumble duration = 10 sdl2.SDL_GameControllerRumble(controller, low, high, duration) @@ -204,9 +222,18 @@ class ControllerContext: port: Optional[str] uart: Optional[PicoUART] report: SwitchReport = field(default_factory=SwitchReport) - dpad: Dict[str, bool] = field(default_factory=lambda: {"up": False, "down": False, "left": False, "right": False}) + dpad: Dict[str, bool] = field( + default_factory=lambda: { + "up": False, + "down": False, + "left": False, + "right": False, + } + ) button_state: Dict[int, bool] = field(default_factory=dict) - last_trigger_state: Dict[str, bool] = field(default_factory=lambda: {"left": False, "right": False}) + last_trigger_state: Dict[str, bool] = field( + default_factory=lambda: {"left": False, "right": False} + ) last_send: float = 0.0 last_reopen_attempt: float = 0.0 last_rumble: float = 0.0 @@ -214,6 +241,10 @@ class ControllerContext: last_rumble_energy: float = 0.0 rumble_active: bool = False axis_offsets: Dict[int, int] = field(default_factory=dict) + sensors_supported: bool = False + sensors_enabled: bool = False + imu_samples: List[IMUSample] = field(default_factory=list) + last_sensor_poll: float = 0.0 def capture_stick_offsets(controller: sdl2.SDL_GameController) -> Dict[int, int]: @@ -226,7 +257,9 @@ def capture_stick_offsets(controller: sdl2.SDL_GameController) -> Dict[int, int] def format_axis_offsets(offsets: Dict[int, int]) -> str: """Return a human-friendly summary of per-axis offsets (for logging).""" - return ", ".join(f"{label}={offsets.get(axis, 0):+d}" for axis, label in STICK_AXIS_LABELS) + return ", ".join( + f"{label}={offsets.get(axis, 0):+d}" for axis, label in STICK_AXIS_LABELS + ) def calibrate_axis_value(value: int, axis: int, ctx: ControllerContext) -> int: @@ -242,7 +275,9 @@ def calibrate_axis_value(value: int, axis: int, ctx: ControllerContext) -> int: class HotkeyMonitor: """Platform-aware helper that watches for configured hotkeys without blocking the main loop.""" - def __init__(self, console: Console, key_messages: Optional[Dict[str, str]] = None) -> None: + def __init__( + self, console: Console, key_messages: Optional[Dict[str, str]] = None + ) -> None: self.console = console self._platform = os.name self._fd: Optional[int] = None @@ -271,7 +306,9 @@ class HotkeyMonitor: try: import msvcrt # type: ignore except ImportError: - self.console.print("[yellow]Hotkeys disabled: msvcrt unavailable.[/yellow]") + self.console.print( + "[yellow]Hotkeys disabled: msvcrt unavailable.[/yellow]" + ) return False self._msvcrt = msvcrt self._active = True @@ -296,7 +333,11 @@ class HotkeyMonitor: def suspend(self) -> None: if not self._active: return - if self._platform != "nt" and self._fd is not None and self._orig_termios is not None: + if ( + self._platform != "nt" + and self._fd is not None + and self._orig_termios is not None + ): import termios termios.tcsetattr(self._fd, termios.TCSADRAIN, self._orig_termios) @@ -316,7 +357,11 @@ class HotkeyMonitor: self._active = True def stop(self) -> None: - if self._platform != "nt" and self._fd is not None and self._orig_termios is not None: + if ( + self._platform != "nt" + and self._fd is not None + and self._orig_termios is not None + ): import termios termios.tcsetattr(self._fd, termios.TCSADRAIN, self._orig_termios) @@ -339,7 +384,9 @@ class HotkeyMonitor: def _print_instructions(self) -> None: if not self._keys: return - instructions = " | ".join(f"'{key.upper()}' to {message}" for key, message in self._keys.items()) + instructions = " | ".join( + f"'{key.upper()}' to {message}" for key, message in self._keys.items() + ) self.console.print(f"[magenta]Hotkeys active: {instructions}[/magenta]") def _read_key(self) -> Optional[str]: @@ -361,7 +408,11 @@ class HotkeyMonitor: return ch -def zero_context_sticks(ctx: ControllerContext, console: Optional[Console] = None, reason: str = "Zeroed stick centers") -> None: +def zero_context_sticks( + ctx: ControllerContext, + console: Optional[Console] = None, + reason: str = "Zeroed stick centers", +) -> None: """Capture and store the current stick positions for a controller.""" offsets = capture_stick_offsets(ctx.controller) ctx.axis_offsets = offsets @@ -371,7 +422,9 @@ def zero_context_sticks(ctx: ControllerContext, console: Optional[Console] = Non ) -def zero_all_context_sticks(contexts: Dict[int, ControllerContext], console: Console) -> None: +def zero_all_context_sticks( + contexts: Dict[int, ControllerContext], console: Console +) -> None: """Zero every connected controller's sticks.""" if not contexts: console.print("[yellow]No controllers available to zero right now.[/yellow]") @@ -390,10 +443,14 @@ def controller_display_name(ctx: ControllerContext) -> str: return str(name) -def toggle_abxy_for_context(ctx: ControllerContext, config: BridgeConfig, console: Console) -> None: +def toggle_abxy_for_context( + ctx: ControllerContext, config: BridgeConfig, console: Console +) -> None: """Toggle the ABXY layout for a single controller.""" if config.swap_abxy_global: - console.print("[yellow]Global --swap-abxy is enabled; disable it to use per-controller toggles.[/yellow]") + console.print( + "[yellow]Global --swap-abxy is enabled; disable it to use per-controller toggles.[/yellow]" + ) return swapped = ctx.stable_id in config.swap_abxy_ids action = "standard" if swapped else "swapped" @@ -414,9 +471,13 @@ def prompt_swap_abxy_controller( ) -> None: """Prompt the user to choose a controller whose ABXY layout should be toggled.""" if not contexts: - console.print("[yellow]No controllers connected to toggle ABXY layout.[/yellow]") + console.print( + "[yellow]No controllers connected to toggle ABXY layout.[/yellow]" + ) return - controllers = sorted(contexts.values(), key=lambda ctx: (ctx.controller_index, ctx.instance_id)) + controllers = sorted( + contexts.values(), key=lambda ctx: (ctx.controller_index, ctx.instance_id) + ) table = Table(title="Toggle ABXY layout for a controller") table.add_column("Choice", justify="center") table.add_column("SDL Index", justify="center") @@ -461,7 +522,9 @@ def open_controller(index: int) -> Tuple[sdl2.SDL_GameController, int, str]: """Open an SDL GameController by index and return it with instance ID and GUID string.""" controller = sdl2.SDL_GameControllerOpen(index) if not controller: - raise RuntimeError(f"Failed to open controller {index}: {sdl2.SDL_GetError().decode()}") + raise RuntimeError( + f"Failed to open controller {index}: {sdl2.SDL_GetError().decode()}" + ) joystick = sdl2.SDL_GameControllerGetJoystick(controller) instance_id = sdl2.SDL_JoystickInstanceID(joystick) guid_str = guid_string_from_joystick(joystick) @@ -503,7 +566,9 @@ def open_uart_or_warn(port: str, baud: int, console: Console) -> Optional[PicoUA def build_arg_parser() -> argparse.ArgumentParser: """Construct the CLI argument parser for the bridge.""" - parser = argparse.ArgumentParser(description="Bridge SDL2 controllers to switch-pico UART (with rumble)") + parser = argparse.ArgumentParser( + description="Bridge SDL2 controllers to switch-pico UART (with rumble)" + ) parser.add_argument( "--map", action="append", @@ -516,8 +581,16 @@ def build_arg_parser() -> argparse.ArgumentParser: nargs="+", help="Serial ports to auto-pair with controllers in ascending index order.", ) - parser.add_argument("--interactive", action="store_true", help="Launch an interactive pairing UI using Rich.") - parser.add_argument("--all-ports", action="store_true", help="Include non-USB serial ports when listing devices.") + parser.add_argument( + "--interactive", + action="store_true", + help="Launch an interactive pairing UI using Rich.", + ) + parser.add_argument( + "--all-ports", + action="store_true", + help="Include non-USB serial ports when listing devices.", + ) parser.add_argument( "--frequency", type=float, @@ -627,7 +700,9 @@ def build_arg_parser() -> argparse.ArgumentParser: return parser -def poll_controller_buttons(ctx: ControllerContext, button_map: Dict[int, SwitchButton]) -> None: +def poll_controller_buttons( + ctx: ControllerContext, button_map: Dict[int, SwitchButton] +) -> None: """Update button/hat state based on current SDL controller readings.""" changed = False for sdl_button, switch_bit in button_map.items(): @@ -682,7 +757,95 @@ class PairingState: include_port_mfr: List[str] = field(default_factory=list) -def load_button_maps(console: Console, args: argparse.Namespace) -> Tuple[Dict[int, SwitchButton], Dict[int, SwitchButton], set[int]]: +def clamp_int16(value: float) -> int: + return int(max(-32768, min(32767, round(value)))) + + +def read_sensor_triplet( + controller: sdl2.SDL_GameController, sensor_type: int +) -> Optional[Tuple[float, float, float]]: + if not hasattr(sdl2, "SDL_GameControllerGetSensorData"): + return None + data = (c_float * 3)() + result = sdl2.SDL_GameControllerGetSensorData(controller, sensor_type, data, 3) + if result != 0: + return None + return float(data[0]), float(data[1]), float(data[2]) + + +def convert_accel_to_raw(accel_ms2: float) -> int: + g_units = accel_ms2 / MS2_PER_G + return clamp_int16(g_units * 4096.0) + + +def convert_gyro_to_raw(gyro_rad: float) -> int: + dps = gyro_rad * RAD_TO_DEG + return clamp_int16(dps / 0.070) + + +def initialize_controller_sensors(ctx: ControllerContext, console: Console) -> None: + if not all( + hasattr(sdl2, name) + for name in ( + "SDL_GameControllerHasSensor", + "SDL_GameControllerSetSensorEnabled", + "SDL_GameControllerGetSensorData", + ) + ): + return + + if SENSOR_ACCEL is None or SENSOR_GYRO is None: + return + + accel_supported = bool( + sdl2.SDL_GameControllerHasSensor(ctx.controller, SENSOR_ACCEL) + ) + gyro_supported = bool(sdl2.SDL_GameControllerHasSensor(ctx.controller, SENSOR_GYRO)) + ctx.sensors_supported = accel_supported and gyro_supported + if not ctx.sensors_supported: + console.print( + f"[yellow]Controller {ctx.controller_index} has no accelerometer/gyro sensors[/yellow]" + ) + return + + accel_enabled = ( + sdl2.SDL_GameControllerSetSensorEnabled(ctx.controller, SENSOR_ACCEL, SDL_TRUE) + == 0 + ) + gyro_enabled = ( + sdl2.SDL_GameControllerSetSensorEnabled(ctx.controller, SENSOR_GYRO, SDL_TRUE) + == 0 + ) + ctx.sensors_enabled = accel_enabled and gyro_enabled + if not ctx.sensors_enabled: + console.print( + f"[yellow]Controller {ctx.controller_index} failed to enable sensors[/yellow]" + ) + + +def collect_imu_sample(ctx: ControllerContext) -> None: + if not ctx.sensors_enabled or SENSOR_ACCEL is None or SENSOR_GYRO is None: + return + accel = read_sensor_triplet(ctx.controller, SENSOR_ACCEL) + gyro = read_sensor_triplet(ctx.controller, SENSOR_GYRO) + if not accel or not gyro: + return + sample = IMUSample( + accel_x=convert_accel_to_raw(accel[0]), + accel_y=convert_accel_to_raw(accel[1]), + accel_z=convert_accel_to_raw(accel[2]), + gyro_x=convert_gyro_to_raw(gyro[0]), + gyro_y=convert_gyro_to_raw(gyro[1]), + gyro_z=convert_gyro_to_raw(gyro[2]), + ) + ctx.imu_samples.append(sample) + if len(ctx.imu_samples) > 6: + ctx.imu_samples = ctx.imu_samples[-6:] + + +def load_button_maps( + console: Console, args: argparse.Namespace +) -> Tuple[Dict[int, SwitchButton], Dict[int, SwitchButton], set[int]]: """Load SDL controller mappings and return button map variants.""" default_mapping = Path(__file__).parent / "controller_db" / "gamecontrollerdb.txt" if args.update_controller_db or not default_mapping.exists(): @@ -697,13 +860,19 @@ def load_button_maps(console: Console, args: argparse.Namespace) -> Tuple[Dict[i button_map_swapped[sdl2.SDL_CONTROLLER_BUTTON_B] = SwitchButton.A button_map_swapped[sdl2.SDL_CONTROLLER_BUTTON_X] = SwitchButton.Y button_map_swapped[sdl2.SDL_CONTROLLER_BUTTON_Y] = SwitchButton.X - swap_abxy_indices = {idx for idx in args.swap_abxy_index if idx is not None and idx >= 0} + swap_abxy_indices = { + idx for idx in args.swap_abxy_index if idx is not None and idx >= 0 + } for mapping_path in mappings_to_load: try: loaded = sdl2.SDL_GameControllerAddMappingsFromFile(mapping_path.encode()) - console.print(f"[green]Loaded {loaded} SDL mapping(s) from {mapping_path}[/green]") + console.print( + f"[green]Loaded {loaded} SDL mapping(s) from {mapping_path}[/green]" + ) except Exception as exc: - console.print(f"[red]Failed to load SDL mapping {mapping_path}: {exc}[/red]") + console.print( + f"[red]Failed to load SDL mapping {mapping_path}: {exc}[/red]" + ) return button_map_default, button_map_swapped, swap_abxy_indices @@ -712,7 +881,9 @@ def build_bridge_config(console: Console, args: argparse.Namespace) -> BridgeCon interval = 1.0 / max(args.frequency, 1.0) deadzone_raw = int(max(0.0, min(args.deadzone, 1.0)) * 32767) trigger_threshold = int(max(0.0, min(args.trigger_threshold, 1.0)) * 32767) - button_map_default, button_map_swapped, swap_abxy_indices = load_button_maps(console, args) + button_map_default, button_map_swapped, swap_abxy_indices = load_button_maps( + console, args + ) swap_abxy_guids = {g.lower() for g in args.swap_abxy_guid} return BridgeConfig( interval=interval, @@ -736,7 +907,14 @@ def initialize_sdl(parser: argparse.ArgumentParser) -> None: set_hint("SDL_JOYSTICK_HIDAPI_SWITCH", "1") # Use controller button labels so Nintendo layouts (ABXY) map correctly on Linux. set_hint("SDL_GAMECONTROLLER_USE_BUTTON_LABELS", "1") - if sdl2.SDL_Init(sdl2.SDL_INIT_GAMECONTROLLER | sdl2.SDL_INIT_JOYSTICK | sdl2.SDL_INIT_EVERYTHING) != 0: + if ( + sdl2.SDL_Init( + sdl2.SDL_INIT_GAMECONTROLLER + | sdl2.SDL_INIT_JOYSTICK + | sdl2.SDL_INIT_EVERYTHING + ) + != 0 + ): parser.error(f"SDL init failed: {sdl2.SDL_GetError().decode(errors='ignore')}") @@ -754,8 +932,12 @@ def detect_controllers( if sdl2.SDL_IsGameController(index): name = sdl2.SDL_GameControllerNameForIndex(index) name_str = name.decode() if isinstance(name, bytes) else str(name) - if include_controller_name and all(substr not in name_str.lower() for substr in include_controller_name): - console.print(f"[yellow]Skipping controller {index} ({name_str}) due to name filter[/yellow]") + if include_controller_name and all( + substr not in name_str.lower() for substr in include_controller_name + ): + console.print( + f"[yellow]Skipping controller {index} ({name_str}) due to name filter[/yellow]" + ) continue console.print(f"[cyan]Detected controller {index}: {name_str}[/cyan]") controller_indices.append(index) @@ -763,14 +945,22 @@ def detect_controllers( else: name = sdl2.SDL_JoystickNameForIndex(index) name_str = name.decode() if isinstance(name, bytes) else str(name) - if include_controller_name and all(substr not in name_str.lower() for substr in include_controller_name): - console.print(f"[yellow]Skipping joystick {index} ({name_str}) due to name filter[/yellow]") + if include_controller_name and all( + substr not in name_str.lower() for substr in include_controller_name + ): + console.print( + f"[yellow]Skipping joystick {index} ({name_str}) due to name filter[/yellow]" + ) continue - console.print(f"[yellow]Found joystick {index} (not a GameController): {name_str}[/yellow]") + console.print( + f"[yellow]Found joystick {index} (not a GameController): {name_str}[/yellow]" + ) return controller_indices, controller_names -def list_controllers_with_guids(console: Console, parser: argparse.ArgumentParser) -> None: +def list_controllers_with_guids( + console: Console, parser: argparse.ArgumentParser +) -> None: """Print detected controllers with their GUID strings and exit.""" count = sdl2.SDL_NumJoysticks() if count < 0: @@ -785,10 +975,16 @@ def list_controllers_with_guids(console: Console, parser: argparse.ArgumentParse table.add_column("GUID") for idx in range(count): is_gc = sdl2.SDL_IsGameController(idx) - name = sdl2.SDL_GameControllerNameForIndex(idx) if is_gc else sdl2.SDL_JoystickNameForIndex(idx) + name = ( + sdl2.SDL_GameControllerNameForIndex(idx) + if is_gc + else sdl2.SDL_JoystickNameForIndex(idx) + ) name_str = name.decode() if isinstance(name, bytes) else str(name) guid_str = guid_string_for_device_index(idx) - table.add_row(str(idx), "GameController" if is_gc else "Joystick", name_str, guid_str) + table.add_row( + str(idx), "GameController" if is_gc else "Joystick", name_str, guid_str + ) console.print(table) @@ -827,7 +1023,9 @@ def prepare_pairing_state( elif auto_pairing_enabled: if args.ports: available_ports.extend(list(args.ports)) - console.print(f"[green]Prepared {len(available_ports)} specified UART port(s) for auto-pairing.[/green]") + console.print( + f"[green]Prepared {len(available_ports)} specified UART port(s) for auto-pairing.[/green]" + ) else: # Passive mode: grab whatever UARTs exist now, and keep looking later. discovered = discover_serial_ports( @@ -842,7 +1040,9 @@ def prepare_pairing_state( for info in discovered: console.print(f" {info['device']} ({info['description']})") else: - console.print("[yellow]No UART devices detected yet; waiting for hotplug...[/yellow]") + console.print( + "[yellow]No UART devices detected yet; waiting for hotplug...[/yellow]" + ) mapping_by_index = {index: port for index, port in mappings} return PairingState( @@ -857,7 +1057,9 @@ def prepare_pairing_state( ) -def assign_port_for_index(pairing: PairingState, idx: int, console: Console) -> Optional[str]: +def assign_port_for_index( + pairing: PairingState, idx: int, console: Console +) -> Optional[str]: """Return the UART assigned to a controller index, auto-pairing if allowed.""" if idx in pairing.mapping_by_index: return pairing.mapping_by_index[idx] @@ -879,12 +1081,21 @@ def ports_in_use(pairing: PairingState, contexts: Dict[int, ControllerContext]) return used -def handle_removed_port(path: str, pairing: PairingState, contexts: Dict[int, ControllerContext], console: Console) -> None: +def handle_removed_port( + path: str, + pairing: PairingState, + contexts: Dict[int, ControllerContext], + console: Console, +) -> None: """Clear mappings/contexts for a UART path that disappeared.""" if path in pairing.available_ports: pairing.available_ports.remove(path) - console.print(f"[yellow]UART {path} removed; dropping from available pool[/yellow]") - indices_to_clear = [idx for idx, mapped in pairing.mapping_by_index.items() if mapped == path] + console.print( + f"[yellow]UART {path} removed; dropping from available pool[/yellow]" + ) + indices_to_clear = [ + idx for idx, mapped in pairing.mapping_by_index.items() if mapped == path + ] for idx in indices_to_clear: pairing.mapping_by_index.pop(idx, None) pairing.auto_assigned_indices.discard(idx) @@ -902,10 +1113,14 @@ def handle_removed_port(path: str, pairing: PairingState, contexts: Dict[int, Co ctx.rumble_active = False ctx.last_rumble_energy = 0.0 ctx.last_reopen_attempt = time.monotonic() - console.print(f"[yellow]UART {path} removed; controller {ctx.controller_index} waiting for reassignment[/yellow]") + console.print( + f"[yellow]UART {path} removed; controller {ctx.controller_index} waiting for reassignment[/yellow]" + ) -def discover_new_ports(pairing: PairingState, contexts: Dict[int, ControllerContext], console: Console) -> None: +def discover_new_ports( + pairing: PairingState, contexts: Dict[int, ControllerContext], console: Console +) -> None: """Scan for new serial ports and add unused ones to the available pool.""" if not pairing.auto_discover_ports: return @@ -929,7 +1144,9 @@ def discover_new_ports(pairing: PairingState, contexts: Dict[int, ControllerCont if path in in_use or path in pairing.available_ports: continue pairing.available_ports.append(path) - console.print(f"[green]Discovered UART {path} ({info['description']}); available for pairing.[/green]") + console.print( + f"[green]Discovered UART {path} ({info['description']}); available for pairing.[/green]" + ) def pair_waiting_contexts( @@ -977,7 +1194,9 @@ def open_initial_contexts( if index >= sdl2.SDL_NumJoysticks() or not sdl2.SDL_IsGameController(index): name = sdl2.SDL_JoystickNameForIndex(index) name_str = name.decode() if isinstance(name, bytes) else str(name) - console.print(f"[yellow]Index {index} is not a GameController ({name_str}). Trying raw open failed.[/yellow]") + console.print( + f"[yellow]Index {index} is not a GameController ({name_str}). Trying raw open failed.[/yellow]" + ) continue port = assign_port_for_index(pairing, index, console) if port is None and not pairing.auto_pairing_enabled: @@ -993,9 +1212,13 @@ def open_initial_contexts( uart = open_uart_or_warn(port, args.baud, console) if port else None if uart: uarts.append(uart) - console.print(f"[green]Controller {index} (id {stable_id}, inst {instance_id}) paired to {port}[/green]") + console.print( + f"[green]Controller {index} (id {stable_id}, inst {instance_id}) paired to {port}[/green]" + ) elif port: - console.print(f"[yellow]Controller {index} (id {stable_id}, inst {instance_id}) waiting for UART {port}[/yellow]") + console.print( + f"[yellow]Controller {index} (id {stable_id}, inst {instance_id}) waiting for UART {port}[/yellow]" + ) else: console.print( f"[yellow]Controller {index} (id {stable_id}, inst {instance_id}) connected; waiting for an available UART[/yellow]" @@ -1010,11 +1233,14 @@ def open_initial_contexts( ) if config.zero_sticks: zero_context_sticks(ctx, console) + initialize_controller_sensors(ctx, console) contexts[instance_id] = ctx return contexts, uarts -def handle_axis_motion(event: sdl2.SDL_Event, contexts: Dict[int, ControllerContext], config: BridgeConfig) -> None: +def handle_axis_motion( + event: sdl2.SDL_Event, contexts: Dict[int, ControllerContext], config: BridgeConfig +) -> None: """Process axis motion event into stick/trigger state.""" ctx = contexts.get(event.caxis.which) if not ctx: @@ -1095,7 +1321,9 @@ def handle_device_added( if idx >= sdl2.SDL_NumJoysticks() or not sdl2.SDL_IsGameController(idx): name = sdl2.SDL_JoystickNameForIndex(idx) name_str = name.decode() if isinstance(name, bytes) else str(name) - console.print(f"[yellow]Index {idx} is not a GameController ({name_str}). Trying raw open failed.[/yellow]") + console.print( + f"[yellow]Index {idx} is not a GameController ({name_str}). Trying raw open failed.[/yellow]" + ) return try: controller, instance_id, guid = open_controller(idx) @@ -1109,9 +1337,13 @@ def handle_device_added( uart = open_uart_or_warn(port, args.baud, console) if port else None if uart: uarts.append(uart) - console.print(f"[green]Controller {idx} (id {stable_id}, inst {instance_id}) paired to {port}[/green]") + console.print( + f"[green]Controller {idx} (id {stable_id}, inst {instance_id}) paired to {port}[/green]" + ) elif port: - console.print(f"[yellow]Controller {idx} (id {stable_id}, inst {instance_id}) waiting for UART {port}[/yellow]") + console.print( + f"[yellow]Controller {idx} (id {stable_id}, inst {instance_id}) waiting for UART {port}[/yellow]" + ) else: console.print( f"[yellow]Controller {idx} (id {stable_id}, inst {instance_id}) connected; waiting for an available UART[/yellow]" @@ -1126,6 +1358,7 @@ def handle_device_added( ) if config.zero_sticks: zero_context_sticks(ctx, console) + initialize_controller_sensors(ctx, console) contexts[instance_id] = ctx @@ -1140,7 +1373,9 @@ def handle_device_removed( ctx = contexts.pop(instance_id, None) if not ctx: return - console.print(f"[yellow]Controller {instance_id} (id {ctx.stable_id}) removed[/yellow]") + console.print( + f"[yellow]Controller {instance_id} (id {ctx.stable_id}) removed[/yellow]" + ) if ctx.controller_index in pairing.auto_assigned_indices: # Return auto-paired UART back to the pool so a future device can use it. freed = pairing.mapping_by_index.pop(ctx.controller_index, None) @@ -1167,18 +1402,25 @@ def service_contexts( else config.button_map_default ) poll_controller_buttons(ctx, current_button_map) + collect_imu_sample(ctx) # Reconnect UART if needed. if ctx.port and ctx.uart is None and (now - ctx.last_reopen_attempt) > 1.0: ctx.last_reopen_attempt = now uart = open_uart_or_warn(ctx.port, args.baud, console) if uart: uarts.append(uart) - console.print(f"[green]Reconnected UART {ctx.port} for controller {ctx.controller_index}[/green]") + console.print( + f"[green]Reconnected UART {ctx.port} for controller {ctx.controller_index}[/green]" + ) ctx.uart = uart if ctx.uart is None: continue try: if now - ctx.last_send >= config.interval: + if ctx.imu_samples: + ctx.report.imu_samples = ctx.imu_samples[-IMU_SAMPLES_PER_REPORT:] + else: + ctx.report.imu_samples = [] ctx.uart.send_report(ctx.report) ctx.last_send = now @@ -1201,7 +1443,10 @@ def service_contexts( sdl2.SDL_GameControllerRumble(ctx.controller, 0, 0, 0) ctx.rumble_active = False ctx.last_rumble_energy = 0.0 - elif ctx.rumble_active and (now - ctx.last_rumble_change) > RUMBLE_STUCK_TIMEOUT: + elif ( + ctx.rumble_active + and (now - ctx.last_rumble_change) > RUMBLE_STUCK_TIMEOUT + ): sdl2.SDL_GameControllerRumble(ctx.controller, 0, 0, 0) ctx.rumble_active = False ctx.last_rumble_energy = 0.0 @@ -1242,10 +1487,15 @@ def run_bridge_loop( break if event.type == sdl2.SDL_CONTROLLERAXISMOTION: handle_axis_motion(event, contexts, config) - elif event.type in (sdl2.SDL_CONTROLLERBUTTONDOWN, sdl2.SDL_CONTROLLERBUTTONUP): + elif event.type in ( + sdl2.SDL_CONTROLLERBUTTONDOWN, + sdl2.SDL_CONTROLLERBUTTONUP, + ): handle_button_event(event, config, contexts) elif event.type == sdl2.SDL_CONTROLLERDEVICEADDED: - handle_device_added(event, args, pairing, contexts, uarts, console, config) + handle_device_added( + event, args, pairing, contexts, uarts, console, config + ) elif event.type == sdl2.SDL_CONTROLLERDEVICEREMOVED: handle_device_removed(event, pairing, contexts, console) @@ -1291,7 +1541,9 @@ def main() -> None: list_controllers_with_guids(console, parser) return controller_indices, controller_names = detect_controllers(console, args, parser) - pairing = prepare_pairing_state(args, console, parser, controller_indices, controller_names) + pairing = prepare_pairing_state( + args, console, parser, controller_indices, controller_names + ) hotkey_messages: Dict[str, str] = {} if config.zero_hotkey: hotkey_messages[config.zero_hotkey] = "re-zero controller sticks" @@ -1301,14 +1553,20 @@ def main() -> None: hotkey_messages[config.swap_hotkey] + "; toggle ABXY layout" ) else: - hotkey_messages[config.swap_hotkey] = "toggle ABXY layout for a controller" + hotkey_messages[config.swap_hotkey] = ( + "toggle ABXY layout for a controller" + ) if hotkey_messages: candidate = HotkeyMonitor(console, hotkey_messages) if candidate.start(): hotkey_monitor = candidate - contexts, uarts = open_initial_contexts(args, pairing, controller_indices, console, config) + contexts, uarts = open_initial_contexts( + args, pairing, controller_indices, console, config + ) if not contexts: - console.print("[yellow]No controllers opened; waiting for hotplug events...[/yellow]") + console.print( + "[yellow]No controllers opened; waiting for hotplug events...[/yellow]" + ) run_bridge_loop(args, console, config, pairing, contexts, uarts, hotkey_monitor) finally: if hotkey_monitor: diff --git a/switch-pico.cpp b/switch-pico.cpp index 5efd6e9..d307e39 100644 --- a/switch-pico.cpp +++ b/switch-pico.cpp @@ -62,8 +62,9 @@ static void on_rumble_from_switch(const uint8_t rumble[8]) { // Consume UART bytes and forward complete frames to the Switch Pro driver. static void poll_uart_frames() { - static uint8_t buffer[8]; + static uint8_t buffer[64]; static uint8_t index = 0; + static uint8_t expected_len = 0; static absolute_time_t last_byte_time = {0}; static bool has_last_byte = false; @@ -73,6 +74,7 @@ static void poll_uart_frames() { uint64_t now = to_ms_since_boot(get_absolute_time()); if (has_last_byte && (now - to_ms_since_boot(last_byte_time)) > 20) { index = 0; // stale data, restart frame + expected_len = 0; } last_byte_time = get_absolute_time(); has_last_byte = true; @@ -84,9 +86,19 @@ static void poll_uart_frames() { } buffer[index++] = byte; - if (index >= sizeof(buffer)) { + if (index == 3) { + // We just stored payload_len; compute expected frame length (payload + header/version/len/checksum). + expected_len = static_cast(buffer[2] + 4); + if (expected_len > sizeof(buffer) || expected_len < 8) { + index = 0; + expected_len = 0; + continue; + } + } + + if (expected_len && index >= expected_len) { SwitchInputState parsed{}; - if (switch_pro_apply_uart_packet(buffer, sizeof(buffer), &parsed)) { + if (switch_pro_apply_uart_packet(buffer, expected_len, &parsed)) { g_user_state = parsed; LOG_PRINTF("[UART] packet buttons=0x%04x hat=%u lx=%u ly=%u rx=%u ry=%u\n", (parsed.button_a ? SWITCH_PRO_MASK_A : 0) | @@ -110,6 +122,7 @@ static void poll_uart_frames() { parsed.lx >> 8, parsed.ly >> 8, parsed.rx >> 8, parsed.ry >> 8); } index = 0; + expected_len = 0; } } } diff --git a/switch_pro_driver.cpp b/switch_pro_driver.cpp index 138ba8f..031e720 100644 --- a/switch_pro_driver.cpp +++ b/switch_pro_driver.cpp @@ -190,9 +190,40 @@ static SwitchInputState make_neutral_state() { s.ly = SWITCH_PRO_JOYSTICK_MID; s.rx = SWITCH_PRO_JOYSTICK_MID; s.ry = SWITCH_PRO_JOYSTICK_MID; + s.imu_sample_count = 0; return s; } +static void fill_imu_report_data(const SwitchInputState& state) { + // Only include IMU data when the host explicitly enabled it. + if (!is_imu_enabled || state.imu_sample_count == 0) { + memset(switch_report.imuData, 0x00, sizeof(switch_report.imuData)); + return; + } + + uint8_t sample_count = state.imu_sample_count > 3 ? 3 : state.imu_sample_count; + uint8_t* dst = switch_report.imuData; + for (uint8_t i = 0; i < 3; ++i) { + SwitchImuSample sample{}; + if (i < sample_count) { + sample = state.imu_samples[i]; + } + dst[0] = static_cast(sample.accel_x & 0xFF); + dst[1] = static_cast((sample.accel_x >> 8) & 0xFF); + dst[2] = static_cast(sample.accel_y & 0xFF); + dst[3] = static_cast((sample.accel_y >> 8) & 0xFF); + dst[4] = static_cast(sample.accel_z & 0xFF); + dst[5] = static_cast((sample.accel_z >> 8) & 0xFF); + dst[6] = static_cast(sample.gyro_x & 0xFF); + dst[7] = static_cast((sample.gyro_x >> 8) & 0xFF); + dst[8] = static_cast(sample.gyro_y & 0xFF); + dst[9] = static_cast((sample.gyro_y >> 8) & 0xFF); + dst[10] = static_cast(sample.gyro_z & 0xFF); + dst[11] = static_cast((sample.gyro_z >> 8) & 0xFF); + dst += 12; + } +} + static void send_identify() { memset(report_buffer, 0x00, sizeof(report_buffer)); report_buffer[0] = REPORT_USB_INPUT_81; @@ -485,6 +516,7 @@ static void update_switch_report_from_state() { switch_report.inputs.rightStick.setX(std::min(std::max(scaleRightStickX,rightMinX), rightMaxX)); switch_report.inputs.rightStick.setY(-std::min(std::max(scaleRightStickY,rightMinY), rightMaxY)); + fill_imu_report_data(g_input_state); switch_report.rumbleReport = 0x09; } @@ -617,18 +649,51 @@ void switch_pro_task() { } bool switch_pro_apply_uart_packet(const uint8_t* packet, uint8_t length, SwitchInputState* out_state) { - // Packet format: 0xAA, buttons(2 LE), hat, lx, ly, rx, ry + // Packet v2 format: + // 0:0xAA header + // 1:version (0x02) + // 2:payload_len (bytes 3..3+len-1) + // 3-4: buttons LE + // 5: hat + // 6-9: lx, ly, rx, ry (0-255) + // 10: imu_sample_count (0-3) + // 11-46: up to 3 samples of accel/gyro (int16 LE each axis) + // 47: checksum (sum of bytes 0..46) & 0xFF if (length < 8 || packet[0] != 0xAA) { return false; } + if (packet[1] != 0x02) { + return false; + } + + uint8_t payload_len = packet[2]; + uint16_t expected_len = static_cast(payload_len) + 4; // header+version+len+checksum + if (length < expected_len) { + return false; + } + + uint16_t checksum_end = static_cast(3 + payload_len - 1); // last payload byte + uint16_t checksum_index = static_cast(3 + payload_len); + if (checksum_index >= length) { + return false; + } + + uint16_t sum = 0; + for (uint16_t i = 0; i <= checksum_end; ++i) { + sum = static_cast(sum + packet[i]); + } + if ((sum & 0xFF) != packet[checksum_index]) { + return false; + } + SwitchProOutReport out{}; - out.buttons = static_cast(packet[1]) | (static_cast(packet[2]) << 8); - out.hat = packet[3]; - out.lx = packet[4]; - out.ly = packet[5]; - out.rx = packet[6]; - out.ry = packet[7]; + out.buttons = static_cast(packet[3]) | (static_cast(packet[4]) << 8); + out.hat = packet[5]; + out.lx = packet[6]; + out.ly = packet[7]; + out.rx = packet[8]; + out.ry = packet[9]; auto expand_axis = [](uint8_t v) -> uint16_t { return static_cast(v) << 8 | v; @@ -636,6 +701,23 @@ bool switch_pro_apply_uart_packet(const uint8_t* packet, uint8_t length, SwitchI SwitchInputState state = make_neutral_state(); + state.imu_sample_count = std::min(packet[10], 3); + + auto read_int16 = [](const uint8_t* src) -> int16_t { + return static_cast(static_cast(src[0]) | (static_cast(src[1]) << 8)); + }; + + const uint8_t* imu_base = &packet[11]; + for (uint8_t i = 0; i < state.imu_sample_count; ++i) { + const uint8_t* sample_ptr = imu_base + (i * 12); + state.imu_samples[i].accel_x = read_int16(sample_ptr + 0); + state.imu_samples[i].accel_y = read_int16(sample_ptr + 2); + state.imu_samples[i].accel_z = read_int16(sample_ptr + 4); + state.imu_samples[i].gyro_x = read_int16(sample_ptr + 6); + state.imu_samples[i].gyro_y = read_int16(sample_ptr + 8); + state.imu_samples[i].gyro_z = read_int16(sample_ptr + 10); + } + switch (out.hat) { case SWITCH_PRO_HAT_UP: state.dpad_up = true; break; case SWITCH_PRO_HAT_UPRIGHT: state.dpad_up = true; state.dpad_right = true; break; diff --git a/switch_pro_driver.h b/switch_pro_driver.h index 73d8582..929a455 100644 --- a/switch_pro_driver.h +++ b/switch_pro_driver.h @@ -10,6 +10,15 @@ #include #include "switch_pro_descriptors.h" +typedef struct { + int16_t accel_x; + int16_t accel_y; + int16_t accel_z; + int16_t gyro_x; + int16_t gyro_y; + int16_t gyro_z; +} SwitchImuSample; + typedef struct { bool dpad_up; bool dpad_down; @@ -35,6 +44,8 @@ typedef struct { uint16_t ly; uint16_t rx; uint16_t ry; + uint8_t imu_sample_count; + SwitchImuSample imu_samples[3]; } SwitchInputState; // Initialize USB state and calibration before entering the main loop. From 3dc91e5f8d1ded80618e8aa49af5860bd2970223 Mon Sep 17 00:00:00 2001 From: Joey Yakimowich-Payne Date: Mon, 24 Nov 2025 18:05:51 -0700 Subject: [PATCH 2/3] game gyro output as pro controller. Still not working --- .gitignore | 5 + gyro_support_plan.md | 49 ++++++ .../controller_uart_bridge.py | 88 +++++++++-- switch_pro_driver.cpp | 29 ++-- tools/read_pro_imu.py | 149 ++++++++++++++++++ 5 files changed, 294 insertions(+), 26 deletions(-) create mode 100644 gyro_support_plan.md create mode 100644 tools/read_pro_imu.py diff --git a/.gitignore b/.gitignore index b95424e..9990210 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,9 @@ debug *.egg-info hid-nintendo.c __pycache__ +Nintendo_Switch_Reverse_Engineering +nxbt +hid-nintendo.c +*.png +switch !.vscode/* diff --git a/gyro_support_plan.md b/gyro_support_plan.md new file mode 100644 index 0000000..8015b19 --- /dev/null +++ b/gyro_support_plan.md @@ -0,0 +1,49 @@ +# Gyro Support Plan + +## Checklist +- [x] Review current UART input/rumble plumbing (`switch-pico.cpp` and `controller_uart_bridge.py`) to confirm the existing 0xAA input frame (buttons/hat/sticks) and 0xBB rumble return path. +- [x] Map the current firmware IMU hooks (`switch_pro_driver.cpp`/`switch_pro_descriptors.h` `imuData` field, `is_imu_enabled` flag, report construction cadence) and note where IMU payloads should be injected. +- [x] Pull expected IMU packet format, sample cadence (3 samples/report), and unit scaling from references: `Nintendo_Switch_Reverse_Engineering/imu_sensor_notes.md`, `hid-nintendo.c`, and NXBT (`nxbt/controller/protocol.py`). +- [x] Examine GP2040-CE’s motion implementation (e.g., `GP2040-CE/src/addons/imu` and Switch report handling) for framing, calibration defaults, and rate control patterns to reuse. +- [x] Decide on UART motion framing: header/length/checksum scheme, sample packing (likely 3 samples of accel+gyro int16 LE), endian/order, and compatibility with existing 8-byte frames (avoid breaking current host builds). +- [ ] Define Switch-facing IMU payload layout inside `SwitchProReport` (axis order, sign conventions, zero/neutral sample) and ensure it matches the reverse-engineered descriptors. +- [ ] Add firmware-side data structures/buffers for incoming IMU samples (triple-buffer if mirroring Joy-Con 3-sample bursts) and default zeroing when IMU is disabled/missing. +- [ ] Extend UART parser in `switch-pico.cpp` to accept the new motion frame(s), validate checksum, and stash samples atomically alongside button state. +- [ ] Gate IMU injection on the host’s `TOGGLE_IMU` feature flag (`is_imu_enabled`) and ensure reports carry motion data only when enabled; default to zeros otherwise. +- [ ] Apply calibration/scaling constants: choose defaults from references (e.g., 0.06103 dps/LSB gyro, accel per imu_sensor_notes) and document where to adjust for sensor-specific offsets. +- [ ] Update host bridge to enable SDL sensor support (`SDL_GameControllerSetSensorEnabled`, `SDL_CONTROLLERAXIS` vs sensor events) and capture gyro (and accel if needed) at the required rate. +- [ ] Buffer and pack host IMU readings into the agreed UART motion frame, including timestamping/rate smoothing so the firmware sees stable 200 Hz-ish samples (3 per 5 ms report). +- [ ] Keep backward compatibility: allow running without IMU-capable controllers (send zero motion) and keep rumble unchanged. +- [ ] Add logging/metrics: lightweight counters for dropped/late IMU frames and a debug toggle to inspect raw samples. +- [ ] Test matrix: host-only sensor capture sanity check; loopback UART frame validator; firmware USB capture with Switch (or nxbt PC host) verifying IMU report contents and that `TOGGLE_IMU` on/off behaves; regression check that buttons/rumble remain stable. + +## Findings to date +- Firmware: `SwitchProReport.imuData[36]` exists but is always zero; `is_imu_enabled` is set only via `TOGGLE_IMU` feature report; `switch_pro_task` always sends `switch_report`, so IMU injection should happen before the memcmp/send path. +- IMU payload layout (standard 0x30/31/32/33): bytes 13-24 are accel_x, accel_y, accel_z, gyro_x, gyro_y, gyro_z (all Int16 LE); bytes 25-48 repeat two more samples (3 samples total, ~5 ms apart). Matches `imuData[36]` size and `hid-nintendo.c` parsing (`imu_raw_bytes` split into 3 `joycon_imu_data` structs). +- Scaling from references: accel ≈ 0.000244 G/LSB (±8000 mG), gyro ≈ 0.06103 dps/LSB (±2000 dps) or 0.070 dps/LSB with ST’s +15% headroom; Switch internally also uses rev/s conversions. Typical packet cadence ~15 ms with 3 samples (≈200 Hz sampling). +- NXBT reference: only injects IMU when `imu_enabled` flag is true; drops a 36-byte sample block at offset 14 (0-based) in the report. Example data is static; good for offset confirmation. +- GP2040-CE reference: Switch Pro driver mirrors this project—`imuData` zeroed, no motion handling yet. No reusable IMU framing, but report/keepalive cadence matches. + +## UART IMU framing decision (breaking change OK) +- New host→Pico frame (versioned) replaces the old 8-byte 0xAA packet: + - Byte0: `0xAA` header + - Byte1: `0x02` version + - Byte2: `payload_len` (44 for the layout below) + - Byte3-4: buttons LE (same masks as before) + - Byte5: hat + - Byte6-9: sticks `lx, ly, rx, ry` (0-255 as before) + - Byte10: `imu_sample_count` (host should send 3; firmware may accept 0-3) + - Byte11-46: IMU samples, 3 blocks of 12 bytes each: + - For sample i: `accel_x, accel_y, accel_z, gyro_x, gyro_y, gyro_z` (all int16 LE, Pro Controller axis/sign convention) + - Byte47: checksum = (sum of bytes 0 through the last payload byte) & 0xFF +- Host behavior: always populate 3 samples per packet (~200 Hz, 5 ms spacing) with reference/default scaling; send zeros and `imu_sample_count=0` if IMU unavailable/disabled. Buttons/sticks unchanged. +- Firmware behavior: parse the new frame, validate `payload_len` and checksum, then atomically store button/stick plus up to `imu_sample_count` samples. If `is_imu_enabled` is true, copy samples into `switch_report.imuData` in the 3× sample layout; otherwise zero the IMU block. +- Axis orientation: match Pro Controller orientation (no additional flips beyond standard report ordering). + +## Open Questions +- All answered: + - Include both accelerometer and gyro (full IMU). + - Reference/default scaling is acceptable (no per-device calibration). + - Mirror 3-sample bursts (~200 Hz, 5 ms spacing). + - Use Pro Controller axis orientation/sign. + - Breaking UART framing change is acceptable (use the versioned packet above). diff --git a/src/switch_pico_bridge/controller_uart_bridge.py b/src/switch_pico_bridge/controller_uart_bridge.py index bcd4f67..b34c2f1 100644 --- a/src/switch_pico_bridge/controller_uart_bridge.py +++ b/src/switch_pico_bridge/controller_uart_bridge.py @@ -58,6 +58,8 @@ RUMBLE_SCALE = 1.0 CONTROLLER_DB_URL_DEFAULT = "https://raw.githubusercontent.com/mdqinc/SDL_GameControllerDB/refs/heads/master/gamecontrollerdb.txt" SDL_TRUE = getattr(sdl2, "SDL_TRUE", 1) +GYRO_BIAS_SAMPLES = 200 +GYRO_DEADZONE_COUNTS = 15 def parse_mapping(value: str) -> Tuple[int, str]: @@ -245,6 +247,12 @@ class ControllerContext: sensors_enabled: bool = False imu_samples: List[IMUSample] = field(default_factory=list) last_sensor_poll: float = 0.0 + gyro_bias_x: float = 0.0 + gyro_bias_y: float = 0.0 + gyro_bias_z: float = 0.0 + gyro_bias_samples: int = 0 + gyro_bias_locked: bool = False + last_debug_imu_print: float = 0.0 def capture_stick_offsets(controller: sdl2.SDL_GameController) -> Dict[int, int]: @@ -594,8 +602,8 @@ def build_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--frequency", type=float, - default=500.0, - help="Report send frequency per controller (Hz, default 500)", + default=1000.0, + help="Report send frequency per controller (Hz, default 1000)", ) parser.add_argument( "--deadzone", @@ -697,6 +705,11 @@ def build_arg_parser() -> argparse.ArgumentParser: default=[], help="Path to an SDL2 controller mapping database (e.g. controllerdb.txt). Repeatable.", ) + parser.add_argument( + "--debug-imu", + action="store_true", + help="Print raw IMU readings (float and converted int16) for debugging.", + ) return parser @@ -742,6 +755,7 @@ class BridgeConfig: swap_abxy_indices: set[int] swap_abxy_ids: set[str] swap_abxy_global: bool + debug_imu: bool @dataclass @@ -775,12 +789,16 @@ def read_sensor_triplet( def convert_accel_to_raw(accel_ms2: float) -> int: g_units = accel_ms2 / MS2_PER_G - return clamp_int16(g_units * 4096.0) + return clamp_int16(g_units * 4000.0) def convert_gyro_to_raw(gyro_rad: float) -> int: + # SDL reports gyroscope data in radians/second; convert to dps then to Switch counts. dps = gyro_rad * RAD_TO_DEG - return clamp_int16(dps / 0.070) + counts = clamp_int16(dps / 0.070) + if abs(counts) < GYRO_DEADZONE_COUNTS: + return 0 + return counts def initialize_controller_sensors(ctx: ControllerContext, console: Console) -> None: @@ -823,25 +841,70 @@ def initialize_controller_sensors(ctx: ControllerContext, console: Console) -> N ) -def collect_imu_sample(ctx: ControllerContext) -> None: +def collect_imu_sample(ctx: ControllerContext, config: BridgeConfig) -> None: if not ctx.sensors_enabled or SENSOR_ACCEL is None or SENSOR_GYRO is None: return accel = read_sensor_triplet(ctx.controller, SENSOR_ACCEL) gyro = read_sensor_triplet(ctx.controller, SENSOR_GYRO) if not accel or not gyro: return + if not ctx.gyro_bias_locked and ctx.gyro_bias_samples < GYRO_BIAS_SAMPLES: + ctx.gyro_bias_x += gyro[0] + ctx.gyro_bias_y += gyro[1] + ctx.gyro_bias_z += gyro[2] + ctx.gyro_bias_samples += 1 + if ctx.gyro_bias_samples >= GYRO_BIAS_SAMPLES: + ctx.gyro_bias_x /= ctx.gyro_bias_samples + ctx.gyro_bias_y /= ctx.gyro_bias_samples + ctx.gyro_bias_z /= ctx.gyro_bias_samples + ctx.gyro_bias_locked = True + + bias_x = ctx.gyro_bias_x if ctx.gyro_bias_locked else 0.0 + bias_y = ctx.gyro_bias_y if ctx.gyro_bias_locked else 0.0 + bias_z = ctx.gyro_bias_z if ctx.gyro_bias_locked else 0.0 + gyro_bias_corrected = ( + gyro[0] - bias_x, + gyro[1] - bias_y, + gyro[2] - bias_z, + ) + + # Map SDL sensor axes to Pro Controller axes: gravity should land on Z. + # print(accel[2] / MS2_PER_G) + accel_raw_x = convert_accel_to_raw(accel[0]) + accel_raw_y = convert_accel_to_raw(accel[1]) + accel_raw_z = convert_accel_to_raw(accel[2]) + gyro_raw_x = convert_gyro_to_raw(gyro[0]) + gyro_raw_y = convert_gyro_to_raw(gyro[1]) + gyro_raw_z = convert_gyro_to_raw(gyro[2]) + + # Map SDL axes to Pro axes to match the native Pro USB output: + # Pro accel: ax = SDL_, ay = SDL_Z, az = SDL_Y (gravity). + # Pro gyro: gx = SDL_X, gy = SDL_Z, gz = SDL_Y. sample = IMUSample( - accel_x=convert_accel_to_raw(accel[0]), - accel_y=convert_accel_to_raw(accel[1]), - accel_z=convert_accel_to_raw(accel[2]), - gyro_x=convert_gyro_to_raw(gyro[0]), - gyro_y=convert_gyro_to_raw(gyro[1]), - gyro_z=convert_gyro_to_raw(gyro[2]), + accel_x=-accel_raw_z, + accel_y=-accel_raw_x, + accel_z=accel_raw_y, + gyro_x=-gyro_raw_z, + gyro_y=-gyro_raw_x, + gyro_z=gyro_raw_y, ) ctx.imu_samples.append(sample) if len(ctx.imu_samples) > 6: ctx.imu_samples = ctx.imu_samples[-6:] + if config.debug_imu: + now = time.monotonic() + if now - ctx.last_debug_imu_print > 0.2: + ctx.last_debug_imu_print = now + print( + f"[IMU dbg idx={ctx.controller_index}] " + f"accel_ms2=({accel[0]:.3f},{accel[1]:.3f},{accel[2]:.3f}) " + f"gyro_dps=({gyro[0]:.3f},{gyro[1]:.3f},{gyro[2]:.3f}) " + f"bias_dps=({bias_x:.3f},{bias_y:.3f},{bias_z:.3f}) " + f"raw=({sample.accel_x},{sample.accel_y},{sample.accel_z};" + f"{sample.gyro_x},{sample.gyro_y},{sample.gyro_z})" + ) + def load_button_maps( console: Console, args: argparse.Namespace @@ -897,6 +960,7 @@ def build_bridge_config(console: Console, args: argparse.Namespace) -> BridgeCon swap_abxy_indices=swap_abxy_indices, swap_abxy_ids=set(swap_abxy_guids), # filled later once stable IDs are known swap_abxy_global=bool(args.swap_abxy), + debug_imu=bool(args.debug_imu), ) @@ -1402,7 +1466,7 @@ def service_contexts( else config.button_map_default ) poll_controller_buttons(ctx, current_button_map) - collect_imu_sample(ctx) + collect_imu_sample(ctx, config) # Reconnect UART if needed. if ctx.port and ctx.uart is None and (now - ctx.last_reopen_attempt) > 1.0: ctx.last_reopen_attempt = now diff --git a/switch_pro_driver.cpp b/switch_pro_driver.cpp index 031e720..ee2b899 100644 --- a/switch_pro_driver.cpp +++ b/switch_pro_driver.cpp @@ -14,8 +14,8 @@ #define LOG_PRINTF(...) ((void)0) #endif -// force a report to be sent every X ms -#define SWITCH_PRO_KEEPALIVE_TIMER 5 +// force a report to be sent every X ms (roughly matches Pro Controller cadence) +#define SWITCH_PRO_KEEPALIVE_TIMER 15 static SwitchInputState g_input_state{ false, false, false, false, @@ -195,14 +195,16 @@ static SwitchInputState make_neutral_state() { } static void fill_imu_report_data(const SwitchInputState& state) { - // Only include IMU data when the host explicitly enabled it. - if (!is_imu_enabled || state.imu_sample_count == 0) { + // Include IMU data when the host provided samples; otherwise zero. + if (state.imu_sample_count == 0) { memset(switch_report.imuData, 0x00, sizeof(switch_report.imuData)); return; } uint8_t sample_count = state.imu_sample_count > 3 ? 3 : state.imu_sample_count; uint8_t* dst = switch_report.imuData; + // Map host IMU axes (host already sends Pro-oriented X,Z,Y) into report layout: + // Report order per sample: accel_x, accel_y, accel_z, gyro_x, gyro_y, gyro_z. for (uint8_t i = 0; i < 3; ++i) { SwitchImuSample sample{}; if (i < sample_count) { @@ -525,6 +527,7 @@ void switch_pro_init() { last_report_counter = 0; handshake_counter = 0; is_ready = false; + is_imu_enabled = true; // default on to allow IMU during host bring-up/debug is_initialized = false; is_report_queued = false; report_sent = false; @@ -626,14 +629,12 @@ void switch_pro_task() { switch_report.timestamp = last_report_counter; void * inputReport = &switch_report; uint16_t report_size = sizeof(switch_report); - if (memcmp(last_report, inputReport, report_size) != 0) { - if (tud_hid_ready() && send_report(0, inputReport, report_size) == true ) { - memcpy(last_report, inputReport, report_size); - report_sent = true; - } - - last_report_timer = now; + if (tud_hid_ready() && send_report(0, inputReport, report_size) == true ) { + memcpy(last_report, inputReport, report_size); + report_sent = true; } + + last_report_timer = now; } } else { if (!is_initialized) { @@ -851,9 +852,9 @@ bool tud_control_request_cb(uint8_t rhport, tusb_control_request_t const * reque void tud_mount_cb(void) { LOG_PRINTF("[USB] mount_cb\n"); last_host_activity_ms = to_ms_since_boot(get_absolute_time()); - forced_ready = false; - is_ready = false; - is_initialized = false; + forced_ready = true; + is_ready = true; + is_initialized = true; } void tud_umount_cb(void) { diff --git a/tools/read_pro_imu.py b/tools/read_pro_imu.py new file mode 100644 index 0000000..cc81ee9 --- /dev/null +++ b/tools/read_pro_imu.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +""" +Read raw IMU samples from a Nintendo Switch Pro Controller (or Pico spoof) over USB. + +Uses the `hidapi` (pyhidapi) package. Press Ctrl+C to exit. +""" + +import argparse +import struct +import sys +from typing import List, Tuple + +import hid # from pyhidapi + +DEFAULT_VENDOR_ID = 0x057E +DEFAULT_PRODUCT_ID = 0x2009 # Switch Pro Controller (USB) + + +def list_devices(filter_vid=None, filter_pid=None): + devices = hid.enumerate() + for d in devices: + if filter_vid and d["vendor_id"] != filter_vid: + continue + if filter_pid and d["product_id"] != filter_pid: + continue + print( + f"VID=0x{d['vendor_id']:04X} PID=0x{d['product_id']:04X} " + f"path={d.get('path')} " + f"serial={d.get('serial_number')} " + f"manufacturer={d.get('manufacturer_string')} " + f"product={d.get('product_string')} " + f"interface={d.get('interface_number')}" + ) + return devices + + +def find_device(vendor_id: int, product_id: int): + for dev in hid.enumerate(): + if dev["vendor_id"] == vendor_id and dev["product_id"] == product_id: + return dev + return None + + +def main(): + parser = argparse.ArgumentParser(description="Read raw 0x30 reports (IMU) from a Switch Pro Controller / Pico.") + parser.add_argument("--vid", type=lambda x: int(x, 0), default=DEFAULT_VENDOR_ID, help="Vendor ID (default 0x057E)") + parser.add_argument("--pid", type=lambda x: int(x, 0), default=DEFAULT_PRODUCT_ID, help="Product ID (default 0x2009)") + parser.add_argument("--path", help="Explicit HID path to open (overrides VID/PID).") + parser.add_argument("--count", type=int, default=0, help="Stop after this many 0x30 reports (0 = infinite).") + parser.add_argument("--timeout", type=int, default=3000, help="Read timeout ms (default 3000).") + parser.add_argument("--list", action="store_true", help="List detected HID devices and exit.") + parser.add_argument("--plot", action="store_true", help="Plot accel/gyro traces after capture (requires matplotlib).") + parser.add_argument("--save-prefix", help="If set, save accel/gyro plots as '_accel.png' and '_gyro.png'.") + args = parser.parse_args() + + if args.list: + list_devices() + return + + if args.path: + dev_info = {"path": bytes(args.path, encoding="utf-8"), "vendor_id": args.vid, "product_id": args.pid} + else: + dev_info = find_device(args.vid, args.pid) + if not dev_info: + print( + f"No HID device found for VID=0x{args.vid:04X} PID=0x{args.pid:04X}. " + "Use --list to inspect devices or --path to target a specific one.", + file=sys.stderr, + ) + sys.exit(1) + + device = hid.device() + device.open_path(dev_info["path"]) + device.set_nonblocking(False) + print( + f"Reading raw 0x30 reports from device (VID=0x{args.vid:04X} PID=0x{args.pid:04X})... " + "Ctrl+C to stop." + ) + accel_series: List[Tuple[int, int, int]] = [] + gyro_series: List[Tuple[int, int, int]] = [] + try: + read_count = 0 + while args.count == 0 or read_count < args.count: + data = device.read(64, timeout_ms=args.timeout) + if not data: + print(f"(timeout after {args.timeout} ms, no data)") + continue + if data[0] != 0x30: + print(f"(non-0x30 report id=0x{data[0]:02X}, len={len(data)})") + continue + samples = [] + offset = 13 # accel_x starts at byte 13 + for _ in range(3): + ax, ay, az, gx, gy, gz = struct.unpack_from(" Date: Mon, 24 Nov 2025 18:57:56 -0700 Subject: [PATCH 3/3] gweaks, nothing works yet --- .../controller_uart_bridge.py | 94 +++++++++++++------ switch-pico.cpp | 13 ++- switch_pro_driver.cpp | 3 + 3 files changed, 77 insertions(+), 33 deletions(-) diff --git a/src/switch_pico_bridge/controller_uart_bridge.py b/src/switch_pico_bridge/controller_uart_bridge.py index b34c2f1..280601a 100644 --- a/src/switch_pico_bridge/controller_uart_bridge.py +++ b/src/switch_pico_bridge/controller_uart_bridge.py @@ -58,8 +58,11 @@ RUMBLE_SCALE = 1.0 CONTROLLER_DB_URL_DEFAULT = "https://raw.githubusercontent.com/mdqinc/SDL_GameControllerDB/refs/heads/master/gamecontrollerdb.txt" SDL_TRUE = getattr(sdl2, "SDL_TRUE", 1) +SDL_CONTROLLERSENSORUPDATE = getattr(sdl2, "SDL_CONTROLLERSENSORUPDATE", 0x658) GYRO_BIAS_SAMPLES = 200 -GYRO_DEADZONE_COUNTS = 15 +GYRO_DEADZONE_COUNTS = 0 +# Keep a small window of IMU samples to pack into the next report +IMU_BUFFER_SIZE = 32 def parse_mapping(value: str) -> Tuple[int, str]: @@ -246,7 +249,7 @@ class ControllerContext: sensors_supported: bool = False sensors_enabled: bool = False imu_samples: List[IMUSample] = field(default_factory=list) - last_sensor_poll: float = 0.0 + last_accel: Tuple[float, float, float] = (0.0, 0.0, 0.0) gyro_bias_x: float = 0.0 gyro_bias_y: float = 0.0 gyro_bias_z: float = 0.0 @@ -602,8 +605,8 @@ def build_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--frequency", type=float, - default=1000.0, - help="Report send frequency per controller (Hz, default 1000)", + default=66.7, + help="Report send frequency per controller (Hz, default ~66.7 => ~15ms)", ) parser.add_argument( "--deadzone", @@ -710,6 +713,17 @@ def build_arg_parser() -> argparse.ArgumentParser: action="store_true", help="Print raw IMU readings (float and converted int16) for debugging.", ) + parser.add_argument( + "--gyro-scale", + type=float, + default=1.0, + help="Scale factor for gyro sensitivity (default 1.0). Reduce to < 1.0 if camera moves too fast.", + ) + parser.add_argument( + "--no-gyro-bias", + action="store_true", + help="Disable automatic gyro bias calibration at startup.", + ) return parser @@ -756,6 +770,8 @@ class BridgeConfig: swap_abxy_ids: set[str] swap_abxy_global: bool debug_imu: bool + gyro_scale: float + no_gyro_bias: bool @dataclass @@ -792,10 +808,11 @@ def convert_accel_to_raw(accel_ms2: float) -> int: return clamp_int16(g_units * 4000.0) -def convert_gyro_to_raw(gyro_rad: float) -> int: +def convert_gyro_to_raw(gyro_rad: float, scale: float = 1.0) -> int: # SDL reports gyroscope data in radians/second; convert to dps then to Switch counts. dps = gyro_rad * RAD_TO_DEG - counts = clamp_int16(dps / 0.070) + # 0.070 dps/LSB is approx 14.28 LSB/dps. + counts = clamp_int16((dps / 0.061) * scale) if abs(counts) < GYRO_DEADZONE_COUNTS: return 0 return counts @@ -841,13 +858,24 @@ def initialize_controller_sensors(ctx: ControllerContext, console: Console) -> N ) -def collect_imu_sample(ctx: ControllerContext, config: BridgeConfig) -> None: - if not ctx.sensors_enabled or SENSOR_ACCEL is None or SENSOR_GYRO is None: +def handle_sensor_update(event: sdl2.SDL_Event, contexts: Dict[int, ControllerContext], config: BridgeConfig) -> None: + ctx = contexts.get(event.csensor.which) + if not ctx: return - accel = read_sensor_triplet(ctx.controller, SENSOR_ACCEL) - gyro = read_sensor_triplet(ctx.controller, SENSOR_GYRO) - if not accel or not gyro: + + sensor_type = event.csensor.sensor + data = event.csensor.data + + if sensor_type == SENSOR_ACCEL: + ctx.last_accel = (data[0], data[1], data[2]) return + + if sensor_type != SENSOR_GYRO: + return + + # Process Gyro update (and combine with last accel) + gyro = (data[0], data[1], data[2]) + if not ctx.gyro_bias_locked and ctx.gyro_bias_samples < GYRO_BIAS_SAMPLES: ctx.gyro_bias_x += gyro[0] ctx.gyro_bias_y += gyro[1] @@ -862,20 +890,17 @@ def collect_imu_sample(ctx: ControllerContext, config: BridgeConfig) -> None: bias_x = ctx.gyro_bias_x if ctx.gyro_bias_locked else 0.0 bias_y = ctx.gyro_bias_y if ctx.gyro_bias_locked else 0.0 bias_z = ctx.gyro_bias_z if ctx.gyro_bias_locked else 0.0 - gyro_bias_corrected = ( - gyro[0] - bias_x, - gyro[1] - bias_y, - gyro[2] - bias_z, - ) + + # Use last known accel + accel = ctx.last_accel # Map SDL sensor axes to Pro Controller axes: gravity should land on Z. - # print(accel[2] / MS2_PER_G) accel_raw_x = convert_accel_to_raw(accel[0]) accel_raw_y = convert_accel_to_raw(accel[1]) accel_raw_z = convert_accel_to_raw(accel[2]) - gyro_raw_x = convert_gyro_to_raw(gyro[0]) - gyro_raw_y = convert_gyro_to_raw(gyro[1]) - gyro_raw_z = convert_gyro_to_raw(gyro[2]) + gyro_raw_x = convert_gyro_to_raw(gyro[0] - bias_x, config.gyro_scale) + gyro_raw_y = convert_gyro_to_raw(gyro[1] - bias_y, config.gyro_scale) + gyro_raw_z = convert_gyro_to_raw(gyro[2] - bias_z, config.gyro_scale) # Map SDL axes to Pro axes to match the native Pro USB output: # Pro accel: ax = SDL_, ay = SDL_Z, az = SDL_Y (gravity). @@ -884,13 +909,13 @@ def collect_imu_sample(ctx: ControllerContext, config: BridgeConfig) -> None: accel_x=-accel_raw_z, accel_y=-accel_raw_x, accel_z=accel_raw_y, - gyro_x=-gyro_raw_z, - gyro_y=-gyro_raw_x, - gyro_z=gyro_raw_y, + gyro_x=convert_gyro_to_raw(-(gyro[2] - bias_z), config.gyro_scale), + gyro_y=convert_gyro_to_raw(-(gyro[0] - bias_x), config.gyro_scale), + gyro_z=convert_gyro_to_raw(gyro[1] - bias_y, config.gyro_scale), ) ctx.imu_samples.append(sample) - if len(ctx.imu_samples) > 6: - ctx.imu_samples = ctx.imu_samples[-6:] + if len(ctx.imu_samples) > IMU_BUFFER_SIZE: + ctx.imu_samples = ctx.imu_samples[-IMU_BUFFER_SIZE:] if config.debug_imu: now = time.monotonic() @@ -936,7 +961,7 @@ def load_button_maps( console.print( f"[red]Failed to load SDL mapping {mapping_path}: {exc}[/red]" ) - return button_map_default, button_map_swapped, swap_abxy_indices + geturn button_map_default, button_map_swapped, swap_abxy_indices def build_bridge_config(console: Console, args: argparse.Namespace) -> BridgeConfig: @@ -961,6 +986,8 @@ def build_bridge_config(console: Console, args: argparse.Namespace) -> BridgeCon swap_abxy_ids=set(swap_abxy_guids), # filled later once stable IDs are known swap_abxy_global=bool(args.swap_abxy), debug_imu=bool(args.debug_imu), + gyro_scale=args.gyro_scale, + no_gyro_bias=args.no_gyro_bias, ) @@ -1466,7 +1493,7 @@ def service_contexts( else config.button_map_default ) poll_controller_buttons(ctx, current_button_map) - collect_imu_sample(ctx, config) + # collect_imu_sample(ctx, config) <-- Removed, using event loop # Reconnect UART if needed. if ctx.port and ctx.uart is None and (now - ctx.last_reopen_attempt) > 1.0: ctx.last_reopen_attempt = now @@ -1481,8 +1508,11 @@ def service_contexts( continue try: if now - ctx.last_send >= config.interval: - if ctx.imu_samples: - ctx.report.imu_samples = ctx.imu_samples[-IMU_SAMPLES_PER_REPORT:] + # Consume up to 3 samples from the head of the queue (FIFO) + count = min(len(ctx.imu_samples), IMU_SAMPLES_PER_REPORT) + if count > 0: + ctx.report.imu_samples = ctx.imu_samples[:count] + ctx.imu_samples = ctx.imu_samples[count:] else: ctx.report.imu_samples = [] ctx.uart.send_report(ctx.report) @@ -1556,12 +1586,16 @@ def run_bridge_loop( sdl2.SDL_CONTROLLERBUTTONUP, ): handle_button_event(event, config, contexts) + elif event.type in (sdl2.SDL_CONTROLLERBUTTONDOWN, sdl2.SDL_CONTROLLERBUTTONUP): + handle_button_event(event, args, config, contexts) + elif event.type == SDL_CONTROLLERSENSORUPDATE: + handle_sensor_update(event, contexts, config) elif event.type == sdl2.SDL_CONTROLLERDEVICEADDED: handle_device_added( event, args, pairing, contexts, uarts, console, config ) elif event.type == sdl2.SDL_CONTROLLERDEVICEREMOVED: - handle_device_removed(event, pairing, contexts, console) + gandle_device_removed(event, pairing, contexts, console) now = time.monotonic() if now - last_port_scan > port_scan_interval: diff --git a/switch-pico.cpp b/switch-pico.cpp index d307e39..51f99d3 100644 --- a/switch-pico.cpp +++ b/switch-pico.cpp @@ -1,4 +1,4 @@ -#include +ginclude #include #include "bsp/board.h" #include "hardware/uart.h" @@ -61,12 +61,13 @@ static void on_rumble_from_switch(const uint8_t rumble[8]) { } // Consume UART bytes and forward complete frames to the Switch Pro driver. -static void poll_uart_frames() { +static bool poll_uart_frames() { static uint8_t buffer[64]; static uint8_t index = 0; static uint8_t expected_len = 0; static absolute_time_t last_byte_time = {0}; static bool has_last_byte = false; + bool new_data = false; while (uart_is_readable(UART_ID)) { uint8_t byte = uart_getc(UART_ID); @@ -123,8 +124,10 @@ static void poll_uart_frames() { } index = 0; expected_len = 0; + new_data = true; } } + return new_data; } static void log_usb_state() { @@ -159,9 +162,13 @@ int main() { while (true) { tud_task(); // USB device tasks - poll_uart_frames(); // Pull controller state from UART1 + bool new_data = poll_uart_frames(); // Pull controller state from UART1 SwitchInputState state = g_user_state; switch_pro_set_input(state); + bool should_update = new_data; + if (should_update) { + switch_pro_set_input(state); + } switch_pro_task(); // Push state to the Switch host log_usb_state(); } diff --git a/switch_pro_driver.cpp b/switch_pro_driver.cpp index ee2b899..432f88b 100644 --- a/switch_pro_driver.cpp +++ b/switch_pro_driver.cpp @@ -632,6 +632,9 @@ void switch_pro_task() { if (tud_hid_ready() && send_report(0, inputReport, report_size) == true ) { memcpy(last_report, inputReport, report_size); report_sent = true; + // Clear IMU samples so they aren't repeated in the next report + // if no new data arrives. + g_input_state.imu_sample_count = 0; } last_report_timer = now;