This commit is contained in:
Joey Yakimowich-Payne 2025-11-23 09:31:43 -07:00
commit 4b7dbc4fbd
No known key found for this signature in database
GPG key ID: 6BFE655FA5ABD1E1

View file

@ -75,6 +75,7 @@ class SwitchHat:
def parse_mapping(value: str) -> Tuple[int, str]: def parse_mapping(value: str) -> Tuple[int, str]:
"""Parse 'index:serial_port' CLI mapping argument."""
if ":" not in value: if ":" not in value:
raise argparse.ArgumentTypeError("Mapping must look like 'index:serial_port'") raise argparse.ArgumentTypeError("Mapping must look like 'index:serial_port'")
idx_str, port = value.split(":", 1) idx_str, port = value.split(":", 1)
@ -88,6 +89,7 @@ def parse_mapping(value: str) -> Tuple[int, str]:
def axis_to_stick(value: int, deadzone: int) -> int: def axis_to_stick(value: int, deadzone: int) -> int:
"""Convert a signed SDL axis value to 0-255 stick range with deadzone."""
if abs(value) < deadzone: if abs(value) < deadzone:
value = 0 value = 0
scaled = int((value + 32768) * 255 / 65535) scaled = int((value + 32768) * 255 / 65535)
@ -95,6 +97,7 @@ def axis_to_stick(value: int, deadzone: int) -> int:
def trigger_to_button(value: int, threshold: int) -> bool: def trigger_to_button(value: int, threshold: int) -> bool:
"""Return True if analog trigger crosses digital threshold."""
return value >= threshold return value >= threshold
@ -131,6 +134,7 @@ DPAD_BUTTONS = {
def dpad_to_hat(flags: Dict[str, bool]) -> int: def dpad_to_hat(flags: Dict[str, bool]) -> int:
"""Translate DPAD button flags into a Switch hat value."""
up = flags["up"] up = flags["up"]
down = flags["down"] down = flags["down"]
left = flags["left"] left = flags["left"]
@ -156,6 +160,7 @@ def dpad_to_hat(flags: Dict[str, bool]) -> int:
def is_usb_serial(path: str) -> bool: def is_usb_serial(path: str) -> bool:
"""Heuristic for USB serial path prefixes."""
if path.startswith("/dev/tty.") and not path.startswith("/dev/tty.usb"): if path.startswith("/dev/tty.") and not path.startswith("/dev/tty.usb"):
return False return False
if path.startswith("/dev/cu.") and not path.startswith("/dev/cu.usb"): if path.startswith("/dev/cu.") and not path.startswith("/dev/cu.usb"):
@ -179,6 +184,7 @@ def discover_ports(
ignore_descriptions: Optional[List[str]] = None, ignore_descriptions: Optional[List[str]] = None,
include_descriptions: Optional[List[str]] = None, include_descriptions: Optional[List[str]] = None,
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
"""List serial ports, optionally filtering by description and USB-ness."""
ignored = [d.lower() for d in ignore_descriptions or []] ignored = [d.lower() for d in ignore_descriptions or []]
includes = [d.lower() for d in include_descriptions or []] includes = [d.lower() for d in include_descriptions or []]
results: List[Dict[str, str]] = [] results: List[Dict[str, str]] = []
@ -205,6 +211,7 @@ def discover_ports(
def interactive_pairing( def interactive_pairing(
console: Console, controller_info: Dict[int, str], ports: List[Dict[str, str]] console: Console, controller_info: Dict[int, str], ports: List[Dict[str, str]]
) -> List[Tuple[int, str]]: ) -> List[Tuple[int, str]]:
"""Prompt the user to pair controllers to UART ports via Rich UI."""
available = ports.copy() available = ports.copy()
mappings: List[Tuple[int, str]] = [] mappings: List[Tuple[int, str]] = []
for controller_idx in controller_info: for controller_idx in controller_info:
@ -246,6 +253,7 @@ class SwitchReport:
ry: int = 128 ry: int = 128
def to_bytes(self) -> bytes: def to_bytes(self) -> bytes:
"""Serialize the report into the UART packet format."""
return struct.pack( return struct.pack(
"<BHBBBBB", UART_HEADER, self.buttons & 0xFFFF, self.hat & 0xFF, self.lx, self.ly, self.rx, self.ry "<BHBBBBB", UART_HEADER, self.buttons & 0xFFFF, self.hat & 0xFF, self.lx, self.ly, self.rx, self.ry
) )
@ -253,6 +261,7 @@ class SwitchReport:
class PicoUART: class PicoUART:
def __init__(self, port: str, baudrate: int = UART_BAUD) -> None: def __init__(self, port: str, baudrate: int = UART_BAUD) -> None:
"""Open a UART connection to the Pico with non-blocking IO."""
self.serial = serial.Serial( self.serial = serial.Serial(
port=port, port=port,
baudrate=baudrate, baudrate=baudrate,
@ -268,6 +277,7 @@ class PicoUART:
self._buffer = bytearray() self._buffer = bytearray()
def send_report(self, report: SwitchReport) -> None: def send_report(self, report: SwitchReport) -> None:
"""Send a controller report to the Pico."""
# Non-blocking write; no flush to avoid sync stalls. # Non-blocking write; no flush to avoid sync stalls.
self.serial.write(report.to_bytes()) self.serial.write(report.to_bytes())
@ -311,10 +321,11 @@ class PicoUART:
del self._buffer[:start + 11] del self._buffer[:start + 11]
return payload return payload
# Bad frame, drop this header and resync # Bad frame, drop this header and resync to the next candidate
del self._buffer[:start + 1] del self._buffer[:start + 1]
def close(self) -> None: def close(self) -> None:
"""Close the UART connection."""
self.serial.close() self.serial.close()
@ -337,6 +348,7 @@ def decode_rumble(payload: bytes) -> Tuple[float, float]:
def apply_rumble(controller: sdl2.SDL_GameController, payload: bytes) -> float: def apply_rumble(controller: sdl2.SDL_GameController, payload: bytes) -> float:
"""Apply rumble payload to SDL controller and return max normalized energy."""
left_norm, right_norm = decode_rumble(payload) left_norm, right_norm = decode_rumble(payload)
max_norm = max(left_norm, right_norm) max_norm = max(left_norm, right_norm)
# Treat small rumble as "off" to avoid idle buzz. # Treat small rumble as "off" to avoid idle buzz.
@ -372,6 +384,7 @@ class ControllerContext:
def open_controller(index: int) -> Tuple[sdl2.SDL_GameController, int]: def open_controller(index: int) -> Tuple[sdl2.SDL_GameController, int]:
"""Open an SDL GameController by index and return it with instance ID."""
controller = sdl2.SDL_GameControllerOpen(index) controller = sdl2.SDL_GameControllerOpen(index)
if not controller: if not controller:
raise RuntimeError(f"Failed to open controller {index}: {sdl2.SDL_GetError().decode()}") raise RuntimeError(f"Failed to open controller {index}: {sdl2.SDL_GetError().decode()}")
@ -381,6 +394,7 @@ def open_controller(index: int) -> Tuple[sdl2.SDL_GameController, int]:
def try_open_uart(port: str, baud: int) -> Optional[PicoUART]: def try_open_uart(port: str, baud: int) -> Optional[PicoUART]:
"""Attempt to open a UART without logging; return None on failure."""
try: try:
return PicoUART(port, baud) return PicoUART(port, baud)
except Exception: except Exception:
@ -388,6 +402,7 @@ def try_open_uart(port: str, baud: int) -> Optional[PicoUART]:
def open_uart_or_warn(port: str, baud: int, console: Console) -> Optional[PicoUART]: def open_uart_or_warn(port: str, baud: int, console: Console) -> Optional[PicoUART]:
"""Open a UART and warn on failure."""
try: try:
return PicoUART(port, baud) return PicoUART(port, baud)
except Exception as exc: except Exception as exc:
@ -396,6 +411,7 @@ def open_uart_or_warn(port: str, baud: int, console: Console) -> Optional[PicoUA
def build_arg_parser() -> argparse.ArgumentParser: def build_arg_parser() -> argparse.ArgumentParser:
"""Construct the CLI argument parser for the bridge."""
parser = argparse.ArgumentParser(description="Bridge SDL2 controllers to switch-pico UART (with rumble)") parser = argparse.ArgumentParser(description="Bridge SDL2 controllers to switch-pico UART (with rumble)")
parser.add_argument( parser.add_argument(
"--map", "--map",
@ -470,6 +486,7 @@ def build_arg_parser() -> argparse.ArgumentParser:
def poll_controller_buttons(ctx: ControllerContext, button_map: Dict[int, int]) -> None: def poll_controller_buttons(ctx: ControllerContext, button_map: Dict[int, int]) -> None:
"""Update button/hat state based on current SDL controller readings."""
changed = False changed = False
for sdl_button, switch_bit in button_map.items(): for sdl_button, switch_bit in button_map.items():
pressed = bool(sdl2.SDL_GameControllerGetButton(ctx.controller, sdl_button)) pressed = bool(sdl2.SDL_GameControllerGetButton(ctx.controller, sdl_button))
@ -541,6 +558,7 @@ def load_button_maps(console: Console, args: argparse.Namespace) -> Tuple[Dict[i
def build_bridge_config(console: Console, args: argparse.Namespace) -> BridgeConfig: def build_bridge_config(console: Console, args: argparse.Namespace) -> BridgeConfig:
"""Derive bridge runtime configuration from CLI arguments."""
interval = 1.0 / max(args.frequency, 1.0) interval = 1.0 / max(args.frequency, 1.0)
deadzone_raw = int(max(0.0, min(args.deadzone, 1.0)) * 32767) deadzone_raw = int(max(0.0, min(args.deadzone, 1.0)) * 32767)
trigger_threshold = int(max(0.0, min(args.trigger_threshold, 1.0)) * 32767) trigger_threshold = int(max(0.0, min(args.trigger_threshold, 1.0)) * 32767)
@ -556,6 +574,7 @@ def build_bridge_config(console: Console, args: argparse.Namespace) -> BridgeCon
def initialize_sdl(parser: argparse.ArgumentParser) -> None: def initialize_sdl(parser: argparse.ArgumentParser) -> None:
"""Set SDL hints and initialize subsystems needed for controllers."""
sdl2.SDL_SetHint(sdl2.SDL_HINT_JOYSTICK_ALLOW_BACKGROUND_EVENTS, b"1") sdl2.SDL_SetHint(sdl2.SDL_HINT_JOYSTICK_ALLOW_BACKGROUND_EVENTS, b"1")
set_hint("SDL_JOYSTICK_HIDAPI", "1") set_hint("SDL_JOYSTICK_HIDAPI", "1")
set_hint("SDL_JOYSTICK_HIDAPI_SWITCH", "1") set_hint("SDL_JOYSTICK_HIDAPI_SWITCH", "1")
@ -568,6 +587,7 @@ def initialize_sdl(parser: argparse.ArgumentParser) -> None:
def detect_controllers( def detect_controllers(
console: Console, args: argparse.Namespace, parser: argparse.ArgumentParser console: Console, args: argparse.Namespace, parser: argparse.ArgumentParser
) -> Tuple[List[int], Dict[int, str]]: ) -> Tuple[List[int], Dict[int, str]]:
"""Detect available controllers and return usable indices and names."""
controller_indices: List[int] = [] controller_indices: List[int] = []
controller_names: Dict[int, str] = {} controller_names: Dict[int, str] = {}
controller_count = sdl2.SDL_NumJoysticks() controller_count = sdl2.SDL_NumJoysticks()
@ -601,6 +621,7 @@ def prepare_pairing_state(
controller_indices: List[int], controller_indices: List[int],
controller_names: Dict[int, str], controller_names: Dict[int, str],
) -> PairingState: ) -> PairingState:
"""Prepare pairing preferences and pre-seeded mappings from CLI options."""
auto_pairing_enabled = not args.map and not args.interactive auto_pairing_enabled = not args.map and not args.interactive
auto_discover_ports = auto_pairing_enabled and not args.ports auto_discover_ports = auto_pairing_enabled and not args.ports
include_non_usb = args.all_ports or False include_non_usb = args.all_ports or False
@ -612,6 +633,7 @@ def prepare_pairing_state(
if args.interactive: if args.interactive:
if not controller_indices: if not controller_indices:
parser.error("No controllers detected for interactive pairing.") parser.error("No controllers detected for interactive pairing.")
# Interactive pairing shows the discovered ports and lets the user bind explicitly.
discovered = discover_ports( discovered = discover_ports(
include_non_usb=include_non_usb, include_non_usb=include_non_usb,
ignore_descriptions=ignore_port_desc, ignore_descriptions=ignore_port_desc,
@ -627,6 +649,7 @@ def prepare_pairing_state(
available_ports.extend(list(args.ports)) available_ports.extend(list(args.ports))
console.print(f"[green]Prepared {len(available_ports)} specified UART port(s) for auto-pairing.[/green]") console.print(f"[green]Prepared {len(available_ports)} specified UART port(s) for auto-pairing.[/green]")
else: else:
# Passive mode: grab whatever UARTs exist now, and keep looking later.
discovered = discover_ports( discovered = discover_ports(
include_non_usb=include_non_usb, include_non_usb=include_non_usb,
ignore_descriptions=ignore_port_desc, ignore_descriptions=ignore_port_desc,
@ -653,6 +676,7 @@ def prepare_pairing_state(
def assign_port_for_index(pairing: PairingState, idx: int, console: Console) -> Optional[str]: def assign_port_for_index(pairing: PairingState, idx: int, console: Console) -> Optional[str]:
"""Return the UART assigned to a controller index, auto-pairing if allowed."""
if idx in pairing.mapping_by_index: if idx in pairing.mapping_by_index:
return pairing.mapping_by_index[idx] return pairing.mapping_by_index[idx]
if not pairing.auto_pairing_enabled: if not pairing.auto_pairing_enabled:
@ -667,12 +691,14 @@ def assign_port_for_index(pairing: PairingState, idx: int, console: Console) ->
def ports_in_use(pairing: PairingState, contexts: Dict[int, ControllerContext]) -> set: def ports_in_use(pairing: PairingState, contexts: Dict[int, ControllerContext]) -> set:
"""Return a set of UART paths currently reserved or mapped."""
used = set(pairing.mapping_by_index.values()) used = set(pairing.mapping_by_index.values())
used.update(ctx.port for ctx in contexts.values() if ctx.port) used.update(ctx.port for ctx in contexts.values() if ctx.port)
return used return used
def discover_new_ports(pairing: PairingState, contexts: Dict[int, ControllerContext], console: Console) -> None: def discover_new_ports(pairing: PairingState, contexts: Dict[int, ControllerContext], console: Console) -> None:
"""Scan for new serial ports and add unused ones to the available pool."""
if not pairing.auto_discover_ports: if not pairing.auto_discover_ports:
return return
discovered = discover_ports( discovered = discover_ports(
@ -696,9 +722,11 @@ def pair_waiting_contexts(
uarts: List[PicoUART], uarts: List[PicoUART],
console: Console, console: Console,
) -> None: ) -> None:
"""Attach UARTs to contexts that are waiting for a port assignment/open."""
for ctx in list(contexts.values()): for ctx in list(contexts.values()):
if ctx.port is not None: if ctx.port is not None:
continue continue
# Try to grab a port for this controller; if none are available, leave it waiting.
port_choice = assign_port_for_index(pairing, ctx.controller_index, console) port_choice = assign_port_for_index(pairing, ctx.controller_index, console)
if port_choice is None: if port_choice is None:
continue continue
@ -717,6 +745,7 @@ def pair_waiting_contexts(
def open_initial_contexts( def open_initial_contexts(
args: argparse.Namespace, pairing: PairingState, controller_indices: List[int], console: Console args: argparse.Namespace, pairing: PairingState, controller_indices: List[int], console: Console
) -> Tuple[Dict[int, ControllerContext], List[PicoUART]]: ) -> Tuple[Dict[int, ControllerContext], List[PicoUART]]:
"""Open initial controllers and UARTs for detected indices."""
contexts: Dict[int, ControllerContext] = {} contexts: Dict[int, ControllerContext] = {}
uarts: List[PicoUART] = [] uarts: List[PicoUART] = []
for index in controller_indices: for index in controller_indices:
@ -753,6 +782,7 @@ def open_initial_contexts(
def handle_axis_motion(event: sdl2.SDL_Event, contexts: Dict[int, ControllerContext], config: BridgeConfig) -> None: def handle_axis_motion(event: sdl2.SDL_Event, contexts: Dict[int, ControllerContext], config: BridgeConfig) -> None:
"""Process axis motion event into stick/trigger state."""
ctx = contexts.get(event.caxis.which) ctx = contexts.get(event.caxis.which)
if not ctx: if not ctx:
return return
@ -790,6 +820,7 @@ def handle_button_event(
config: BridgeConfig, config: BridgeConfig,
contexts: Dict[int, ControllerContext], contexts: Dict[int, ControllerContext],
) -> None: ) -> None:
"""Process button events into report/dpad state."""
ctx = contexts.get(event.cbutton.which) ctx = contexts.get(event.cbutton.which)
if not ctx: if not ctx:
return return
@ -820,7 +851,9 @@ def handle_device_added(
uarts: List[PicoUART], uarts: List[PicoUART],
console: Console, console: Console,
) -> None: ) -> None:
"""Handle controller hotplug by opening and pairing UART if possible."""
idx = event.cdevice.which idx = event.cdevice.which
# If we already have a context for this logical index, ignore the duplicate event.
if any(c.controller_index == idx for c in contexts.values()): if any(c.controller_index == idx for c in contexts.values()):
return return
port = assign_port_for_index(pairing, idx, console) port = assign_port_for_index(pairing, idx, console)
@ -860,12 +893,14 @@ def handle_device_removed(
contexts: Dict[int, ControllerContext], contexts: Dict[int, ControllerContext],
console: Console, console: Console,
) -> None: ) -> None:
"""Handle controller removal and release any auto-assigned UART."""
instance_id = event.cdevice.which instance_id = event.cdevice.which
ctx = contexts.pop(instance_id, None) ctx = contexts.pop(instance_id, None)
if not ctx: if not ctx:
return return
console.print(f"[yellow]Controller {instance_id} removed[/yellow]") console.print(f"[yellow]Controller {instance_id} removed[/yellow]")
if ctx.controller_index in pairing.auto_assigned_indices: if ctx.controller_index in pairing.auto_assigned_indices:
# Return auto-paired UART back to the pool so a future device can use it.
freed = pairing.mapping_by_index.pop(ctx.controller_index, None) freed = pairing.mapping_by_index.pop(ctx.controller_index, None)
pairing.auto_assigned_indices.discard(ctx.controller_index) pairing.auto_assigned_indices.discard(ctx.controller_index)
if freed and freed not in pairing.available_ports: if freed and freed not in pairing.available_ports:
@ -882,6 +917,7 @@ def service_contexts(
uarts: List[PicoUART], uarts: List[PicoUART],
console: Console, console: Console,
) -> None: ) -> None:
"""Poll controllers, reconnect UARTs, send reports, and apply rumble."""
for ctx in list(contexts.values()): for ctx in list(contexts.values()):
current_button_map = ( current_button_map = (
config.button_map_swapped config.button_map_swapped
@ -912,6 +948,7 @@ def service_contexts(
last_payload = p last_payload = p
if last_payload is not None: if last_payload is not None:
# Apply only the freshest rumble payload seen during this tick.
energy = apply_rumble(ctx.controller, last_payload) energy = apply_rumble(ctx.controller, last_payload)
ctx.rumble_active = energy >= RUMBLE_MIN_ACTIVE ctx.rumble_active = energy >= RUMBLE_MIN_ACTIVE
if ctx.rumble_active and energy != ctx.last_rumble_energy: if ctx.rumble_active and energy != ctx.last_rumble_energy:
@ -949,6 +986,7 @@ def run_bridge_loop(
contexts: Dict[int, ControllerContext], contexts: Dict[int, ControllerContext],
uarts: List[PicoUART], uarts: List[PicoUART],
) -> None: ) -> None:
"""Main event loop for bridging controllers to UART and handling rumble."""
event = sdl2.SDL_Event() event = sdl2.SDL_Event()
port_scan_interval = 2.0 port_scan_interval = 2.0
last_port_scan = time.monotonic() last_port_scan = time.monotonic()
@ -970,6 +1008,7 @@ def run_bridge_loop(
now = time.monotonic() now = time.monotonic()
if now - last_port_scan > port_scan_interval: if now - last_port_scan > port_scan_interval:
# Periodically rescan for new UARTs to auto-pair hotplugged devices.
discover_new_ports(pairing, contexts, console) discover_new_ports(pairing, contexts, console)
last_port_scan = now last_port_scan = now
pair_waiting_contexts(args, pairing, contexts, uarts, console) pair_waiting_contexts(args, pairing, contexts, uarts, console)
@ -980,6 +1019,7 @@ def run_bridge_loop(
def cleanup(contexts: Dict[int, ControllerContext], uarts: List[PicoUART]) -> None: def cleanup(contexts: Dict[int, ControllerContext], uarts: List[PicoUART]) -> None:
"""Gracefully close controllers, UARTs, and SDL subsystems."""
for ctx in contexts.values(): for ctx in contexts.values():
sdl2.SDL_GameControllerClose(ctx.controller) sdl2.SDL_GameControllerClose(ctx.controller)
for uart in uarts: for uart in uarts:
@ -988,6 +1028,7 @@ def cleanup(contexts: Dict[int, ControllerContext], uarts: List[PicoUART]) -> No
def main() -> None: def main() -> None:
"""Entry point: parse args, set up SDL, and run the bridge loop."""
parser = build_arg_parser() parser = build_arg_parser()
args = parser.parse_args() args = parser.parse_args()
console = Console() console = Console()