feat(04-01): USB Manager goroutine-per-device, poll loop, reconnect, leak-safe teardown

- Manager with injectable enumerateFunc and serialOpener for test isolation
- goroutine-per-device model: deviceLoop owns one serial port per device
- context+done-channel teardown: inner read goroutine exits on ctx.Done()
- Read buffer capped at 4096 bytes (T-04-01 threat mitigation)
- Poll loop reconciles prev/current snapshots for connect/disconnect events
- ErrDeviceNotConnected returned by Send() when device absent
- All 5 manager tests pass with -race flag; goroutine count stable across 5 replug cycles
- mockPort test helper blocks Read until Close() unblocks it (realistic behavior)
This commit is contained in:
Mikkel Georgsen 2026-04-10 06:45:26 +00:00
parent f5b1d3156c
commit 82eaf6bed7
3 changed files with 592 additions and 0 deletions

327
internal/usb/manager.go Normal file
View file

@ -0,0 +1,327 @@
package usb
import (
"context"
"errors"
"log"
"sync"
"time"
"go.bug.st/serial"
)
// ErrDeviceNotConnected is returned by Send when the target device is not currently connected.
var ErrDeviceNotConnected = errors.New("usb: device not connected")
// serialPort abstracts go.bug.st/serial.Port for testing without hardware.
type serialPort interface {
Read(p []byte) (n int, err error)
Write(p []byte) (n int, err error)
Close() error
}
// serialOpener is the function type used to open a serial port.
// Injectable for tests.
type serialOpener func(path string, baud int) (serialPort, error)
// defaultSerialOpener opens a real serial port using go.bug.st/serial.
func defaultSerialOpener(path string, baud int) (serialPort, error) {
if baud == 0 {
baud = 9600
}
mode := &serial.Mode{BaudRate: baud}
return serial.Open(path, mode)
}
// deviceHandle tracks a running device goroutine.
type deviceHandle struct {
cmdChan chan Command
done chan struct{} // closed when deviceLoop exits
cancel context.CancelFunc
}
// Manager owns all USB device goroutines and emits DeviceEvents to subscribers.
type Manager struct {
pollInterval time.Duration
events chan DeviceEvent
devices map[string]*deviceHandle // keyed by VID:PID
mu sync.Mutex
cancel context.CancelFunc
wg sync.WaitGroup
enumerateFunc func() (map[string]string, error)
openSerialPort serialOpener
}
// NewManager creates a new Manager with production defaults.
func NewManager(pollInterval time.Duration) *Manager {
return &Manager{
pollInterval: pollInterval,
events: make(chan DeviceEvent, 32),
devices: make(map[string]*deviceHandle),
enumerateFunc: enumerateConnected,
openSerialPort: defaultSerialOpener,
}
}
// newManagerForTest creates a Manager with injected enumerator and serial opener.
// Used only in tests.
func newManagerForTest(enumFn func() (map[string]string, error), opener serialOpener, pollInterval time.Duration) *Manager {
return &Manager{
pollInterval: pollInterval,
events: make(chan DeviceEvent, 64),
devices: make(map[string]*deviceHandle),
enumerateFunc: enumFn,
openSerialPort: opener,
}
}
// Start begins the poll loop in a background goroutine.
// Call Stop to shut down cleanly.
func (m *Manager) Start(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
m.cancel = cancel
m.wg.Add(1)
go m.pollLoop(ctx)
}
// Stop cancels all device goroutines and waits for them to exit.
func (m *Manager) Stop() {
if m.cancel != nil {
m.cancel()
}
m.wg.Wait()
close(m.events)
}
// Events returns the read-only channel of DeviceEvent notifications.
func (m *Manager) Events() <-chan DeviceEvent {
return m.events
}
// Send delivers a Command to the named device. Returns ErrDeviceNotConnected if absent.
func (m *Manager) Send(vidpid string, cmd Command) error {
m.mu.Lock()
h, ok := m.devices[vidpid]
m.mu.Unlock()
if !ok {
return ErrDeviceNotConnected
}
select {
case h.cmdChan <- cmd:
return nil
default:
return errors.New("usb: command channel full for " + vidpid)
}
}
// pollLoop runs on a background goroutine, periodically enumerating USB devices.
func (m *Manager) pollLoop(ctx context.Context) {
defer m.wg.Done()
ticker := time.NewTicker(m.pollInterval)
defer ticker.Stop()
var prev map[string]string // previous snapshot
for {
select {
case <-ctx.Done():
// Shut down all running device goroutines.
m.mu.Lock()
handles := make(map[string]*deviceHandle, len(m.devices))
for k, v := range m.devices {
handles[k] = v
}
m.mu.Unlock()
for vidpid, h := range handles {
m.disconnectDevice(ctx, vidpid, h)
}
return
case <-ticker.C:
current, err := m.enumerateFunc()
if err != nil {
log.Printf("usb: enumeration error: %v", err)
continue
}
m.reconcile(ctx, prev, current)
prev = current
}
}
}
// reconcile compares previous and current device snapshots, spawning or
// stopping device goroutines as needed.
func (m *Manager) reconcile(ctx context.Context, prev, current map[string]string) {
// Detect newly connected devices (in current but not prev).
for vidpid, portPath := range current {
if _, wasThere := prev[vidpid]; !wasThere {
spec, known := KnownDevices[vidpid]
if !known {
continue
}
m.connectDevice(ctx, vidpid, portPath, spec)
}
}
// Detect newly disconnected devices (in prev but not current).
for vidpid := range prev {
if _, stillThere := current[vidpid]; !stillThere {
m.mu.Lock()
h, ok := m.devices[vidpid]
m.mu.Unlock()
if ok {
m.disconnectDevice(ctx, vidpid, h)
}
}
}
}
// connectDevice spawns a deviceLoop goroutine for the given device.
func (m *Manager) connectDevice(ctx context.Context, vidpid, portPath string, spec DeviceSpec) {
devCtx, devCancel := context.WithCancel(ctx)
cmdChan := make(chan Command, 16)
done := make(chan struct{})
h := &deviceHandle{
cmdChan: cmdChan,
done: done,
cancel: devCancel,
}
m.mu.Lock()
m.devices[vidpid] = h
m.mu.Unlock()
m.wg.Add(1)
go func() {
defer m.wg.Done()
m.deviceLoop(devCtx, spec, portPath, cmdChan, done)
devCancel() // ensure context is cancelled on exit
m.mu.Lock()
delete(m.devices, vidpid)
m.mu.Unlock()
}()
m.emit(DeviceEvent{VIDPID: vidpid, Spec: spec, State: StateConnected})
}
// disconnectDevice sends CmdClose to the device goroutine and waits for it to exit.
func (m *Manager) disconnectDevice(ctx context.Context, vidpid string, h *deviceHandle) {
// Signal the device loop to close.
h.cancel()
// Also send an explicit CmdClose in case the loop is blocked on cmdChan.
select {
case h.cmdChan <- Command{Type: CmdClose}:
default:
}
// Wait for device goroutine to exit with a timeout.
spec := KnownDevices[vidpid]
select {
case <-h.done:
case <-time.After(2 * time.Second):
log.Printf("usb: timeout waiting for device %s goroutine to exit", vidpid)
}
m.emit(DeviceEvent{VIDPID: vidpid, Spec: spec, State: StateDisconnected})
}
// deviceLoop manages a single USB device connection.
// It opens the serial port, handles incoming commands, and exits cleanly on
// context cancellation or CmdClose.
//
// Goroutine leak prevention (per PITFALLS.md Pitfall 2):
// - Child context is cancelled before port.Close()
// - Inner read goroutine selects on ctx.Done() to exit if context is done
// before the blocking Read returns
// - done channel is closed on return so the poll loop can wait with timeout
func (m *Manager) deviceLoop(ctx context.Context, spec DeviceSpec, portPath string, cmdChan <-chan Command, done chan<- struct{}) {
defer close(done)
baud := spec.BaudRate
if baud == 0 {
baud = 9600
}
port, err := m.openSerialPort(portPath, baud)
if err != nil {
log.Printf("usb: failed to open %s (%s): %v", portPath, spec.String(), err)
return
}
// Inner read goroutine — reads from the port until context is cancelled or port closes.
readDone := make(chan struct{})
go func() {
defer close(readDone)
// T-04-01: Cap read buffer at 4096 bytes to prevent large-payload panics.
buf := make([]byte, 4096)
for {
select {
case <-ctx.Done():
return
default:
}
n, err := port.Read(buf)
if n > 0 {
// Forward read bytes as a CmdWrite reply event in future phases.
// For now, log for observability.
log.Printf("usb: [%s] read %d bytes", spec.String(), n)
}
if err != nil {
// Port was closed or device disconnected — exit read goroutine.
select {
case <-ctx.Done():
default:
log.Printf("usb: [%s] read error (device likely disconnected): %v", spec.String(), err)
}
return
}
}
}()
// Main command loop.
for {
select {
case <-ctx.Done():
// Context cancelled — close port to unblock the read goroutine.
port.Close()
<-readDone
return
case cmd, ok := <-cmdChan:
if !ok {
port.Close()
<-readDone
return
}
switch cmd.Type {
case CmdWrite:
_, err := port.Write(cmd.Payload)
if cmd.Reply != nil {
cmd.Reply <- err
}
case CmdClose:
port.Close()
<-readDone
return
}
}
}
}
// emit sends a DeviceEvent to the events channel (non-blocking with logging on full).
func (m *Manager) emit(event DeviceEvent) {
select {
case m.events <- event:
default:
log.Printf("usb: events channel full, dropping event for %s", event.VIDPID)
}
}

