From 0db04be85883ec0fb50e47985c9ea61f94503408 Mon Sep 17 00:00:00 2001 From: Joey Yakimowich-Payne Date: Mon, 16 Mar 2026 11:50:32 -0600 Subject: [PATCH] feat(firmware): add IMU data structures and UART v2 parser - Add SwitchImuSample struct and IMU fields to SwitchInputState - Add fill_imu_report_data() to pack samples into imuData[36] - Rewrite switch_pro_apply_uart_packet() for v2 with checksum validation - Enlarge poll_uart_frames() buffer to 64 bytes for variable-length frames - Clear imu_sample_count after each USB report send (prevent stale IMU) --- switch-pico.cpp | 40 +++++++++++---- switch_pro_driver.cpp | 111 ++++++++++++++++++++++++++++++++++-------- switch_pro_driver.h | 12 +++++ 3 files changed, 135 insertions(+), 28 deletions(-) diff --git a/switch-pico.cpp b/switch-pico.cpp index 5efd6e9..22beaef 100644 --- a/switch-pico.cpp +++ b/switch-pico.cpp @@ -61,11 +61,13 @@ static void on_rumble_from_switch(const uint8_t rumble[8]) { } // Consume UART bytes and forward complete frames to the Switch Pro driver. -static void poll_uart_frames() { - static uint8_t buffer[8]; +static bool poll_uart_frames() { + static uint8_t buffer[64]; static uint8_t index = 0; + static uint8_t expected_len = 0; static absolute_time_t last_byte_time = {0}; static bool has_last_byte = false; + bool new_data = false; while (uart_is_readable(UART_ID)) { uint8_t byte = uart_getc(UART_ID); @@ -73,6 +75,7 @@ static void poll_uart_frames() { uint64_t now = to_ms_since_boot(get_absolute_time()); if (has_last_byte && (now - to_ms_since_boot(last_byte_time)) > 20) { index = 0; // stale data, restart frame + expected_len = 0; } last_byte_time = get_absolute_time(); has_last_byte = true; @@ -83,11 +86,26 @@ static void poll_uart_frames() { } } - buffer[index++] = byte; if (index >= sizeof(buffer)) { + index = 0; + expected_len = 0; + } + + buffer[index++] = byte; + if (index == 3) { + expected_len = static_cast(buffer[2] + 4u); + if (expected_len < 12 || expected_len > sizeof(buffer)) { + index = 0; + expected_len = 0; + continue; + } + } + + if (expected_len > 0 && index >= expected_len) { SwitchInputState parsed{}; - if (switch_pro_apply_uart_packet(buffer, sizeof(buffer), &parsed)) { + if (switch_pro_apply_uart_packet(buffer, expected_len, &parsed)) { g_user_state = parsed; + new_data = true; LOG_PRINTF("[UART] packet buttons=0x%04x hat=%u lx=%u ly=%u rx=%u ry=%u\n", (parsed.button_a ? SWITCH_PRO_MASK_A : 0) | (parsed.button_b ? SWITCH_PRO_MASK_B : 0) | @@ -104,14 +122,17 @@ static void poll_uart_frames() { (parsed.button_l3 ? SWITCH_PRO_MASK_L3 : 0) | (parsed.button_r3 ? SWITCH_PRO_MASK_R3 : 0), parsed.dpad_up ? SWITCH_PRO_HAT_UP : - parsed.dpad_down ? SWITCH_PRO_HAT_DOWN : - parsed.dpad_left ? SWITCH_PRO_HAT_LEFT : - parsed.dpad_right ? SWITCH_PRO_HAT_RIGHT : SWITCH_PRO_HAT_NOTHING, - parsed.lx >> 8, parsed.ly >> 8, parsed.rx >> 8, parsed.ry >> 8); + parsed.dpad_down ? SWITCH_PRO_HAT_DOWN : + parsed.dpad_left ? SWITCH_PRO_HAT_LEFT : + parsed.dpad_right ? SWITCH_PRO_HAT_RIGHT : SWITCH_PRO_HAT_NOTHING, + parsed.lx >> 8, parsed.ly >> 8, parsed.rx >> 8, parsed.ry >> 8); } index = 0; + expected_len = 0; } } + + return new_data; } static void log_usb_state() { @@ -146,7 +167,8 @@ int main() { while (true) { tud_task(); // USB device tasks - poll_uart_frames(); // Pull controller state from UART1 + bool new_data = poll_uart_frames(); // Pull controller state from UART1 + (void)new_data; SwitchInputState state = g_user_state; switch_pro_set_input(state); switch_pro_task(); // Push state to the Switch host diff --git a/switch_pro_driver.cpp b/switch_pro_driver.cpp index 138ba8f..ee68f68 100644 --- a/switch_pro_driver.cpp +++ b/switch_pro_driver.cpp @@ -184,12 +184,39 @@ static std::map spi_flash_data = { static inline uint16_t scale16To12(uint16_t pos) { return pos >> 4; } +static void fill_imu_report_data(const SwitchInputState& state) { + if (state.imu_sample_count == 0) { + memset(switch_report.imuData, 0x00, sizeof(switch_report.imuData)); + return; + } + uint8_t sample_count = state.imu_sample_count > 3 ? 3 : state.imu_sample_count; + // If fewer than 3 samples, duplicate the last one to fill all 3 slots + uint8_t* dst = switch_report.imuData; + for (uint8_t i = 0; i < 3; ++i) { + const SwitchImuSample& s = (i < sample_count) ? state.imu_samples[i] : state.imu_samples[sample_count - 1]; + dst[0] = static_cast(s.accel_x & 0xFF); + dst[1] = static_cast((s.accel_x >> 8) & 0xFF); + dst[2] = static_cast(s.accel_y & 0xFF); + dst[3] = static_cast((s.accel_y >> 8) & 0xFF); + dst[4] = static_cast(s.accel_z & 0xFF); + dst[5] = static_cast((s.accel_z >> 8) & 0xFF); + dst[6] = static_cast(s.gyro_x & 0xFF); + dst[7] = static_cast((s.gyro_x >> 8) & 0xFF); + dst[8] = static_cast(s.gyro_y & 0xFF); + dst[9] = static_cast((s.gyro_y >> 8) & 0xFF); + dst[10] = static_cast(s.gyro_z & 0xFF); + dst[11] = static_cast((s.gyro_z >> 8) & 0xFF); + dst += 12; + } +} + static SwitchInputState make_neutral_state() { SwitchInputState s{}; s.lx = SWITCH_PRO_JOYSTICK_MID; s.ly = SWITCH_PRO_JOYSTICK_MID; s.rx = SWITCH_PRO_JOYSTICK_MID; s.ry = SWITCH_PRO_JOYSTICK_MID; + s.imu_sample_count = 0; return s; } @@ -485,6 +512,7 @@ static void update_switch_report_from_state() { switch_report.inputs.rightStick.setX(std::min(std::max(scaleRightStickX,rightMinX), rightMaxX)); switch_report.inputs.rightStick.setY(-std::min(std::max(scaleRightStickY,rightMinY), rightMaxY)); + fill_imu_report_data(g_input_state); switch_report.rumbleReport = 0x09; } @@ -594,14 +622,13 @@ void switch_pro_task() { switch_report.timestamp = last_report_counter; void * inputReport = &switch_report; uint16_t report_size = sizeof(switch_report); - if (memcmp(last_report, inputReport, report_size) != 0) { - if (tud_hid_ready() && send_report(0, inputReport, report_size) == true ) { - memcpy(last_report, inputReport, report_size); - report_sent = true; - } - - last_report_timer = now; + if (tud_hid_ready() && send_report(0, inputReport, report_size) == true ) { + memcpy(last_report, inputReport, report_size); + g_input_state.imu_sample_count = 0; + report_sent = true; } + + last_report_timer = now; } } else { if (!is_initialized) { @@ -617,24 +644,71 @@ void switch_pro_task() { } bool switch_pro_apply_uart_packet(const uint8_t* packet, uint8_t length, SwitchInputState* out_state) { - // Packet format: 0xAA, buttons(2 LE), hat, lx, ly, rx, ry - if (length < 8 || packet[0] != 0xAA) { + // v2 format: 0xAA + 0x02 + payload_len + payload... + checksum + if (length < 12) { + return false; + } + if (packet[0] != 0xAA) { + return false; + } + if (packet[1] != 0x02) { + return false; + } + + uint8_t payload_len = packet[2]; + if ((uint16_t)payload_len + 4u != length) { + return false; + } + + uint16_t sum = 0; + for (uint16_t i = 0; i < (uint16_t)(3u + payload_len); ++i) { + sum += packet[i]; + } + if ((sum & 0xFF) != packet[length - 1]) { + return false; + } + + // payload: buttons(2 LE), hat, lx, ly, rx, ry, imu_count, [imu_samples...] + if (payload_len < 8) { return false; } SwitchProOutReport out{}; - out.buttons = static_cast(packet[1]) | (static_cast(packet[2]) << 8); - out.hat = packet[3]; - out.lx = packet[4]; - out.ly = packet[5]; - out.rx = packet[6]; - out.ry = packet[7]; + out.buttons = static_cast(packet[3]) | (static_cast(packet[4]) << 8); + out.hat = packet[5]; + out.lx = packet[6]; + out.ly = packet[7]; + out.rx = packet[8]; + out.ry = packet[9]; + uint8_t imu_count = packet[10]; + if (imu_count > 3) { + imu_count = 3; + } + + uint16_t required_payload_len = static_cast(8u + static_cast(imu_count) * 12u); + if (payload_len < required_payload_len) { + return false; + } auto expand_axis = [](uint8_t v) -> uint16_t { return static_cast(v) << 8 | v; }; SwitchInputState state = make_neutral_state(); + state.imu_sample_count = imu_count; + + auto read_int16 = [](const uint8_t* src) -> int16_t { + return static_cast(static_cast(src[0]) | (static_cast(src[1]) << 8)); + }; + for (uint8_t i = 0; i < imu_count; ++i) { + const uint8_t* base = &packet[11 + i * 12]; + state.imu_samples[i].accel_x = read_int16(base + 0); + state.imu_samples[i].accel_y = read_int16(base + 2); + state.imu_samples[i].accel_z = read_int16(base + 4); + state.imu_samples[i].gyro_x = read_int16(base + 6); + state.imu_samples[i].gyro_y = read_int16(base + 8); + state.imu_samples[i].gyro_z = read_int16(base + 10); + } switch (out.hat) { case SWITCH_PRO_HAT_UP: state.dpad_up = true; break; @@ -668,11 +742,10 @@ bool switch_pro_apply_uart_packet(const uint8_t* packet, uint8_t length, SwitchI state.rx = expand_axis(out.rx); state.ry = expand_axis(out.ry); - if (out_state) { - *out_state = state; - } else { - switch_pro_set_input(state); + if (!out_state) { + return false; } + *out_state = state; return true; } diff --git a/switch_pro_driver.h b/switch_pro_driver.h index 73d8582..951d68b 100644 --- a/switch_pro_driver.h +++ b/switch_pro_driver.h @@ -10,6 +10,15 @@ #include #include "switch_pro_descriptors.h" +typedef struct { + int16_t accel_x; + int16_t accel_y; + int16_t accel_z; + int16_t gyro_x; + int16_t gyro_y; + int16_t gyro_z; +} SwitchImuSample; + typedef struct { bool dpad_up; bool dpad_down; @@ -35,6 +44,9 @@ typedef struct { uint16_t ly; uint16_t rx; uint16_t ry; + + uint8_t imu_sample_count; // 0-3 + SwitchImuSample imu_samples[3]; } SwitchInputState; // Initialize USB state and calibration before entering the main loop.