feat(04-01): USB Manager goroutine-per-device, poll loop, reconnect, leak-safe teardown
- Manager with injectable enumerateFunc and serialOpener for test isolation - goroutine-per-device model: deviceLoop owns one serial port per device - context+done-channel teardown: inner read goroutine exits on ctx.Done() - Read buffer capped at 4096 bytes (T-04-01 threat mitigation) - Poll loop reconciles prev/current snapshots for connect/disconnect events - ErrDeviceNotConnected returned by Send() when device absent - All 5 manager tests pass with -race flag; goroutine count stable across 5 replug cycles - mockPort test helper blocks Read until Close() unblocks it (realistic behavior)
This commit is contained in:
parent
f5b1d3156c
commit
82eaf6bed7
3 changed files with 592 additions and 0 deletions
327
internal/usb/manager.go
Normal file
327
internal/usb/manager.go
Normal file
|
|
@ -0,0 +1,327 @@
|
|||
package usb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.bug.st/serial"
|
||||
)
|
||||
|
||||
// ErrDeviceNotConnected is returned by Send when the target device is not currently connected.
|
||||
var ErrDeviceNotConnected = errors.New("usb: device not connected")
|
||||
|
||||
// serialPort abstracts go.bug.st/serial.Port for testing without hardware.
|
||||
type serialPort interface {
|
||||
Read(p []byte) (n int, err error)
|
||||
Write(p []byte) (n int, err error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// serialOpener is the function type used to open a serial port.
|
||||
// Injectable for tests.
|
||||
type serialOpener func(path string, baud int) (serialPort, error)
|
||||
|
||||
// defaultSerialOpener opens a real serial port using go.bug.st/serial.
|
||||
func defaultSerialOpener(path string, baud int) (serialPort, error) {
|
||||
if baud == 0 {
|
||||
baud = 9600
|
||||
}
|
||||
mode := &serial.Mode{BaudRate: baud}
|
||||
return serial.Open(path, mode)
|
||||
}
|
||||
|
||||
// deviceHandle tracks a running device goroutine.
|
||||
type deviceHandle struct {
|
||||
cmdChan chan Command
|
||||
done chan struct{} // closed when deviceLoop exits
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Manager owns all USB device goroutines and emits DeviceEvents to subscribers.
|
||||
type Manager struct {
|
||||
pollInterval time.Duration
|
||||
events chan DeviceEvent
|
||||
devices map[string]*deviceHandle // keyed by VID:PID
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
enumerateFunc func() (map[string]string, error)
|
||||
openSerialPort serialOpener
|
||||
}
|
||||
|
||||
// NewManager creates a new Manager with production defaults.
|
||||
func NewManager(pollInterval time.Duration) *Manager {
|
||||
return &Manager{
|
||||
pollInterval: pollInterval,
|
||||
events: make(chan DeviceEvent, 32),
|
||||
devices: make(map[string]*deviceHandle),
|
||||
enumerateFunc: enumerateConnected,
|
||||
openSerialPort: defaultSerialOpener,
|
||||
}
|
||||
}
|
||||
|
||||
// newManagerForTest creates a Manager with injected enumerator and serial opener.
|
||||
// Used only in tests.
|
||||
func newManagerForTest(enumFn func() (map[string]string, error), opener serialOpener, pollInterval time.Duration) *Manager {
|
||||
return &Manager{
|
||||
pollInterval: pollInterval,
|
||||
events: make(chan DeviceEvent, 64),
|
||||
devices: make(map[string]*deviceHandle),
|
||||
enumerateFunc: enumFn,
|
||||
openSerialPort: opener,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the poll loop in a background goroutine.
|
||||
// Call Stop to shut down cleanly.
|
||||
func (m *Manager) Start(ctx context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
m.cancel = cancel
|
||||
|
||||
m.wg.Add(1)
|
||||
go m.pollLoop(ctx)
|
||||
}
|
||||
|
||||
// Stop cancels all device goroutines and waits for them to exit.
|
||||
func (m *Manager) Stop() {
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
m.wg.Wait()
|
||||
close(m.events)
|
||||
}
|
||||
|
||||
// Events returns the read-only channel of DeviceEvent notifications.
|
||||
func (m *Manager) Events() <-chan DeviceEvent {
|
||||
return m.events
|
||||
}
|
||||
|
||||
// Send delivers a Command to the named device. Returns ErrDeviceNotConnected if absent.
|
||||
func (m *Manager) Send(vidpid string, cmd Command) error {
|
||||
m.mu.Lock()
|
||||
h, ok := m.devices[vidpid]
|
||||
m.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return ErrDeviceNotConnected
|
||||
}
|
||||
|
||||
select {
|
||||
case h.cmdChan <- cmd:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("usb: command channel full for " + vidpid)
|
||||
}
|
||||
}
|
||||
|
||||
// pollLoop runs on a background goroutine, periodically enumerating USB devices.
|
||||
func (m *Manager) pollLoop(ctx context.Context) {
|
||||
defer m.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(m.pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
var prev map[string]string // previous snapshot
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Shut down all running device goroutines.
|
||||
m.mu.Lock()
|
||||
handles := make(map[string]*deviceHandle, len(m.devices))
|
||||
for k, v := range m.devices {
|
||||
handles[k] = v
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
for vidpid, h := range handles {
|
||||
m.disconnectDevice(ctx, vidpid, h)
|
||||
}
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
current, err := m.enumerateFunc()
|
||||
if err != nil {
|
||||
log.Printf("usb: enumeration error: %v", err)
|
||||
continue
|
||||
}
|
||||
m.reconcile(ctx, prev, current)
|
||||
prev = current
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reconcile compares previous and current device snapshots, spawning or
|
||||
// stopping device goroutines as needed.
|
||||
func (m *Manager) reconcile(ctx context.Context, prev, current map[string]string) {
|
||||
// Detect newly connected devices (in current but not prev).
|
||||
for vidpid, portPath := range current {
|
||||
if _, wasThere := prev[vidpid]; !wasThere {
|
||||
spec, known := KnownDevices[vidpid]
|
||||
if !known {
|
||||
continue
|
||||
}
|
||||
m.connectDevice(ctx, vidpid, portPath, spec)
|
||||
}
|
||||
}
|
||||
|
||||
// Detect newly disconnected devices (in prev but not current).
|
||||
for vidpid := range prev {
|
||||
if _, stillThere := current[vidpid]; !stillThere {
|
||||
m.mu.Lock()
|
||||
h, ok := m.devices[vidpid]
|
||||
m.mu.Unlock()
|
||||
|
||||
if ok {
|
||||
m.disconnectDevice(ctx, vidpid, h)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// connectDevice spawns a deviceLoop goroutine for the given device.
|
||||
func (m *Manager) connectDevice(ctx context.Context, vidpid, portPath string, spec DeviceSpec) {
|
||||
devCtx, devCancel := context.WithCancel(ctx)
|
||||
cmdChan := make(chan Command, 16)
|
||||
done := make(chan struct{})
|
||||
|
||||
h := &deviceHandle{
|
||||
cmdChan: cmdChan,
|
||||
done: done,
|
||||
cancel: devCancel,
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.devices[vidpid] = h
|
||||
m.mu.Unlock()
|
||||
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
m.deviceLoop(devCtx, spec, portPath, cmdChan, done)
|
||||
devCancel() // ensure context is cancelled on exit
|
||||
|
||||
m.mu.Lock()
|
||||
delete(m.devices, vidpid)
|
||||
m.mu.Unlock()
|
||||
}()
|
||||
|
||||
m.emit(DeviceEvent{VIDPID: vidpid, Spec: spec, State: StateConnected})
|
||||
}
|
||||
|
||||
// disconnectDevice sends CmdClose to the device goroutine and waits for it to exit.
|
||||
func (m *Manager) disconnectDevice(ctx context.Context, vidpid string, h *deviceHandle) {
|
||||
// Signal the device loop to close.
|
||||
h.cancel()
|
||||
|
||||
// Also send an explicit CmdClose in case the loop is blocked on cmdChan.
|
||||
select {
|
||||
case h.cmdChan <- Command{Type: CmdClose}:
|
||||
default:
|
||||
}
|
||||
|
||||
// Wait for device goroutine to exit with a timeout.
|
||||
spec := KnownDevices[vidpid]
|
||||
select {
|
||||
case <-h.done:
|
||||
case <-time.After(2 * time.Second):
|
||||
log.Printf("usb: timeout waiting for device %s goroutine to exit", vidpid)
|
||||
}
|
||||
|
||||
m.emit(DeviceEvent{VIDPID: vidpid, Spec: spec, State: StateDisconnected})
|
||||
}
|
||||
|
||||
// deviceLoop manages a single USB device connection.
|
||||
// It opens the serial port, handles incoming commands, and exits cleanly on
|
||||
// context cancellation or CmdClose.
|
||||
//
|
||||
// Goroutine leak prevention (per PITFALLS.md Pitfall 2):
|
||||
// - Child context is cancelled before port.Close()
|
||||
// - Inner read goroutine selects on ctx.Done() to exit if context is done
|
||||
// before the blocking Read returns
|
||||
// - done channel is closed on return so the poll loop can wait with timeout
|
||||
func (m *Manager) deviceLoop(ctx context.Context, spec DeviceSpec, portPath string, cmdChan <-chan Command, done chan<- struct{}) {
|
||||
defer close(done)
|
||||
|
||||
baud := spec.BaudRate
|
||||
if baud == 0 {
|
||||
baud = 9600
|
||||
}
|
||||
|
||||
port, err := m.openSerialPort(portPath, baud)
|
||||
if err != nil {
|
||||
log.Printf("usb: failed to open %s (%s): %v", portPath, spec.String(), err)
|
||||
return
|
||||
}
|
||||
|
||||
// Inner read goroutine — reads from the port until context is cancelled or port closes.
|
||||
readDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(readDone)
|
||||
// T-04-01: Cap read buffer at 4096 bytes to prevent large-payload panics.
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
n, err := port.Read(buf)
|
||||
if n > 0 {
|
||||
// Forward read bytes as a CmdWrite reply event in future phases.
|
||||
// For now, log for observability.
|
||||
log.Printf("usb: [%s] read %d bytes", spec.String(), n)
|
||||
}
|
||||
if err != nil {
|
||||
// Port was closed or device disconnected — exit read goroutine.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
default:
|
||||
log.Printf("usb: [%s] read error (device likely disconnected): %v", spec.String(), err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Main command loop.
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context cancelled — close port to unblock the read goroutine.
|
||||
port.Close()
|
||||
<-readDone
|
||||
return
|
||||
|
||||
case cmd, ok := <-cmdChan:
|
||||
if !ok {
|
||||
port.Close()
|
||||
<-readDone
|
||||
return
|
||||
}
|
||||
switch cmd.Type {
|
||||
case CmdWrite:
|
||||
_, err := port.Write(cmd.Payload)
|
||||
if cmd.Reply != nil {
|
||||
cmd.Reply <- err
|
||||
}
|
||||
case CmdClose:
|
||||
port.Close()
|
||||
<-readDone
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// emit sends a DeviceEvent to the events channel (non-blocking with logging on full).
|
||||
func (m *Manager) emit(event DeviceEvent) {
|
||||
select {
|
||||
case m.events <- event:
|
||||
default:
|
||||
log.Printf("usb: events channel full, dropping event for %s", event.VIDPID)
|
||||
}
|
||||
}
|
||||
215
internal/usb/manager_test.go
Normal file
215
internal/usb/manager_test.go
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
package usb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockEnumerator returns predefined snapshots per call number.
|
||||
// Each call to next() advances the snapshot index.
|
||||
type mockEnumerator struct {
|
||||
snapshots []map[string]string
|
||||
idx atomic.Int32
|
||||
}
|
||||
|
||||
func newMockEnumerator(snapshots ...map[string]string) *mockEnumerator {
|
||||
return &mockEnumerator{snapshots: snapshots}
|
||||
}
|
||||
|
||||
func (m *mockEnumerator) next() (map[string]string, error) {
|
||||
i := int(m.idx.Add(1)) - 1
|
||||
if i >= len(m.snapshots) {
|
||||
// Return last snapshot forever after exhaustion.
|
||||
i = len(m.snapshots) - 1
|
||||
}
|
||||
// Deep copy to prevent mutation.
|
||||
out := make(map[string]string, len(m.snapshots[i]))
|
||||
for k, v := range m.snapshots[i] {
|
||||
out[k] = v
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// noopSerialOpener is a mock serial opener that does nothing (no hardware needed).
|
||||
// It returns a mockPort that satisfies the serialPort interface.
|
||||
func noopSerialOpener(path string, baud int) (serialPort, error) {
|
||||
return newMockPort(), nil
|
||||
}
|
||||
|
||||
// --- Test 1: Start with one device → exactly 1 device goroutine spawned ---
|
||||
|
||||
func TestManagerStartSpawnsOneGoroutine(t *testing.T) {
|
||||
enum := newMockEnumerator(
|
||||
map[string]string{"0525:a4a7": "/dev/cu.mock0"}, // device present
|
||||
)
|
||||
|
||||
m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond)
|
||||
|
||||
baseline := runtime.NumGoroutine()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
m.Start(ctx)
|
||||
time.Sleep(60 * time.Millisecond) // allow poll loop and device goroutine to start
|
||||
|
||||
after := runtime.NumGoroutine()
|
||||
// We expect at least 1 new goroutine (the device loop) + poll loop goroutine.
|
||||
// Baseline can fluctuate ±2 so check at least 1 extra goroutine appeared.
|
||||
if after <= baseline {
|
||||
t.Errorf("expected goroutine count to increase after Start(); baseline=%d after=%d", baseline, after)
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
}
|
||||
|
||||
// --- Test 2: Device disconnect then reconnect emits Disconnected then Connected events ---
|
||||
|
||||
func TestManagerDisconnectReconnectEvents(t *testing.T) {
|
||||
enum := newMockEnumerator(
|
||||
map[string]string{"0525:a4a7": "/dev/cu.mock0"}, // call 1: present
|
||||
map[string]string{}, // call 2: absent
|
||||
map[string]string{"0525:a4a7": "/dev/cu.mock0"}, // call 3: present again
|
||||
)
|
||||
|
||||
m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
m.Start(ctx)
|
||||
defer m.Stop()
|
||||
|
||||
var events []DeviceEvent
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for e := range m.Events() {
|
||||
events = append(events, e)
|
||||
if len(events) >= 3 { // Connected, Disconnected, Connected
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("timeout waiting for 3 events; got %d: %v", len(events), events)
|
||||
}
|
||||
|
||||
if len(events) < 3 {
|
||||
t.Fatalf("expected at least 3 events, got %d: %v", len(events), events)
|
||||
}
|
||||
if events[0].State != StateConnected {
|
||||
t.Errorf("event[0]: expected Connected, got %v", events[0].State)
|
||||
}
|
||||
if events[1].State != StateDisconnected {
|
||||
t.Errorf("event[1]: expected Disconnected, got %v", events[1].State)
|
||||
}
|
||||
if events[2].State != StateConnected {
|
||||
t.Errorf("event[2]: expected Connected, got %v", events[2].State)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test 3: Stop causes all goroutines to exit within 500ms ---
|
||||
|
||||
func TestManagerStopGoroutineLeak(t *testing.T) {
|
||||
enum := newMockEnumerator(
|
||||
map[string]string{"0525:a4a7": "/dev/cu.mock0"},
|
||||
)
|
||||
|
||||
m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
baseline := runtime.NumGoroutine()
|
||||
m.Start(ctx)
|
||||
time.Sleep(60 * time.Millisecond) // let goroutines spin up
|
||||
|
||||
m.Stop()
|
||||
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
current := runtime.NumGoroutine()
|
||||
// Allow ±2 goroutines for runtime background goroutines.
|
||||
if current <= baseline+2 {
|
||||
return // passed
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Errorf("goroutine leak: baseline=%d, after Stop()=%d (expected within ±2)",
|
||||
baseline, runtime.NumGoroutine())
|
||||
}
|
||||
|
||||
// --- Test 4: Send to absent device returns ErrDeviceNotConnected ---
|
||||
|
||||
func TestManagerSendToAbsentDevice(t *testing.T) {
|
||||
// Empty snapshot — no devices connected.
|
||||
enum := newMockEnumerator(map[string]string{})
|
||||
m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
m.Start(ctx)
|
||||
defer m.Stop()
|
||||
|
||||
time.Sleep(40 * time.Millisecond) // let poll loop run
|
||||
|
||||
replyCh := make(chan error, 1)
|
||||
err := m.Send("0525:a4a7", Command{
|
||||
Type: CmdWrite,
|
||||
Reply: replyCh,
|
||||
})
|
||||
if !errors.Is(err, ErrDeviceNotConnected) {
|
||||
t.Errorf("expected ErrDeviceNotConnected, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test 5: Goroutine count stable across 5 simulated unplug/replug cycles ---
|
||||
|
||||
func TestGoroutineStability(t *testing.T) {
|
||||
// Build snapshots: 5 cycles of present/absent/present
|
||||
snapshots := []map[string]string{}
|
||||
snapshots = append(snapshots, map[string]string{}) // initial empty
|
||||
for i := 0; i < 5; i++ {
|
||||
snapshots = append(snapshots, map[string]string{"0525:a4a7": "/dev/cu.mock0"}) // plug in
|
||||
snapshots = append(snapshots, map[string]string{}) // unplug
|
||||
}
|
||||
// Final state: device connected
|
||||
snapshots = append(snapshots, map[string]string{"0525:a4a7": "/dev/cu.mock0"})
|
||||
|
||||
enum := newMockEnumerator(snapshots...)
|
||||
m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
baseline := runtime.NumGoroutine()
|
||||
m.Start(ctx)
|
||||
|
||||
// Wait long enough for all snapshots to be polled through.
|
||||
// 12 snapshots * 20ms poll interval = ~240ms minimum; give 2x buffer.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
m.Stop()
|
||||
|
||||
// Give goroutines up to 500ms to settle.
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
var finalCount int
|
||||
for time.Now().Before(deadline) {
|
||||
finalCount = runtime.NumGoroutine()
|
||||
if finalCount <= baseline+2 {
|
||||
return // stable
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Errorf("goroutine count unstable after 5 replug cycles: baseline=%d, final=%d (max allowed %d)",
|
||||
baseline, finalCount, baseline+2)
|
||||
}
|
||||
50
internal/usb/mock_port_test.go
Normal file
50
internal/usb/mock_port_test.go
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
package usb
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// mockPort implements serialPort for testing without hardware.
|
||||
// Reads block until Close() is called (simulating a hardware device that
|
||||
// stays connected until explicitly disconnected).
|
||||
type mockPort struct {
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
func newMockPort() *mockPort {
|
||||
return &mockPort{
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Read blocks until the port is closed, then returns io.EOF.
|
||||
// This mirrors the behaviour of a real serial port: Read blocks until data
|
||||
// arrives or the port is closed.
|
||||
func (p *mockPort) Read(buf []byte) (int, error) {
|
||||
<-p.closeCh
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Write is a no-op for the mock (no hardware to write to).
|
||||
func (p *mockPort) Write(data []byte) (int, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.closed {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// Close signals the mock port as closed, unblocking any pending Read.
|
||||
func (p *mockPort) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if !p.closed {
|
||||
p.closed = true
|
||||
close(p.closeCh)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue