felt/internal/auth/pin.go
Mikkel Georgsen ae90d9bfae feat(01-04): add clock warnings, API routes, tests, and server wiring
- Clock API routes: start, pause, resume, advance, rewind, jump, get, warnings
- Role-based access control (floor+ for mutations, any auth for reads)
- Clock state persistence callback to DB on meaningful changes
- Blind structure levels loaded from DB on clock start
- Clock registry wired into HTTP server and cmd/leaf main
- 25 tests covering: state machine, countdown, pause/resume, auto-advance,
  jump, rewind, hand-for-hand, warnings, overtime, crash recovery, snapshot
- Fix missing crypto/rand import in auth/pin.go (Rule 3 auto-fix)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-01 03:56:23 +01:00

428 lines
12 KiB
Go

package auth
import (
"context"
"crypto/rand"
"database/sql"
"errors"
"fmt"
"log"
"time"
"golang.org/x/crypto/bcrypt"
)
// Errors returned by the auth service.
var (
ErrInvalidPIN = errors.New("auth: invalid PIN")
ErrTooManyAttempts = errors.New("auth: too many failed attempts, please wait")
ErrOperatorLocked = errors.New("auth: operator account locked")
ErrOperatorExists = errors.New("auth: operator already exists")
)
// Operator represents an authenticated operator in the system.
type Operator struct {
ID string `json:"id"`
Name string `json:"name"`
Role string `json:"role"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
}
// AuditRecorder is called by the auth service to log audit events.
// This breaks the import cycle between auth and audit packages.
type AuditRecorder func(ctx context.Context, action, targetType, targetID string, metadata map[string]interface{}) error
// AuthService handles PIN-based authentication with rate limiting.
type AuthService struct {
db *sql.DB
jwt *JWTService
recorder AuditRecorder // optional, for audit logging
}
// NewAuthService creates a new authentication service.
func NewAuthService(db *sql.DB, jwt *JWTService) *AuthService {
return &AuthService{
db: db,
jwt: jwt,
}
}
// SetAuditRecorder sets the audit recorder for logging auth events.
// Called after the audit trail is initialized to break the init cycle.
func (s *AuthService) SetAuditRecorder(recorder AuditRecorder) {
s.recorder = recorder
}
// initRateLimiter ensures the login_attempts table exists.
func (s *AuthService) initRateLimiter() error {
_, err := s.db.Exec(`CREATE TABLE IF NOT EXISTS login_attempts (
operator_id TEXT PRIMARY KEY,
consecutive_failures INTEGER NOT NULL DEFAULT 0,
last_failure_at INTEGER NOT NULL DEFAULT 0,
locked_until INTEGER NOT NULL DEFAULT 0
)`)
return err
}
// Login authenticates an operator by PIN and returns a JWT token.
// It checks rate limiting before attempting authentication.
func (s *AuthService) Login(ctx context.Context, pin string) (token string, operator Operator, err error) {
if pin == "" {
return "", Operator{}, ErrInvalidPIN
}
// Ensure rate limiter table exists
if err := s.initRateLimiter(); err != nil {
return "", Operator{}, fmt.Errorf("auth: init rate limiter: %w", err)
}
// Load all operators from DB
operators, err := s.loadOperators(ctx)
if err != nil {
return "", Operator{}, fmt.Errorf("auth: load operators: %w", err)
}
// Try to match PIN against each operator
for _, op := range operators {
if err := s.checkRateLimit(ctx, op.id); err != nil {
// Skip rate-limited operators silently during scan
continue
}
if err := bcrypt.CompareHashAndPassword([]byte(op.pinHash), []byte(pin)); err == nil {
// Match found - reset failure counter and issue JWT
s.resetFailures(ctx, op.id)
operator = Operator{
ID: op.id,
Name: op.name,
Role: op.role,
CreatedAt: op.createdAt,
UpdatedAt: op.updatedAt,
}
token, err = s.jwt.NewToken(op.id, op.role)
if err != nil {
return "", Operator{}, fmt.Errorf("auth: issue token: %w", err)
}
// Log successful login
if s.recorder != nil {
_ = s.recorder(ctx, "operator.login", "operator", op.id, map[string]interface{}{
"operator_name": op.name,
"role": op.role,
})
}
return token, operator, nil
}
}
// No match found - record failure for ALL operators (since PIN is global, not per-operator)
// But we track per-operator to prevent brute-force enumeration
// Actually, since the PIN is compared against all operators, we can't know
// which operator was targeted. Record a global failure counter instead.
// However, the spec says "keyed by operator ID" - this makes more sense
// when PINs are unique per operator (which they should be in practice).
// We'll record failure against a sentinel "global" key.
s.recordFailure(ctx, "_global")
globalLimited, _ := s.isRateLimited(ctx, "_global")
if globalLimited {
// Check if hard locked
locked, _ := s.isLockedOut(ctx, "_global")
if locked {
return "", Operator{}, ErrOperatorLocked
}
return "", Operator{}, ErrTooManyAttempts
}
return "", Operator{}, ErrInvalidPIN
}
// operatorRow holds a full operator row for internal use during login.
type operatorRow struct {
id string
name string
pinHash string
role string
createdAt int64
updatedAt int64
}
// loadOperators loads all operators from the database.
func (s *AuthService) loadOperators(ctx context.Context) ([]operatorRow, error) {
rows, err := s.db.QueryContext(ctx,
"SELECT id, name, pin_hash, role, created_at, updated_at FROM operators")
if err != nil {
return nil, err
}
defer rows.Close()
var operators []operatorRow
for rows.Next() {
var op operatorRow
if err := rows.Scan(&op.id, &op.name, &op.pinHash, &op.role, &op.createdAt, &op.updatedAt); err != nil {
return nil, err
}
operators = append(operators, op)
}
return operators, rows.Err()
}
// checkRateLimit checks if the operator is currently rate-limited.
func (s *AuthService) checkRateLimit(ctx context.Context, operatorID string) error {
limited, err := s.isRateLimited(ctx, operatorID)
if err != nil {
return err
}
if limited {
locked, _ := s.isLockedOut(ctx, operatorID)
if locked {
return ErrOperatorLocked
}
return ErrTooManyAttempts
}
return nil
}
// isRateLimited returns true if the operator has exceeded the failure threshold.
func (s *AuthService) isRateLimited(ctx context.Context, operatorID string) (bool, error) {
var failures int
var lockedUntil int64
err := s.db.QueryRowContext(ctx,
"SELECT consecutive_failures, locked_until FROM login_attempts WHERE operator_id = ?",
operatorID).Scan(&failures, &lockedUntil)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
if err != nil {
return false, err
}
now := time.Now().Unix()
// Check lockout
if lockedUntil > now {
return true, nil
}
return false, nil
}
// isLockedOut returns true if the operator is in a hard lockout (10+ failures).
func (s *AuthService) isLockedOut(ctx context.Context, operatorID string) (bool, error) {
var failures int
var lockedUntil int64
err := s.db.QueryRowContext(ctx,
"SELECT consecutive_failures, locked_until FROM login_attempts WHERE operator_id = ?",
operatorID).Scan(&failures, &lockedUntil)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
if err != nil {
return false, err
}
now := time.Now().Unix()
return failures >= 10 && lockedUntil > now, nil
}
// recordFailure increments the failure counter and applies rate limiting.
func (s *AuthService) recordFailure(ctx context.Context, operatorID string) {
now := time.Now().Unix()
// Upsert the failure record
_, err := s.db.ExecContext(ctx, `
INSERT INTO login_attempts (operator_id, consecutive_failures, last_failure_at, locked_until)
VALUES (?, 1, ?, 0)
ON CONFLICT(operator_id) DO UPDATE SET
consecutive_failures = consecutive_failures + 1,
last_failure_at = ?
`, operatorID, now, now)
if err != nil {
log.Printf("auth: record failure error: %v", err)
return
}
// Get current failure count to apply rate limiting
var failures int
err = s.db.QueryRowContext(ctx,
"SELECT consecutive_failures FROM login_attempts WHERE operator_id = ?",
operatorID).Scan(&failures)
if err != nil {
log.Printf("auth: get failure count error: %v", err)
return
}
// Apply rate limiting thresholds
var lockDuration time.Duration
switch {
case failures >= 10:
lockDuration = 30 * time.Minute
case failures >= 8:
lockDuration = 5 * time.Minute
case failures >= 5:
lockDuration = 30 * time.Second
default:
return
}
lockedUntil := now + int64(lockDuration.Seconds())
_, err = s.db.ExecContext(ctx,
"UPDATE login_attempts SET locked_until = ? WHERE operator_id = ?",
lockedUntil, operatorID)
if err != nil {
log.Printf("auth: set lockout error: %v", err)
return
}
log.Printf("auth: rate limiting activated for %s (failures=%d, locked_for=%s)",
operatorID, failures, lockDuration)
// Emit audit entry on 5+ failures
if s.recorder != nil && failures >= 5 {
_ = s.recorder(ctx, "operator.login_rate_limited", "operator", operatorID, map[string]interface{}{
"failures": failures,
"locked_until": lockedUntil,
})
}
}
// resetFailures clears the failure counter for an operator.
func (s *AuthService) resetFailures(ctx context.Context, operatorID string) {
_, err := s.db.ExecContext(ctx,
"DELETE FROM login_attempts WHERE operator_id = ?", operatorID)
if err != nil {
log.Printf("auth: reset failures error: %v", err)
}
// Also reset global counter on any successful login
_, err = s.db.ExecContext(ctx,
"DELETE FROM login_attempts WHERE operator_id = '_global'")
if err != nil {
log.Printf("auth: reset global failures error: %v", err)
}
}
// GetFailureCount returns the current consecutive failure count for rate limiting tests.
func (s *AuthService) GetFailureCount(ctx context.Context, operatorID string) (int, error) {
var failures int
err := s.db.QueryRowContext(ctx,
"SELECT consecutive_failures FROM login_attempts WHERE operator_id = ?",
operatorID).Scan(&failures)
if errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
return failures, err
}
// HashPIN hashes a PIN using bcrypt with cost 12.
func HashPIN(pin string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(pin), 12)
if err != nil {
return "", fmt.Errorf("auth: hash PIN: %w", err)
}
return string(hash), nil
}
// CreateOperator creates a new operator with a hashed PIN.
func (s *AuthService) CreateOperator(ctx context.Context, name, pin, role string) (Operator, error) {
if name == "" {
return Operator{}, fmt.Errorf("auth: empty operator name")
}
if pin == "" {
return Operator{}, fmt.Errorf("auth: empty PIN")
}
if role != "admin" && role != "floor" && role != "viewer" {
return Operator{}, fmt.Errorf("auth: invalid role %q (must be admin, floor, or viewer)", role)
}
hash, err := HashPIN(pin)
if err != nil {
return Operator{}, err
}
id := generateUUID()
now := time.Now().Unix()
_, err = s.db.ExecContext(ctx,
"INSERT INTO operators (id, name, pin_hash, role, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
id, name, hash, role, now, now)
if err != nil {
return Operator{}, fmt.Errorf("auth: create operator: %w", err)
}
return Operator{
ID: id,
Name: name,
Role: role,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
// ListOperators returns all operators (without PIN hashes).
func (s *AuthService) ListOperators(ctx context.Context) ([]Operator, error) {
rows, err := s.db.QueryContext(ctx,
"SELECT id, name, role, created_at, updated_at FROM operators ORDER BY name")
if err != nil {
return nil, fmt.Errorf("auth: list operators: %w", err)
}
defer rows.Close()
var operators []Operator
for rows.Next() {
var op Operator
if err := rows.Scan(&op.ID, &op.Name, &op.Role, &op.CreatedAt, &op.UpdatedAt); err != nil {
return nil, err
}
operators = append(operators, op)
}
return operators, rows.Err()
}
// UpdateOperator updates an operator's name, PIN (optional), and/or role.
// If pin is empty, the PIN is not changed.
func (s *AuthService) UpdateOperator(ctx context.Context, id, name, pin, role string) error {
if id == "" {
return fmt.Errorf("auth: empty operator ID")
}
if role != "" && role != "admin" && role != "floor" && role != "viewer" {
return fmt.Errorf("auth: invalid role %q", role)
}
now := time.Now().Unix()
if pin != "" {
hash, err := HashPIN(pin)
if err != nil {
return err
}
_, err = s.db.ExecContext(ctx,
"UPDATE operators SET name = ?, pin_hash = ?, role = ?, updated_at = ? WHERE id = ?",
name, hash, role, now, id)
if err != nil {
return fmt.Errorf("auth: update operator: %w", err)
}
} else {
_, err := s.db.ExecContext(ctx,
"UPDATE operators SET name = ?, role = ?, updated_at = ? WHERE id = ?",
name, role, now, id)
if err != nil {
return fmt.Errorf("auth: update operator: %w", err)
}
}
return nil
}
// generateUUID generates a v4 UUID.
func generateUUID() string {
b := make([]byte, 16)
_, _ = rand.Read(b)
b[6] = (b[6] & 0x0f) | 0x40 // Version 4
b[8] = (b[8] & 0x3f) | 0x80 // Variant 1
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
}