View file

@ -0,0 +1,215 @@
package usb
import (
"context"
"errors"
"runtime"
"sync/atomic"
"testing"
"time"
)
// mockEnumerator returns predefined snapshots per call number.
// Each call to next() advances the snapshot index.
type mockEnumerator struct {
snapshots []map[string]string
idx atomic.Int32
}
func newMockEnumerator(snapshots ...map[string]string) *mockEnumerator {
return &mockEnumerator{snapshots: snapshots}
}
func (m *mockEnumerator) next() (map[string]string, error) {
i := int(m.idx.Add(1)) - 1
if i >= len(m.snapshots) {
// Return last snapshot forever after exhaustion.
i = len(m.snapshots) - 1
}
// Deep copy to prevent mutation.
out := make(map[string]string, len(m.snapshots[i]))
for k, v := range m.snapshots[i] {
out[k] = v
}
return out, nil
}
// noopSerialOpener is a mock serial opener that does nothing (no hardware needed).
// It returns a mockPort that satisfies the serialPort interface.
func noopSerialOpener(path string, baud int) (serialPort, error) {
return newMockPort(), nil
}
// --- Test 1: Start with one device → exactly 1 device goroutine spawned ---
func TestManagerStartSpawnsOneGoroutine(t *testing.T) {
enum := newMockEnumerator(
map[string]string{"0525:a4a7": "/dev/cu.mock0"}, // device present
)
m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond)
baseline := runtime.NumGoroutine()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
m.Start(ctx)
time.Sleep(60 * time.Millisecond) // allow poll loop and device goroutine to start
after := runtime.NumGoroutine()
// We expect at least 1 new goroutine (the device loop) + poll loop goroutine.
// Baseline can fluctuate ±2 so check at least 1 extra goroutine appeared.
if after <= baseline {
t.Errorf("expected goroutine count to increase after Start(); baseline=%d after=%d", baseline, after)
}
m.Stop()
}
// --- Test 2: Device disconnect then reconnect emits Disconnected then Connected events ---
func TestManagerDisconnectReconnectEvents(t *testing.T) {
enum := newMockEnumerator(
map[string]string{"0525:a4a7": "/dev/cu.mock0"}, // call 1: present
map[string]string{}, // call 2: absent
map[string]string{"0525:a4a7": "/dev/cu.mock0"}, // call 3: present again
)
m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
m.Start(ctx)
defer m.Stop()
var events []DeviceEvent
done := make(chan struct{})
go func() {
defer close(done)
for e := range m.Events() {
events = append(events, e)
if len(events) >= 3 { // Connected, Disconnected, Connected
return
}
}
}()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatalf("timeout waiting for 3 events; got %d: %v", len(events), events)
}
if len(events) < 3 {
t.Fatalf("expected at least 3 events, got %d: %v", len(events), events)
}
if events[0].State != StateConnected {
t.Errorf("event[0]: expected Connected, got %v", events[0].State)
}
if events[1].State != StateDisconnected {
t.Errorf("event[1]: expected Disconnected, got %v", events[1].State)
}
if events[2].State != StateConnected {
t.Errorf("event[2]: expected Connected, got %v", events[2].State)
}
}
// --- Test 3: Stop causes all goroutines to exit within 500ms ---
func TestManagerStopGoroutineLeak(t *testing.T) {
enum := newMockEnumerator(
map[string]string{"0525:a4a7": "/dev/cu.mock0"},
)
m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
baseline := runtime.NumGoroutine()
m.Start(ctx)
time.Sleep(60 * time.Millisecond) // let goroutines spin up
m.Stop()
deadline := time.Now().Add(500 * time.Millisecond)
for time.Now().Before(deadline) {
current := runtime.NumGoroutine()
// Allow ±2 goroutines for runtime background goroutines.
if current <= baseline+2 {
return // passed
}
time.Sleep(10 * time.Millisecond)
}
t.Errorf("goroutine leak: baseline=%d, after Stop()=%d (expected within ±2)",
baseline, runtime.NumGoroutine())
}
// --- Test 4: Send to absent device returns ErrDeviceNotConnected ---
func TestManagerSendToAbsentDevice(t *testing.T) {
// Empty snapshot — no devices connected.
enum := newMockEnumerator(map[string]string{})
m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
m.Start(ctx)
defer m.Stop()
time.Sleep(40 * time.Millisecond) // let poll loop run
replyCh := make(chan error, 1)
err := m.Send("0525:a4a7", Command{
Type: CmdWrite,
Reply: replyCh,
})
if !errors.Is(err, ErrDeviceNotConnected) {
t.Errorf("expected ErrDeviceNotConnected, got %v", err)
}
}
// --- Test 5: Goroutine count stable across 5 simulated unplug/replug cycles ---
func TestGoroutineStability(t *testing.T) {
// Build snapshots: 5 cycles of present/absent/present
snapshots := []map[string]string{}
snapshots = append(snapshots, map[string]string{}) // initial empty
for i := 0; i < 5; i++ {
snapshots = append(snapshots, map[string]string{"0525:a4a7": "/dev/cu.mock0"}) // plug in
snapshots = append(snapshots, map[string]string{}) // unplug
}
// Final state: device connected
snapshots = append(snapshots, map[string]string{"0525:a4a7": "/dev/cu.mock0"})
enum := newMockEnumerator(snapshots...)
m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
baseline := runtime.NumGoroutine()
m.Start(ctx)
// Wait long enough for all snapshots to be polled through.
// 12 snapshots * 20ms poll interval = ~240ms minimum; give 2x buffer.
time.Sleep(500 * time.Millisecond)
m.Stop()
// Give goroutines up to 500ms to settle.
deadline := time.Now().Add(500 * time.Millisecond)
var finalCount int
for time.Now().Before(deadline) {
finalCount = runtime.NumGoroutine()
if finalCount <= baseline+2 {
return // stable
}
time.Sleep(10 * time.Millisecond)
}
t.Errorf("goroutine count unstable after 5 replug cycles: baseline=%d, final=%d (max allowed %d)",
baseline, finalCount, baseline+2)
}

View file

@ -0,0 +1,50 @@
package usb
import (
"io"
"sync"
)
// mockPort implements serialPort for testing without hardware.
// Reads block until Close() is called (simulating a hardware device that
// stays connected until explicitly disconnected).
type mockPort struct {
mu sync.Mutex
closed bool
closeCh chan struct{}
}
func newMockPort() *mockPort {
return &mockPort{
closeCh: make(chan struct{}),
}
}
// Read blocks until the port is closed, then returns io.EOF.
// This mirrors the behaviour of a real serial port: Read blocks until data
// arrives or the port is closed.
func (p *mockPort) Read(buf []byte) (int, error) {
<-p.closeCh
return 0, io.EOF
}
// Write is a no-op for the mock (no hardware to write to).
func (p *mockPort) Write(data []byte) (int, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return 0, io.ErrClosedPipe
}
return len(data), nil
}
// Close signals the mock port as closed, unblocking any pending Read.
func (p *mockPort) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
if !p.closed {
p.closed = true
close(p.closeCh)
}
return nil
}