switch-pico/tools/read_pro_imu.py

149 lines
5.6 KiB
Python

#!/usr/bin/env python3
"""
Read raw IMU samples from a Nintendo Switch Pro Controller (or Pico spoof) over USB.
Uses the `hidapi` (pyhidapi) package. Press Ctrl+C to exit.
"""
import argparse
import struct
import sys
from typing import List, Tuple
import hid # from pyhidapi
DEFAULT_VENDOR_ID = 0x057E
DEFAULT_PRODUCT_ID = 0x2009 # Switch Pro Controller (USB)
def list_devices(filter_vid=None, filter_pid=None):
devices = hid.enumerate()
for d in devices:
if filter_vid and d["vendor_id"] != filter_vid:
continue
if filter_pid and d["product_id"] != filter_pid:
continue
print(
f"VID=0x{d['vendor_id']:04X} PID=0x{d['product_id']:04X} "
f"path={d.get('path')} "
f"serial={d.get('serial_number')} "
f"manufacturer={d.get('manufacturer_string')} "
f"product={d.get('product_string')} "
f"interface={d.get('interface_number')}"
)
return devices
def find_device(vendor_id: int, product_id: int):
for dev in hid.enumerate():
if dev["vendor_id"] == vendor_id and dev["product_id"] == product_id:
return dev
return None
def main():
parser = argparse.ArgumentParser(description="Read raw 0x30 reports (IMU) from a Switch Pro Controller / Pico.")
parser.add_argument("--vid", type=lambda x: int(x, 0), default=DEFAULT_VENDOR_ID, help="Vendor ID (default 0x057E)")
parser.add_argument("--pid", type=lambda x: int(x, 0), default=DEFAULT_PRODUCT_ID, help="Product ID (default 0x2009)")
parser.add_argument("--path", help="Explicit HID path to open (overrides VID/PID).")
parser.add_argument("--count", type=int, default=0, help="Stop after this many 0x30 reports (0 = infinite).")
parser.add_argument("--timeout", type=int, default=3000, help="Read timeout ms (default 3000).")
parser.add_argument("--list", action="store_true", help="List detected HID devices and exit.")
parser.add_argument("--plot", action="store_true", help="Plot accel/gyro traces after capture (requires matplotlib).")
parser.add_argument("--save-prefix", help="If set, save accel/gyro plots as '<prefix>_accel.png' and '<prefix>_gyro.png'.")
args = parser.parse_args()
if args.list:
list_devices()
return
if args.path:
dev_info = {"path": bytes(args.path, encoding="utf-8"), "vendor_id": args.vid, "product_id": args.pid}
else:
dev_info = find_device(args.vid, args.pid)
if not dev_info:
print(
f"No HID device found for VID=0x{args.vid:04X} PID=0x{args.pid:04X}. "
"Use --list to inspect devices or --path to target a specific one.",
file=sys.stderr,
)
sys.exit(1)
device = hid.device()
device.open_path(dev_info["path"])
device.set_nonblocking(False)
print(
f"Reading raw 0x30 reports from device (VID=0x{args.vid:04X} PID=0x{args.pid:04X})... "
"Ctrl+C to stop."
)
accel_series: List[Tuple[int, int, int]] = []
gyro_series: List[Tuple[int, int, int]] = []
try:
read_count = 0
while args.count == 0 or read_count < args.count:
data = device.read(64, timeout_ms=args.timeout)
if not data:
print(f"(timeout after {args.timeout} ms, no data)")
continue
if data[0] != 0x30:
print(f"(non-0x30 report id=0x{data[0]:02X}, len={len(data)})")
continue
samples = []
offset = 13 # accel_x starts at byte 13
for _ in range(3):
ax, ay, az, gx, gy, gz = struct.unpack_from("<hhhhhh", bytes(data), offset)
samples.append((ax, ay, az, gx, gy, gz))
offset += 12
print(samples)
accel_series.extend((s[0], s[1], s[2]) for s in samples)
gyro_series.extend((s[3], s[4], s[5]) for s in samples)
read_count += 1
except KeyboardInterrupt:
pass
finally:
device.close()
if args.plot:
try:
import matplotlib.pyplot as plt
except Exception as exc: # pragma: no cover - optional dependency
print(f"Unable to plot (matplotlib not available): {exc}", file=sys.stderr)
return
if accel_series and gyro_series:
# Each sample is a tuple of three axes; plot per axis vs sample index.
accel_x = [s[0] for s in accel_series]
accel_y = [s[1] for s in accel_series]
accel_z = [s[2] for s in accel_series]
gyro_x = [s[0] for s in gyro_series]
gyro_y = [s[1] for s in gyro_series]
gyro_z = [s[2] for s in gyro_series]
fig1, ax1 = plt.subplots()
ax1.plot(accel_x, label="ax")
ax1.plot(accel_y, label="ay")
ax1.plot(accel_z, label="az")
ax1.set_title("Accel (counts)")
ax1.set_xlabel("Sample")
ax1.set_ylabel("Counts")
ax1.legend()
fig2, ax2 = plt.subplots()
ax2.plot(gyro_x, label="gx")
ax2.plot(gyro_y, label="gy")
ax2.plot(gyro_z, label="gz")
ax2.set_title("Gyro (counts)")
ax2.set_xlabel("Sample")
ax2.set_ylabel("Counts")
ax2.legend()
if args.save_prefix:
fig1.savefig(f"{args.save_prefix}_accel.png", dpi=150, bbox_inches="tight")
fig2.savefig(f"{args.save_prefix}_gyro.png", dpi=150, bbox_inches="tight")
print(f"Saved plots to {args.save_prefix}_accel.png and {args.save_prefix}_gyro.png")
plt.show()
if __name__ == "__main__":
main()