Add hot plugging

This commit is contained in:
Joey Yakimowich-Payne 2025-11-20 22:50:14 -07:00
commit f73b9f8604
No known key found for this signature in database
GPG key ID: 6BFE655FA5ABD1E1

View file

@ -23,6 +23,7 @@ from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import serial import serial
from serial import SerialException
from serial.tools import list_ports from serial.tools import list_ports
import sdl2 import sdl2
@ -287,7 +288,7 @@ def apply_rumble(controller: sdl2.SDL_GameController, payload: bytes) -> None:
scale = 0.40 scale = 0.40
left = int(min(1.0, left_norm * scale) * 0xFFFF) left = int(min(1.0, left_norm * scale) * 0xFFFF)
right = int(min(1.0, right_norm * scale) * 0xFFFF) right = int(min(1.0, right_norm * scale) * 0xFFFF)
duration = 25 duration = 10
sdl2.SDL_GameControllerRumble(controller, left, right, duration) sdl2.SDL_GameControllerRumble(controller, left, right, duration)
@ -295,13 +296,14 @@ def apply_rumble(controller: sdl2.SDL_GameController, payload: bytes) -> None:
class ControllerContext: class ControllerContext:
controller: sdl2.SDL_GameController controller: sdl2.SDL_GameController
instance_id: int instance_id: int
uart: PicoUART controller_index: int
port: str
uart: Optional[PicoUART]
report: SwitchReport = field(default_factory=SwitchReport) report: SwitchReport = field(default_factory=SwitchReport)
dpad: Dict[str, bool] = field(default_factory=lambda: {"up": False, "down": False, "left": False, "right": False}) dpad: Dict[str, bool] = field(default_factory=lambda: {"up": False, "down": False, "left": False, "right": False})
last_trigger_state: Dict[str, bool] = field(default_factory=lambda: {"left": False, "right": False}) last_trigger_state: Dict[str, bool] = field(default_factory=lambda: {"left": False, "right": False})
last_send: float = 0.0 last_send: float = 0.0
stop_event: threading.Event = field(default_factory=threading.Event) last_reopen_attempt: float = 0.0
rumble_thread: Optional[threading.Thread] = None
def open_controller(index: int) -> Tuple[sdl2.SDL_GameController, int]: def open_controller(index: int) -> Tuple[sdl2.SDL_GameController, int]:
@ -313,18 +315,16 @@ def open_controller(index: int) -> Tuple[sdl2.SDL_GameController, int]:
return controller, instance_id return controller, instance_id
def start_rumble_listener(ctx: ControllerContext) -> threading.Thread: def try_open_uart(port: str, baud: int) -> Optional[PicoUART]:
def _worker() -> None: try:
while not ctx.stop_event.is_set(): return PicoUART(port, baud)
payload = ctx.uart.read_rumble_payload() except Exception:
if not payload: return None
time.sleep(0.0005) # small yield to avoid busy-spin
continue
apply_rumble(ctx.controller, payload)
thread = threading.Thread(target=_worker, name=f"rumble-{ctx.instance_id}", daemon=True)
thread.start() def start_rumble_listener(ctx: ControllerContext) -> threading.Thread:
return thread # No-op placeholder (rumble is polled in the main loop for hotplug safety).
return None
def build_arg_parser() -> argparse.ArgumentParser: def build_arg_parser() -> argparse.ArgumentParser:
@ -356,6 +356,12 @@ def build_arg_parser() -> argparse.ArgumentParser:
default=0.35, default=0.35,
help="Trigger threshold treated as a digital press (0.0-1.0, default 0.35)", help="Trigger threshold treated as a digital press (0.0-1.0, default 0.35)",
) )
parser.add_argument(
"--baud",
type=int,
default=UART_BAUD,
help=f"UART baud rate (default {UART_BAUD}; must match switch-pico firmware)",
)
return parser return parser
@ -369,6 +375,7 @@ def main() -> None:
sdl2.SDL_Init(sdl2.SDL_INIT_GAMECONTROLLER) sdl2.SDL_Init(sdl2.SDL_INIT_GAMECONTROLLER)
contexts: Dict[int, ControllerContext] = {} contexts: Dict[int, ControllerContext] = {}
uarts: List[PicoUART] = [] uarts: List[PicoUART] = []
mapping_by_index: Dict[int, str] = {}
console = Console() console = Console()
try: try:
controller_indices: List[int] = [] controller_indices: List[int] = []
@ -430,15 +437,33 @@ def main() -> None:
console.print(f" Controller {idx} -> {port}") console.print(f" Controller {idx} -> {port}")
for index, port in mappings: for index, port in mappings:
controller, instance_id = open_controller(index) mapping_by_index[index] = port
uart = PicoUART(port)
uarts.append(uart)
ctx = ControllerContext(controller=controller, instance_id=instance_id, uart=uart)
ctx.rumble_thread = start_rumble_listener(ctx)
contexts[instance_id] = ctx
console.print(f"[green]Controller {index} ({instance_id}) paired to {port}[/green]")
if not contexts: # Open currently connected controllers that match the mapping.
for index, port in mappings:
if index >= sdl2.SDL_NumJoysticks() or not sdl2.SDL_IsGameController(index):
continue
try:
controller, instance_id = open_controller(index)
except Exception as exc:
console.print(f"[red]Failed to open controller {index}: {exc}[/red]")
continue
uart = try_open_uart(port, args.baud)
if uart:
uarts.append(uart)
console.print(f"[green]Controller {index} ({instance_id}) paired to {port}[/green]")
else:
console.print(f"[yellow]Controller {index} ({instance_id}) waiting for UART {port}[/yellow]")
ctx = ControllerContext(
controller=controller,
instance_id=instance_id,
controller_index=index,
port=port,
uart=uart,
)
contexts[instance_id] = ctx
if not contexts and not mapping_by_index:
parser.error("No controllers opened. Check --map/--ports/--interactive values.") parser.error("No controllers opened. Check --map/--ports/--interactive values.")
event = sdl2.SDL_Event() event = sdl2.SDL_Event()
@ -493,19 +518,74 @@ def main() -> None:
elif button in DPAD_BUTTONS: elif button in DPAD_BUTTONS:
ctx.dpad[DPAD_BUTTONS[button]] = pressed ctx.dpad[DPAD_BUTTONS[button]] = pressed
ctx.report.hat = dpad_to_hat(ctx.dpad) ctx.report.hat = dpad_to_hat(ctx.dpad)
elif event.type == sdl2.SDL_CONTROLLERDEVICEADDED:
idx = event.cdevice.which
port = mapping_by_index.get(idx)
if port is None:
continue
# Avoid duplicate opens for already connected instance IDs.
already = any(c.controller_index == idx for c in contexts.values())
if already:
continue
try:
controller, instance_id = open_controller(idx)
except Exception as exc:
console.print(f"[red]Hotplug open failed for controller {idx}: {exc}[/red]")
continue
uart = try_open_uart(port, args.baud)
if uart:
uarts.append(uart)
console.print(f"[green]Controller {idx} ({instance_id}) paired to {port}[/green]")
else:
console.print(f"[yellow]Controller {idx} ({instance_id}) waiting for UART {port}[/yellow]")
ctx = ControllerContext(
controller=controller,
instance_id=instance_id,
controller_index=idx,
port=port,
uart=uart,
)
contexts[instance_id] = ctx
elif event.type == sdl2.SDL_CONTROLLERDEVICEREMOVED:
instance_id = event.cdevice.which
ctx = contexts.pop(instance_id, None)
if ctx:
console.print(f"[yellow]Controller {instance_id} removed[/yellow]")
sdl2.SDL_GameControllerClose(ctx.controller)
now = time.monotonic() now = time.monotonic()
for ctx in contexts.values(): for ctx in list(contexts.values()):
if now - ctx.last_send >= interval: # Reconnect UART if needed.
ctx.uart.send_report(ctx.report) if ctx.uart is None and (now - ctx.last_reopen_attempt) > 1.0:
ctx.last_send = now ctx.last_reopen_attempt = now
uart = try_open_uart(ctx.port, args.baud)
if uart:
uarts.append(uart)
console.print(f"[green]Reconnected UART {ctx.port} for controller {ctx.controller_index}[/green]")
ctx.uart = uart
if ctx.uart is None:
continue
try:
if now - ctx.last_send >= interval:
ctx.uart.send_report(ctx.report)
ctx.last_send = now
# Poll rumble quickly while we have the port.
payload = ctx.uart.read_rumble_payload()
if payload:
apply_rumble(ctx.controller, payload)
except SerialException as exc:
console.print(f"[yellow]UART {ctx.port} disconnected: {exc}[/yellow]")
try:
ctx.uart.close()
except Exception:
pass
ctx.uart = None
ctx.last_reopen_attempt = now
except Exception as exc:
console.print(f"[red]UART error on {ctx.port}: {exc}[/red]")
sdl2.SDL_Delay(1) sdl2.SDL_Delay(1)
finally: finally:
for ctx in contexts.values(): for ctx in contexts.values():
ctx.stop_event.set()
if ctx.rumble_thread:
ctx.rumble_thread.join(timeout=0.2)
sdl2.SDL_GameControllerClose(ctx.controller) sdl2.SDL_GameControllerClose(ctx.controller)
for uart in uarts: for uart in uarts:
uart.close() uart.close()