Refactor into lib
This commit is contained in:
parent
491a35888a
commit
396d8ba33d
3 changed files with 359 additions and 184 deletions
|
|
@ -16,15 +16,12 @@ Features inspired by ``host/controller_bridge.py``:
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from ctypes import create_string_buffer
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import serial
|
||||
from serial import SerialException
|
||||
from serial.tools import list_ports
|
||||
from serial.tools import list_ports_common
|
||||
|
|
@ -35,46 +32,24 @@ from rich.console import Console
|
|||
from rich.prompt import Prompt
|
||||
from rich.table import Table
|
||||
|
||||
UART_HEADER = 0xAA
|
||||
RUMBLE_HEADER = 0xBB
|
||||
RUMBLE_TYPE_RUMBLE = 0x01
|
||||
UART_BAUD = 921600
|
||||
from switch_pico_uart import (
|
||||
UART_BAUD,
|
||||
PicoUART,
|
||||
SwitchButton,
|
||||
SwitchHat,
|
||||
SwitchReport,
|
||||
axis_to_stick,
|
||||
dpad_to_hat,
|
||||
decode_rumble,
|
||||
trigger_to_button,
|
||||
)
|
||||
|
||||
RUMBLE_IDLE_TIMEOUT = 0.25 # seconds without packets before forcing rumble off
|
||||
RUMBLE_STUCK_TIMEOUT = 0.60 # continuous same-energy rumble will be stopped after this
|
||||
RUMBLE_MIN_ACTIVE = 0.50 # below this, rumble is treated as off/noise
|
||||
RUMBLE_SCALE = 0.8
|
||||
|
||||
|
||||
class SwitchButton:
|
||||
# Mirrors the masks defined in switch_pro_descriptors.h
|
||||
Y = 1 << 0
|
||||
B = 1 << 1
|
||||
A = 1 << 2
|
||||
X = 1 << 3
|
||||
L = 1 << 4
|
||||
R = 1 << 5
|
||||
ZL = 1 << 6
|
||||
ZR = 1 << 7
|
||||
MINUS = 1 << 8
|
||||
PLUS = 1 << 9
|
||||
LCLICK = 1 << 10
|
||||
RCLICK = 1 << 11
|
||||
HOME = 1 << 12
|
||||
CAPTURE = 1 << 13
|
||||
|
||||
|
||||
class SwitchHat:
|
||||
TOP = 0x00
|
||||
TOP_RIGHT = 0x01
|
||||
RIGHT = 0x02
|
||||
BOTTOM_RIGHT = 0x03
|
||||
BOTTOM = 0x04
|
||||
BOTTOM_LEFT = 0x05
|
||||
LEFT = 0x06
|
||||
TOP_LEFT = 0x07
|
||||
CENTER = 0x08
|
||||
|
||||
|
||||
def parse_mapping(value: str) -> Tuple[int, str]:
|
||||
"""Parse 'index:serial_port' CLI mapping argument."""
|
||||
if ":" not in value:
|
||||
|
|
@ -89,19 +64,6 @@ def parse_mapping(value: str) -> Tuple[int, str]:
|
|||
return idx, port.strip()
|
||||
|
||||
|
||||
def axis_to_stick(value: int, deadzone: int) -> int:
|
||||
"""Convert a signed SDL axis value to 0-255 stick range with deadzone."""
|
||||
if abs(value) < deadzone:
|
||||
value = 0
|
||||
scaled = int((value + 32768) * 255 / 65535)
|
||||
return max(0, min(255, scaled))
|
||||
|
||||
|
||||
def trigger_to_button(value: int, threshold: int) -> bool:
|
||||
"""Return True if analog trigger crosses digital threshold."""
|
||||
return value >= threshold
|
||||
|
||||
|
||||
def set_hint(name: str, value: str) -> None:
|
||||
"""Set an SDL hint safely even if the constant is missing in PySDL2."""
|
||||
try:
|
||||
|
|
@ -134,32 +96,6 @@ DPAD_BUTTONS = {
|
|||
}
|
||||
|
||||
|
||||
def dpad_to_hat(flags: Dict[str, bool]) -> int:
|
||||
"""Translate DPAD button flags into a Switch hat value."""
|
||||
up = flags["up"]
|
||||
down = flags["down"]
|
||||
left = flags["left"]
|
||||
right = flags["right"]
|
||||
|
||||
if up and right:
|
||||
return SwitchHat.TOP_RIGHT
|
||||
if up and left:
|
||||
return SwitchHat.TOP_LEFT
|
||||
if down and right:
|
||||
return SwitchHat.BOTTOM_RIGHT
|
||||
if down and left:
|
||||
return SwitchHat.BOTTOM_LEFT
|
||||
if up:
|
||||
return SwitchHat.TOP
|
||||
if down:
|
||||
return SwitchHat.BOTTOM
|
||||
if right:
|
||||
return SwitchHat.RIGHT
|
||||
if left:
|
||||
return SwitchHat.LEFT
|
||||
return SwitchHat.CENTER
|
||||
|
||||
|
||||
def is_usb_serial(path: str) -> bool:
|
||||
"""Heuristic for USB serial path prefixes."""
|
||||
if path.startswith("/dev/tty.") and not path.startswith("/dev/tty.usb"):
|
||||
|
|
@ -244,110 +180,6 @@ def interactive_pairing(
|
|||
return mappings
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwitchReport:
|
||||
buttons: int = 0
|
||||
hat: int = SwitchHat.CENTER
|
||||
lx: int = 128
|
||||
ly: int = 128
|
||||
rx: int = 128
|
||||
ry: int = 128
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
"""Serialize the report into the UART packet format."""
|
||||
return struct.pack(
|
||||
"<BHBBBBB", UART_HEADER, self.buttons & 0xFFFF, self.hat & 0xFF, self.lx, self.ly, self.rx, self.ry
|
||||
)
|
||||
|
||||
|
||||
class PicoUART:
|
||||
def __init__(self, port: str, baudrate: int = UART_BAUD) -> None:
|
||||
"""Open a UART connection to the Pico with non-blocking IO."""
|
||||
self.serial = serial.Serial(
|
||||
port=port,
|
||||
baudrate=baudrate,
|
||||
bytesize=serial.EIGHTBITS,
|
||||
stopbits=serial.STOPBITS_ONE,
|
||||
parity=serial.PARITY_NONE,
|
||||
timeout=0.0,
|
||||
write_timeout=0.0,
|
||||
xonxoff=False,
|
||||
rtscts=False,
|
||||
dsrdtr=False,
|
||||
)
|
||||
self._buffer = bytearray()
|
||||
|
||||
def send_report(self, report: SwitchReport) -> None:
|
||||
"""Send a controller report to the Pico."""
|
||||
# Non-blocking write; no flush to avoid sync stalls.
|
||||
self.serial.write(report.to_bytes())
|
||||
|
||||
def read_rumble_payload(self) -> Optional[bytes]:
|
||||
"""
|
||||
Drain all currently available UART bytes into an internal buffer,
|
||||
then try to extract a single valid rumble frame.
|
||||
|
||||
Frame format:
|
||||
0: 0xBB (RUMBLE_HEADER)
|
||||
1: type (0x01 for rumble)
|
||||
2-9: 8-byte rumble payload
|
||||
10: checksum (sum of first 10 bytes) & 0xFF
|
||||
"""
|
||||
# Read whatever is waiting in OS buffer
|
||||
waiting = self.serial.in_waiting
|
||||
if waiting:
|
||||
self._buffer.extend(self.serial.read(waiting))
|
||||
|
||||
while True:
|
||||
if not self._buffer:
|
||||
return None
|
||||
|
||||
start = self._buffer.find(bytes([RUMBLE_HEADER]))
|
||||
if start < 0:
|
||||
# No header at all, drop garbage
|
||||
self._buffer.clear()
|
||||
return None
|
||||
|
||||
# Not enough data for a full frame yet
|
||||
if len(self._buffer) - start < 11:
|
||||
if start > 0:
|
||||
del self._buffer[:start]
|
||||
return None
|
||||
|
||||
frame = self._buffer[start:start + 11]
|
||||
checksum = sum(frame[:10]) & 0xFF
|
||||
|
||||
if frame[1] == RUMBLE_TYPE_RUMBLE and checksum == frame[10]:
|
||||
payload = bytes(frame[2:10])
|
||||
del self._buffer[:start + 11]
|
||||
return payload
|
||||
|
||||
# Bad frame, drop this header and resync to the next candidate
|
||||
del self._buffer[:start + 1]
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the UART connection."""
|
||||
self.serial.close()
|
||||
|
||||
|
||||
def decode_rumble(payload: bytes) -> Tuple[float, float]:
|
||||
"""Return normalized rumble amplitudes (0.0-1.0) for left/right."""
|
||||
if len(payload) < 8:
|
||||
return 0.0, 0.0
|
||||
# Neutral/idle pattern used by Switch: no rumble energy.
|
||||
if payload == b"\x00\x01\x40\x40\x00\x01\x40\x40":
|
||||
return 0.0, 0.0
|
||||
# Rumble amp is 10 bits across bytes 0/1 and 4/5.
|
||||
# Switch format is right rumble first, then left rumble (4 bytes each).
|
||||
right_raw = ((payload[1] & 0x03) << 8) | payload[0]
|
||||
left_raw = ((payload[5] & 0x03) << 8) | payload[4]
|
||||
if left_raw < 8 and right_raw < 8:
|
||||
return 0.0, 0.0
|
||||
left = min(max(left_raw / 1023.0, 0.0), 1.0)
|
||||
right = min(max(right_raw / 1023.0, 0.0), 1.0)
|
||||
return left, right
|
||||
|
||||
|
||||
def apply_rumble(controller: sdl2.SDL_GameController, payload: bytes) -> float:
|
||||
"""Apply rumble payload to SDL controller and return max normalized energy."""
|
||||
left_norm, right_norm = decode_rumble(payload)
|
||||
|
|
@ -515,7 +347,7 @@ def build_arg_parser() -> argparse.ArgumentParser:
|
|||
return parser
|
||||
|
||||
|
||||
def poll_controller_buttons(ctx: ControllerContext, button_map: Dict[int, int]) -> None:
|
||||
def poll_controller_buttons(ctx: ControllerContext, button_map: Dict[int, SwitchButton]) -> None:
|
||||
"""Update button/hat state based on current SDL controller readings."""
|
||||
changed = False
|
||||
for sdl_button, switch_bit in button_map.items():
|
||||
|
|
@ -547,8 +379,8 @@ class BridgeConfig:
|
|||
interval: float
|
||||
deadzone_raw: int
|
||||
trigger_threshold: int
|
||||
button_map_default: Dict[int, int]
|
||||
button_map_swapped: Dict[int, int]
|
||||
button_map_default: Dict[int, SwitchButton]
|
||||
button_map_swapped: Dict[int, SwitchButton]
|
||||
swap_abxy_indices: set[int]
|
||||
swap_abxy_ids: set[str]
|
||||
|
||||
|
|
@ -565,7 +397,7 @@ class PairingState:
|
|||
include_port_desc: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def load_button_maps(console: Console, args: argparse.Namespace) -> Tuple[Dict[int, int], Dict[int, int], set[int]]:
|
||||
def load_button_maps(console: Console, args: argparse.Namespace) -> Tuple[Dict[int, SwitchButton], Dict[int, SwitchButton], set[int]]:
|
||||
"""Load SDL controller mappings and return button map variants."""
|
||||
default_mapping = Path(__file__).parent / "controller_db" / "gamecontrollerdb.txt"
|
||||
mappings_to_load: List[str] = []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue