feat(firmware): add IMU data structures and UART v2 parser

- Add SwitchImuSample struct and IMU fields to SwitchInputState
- Add fill_imu_report_data() to pack samples into imuData[36]
- Rewrite switch_pro_apply_uart_packet() for v2 with checksum validation
- Enlarge poll_uart_frames() buffer to 64 bytes for variable-length frames
- Clear imu_sample_count after each USB report send (prevent stale IMU)
This commit is contained in:
Joey Yakimowich-Payne 2026-03-16 11:50:32 -06:00
commit 0db04be858
No known key found for this signature in database
GPG key ID: DDF6AF5B21B407D4
3 changed files with 135 additions and 28 deletions

View file

@ -184,12 +184,39 @@ static std::map<uint32_t, const uint8_t*> spi_flash_data = {
static inline uint16_t scale16To12(uint16_t pos) { return pos >> 4; }
static void fill_imu_report_data(const SwitchInputState& state) {
if (state.imu_sample_count == 0) {
memset(switch_report.imuData, 0x00, sizeof(switch_report.imuData));
return;
}
uint8_t sample_count = state.imu_sample_count > 3 ? 3 : state.imu_sample_count;
// If fewer than 3 samples, duplicate the last one to fill all 3 slots
uint8_t* dst = switch_report.imuData;
for (uint8_t i = 0; i < 3; ++i) {
const SwitchImuSample& s = (i < sample_count) ? state.imu_samples[i] : state.imu_samples[sample_count - 1];
dst[0] = static_cast<uint8_t>(s.accel_x & 0xFF);
dst[1] = static_cast<uint8_t>((s.accel_x >> 8) & 0xFF);
dst[2] = static_cast<uint8_t>(s.accel_y & 0xFF);
dst[3] = static_cast<uint8_t>((s.accel_y >> 8) & 0xFF);
dst[4] = static_cast<uint8_t>(s.accel_z & 0xFF);
dst[5] = static_cast<uint8_t>((s.accel_z >> 8) & 0xFF);
dst[6] = static_cast<uint8_t>(s.gyro_x & 0xFF);
dst[7] = static_cast<uint8_t>((s.gyro_x >> 8) & 0xFF);
dst[8] = static_cast<uint8_t>(s.gyro_y & 0xFF);
dst[9] = static_cast<uint8_t>((s.gyro_y >> 8) & 0xFF);
dst[10] = static_cast<uint8_t>(s.gyro_z & 0xFF);
dst[11] = static_cast<uint8_t>((s.gyro_z >> 8) & 0xFF);
dst += 12;
}
}
static SwitchInputState make_neutral_state() {
SwitchInputState s{};
s.lx = SWITCH_PRO_JOYSTICK_MID;
s.ly = SWITCH_PRO_JOYSTICK_MID;
s.rx = SWITCH_PRO_JOYSTICK_MID;
s.ry = SWITCH_PRO_JOYSTICK_MID;
s.imu_sample_count = 0;
return s;
}
@ -485,6 +512,7 @@ static void update_switch_report_from_state() {
switch_report.inputs.rightStick.setX(std::min(std::max(scaleRightStickX,rightMinX), rightMaxX));
switch_report.inputs.rightStick.setY(-std::min(std::max(scaleRightStickY,rightMinY), rightMaxY));
fill_imu_report_data(g_input_state);
switch_report.rumbleReport = 0x09;
}
@ -594,14 +622,13 @@ void switch_pro_task() {
switch_report.timestamp = last_report_counter;
void * inputReport = &switch_report;
uint16_t report_size = sizeof(switch_report);
if (memcmp(last_report, inputReport, report_size) != 0) {
if (tud_hid_ready() && send_report(0, inputReport, report_size) == true ) {
memcpy(last_report, inputReport, report_size);
report_sent = true;
}
last_report_timer = now;
if (tud_hid_ready() && send_report(0, inputReport, report_size) == true ) {
memcpy(last_report, inputReport, report_size);
g_input_state.imu_sample_count = 0;
report_sent = true;
}
last_report_timer = now;
}
} else {
if (!is_initialized) {
@ -617,24 +644,71 @@ void switch_pro_task() {
}
bool switch_pro_apply_uart_packet(const uint8_t* packet, uint8_t length, SwitchInputState* out_state) {
// Packet format: 0xAA, buttons(2 LE), hat, lx, ly, rx, ry
if (length < 8 || packet[0] != 0xAA) {
// v2 format: 0xAA + 0x02 + payload_len + payload... + checksum
if (length < 12) {
return false;
}
if (packet[0] != 0xAA) {
return false;
}
if (packet[1] != 0x02) {
return false;
}
uint8_t payload_len = packet[2];
if ((uint16_t)payload_len + 4u != length) {
return false;
}
uint16_t sum = 0;
for (uint16_t i = 0; i < (uint16_t)(3u + payload_len); ++i) {
sum += packet[i];
}
if ((sum & 0xFF) != packet[length - 1]) {
return false;
}
// payload: buttons(2 LE), hat, lx, ly, rx, ry, imu_count, [imu_samples...]
if (payload_len < 8) {
return false;
}
SwitchProOutReport out{};
out.buttons = static_cast<uint16_t>(packet[1]) | (static_cast<uint16_t>(packet[2]) << 8);
out.hat = packet[3];
out.lx = packet[4];
out.ly = packet[5];
out.rx = packet[6];
out.ry = packet[7];
out.buttons = static_cast<uint16_t>(packet[3]) | (static_cast<uint16_t>(packet[4]) << 8);
out.hat = packet[5];
out.lx = packet[6];
out.ly = packet[7];
out.rx = packet[8];
out.ry = packet[9];
uint8_t imu_count = packet[10];
if (imu_count > 3) {
imu_count = 3;
}
uint16_t required_payload_len = static_cast<uint16_t>(8u + static_cast<uint16_t>(imu_count) * 12u);
if (payload_len < required_payload_len) {
return false;
}
auto expand_axis = [](uint8_t v) -> uint16_t {
return static_cast<uint16_t>(v) << 8 | v;
};
SwitchInputState state = make_neutral_state();
state.imu_sample_count = imu_count;
auto read_int16 = [](const uint8_t* src) -> int16_t {
return static_cast<int16_t>(static_cast<uint16_t>(src[0]) | (static_cast<uint16_t>(src[1]) << 8));
};
for (uint8_t i = 0; i < imu_count; ++i) {
const uint8_t* base = &packet[11 + i * 12];
state.imu_samples[i].accel_x = read_int16(base + 0);
state.imu_samples[i].accel_y = read_int16(base + 2);
state.imu_samples[i].accel_z = read_int16(base + 4);
state.imu_samples[i].gyro_x = read_int16(base + 6);
state.imu_samples[i].gyro_y = read_int16(base + 8);
state.imu_samples[i].gyro_z = read_int16(base + 10);
}
switch (out.hat) {
case SWITCH_PRO_HAT_UP: state.dpad_up = true; break;
@ -668,11 +742,10 @@ bool switch_pro_apply_uart_packet(const uint8_t* packet, uint8_t length, SwitchI
state.rx = expand_axis(out.rx);
state.ry = expand_axis(out.ry);
if (out_state) {
*out_state = state;
} else {
switch_pro_set_input(state);
if (!out_state) {
return false;
}
*out_state = state;
return true;
}