- Manager with injectable enumerateFunc and serialOpener for test isolation - goroutine-per-device model: deviceLoop owns one serial port per device - context+done-channel teardown: inner read goroutine exits on ctx.Done() - Read buffer capped at 4096 bytes (T-04-01 threat mitigation) - Poll loop reconciles prev/current snapshots for connect/disconnect events - ErrDeviceNotConnected returned by Send() when device absent - All 5 manager tests pass with -race flag; goroutine count stable across 5 replug cycles - mockPort test helper blocks Read until Close() unblocks it (realistic behavior)
327 lines
8.3 KiB
Go
327 lines
8.3 KiB
Go
package usb
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log"
|
|
"sync"
|
|
"time"
|
|
|
|
"go.bug.st/serial"
|
|
)
|
|
|
|
// ErrDeviceNotConnected is returned by Send when the target device is not currently connected.
|
|
var ErrDeviceNotConnected = errors.New("usb: device not connected")
|
|
|
|
// serialPort abstracts go.bug.st/serial.Port for testing without hardware.
|
|
type serialPort interface {
|
|
Read(p []byte) (n int, err error)
|
|
Write(p []byte) (n int, err error)
|
|
Close() error
|
|
}
|
|
|
|
// serialOpener is the function type used to open a serial port.
|
|
// Injectable for tests.
|
|
type serialOpener func(path string, baud int) (serialPort, error)
|
|
|
|
// defaultSerialOpener opens a real serial port using go.bug.st/serial.
|
|
func defaultSerialOpener(path string, baud int) (serialPort, error) {
|
|
if baud == 0 {
|
|
baud = 9600
|
|
}
|
|
mode := &serial.Mode{BaudRate: baud}
|
|
return serial.Open(path, mode)
|
|
}
|
|
|
|
// deviceHandle tracks a running device goroutine.
|
|
type deviceHandle struct {
|
|
cmdChan chan Command
|
|
done chan struct{} // closed when deviceLoop exits
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// Manager owns all USB device goroutines and emits DeviceEvents to subscribers.
|
|
type Manager struct {
|
|
pollInterval time.Duration
|
|
events chan DeviceEvent
|
|
devices map[string]*deviceHandle // keyed by VID:PID
|
|
mu sync.Mutex
|
|
cancel context.CancelFunc
|
|
wg sync.WaitGroup
|
|
enumerateFunc func() (map[string]string, error)
|
|
openSerialPort serialOpener
|
|
}
|
|
|
|
// NewManager creates a new Manager with production defaults.
|
|
func NewManager(pollInterval time.Duration) *Manager {
|
|
return &Manager{
|
|
pollInterval: pollInterval,
|
|
events: make(chan DeviceEvent, 32),
|
|
devices: make(map[string]*deviceHandle),
|
|
enumerateFunc: enumerateConnected,
|
|
openSerialPort: defaultSerialOpener,
|
|
}
|
|
}
|
|
|
|
// newManagerForTest creates a Manager with injected enumerator and serial opener.
|
|
// Used only in tests.
|
|
func newManagerForTest(enumFn func() (map[string]string, error), opener serialOpener, pollInterval time.Duration) *Manager {
|
|
return &Manager{
|
|
pollInterval: pollInterval,
|
|
events: make(chan DeviceEvent, 64),
|
|
devices: make(map[string]*deviceHandle),
|
|
enumerateFunc: enumFn,
|
|
openSerialPort: opener,
|
|
}
|
|
}
|
|
|
|
// Start begins the poll loop in a background goroutine.
|
|
// Call Stop to shut down cleanly.
|
|
func (m *Manager) Start(ctx context.Context) {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
m.cancel = cancel
|
|
|
|
m.wg.Add(1)
|
|
go m.pollLoop(ctx)
|
|
}
|
|
|
|
// Stop cancels all device goroutines and waits for them to exit.
|
|
func (m *Manager) Stop() {
|
|
if m.cancel != nil {
|
|
m.cancel()
|
|
}
|
|
m.wg.Wait()
|
|
close(m.events)
|
|
}
|
|
|
|
// Events returns the read-only channel of DeviceEvent notifications.
|
|
func (m *Manager) Events() <-chan DeviceEvent {
|
|
return m.events
|
|
}
|
|
|
|
// Send delivers a Command to the named device. Returns ErrDeviceNotConnected if absent.
|
|
func (m *Manager) Send(vidpid string, cmd Command) error {
|
|
m.mu.Lock()
|
|
h, ok := m.devices[vidpid]
|
|
m.mu.Unlock()
|
|
|
|
if !ok {
|
|
return ErrDeviceNotConnected
|
|
}
|
|
|
|
select {
|
|
case h.cmdChan <- cmd:
|
|
return nil
|
|
default:
|
|
return errors.New("usb: command channel full for " + vidpid)
|
|
}
|
|
}
|
|
|
|
// pollLoop runs on a background goroutine, periodically enumerating USB devices.
|
|
func (m *Manager) pollLoop(ctx context.Context) {
|
|
defer m.wg.Done()
|
|
|
|
ticker := time.NewTicker(m.pollInterval)
|
|
defer ticker.Stop()
|
|
|
|
var prev map[string]string // previous snapshot
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
// Shut down all running device goroutines.
|
|
m.mu.Lock()
|
|
handles := make(map[string]*deviceHandle, len(m.devices))
|
|
for k, v := range m.devices {
|
|
handles[k] = v
|
|
}
|
|
m.mu.Unlock()
|
|
|
|
for vidpid, h := range handles {
|
|
m.disconnectDevice(ctx, vidpid, h)
|
|
}
|
|
return
|
|
|
|
case <-ticker.C:
|
|
current, err := m.enumerateFunc()
|
|
if err != nil {
|
|
log.Printf("usb: enumeration error: %v", err)
|
|
continue
|
|
}
|
|
m.reconcile(ctx, prev, current)
|
|
prev = current
|
|
}
|
|
}
|
|
}
|
|
|
|
// reconcile compares previous and current device snapshots, spawning or
|
|
// stopping device goroutines as needed.
|
|
func (m *Manager) reconcile(ctx context.Context, prev, current map[string]string) {
|
|
// Detect newly connected devices (in current but not prev).
|
|
for vidpid, portPath := range current {
|
|
if _, wasThere := prev[vidpid]; !wasThere {
|
|
spec, known := KnownDevices[vidpid]
|
|
if !known {
|
|
continue
|
|
}
|
|
m.connectDevice(ctx, vidpid, portPath, spec)
|
|
}
|
|
}
|
|
|
|
// Detect newly disconnected devices (in prev but not current).
|
|
for vidpid := range prev {
|
|
if _, stillThere := current[vidpid]; !stillThere {
|
|
m.mu.Lock()
|
|
h, ok := m.devices[vidpid]
|
|
m.mu.Unlock()
|
|
|
|
if ok {
|
|
m.disconnectDevice(ctx, vidpid, h)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// connectDevice spawns a deviceLoop goroutine for the given device.
|
|
func (m *Manager) connectDevice(ctx context.Context, vidpid, portPath string, spec DeviceSpec) {
|
|
devCtx, devCancel := context.WithCancel(ctx)
|
|
cmdChan := make(chan Command, 16)
|
|
done := make(chan struct{})
|
|
|
|
h := &deviceHandle{
|
|
cmdChan: cmdChan,
|
|
done: done,
|
|
cancel: devCancel,
|
|
}
|
|
|
|
m.mu.Lock()
|
|
m.devices[vidpid] = h
|
|
m.mu.Unlock()
|
|
|
|
m.wg.Add(1)
|
|
go func() {
|
|
defer m.wg.Done()
|
|
m.deviceLoop(devCtx, spec, portPath, cmdChan, done)
|
|
devCancel() // ensure context is cancelled on exit
|
|
|
|
m.mu.Lock()
|
|
delete(m.devices, vidpid)
|
|
m.mu.Unlock()
|
|
}()
|
|
|
|
m.emit(DeviceEvent{VIDPID: vidpid, Spec: spec, State: StateConnected})
|
|
}
|
|
|
|
// disconnectDevice sends CmdClose to the device goroutine and waits for it to exit.
|
|
func (m *Manager) disconnectDevice(ctx context.Context, vidpid string, h *deviceHandle) {
|
|
// Signal the device loop to close.
|
|
h.cancel()
|
|
|
|
// Also send an explicit CmdClose in case the loop is blocked on cmdChan.
|
|
select {
|
|
case h.cmdChan <- Command{Type: CmdClose}:
|
|
default:
|
|
}
|
|
|
|
// Wait for device goroutine to exit with a timeout.
|
|
spec := KnownDevices[vidpid]
|
|
select {
|
|
case <-h.done:
|
|
case <-time.After(2 * time.Second):
|
|
log.Printf("usb: timeout waiting for device %s goroutine to exit", vidpid)
|
|
}
|
|
|
|
m.emit(DeviceEvent{VIDPID: vidpid, Spec: spec, State: StateDisconnected})
|
|
}
|
|
|
|
// deviceLoop manages a single USB device connection.
|
|
// It opens the serial port, handles incoming commands, and exits cleanly on
|
|
// context cancellation or CmdClose.
|
|
//
|
|
// Goroutine leak prevention (per PITFALLS.md Pitfall 2):
|
|
// - Child context is cancelled before port.Close()
|
|
// - Inner read goroutine selects on ctx.Done() to exit if context is done
|
|
// before the blocking Read returns
|
|
// - done channel is closed on return so the poll loop can wait with timeout
|
|
func (m *Manager) deviceLoop(ctx context.Context, spec DeviceSpec, portPath string, cmdChan <-chan Command, done chan<- struct{}) {
|
|
defer close(done)
|
|
|
|
baud := spec.BaudRate
|
|
if baud == 0 {
|
|
baud = 9600
|
|
}
|
|
|
|
port, err := m.openSerialPort(portPath, baud)
|
|
if err != nil {
|
|
log.Printf("usb: failed to open %s (%s): %v", portPath, spec.String(), err)
|
|
return
|
|
}
|
|
|
|
// Inner read goroutine — reads from the port until context is cancelled or port closes.
|
|
readDone := make(chan struct{})
|
|
go func() {
|
|
defer close(readDone)
|
|
// T-04-01: Cap read buffer at 4096 bytes to prevent large-payload panics.
|
|
buf := make([]byte, 4096)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
n, err := port.Read(buf)
|
|
if n > 0 {
|
|
// Forward read bytes as a CmdWrite reply event in future phases.
|
|
// For now, log for observability.
|
|
log.Printf("usb: [%s] read %d bytes", spec.String(), n)
|
|
}
|
|
if err != nil {
|
|
// Port was closed or device disconnected — exit read goroutine.
|
|
select {
|
|
case <-ctx.Done():
|
|
default:
|
|
log.Printf("usb: [%s] read error (device likely disconnected): %v", spec.String(), err)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Main command loop.
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
// Context cancelled — close port to unblock the read goroutine.
|
|
port.Close()
|
|
<-readDone
|
|
return
|
|
|
|
case cmd, ok := <-cmdChan:
|
|
if !ok {
|
|
port.Close()
|
|
<-readDone
|
|
return
|
|
}
|
|
switch cmd.Type {
|
|
case CmdWrite:
|
|
_, err := port.Write(cmd.Payload)
|
|
if cmd.Reply != nil {
|
|
cmd.Reply <- err
|
|
}
|
|
case CmdClose:
|
|
port.Close()
|
|
<-readDone
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// emit sends a DeviceEvent to the events channel (non-blocking with logging on full).
|
|
func (m *Manager) emit(event DeviceEvent) {
|
|
select {
|
|
case m.events <- event:
|
|
default:
|
|
log.Printf("usb: events channel full, dropping event for %s", event.VIDPID)
|
|
}
|
|
}
|