diff --git a/internal/usb/manager.go b/internal/usb/manager.go new file mode 100644 index 0000000..ce623d7 --- /dev/null +++ b/internal/usb/manager.go @@ -0,0 +1,327 @@ +package usb + +import ( + "context" + "errors" + "log" + "sync" + "time" + + "go.bug.st/serial" +) + +// ErrDeviceNotConnected is returned by Send when the target device is not currently connected. +var ErrDeviceNotConnected = errors.New("usb: device not connected") + +// serialPort abstracts go.bug.st/serial.Port for testing without hardware. +type serialPort interface { + Read(p []byte) (n int, err error) + Write(p []byte) (n int, err error) + Close() error +} + +// serialOpener is the function type used to open a serial port. +// Injectable for tests. +type serialOpener func(path string, baud int) (serialPort, error) + +// defaultSerialOpener opens a real serial port using go.bug.st/serial. +func defaultSerialOpener(path string, baud int) (serialPort, error) { + if baud == 0 { + baud = 9600 + } + mode := &serial.Mode{BaudRate: baud} + return serial.Open(path, mode) +} + +// deviceHandle tracks a running device goroutine. +type deviceHandle struct { + cmdChan chan Command + done chan struct{} // closed when deviceLoop exits + cancel context.CancelFunc +} + +// Manager owns all USB device goroutines and emits DeviceEvents to subscribers. +type Manager struct { + pollInterval time.Duration + events chan DeviceEvent + devices map[string]*deviceHandle // keyed by VID:PID + mu sync.Mutex + cancel context.CancelFunc + wg sync.WaitGroup + enumerateFunc func() (map[string]string, error) + openSerialPort serialOpener +} + +// NewManager creates a new Manager with production defaults. +func NewManager(pollInterval time.Duration) *Manager { + return &Manager{ + pollInterval: pollInterval, + events: make(chan DeviceEvent, 32), + devices: make(map[string]*deviceHandle), + enumerateFunc: enumerateConnected, + openSerialPort: defaultSerialOpener, + } +} + +// newManagerForTest creates a Manager with injected enumerator and serial opener. +// Used only in tests. +func newManagerForTest(enumFn func() (map[string]string, error), opener serialOpener, pollInterval time.Duration) *Manager { + return &Manager{ + pollInterval: pollInterval, + events: make(chan DeviceEvent, 64), + devices: make(map[string]*deviceHandle), + enumerateFunc: enumFn, + openSerialPort: opener, + } +} + +// Start begins the poll loop in a background goroutine. +// Call Stop to shut down cleanly. +func (m *Manager) Start(ctx context.Context) { + ctx, cancel := context.WithCancel(ctx) + m.cancel = cancel + + m.wg.Add(1) + go m.pollLoop(ctx) +} + +// Stop cancels all device goroutines and waits for them to exit. +func (m *Manager) Stop() { + if m.cancel != nil { + m.cancel() + } + m.wg.Wait() + close(m.events) +} + +// Events returns the read-only channel of DeviceEvent notifications. +func (m *Manager) Events() <-chan DeviceEvent { + return m.events +} + +// Send delivers a Command to the named device. Returns ErrDeviceNotConnected if absent. +func (m *Manager) Send(vidpid string, cmd Command) error { + m.mu.Lock() + h, ok := m.devices[vidpid] + m.mu.Unlock() + + if !ok { + return ErrDeviceNotConnected + } + + select { + case h.cmdChan <- cmd: + return nil + default: + return errors.New("usb: command channel full for " + vidpid) + } +} + +// pollLoop runs on a background goroutine, periodically enumerating USB devices. +func (m *Manager) pollLoop(ctx context.Context) { + defer m.wg.Done() + + ticker := time.NewTicker(m.pollInterval) + defer ticker.Stop() + + var prev map[string]string // previous snapshot + + for { + select { + case <-ctx.Done(): + // Shut down all running device goroutines. + m.mu.Lock() + handles := make(map[string]*deviceHandle, len(m.devices)) + for k, v := range m.devices { + handles[k] = v + } + m.mu.Unlock() + + for vidpid, h := range handles { + m.disconnectDevice(ctx, vidpid, h) + } + return + + case <-ticker.C: + current, err := m.enumerateFunc() + if err != nil { + log.Printf("usb: enumeration error: %v", err) + continue + } + m.reconcile(ctx, prev, current) + prev = current + } + } +} + +// reconcile compares previous and current device snapshots, spawning or +// stopping device goroutines as needed. +func (m *Manager) reconcile(ctx context.Context, prev, current map[string]string) { + // Detect newly connected devices (in current but not prev). + for vidpid, portPath := range current { + if _, wasThere := prev[vidpid]; !wasThere { + spec, known := KnownDevices[vidpid] + if !known { + continue + } + m.connectDevice(ctx, vidpid, portPath, spec) + } + } + + // Detect newly disconnected devices (in prev but not current). + for vidpid := range prev { + if _, stillThere := current[vidpid]; !stillThere { + m.mu.Lock() + h, ok := m.devices[vidpid] + m.mu.Unlock() + + if ok { + m.disconnectDevice(ctx, vidpid, h) + } + } + } +} + +// connectDevice spawns a deviceLoop goroutine for the given device. +func (m *Manager) connectDevice(ctx context.Context, vidpid, portPath string, spec DeviceSpec) { + devCtx, devCancel := context.WithCancel(ctx) + cmdChan := make(chan Command, 16) + done := make(chan struct{}) + + h := &deviceHandle{ + cmdChan: cmdChan, + done: done, + cancel: devCancel, + } + + m.mu.Lock() + m.devices[vidpid] = h + m.mu.Unlock() + + m.wg.Add(1) + go func() { + defer m.wg.Done() + m.deviceLoop(devCtx, spec, portPath, cmdChan, done) + devCancel() // ensure context is cancelled on exit + + m.mu.Lock() + delete(m.devices, vidpid) + m.mu.Unlock() + }() + + m.emit(DeviceEvent{VIDPID: vidpid, Spec: spec, State: StateConnected}) +} + +// disconnectDevice sends CmdClose to the device goroutine and waits for it to exit. +func (m *Manager) disconnectDevice(ctx context.Context, vidpid string, h *deviceHandle) { + // Signal the device loop to close. + h.cancel() + + // Also send an explicit CmdClose in case the loop is blocked on cmdChan. + select { + case h.cmdChan <- Command{Type: CmdClose}: + default: + } + + // Wait for device goroutine to exit with a timeout. + spec := KnownDevices[vidpid] + select { + case <-h.done: + case <-time.After(2 * time.Second): + log.Printf("usb: timeout waiting for device %s goroutine to exit", vidpid) + } + + m.emit(DeviceEvent{VIDPID: vidpid, Spec: spec, State: StateDisconnected}) +} + +// deviceLoop manages a single USB device connection. +// It opens the serial port, handles incoming commands, and exits cleanly on +// context cancellation or CmdClose. +// +// Goroutine leak prevention (per PITFALLS.md Pitfall 2): +// - Child context is cancelled before port.Close() +// - Inner read goroutine selects on ctx.Done() to exit if context is done +// before the blocking Read returns +// - done channel is closed on return so the poll loop can wait with timeout +func (m *Manager) deviceLoop(ctx context.Context, spec DeviceSpec, portPath string, cmdChan <-chan Command, done chan<- struct{}) { + defer close(done) + + baud := spec.BaudRate + if baud == 0 { + baud = 9600 + } + + port, err := m.openSerialPort(portPath, baud) + if err != nil { + log.Printf("usb: failed to open %s (%s): %v", portPath, spec.String(), err) + return + } + + // Inner read goroutine — reads from the port until context is cancelled or port closes. + readDone := make(chan struct{}) + go func() { + defer close(readDone) + // T-04-01: Cap read buffer at 4096 bytes to prevent large-payload panics. + buf := make([]byte, 4096) + for { + select { + case <-ctx.Done(): + return + default: + } + n, err := port.Read(buf) + if n > 0 { + // Forward read bytes as a CmdWrite reply event in future phases. + // For now, log for observability. + log.Printf("usb: [%s] read %d bytes", spec.String(), n) + } + if err != nil { + // Port was closed or device disconnected — exit read goroutine. + select { + case <-ctx.Done(): + default: + log.Printf("usb: [%s] read error (device likely disconnected): %v", spec.String(), err) + } + return + } + } + }() + + // Main command loop. + for { + select { + case <-ctx.Done(): + // Context cancelled — close port to unblock the read goroutine. + port.Close() + <-readDone + return + + case cmd, ok := <-cmdChan: + if !ok { + port.Close() + <-readDone + return + } + switch cmd.Type { + case CmdWrite: + _, err := port.Write(cmd.Payload) + if cmd.Reply != nil { + cmd.Reply <- err + } + case CmdClose: + port.Close() + <-readDone + return + } + } + } +} + +// emit sends a DeviceEvent to the events channel (non-blocking with logging on full). +func (m *Manager) emit(event DeviceEvent) { + select { + case m.events <- event: + default: + log.Printf("usb: events channel full, dropping event for %s", event.VIDPID) + } +} diff --git a/internal/usb/manager_test.go b/internal/usb/manager_test.go new file mode 100644 index 0000000..9658740 --- /dev/null +++ b/internal/usb/manager_test.go @@ -0,0 +1,215 @@ +package usb + +import ( + "context" + "errors" + "runtime" + "sync/atomic" + "testing" + "time" +) + +// mockEnumerator returns predefined snapshots per call number. +// Each call to next() advances the snapshot index. +type mockEnumerator struct { + snapshots []map[string]string + idx atomic.Int32 +} + +func newMockEnumerator(snapshots ...map[string]string) *mockEnumerator { + return &mockEnumerator{snapshots: snapshots} +} + +func (m *mockEnumerator) next() (map[string]string, error) { + i := int(m.idx.Add(1)) - 1 + if i >= len(m.snapshots) { + // Return last snapshot forever after exhaustion. + i = len(m.snapshots) - 1 + } + // Deep copy to prevent mutation. + out := make(map[string]string, len(m.snapshots[i])) + for k, v := range m.snapshots[i] { + out[k] = v + } + return out, nil +} + +// noopSerialOpener is a mock serial opener that does nothing (no hardware needed). +// It returns a mockPort that satisfies the serialPort interface. +func noopSerialOpener(path string, baud int) (serialPort, error) { + return newMockPort(), nil +} + +// --- Test 1: Start with one device → exactly 1 device goroutine spawned --- + +func TestManagerStartSpawnsOneGoroutine(t *testing.T) { + enum := newMockEnumerator( + map[string]string{"0525:a4a7": "/dev/cu.mock0"}, // device present + ) + + m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond) + + baseline := runtime.NumGoroutine() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + m.Start(ctx) + time.Sleep(60 * time.Millisecond) // allow poll loop and device goroutine to start + + after := runtime.NumGoroutine() + // We expect at least 1 new goroutine (the device loop) + poll loop goroutine. + // Baseline can fluctuate ±2 so check at least 1 extra goroutine appeared. + if after <= baseline { + t.Errorf("expected goroutine count to increase after Start(); baseline=%d after=%d", baseline, after) + } + + m.Stop() +} + +// --- Test 2: Device disconnect then reconnect emits Disconnected then Connected events --- + +func TestManagerDisconnectReconnectEvents(t *testing.T) { + enum := newMockEnumerator( + map[string]string{"0525:a4a7": "/dev/cu.mock0"}, // call 1: present + map[string]string{}, // call 2: absent + map[string]string{"0525:a4a7": "/dev/cu.mock0"}, // call 3: present again + ) + + m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + m.Start(ctx) + defer m.Stop() + + var events []DeviceEvent + done := make(chan struct{}) + go func() { + defer close(done) + for e := range m.Events() { + events = append(events, e) + if len(events) >= 3 { // Connected, Disconnected, Connected + return + } + } + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for 3 events; got %d: %v", len(events), events) + } + + if len(events) < 3 { + t.Fatalf("expected at least 3 events, got %d: %v", len(events), events) + } + if events[0].State != StateConnected { + t.Errorf("event[0]: expected Connected, got %v", events[0].State) + } + if events[1].State != StateDisconnected { + t.Errorf("event[1]: expected Disconnected, got %v", events[1].State) + } + if events[2].State != StateConnected { + t.Errorf("event[2]: expected Connected, got %v", events[2].State) + } +} + +// --- Test 3: Stop causes all goroutines to exit within 500ms --- + +func TestManagerStopGoroutineLeak(t *testing.T) { + enum := newMockEnumerator( + map[string]string{"0525:a4a7": "/dev/cu.mock0"}, + ) + + m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + baseline := runtime.NumGoroutine() + m.Start(ctx) + time.Sleep(60 * time.Millisecond) // let goroutines spin up + + m.Stop() + + deadline := time.Now().Add(500 * time.Millisecond) + for time.Now().Before(deadline) { + current := runtime.NumGoroutine() + // Allow ±2 goroutines for runtime background goroutines. + if current <= baseline+2 { + return // passed + } + time.Sleep(10 * time.Millisecond) + } + + t.Errorf("goroutine leak: baseline=%d, after Stop()=%d (expected within ±2)", + baseline, runtime.NumGoroutine()) +} + +// --- Test 4: Send to absent device returns ErrDeviceNotConnected --- + +func TestManagerSendToAbsentDevice(t *testing.T) { + // Empty snapshot — no devices connected. + enum := newMockEnumerator(map[string]string{}) + m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + m.Start(ctx) + defer m.Stop() + + time.Sleep(40 * time.Millisecond) // let poll loop run + + replyCh := make(chan error, 1) + err := m.Send("0525:a4a7", Command{ + Type: CmdWrite, + Reply: replyCh, + }) + if !errors.Is(err, ErrDeviceNotConnected) { + t.Errorf("expected ErrDeviceNotConnected, got %v", err) + } +} + +// --- Test 5: Goroutine count stable across 5 simulated unplug/replug cycles --- + +func TestGoroutineStability(t *testing.T) { + // Build snapshots: 5 cycles of present/absent/present + snapshots := []map[string]string{} + snapshots = append(snapshots, map[string]string{}) // initial empty + for i := 0; i < 5; i++ { + snapshots = append(snapshots, map[string]string{"0525:a4a7": "/dev/cu.mock0"}) // plug in + snapshots = append(snapshots, map[string]string{}) // unplug + } + // Final state: device connected + snapshots = append(snapshots, map[string]string{"0525:a4a7": "/dev/cu.mock0"}) + + enum := newMockEnumerator(snapshots...) + m := newManagerForTest(enum.next, noopSerialOpener, 20*time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + baseline := runtime.NumGoroutine() + m.Start(ctx) + + // Wait long enough for all snapshots to be polled through. + // 12 snapshots * 20ms poll interval = ~240ms minimum; give 2x buffer. + time.Sleep(500 * time.Millisecond) + + m.Stop() + + // Give goroutines up to 500ms to settle. + deadline := time.Now().Add(500 * time.Millisecond) + var finalCount int + for time.Now().Before(deadline) { + finalCount = runtime.NumGoroutine() + if finalCount <= baseline+2 { + return // stable + } + time.Sleep(10 * time.Millisecond) + } + + t.Errorf("goroutine count unstable after 5 replug cycles: baseline=%d, final=%d (max allowed %d)", + baseline, finalCount, baseline+2) +} diff --git a/internal/usb/mock_port_test.go b/internal/usb/mock_port_test.go new file mode 100644 index 0000000..5f25dba --- /dev/null +++ b/internal/usb/mock_port_test.go @@ -0,0 +1,50 @@ +package usb + +import ( + "io" + "sync" +) + +// mockPort implements serialPort for testing without hardware. +// Reads block until Close() is called (simulating a hardware device that +// stays connected until explicitly disconnected). +type mockPort struct { + mu sync.Mutex + closed bool + closeCh chan struct{} +} + +func newMockPort() *mockPort { + return &mockPort{ + closeCh: make(chan struct{}), + } +} + +// Read blocks until the port is closed, then returns io.EOF. +// This mirrors the behaviour of a real serial port: Read blocks until data +// arrives or the port is closed. +func (p *mockPort) Read(buf []byte) (int, error) { + <-p.closeCh + return 0, io.EOF +} + +// Write is a no-op for the mock (no hardware to write to). +func (p *mockPort) Write(data []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + return 0, io.ErrClosedPipe + } + return len(data), nil +} + +// Close signals the mock port as closed, unblocking any pending Read. +func (p *mockPort) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + if !p.closed { + p.closed = true + close(p.closeCh) + } + return nil +}