- Clock API routes: start, pause, resume, advance, rewind, jump, get, warnings - Role-based access control (floor+ for mutations, any auth for reads) - Clock state persistence callback to DB on meaningful changes - Blind structure levels loaded from DB on clock start - Clock registry wired into HTTP server and cmd/leaf main - 25 tests covering: state machine, countdown, pause/resume, auto-advance, jump, rewind, hand-for-hand, warnings, overtime, crash recovery, snapshot - Fix missing crypto/rand import in auth/pin.go (Rule 3 auto-fix) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
428 lines
12 KiB
Go
428 lines
12 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
// Errors returned by the auth service.
|
|
var (
|
|
ErrInvalidPIN = errors.New("auth: invalid PIN")
|
|
ErrTooManyAttempts = errors.New("auth: too many failed attempts, please wait")
|
|
ErrOperatorLocked = errors.New("auth: operator account locked")
|
|
ErrOperatorExists = errors.New("auth: operator already exists")
|
|
)
|
|
|
|
// Operator represents an authenticated operator in the system.
|
|
type Operator struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Role string `json:"role"`
|
|
CreatedAt int64 `json:"created_at"`
|
|
UpdatedAt int64 `json:"updated_at"`
|
|
}
|
|
|
|
// AuditRecorder is called by the auth service to log audit events.
|
|
// This breaks the import cycle between auth and audit packages.
|
|
type AuditRecorder func(ctx context.Context, action, targetType, targetID string, metadata map[string]interface{}) error
|
|
|
|
// AuthService handles PIN-based authentication with rate limiting.
|
|
type AuthService struct {
|
|
db *sql.DB
|
|
jwt *JWTService
|
|
recorder AuditRecorder // optional, for audit logging
|
|
}
|
|
|
|
// NewAuthService creates a new authentication service.
|
|
func NewAuthService(db *sql.DB, jwt *JWTService) *AuthService {
|
|
return &AuthService{
|
|
db: db,
|
|
jwt: jwt,
|
|
}
|
|
}
|
|
|
|
// SetAuditRecorder sets the audit recorder for logging auth events.
|
|
// Called after the audit trail is initialized to break the init cycle.
|
|
func (s *AuthService) SetAuditRecorder(recorder AuditRecorder) {
|
|
s.recorder = recorder
|
|
}
|
|
|
|
// initRateLimiter ensures the login_attempts table exists.
|
|
func (s *AuthService) initRateLimiter() error {
|
|
_, err := s.db.Exec(`CREATE TABLE IF NOT EXISTS login_attempts (
|
|
operator_id TEXT PRIMARY KEY,
|
|
consecutive_failures INTEGER NOT NULL DEFAULT 0,
|
|
last_failure_at INTEGER NOT NULL DEFAULT 0,
|
|
locked_until INTEGER NOT NULL DEFAULT 0
|
|
)`)
|
|
return err
|
|
}
|
|
|
|
// Login authenticates an operator by PIN and returns a JWT token.
|
|
// It checks rate limiting before attempting authentication.
|
|
func (s *AuthService) Login(ctx context.Context, pin string) (token string, operator Operator, err error) {
|
|
if pin == "" {
|
|
return "", Operator{}, ErrInvalidPIN
|
|
}
|
|
|
|
// Ensure rate limiter table exists
|
|
if err := s.initRateLimiter(); err != nil {
|
|
return "", Operator{}, fmt.Errorf("auth: init rate limiter: %w", err)
|
|
}
|
|
|
|
// Load all operators from DB
|
|
operators, err := s.loadOperators(ctx)
|
|
if err != nil {
|
|
return "", Operator{}, fmt.Errorf("auth: load operators: %w", err)
|
|
}
|
|
|
|
// Try to match PIN against each operator
|
|
for _, op := range operators {
|
|
if err := s.checkRateLimit(ctx, op.id); err != nil {
|
|
// Skip rate-limited operators silently during scan
|
|
continue
|
|
}
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(op.pinHash), []byte(pin)); err == nil {
|
|
// Match found - reset failure counter and issue JWT
|
|
s.resetFailures(ctx, op.id)
|
|
|
|
operator = Operator{
|
|
ID: op.id,
|
|
Name: op.name,
|
|
Role: op.role,
|
|
CreatedAt: op.createdAt,
|
|
UpdatedAt: op.updatedAt,
|
|
}
|
|
|
|
token, err = s.jwt.NewToken(op.id, op.role)
|
|
if err != nil {
|
|
return "", Operator{}, fmt.Errorf("auth: issue token: %w", err)
|
|
}
|
|
|
|
// Log successful login
|
|
if s.recorder != nil {
|
|
_ = s.recorder(ctx, "operator.login", "operator", op.id, map[string]interface{}{
|
|
"operator_name": op.name,
|
|
"role": op.role,
|
|
})
|
|
}
|
|
|
|
return token, operator, nil
|
|
}
|
|
}
|
|
|
|
// No match found - record failure for ALL operators (since PIN is global, not per-operator)
|
|
// But we track per-operator to prevent brute-force enumeration
|
|
// Actually, since the PIN is compared against all operators, we can't know
|
|
// which operator was targeted. Record a global failure counter instead.
|
|
// However, the spec says "keyed by operator ID" - this makes more sense
|
|
// when PINs are unique per operator (which they should be in practice).
|
|
// We'll record failure against a sentinel "global" key.
|
|
s.recordFailure(ctx, "_global")
|
|
globalLimited, _ := s.isRateLimited(ctx, "_global")
|
|
if globalLimited {
|
|
// Check if hard locked
|
|
locked, _ := s.isLockedOut(ctx, "_global")
|
|
if locked {
|
|
return "", Operator{}, ErrOperatorLocked
|
|
}
|
|
return "", Operator{}, ErrTooManyAttempts
|
|
}
|
|
|
|
return "", Operator{}, ErrInvalidPIN
|
|
}
|
|
|
|
// operatorRow holds a full operator row for internal use during login.
|
|
type operatorRow struct {
|
|
id string
|
|
name string
|
|
pinHash string
|
|
role string
|
|
createdAt int64
|
|
updatedAt int64
|
|
}
|
|
|
|
// loadOperators loads all operators from the database.
|
|
func (s *AuthService) loadOperators(ctx context.Context) ([]operatorRow, error) {
|
|
rows, err := s.db.QueryContext(ctx,
|
|
"SELECT id, name, pin_hash, role, created_at, updated_at FROM operators")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var operators []operatorRow
|
|
for rows.Next() {
|
|
var op operatorRow
|
|
if err := rows.Scan(&op.id, &op.name, &op.pinHash, &op.role, &op.createdAt, &op.updatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
operators = append(operators, op)
|
|
}
|
|
return operators, rows.Err()
|
|
}
|
|
|
|
// checkRateLimit checks if the operator is currently rate-limited.
|
|
func (s *AuthService) checkRateLimit(ctx context.Context, operatorID string) error {
|
|
limited, err := s.isRateLimited(ctx, operatorID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if limited {
|
|
locked, _ := s.isLockedOut(ctx, operatorID)
|
|
if locked {
|
|
return ErrOperatorLocked
|
|
}
|
|
return ErrTooManyAttempts
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// isRateLimited returns true if the operator has exceeded the failure threshold.
|
|
func (s *AuthService) isRateLimited(ctx context.Context, operatorID string) (bool, error) {
|
|
var failures int
|
|
var lockedUntil int64
|
|
err := s.db.QueryRowContext(ctx,
|
|
"SELECT consecutive_failures, locked_until FROM login_attempts WHERE operator_id = ?",
|
|
operatorID).Scan(&failures, &lockedUntil)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
now := time.Now().Unix()
|
|
|
|
// Check lockout
|
|
if lockedUntil > now {
|
|
return true, nil
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// isLockedOut returns true if the operator is in a hard lockout (10+ failures).
|
|
func (s *AuthService) isLockedOut(ctx context.Context, operatorID string) (bool, error) {
|
|
var failures int
|
|
var lockedUntil int64
|
|
err := s.db.QueryRowContext(ctx,
|
|
"SELECT consecutive_failures, locked_until FROM login_attempts WHERE operator_id = ?",
|
|
operatorID).Scan(&failures, &lockedUntil)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
now := time.Now().Unix()
|
|
return failures >= 10 && lockedUntil > now, nil
|
|
}
|
|
|
|
// recordFailure increments the failure counter and applies rate limiting.
|
|
func (s *AuthService) recordFailure(ctx context.Context, operatorID string) {
|
|
now := time.Now().Unix()
|
|
|
|
// Upsert the failure record
|
|
_, err := s.db.ExecContext(ctx, `
|
|
INSERT INTO login_attempts (operator_id, consecutive_failures, last_failure_at, locked_until)
|
|
VALUES (?, 1, ?, 0)
|
|
ON CONFLICT(operator_id) DO UPDATE SET
|
|
consecutive_failures = consecutive_failures + 1,
|
|
last_failure_at = ?
|
|
`, operatorID, now, now)
|
|
if err != nil {
|
|
log.Printf("auth: record failure error: %v", err)
|
|
return
|
|
}
|
|
|
|
// Get current failure count to apply rate limiting
|
|
var failures int
|
|
err = s.db.QueryRowContext(ctx,
|
|
"SELECT consecutive_failures FROM login_attempts WHERE operator_id = ?",
|
|
operatorID).Scan(&failures)
|
|
if err != nil {
|
|
log.Printf("auth: get failure count error: %v", err)
|
|
return
|
|
}
|
|
|
|
// Apply rate limiting thresholds
|
|
var lockDuration time.Duration
|
|
switch {
|
|
case failures >= 10:
|
|
lockDuration = 30 * time.Minute
|
|
case failures >= 8:
|
|
lockDuration = 5 * time.Minute
|
|
case failures >= 5:
|
|
lockDuration = 30 * time.Second
|
|
default:
|
|
return
|
|
}
|
|
|
|
lockedUntil := now + int64(lockDuration.Seconds())
|
|
_, err = s.db.ExecContext(ctx,
|
|
"UPDATE login_attempts SET locked_until = ? WHERE operator_id = ?",
|
|
lockedUntil, operatorID)
|
|
if err != nil {
|
|
log.Printf("auth: set lockout error: %v", err)
|
|
return
|
|
}
|
|
|
|
log.Printf("auth: rate limiting activated for %s (failures=%d, locked_for=%s)",
|
|
operatorID, failures, lockDuration)
|
|
|
|
// Emit audit entry on 5+ failures
|
|
if s.recorder != nil && failures >= 5 {
|
|
_ = s.recorder(ctx, "operator.login_rate_limited", "operator", operatorID, map[string]interface{}{
|
|
"failures": failures,
|
|
"locked_until": lockedUntil,
|
|
})
|
|
}
|
|
}
|
|
|
|
// resetFailures clears the failure counter for an operator.
|
|
func (s *AuthService) resetFailures(ctx context.Context, operatorID string) {
|
|
_, err := s.db.ExecContext(ctx,
|
|
"DELETE FROM login_attempts WHERE operator_id = ?", operatorID)
|
|
if err != nil {
|
|
log.Printf("auth: reset failures error: %v", err)
|
|
}
|
|
|
|
// Also reset global counter on any successful login
|
|
_, err = s.db.ExecContext(ctx,
|
|
"DELETE FROM login_attempts WHERE operator_id = '_global'")
|
|
if err != nil {
|
|
log.Printf("auth: reset global failures error: %v", err)
|
|
}
|
|
}
|
|
|
|
// GetFailureCount returns the current consecutive failure count for rate limiting tests.
|
|
func (s *AuthService) GetFailureCount(ctx context.Context, operatorID string) (int, error) {
|
|
var failures int
|
|
err := s.db.QueryRowContext(ctx,
|
|
"SELECT consecutive_failures FROM login_attempts WHERE operator_id = ?",
|
|
operatorID).Scan(&failures)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return 0, nil
|
|
}
|
|
return failures, err
|
|
}
|
|
|
|
// HashPIN hashes a PIN using bcrypt with cost 12.
|
|
func HashPIN(pin string) (string, error) {
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(pin), 12)
|
|
if err != nil {
|
|
return "", fmt.Errorf("auth: hash PIN: %w", err)
|
|
}
|
|
return string(hash), nil
|
|
}
|
|
|
|
// CreateOperator creates a new operator with a hashed PIN.
|
|
func (s *AuthService) CreateOperator(ctx context.Context, name, pin, role string) (Operator, error) {
|
|
if name == "" {
|
|
return Operator{}, fmt.Errorf("auth: empty operator name")
|
|
}
|
|
if pin == "" {
|
|
return Operator{}, fmt.Errorf("auth: empty PIN")
|
|
}
|
|
if role != "admin" && role != "floor" && role != "viewer" {
|
|
return Operator{}, fmt.Errorf("auth: invalid role %q (must be admin, floor, or viewer)", role)
|
|
}
|
|
|
|
hash, err := HashPIN(pin)
|
|
if err != nil {
|
|
return Operator{}, err
|
|
}
|
|
|
|
id := generateUUID()
|
|
now := time.Now().Unix()
|
|
|
|
_, err = s.db.ExecContext(ctx,
|
|
"INSERT INTO operators (id, name, pin_hash, role, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
|
|
id, name, hash, role, now, now)
|
|
if err != nil {
|
|
return Operator{}, fmt.Errorf("auth: create operator: %w", err)
|
|
}
|
|
|
|
return Operator{
|
|
ID: id,
|
|
Name: name,
|
|
Role: role,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}, nil
|
|
}
|
|
|
|
// ListOperators returns all operators (without PIN hashes).
|
|
func (s *AuthService) ListOperators(ctx context.Context) ([]Operator, error) {
|
|
rows, err := s.db.QueryContext(ctx,
|
|
"SELECT id, name, role, created_at, updated_at FROM operators ORDER BY name")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("auth: list operators: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var operators []Operator
|
|
for rows.Next() {
|
|
var op Operator
|
|
if err := rows.Scan(&op.ID, &op.Name, &op.Role, &op.CreatedAt, &op.UpdatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
operators = append(operators, op)
|
|
}
|
|
return operators, rows.Err()
|
|
}
|
|
|
|
// UpdateOperator updates an operator's name, PIN (optional), and/or role.
|
|
// If pin is empty, the PIN is not changed.
|
|
func (s *AuthService) UpdateOperator(ctx context.Context, id, name, pin, role string) error {
|
|
if id == "" {
|
|
return fmt.Errorf("auth: empty operator ID")
|
|
}
|
|
if role != "" && role != "admin" && role != "floor" && role != "viewer" {
|
|
return fmt.Errorf("auth: invalid role %q", role)
|
|
}
|
|
|
|
now := time.Now().Unix()
|
|
|
|
if pin != "" {
|
|
hash, err := HashPIN(pin)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = s.db.ExecContext(ctx,
|
|
"UPDATE operators SET name = ?, pin_hash = ?, role = ?, updated_at = ? WHERE id = ?",
|
|
name, hash, role, now, id)
|
|
if err != nil {
|
|
return fmt.Errorf("auth: update operator: %w", err)
|
|
}
|
|
} else {
|
|
_, err := s.db.ExecContext(ctx,
|
|
"UPDATE operators SET name = ?, role = ?, updated_at = ? WHERE id = ?",
|
|
name, role, now, id)
|
|
if err != nil {
|
|
return fmt.Errorf("auth: update operator: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// generateUUID generates a v4 UUID.
|
|
func generateUUID() string {
|
|
b := make([]byte, 16)
|
|
_, _ = rand.Read(b)
|
|
b[6] = (b[6] & 0x0f) | 0x40 // Version 4
|
|
b[8] = (b[8] & 0x3f) | 0x80 // Variant 1
|
|
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
|
|
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
|
}
|