package auth import ( "context" "crypto/rand" "database/sql" "errors" "fmt" "log" "time" "golang.org/x/crypto/bcrypt" ) // Errors returned by the auth service. var ( ErrInvalidPIN = errors.New("auth: invalid PIN") ErrTooManyAttempts = errors.New("auth: too many failed attempts, please wait") ErrOperatorLocked = errors.New("auth: operator account locked") ErrOperatorExists = errors.New("auth: operator already exists") ) // Operator represents an authenticated operator in the system. type Operator struct { ID string `json:"id"` Name string `json:"name"` Role string `json:"role"` CreatedAt int64 `json:"created_at"` UpdatedAt int64 `json:"updated_at"` } // AuditRecorder is called by the auth service to log audit events. // This breaks the import cycle between auth and audit packages. type AuditRecorder func(ctx context.Context, action, targetType, targetID string, metadata map[string]interface{}) error // AuthService handles PIN-based authentication with rate limiting. type AuthService struct { db *sql.DB jwt *JWTService recorder AuditRecorder // optional, for audit logging } // NewAuthService creates a new authentication service. func NewAuthService(db *sql.DB, jwt *JWTService) *AuthService { return &AuthService{ db: db, jwt: jwt, } } // SetAuditRecorder sets the audit recorder for logging auth events. // Called after the audit trail is initialized to break the init cycle. func (s *AuthService) SetAuditRecorder(recorder AuditRecorder) { s.recorder = recorder } // initRateLimiter ensures the login_attempts table exists. func (s *AuthService) initRateLimiter() error { _, err := s.db.Exec(`CREATE TABLE IF NOT EXISTS login_attempts ( operator_id TEXT PRIMARY KEY, consecutive_failures INTEGER NOT NULL DEFAULT 0, last_failure_at INTEGER NOT NULL DEFAULT 0, locked_until INTEGER NOT NULL DEFAULT 0 )`) return err } // Login authenticates an operator by PIN and returns a JWT token. // It checks rate limiting before attempting authentication. func (s *AuthService) Login(ctx context.Context, pin string) (token string, operator Operator, err error) { if pin == "" { return "", Operator{}, ErrInvalidPIN } // Ensure rate limiter table exists if err := s.initRateLimiter(); err != nil { return "", Operator{}, fmt.Errorf("auth: init rate limiter: %w", err) } // Load all operators from DB operators, err := s.loadOperators(ctx) if err != nil { return "", Operator{}, fmt.Errorf("auth: load operators: %w", err) } // Try to match PIN against each operator for _, op := range operators { if err := s.checkRateLimit(ctx, op.id); err != nil { // Skip rate-limited operators silently during scan continue } if err := bcrypt.CompareHashAndPassword([]byte(op.pinHash), []byte(pin)); err == nil { // Match found - reset failure counter and issue JWT s.resetFailures(ctx, op.id) operator = Operator{ ID: op.id, Name: op.name, Role: op.role, CreatedAt: op.createdAt, UpdatedAt: op.updatedAt, } token, err = s.jwt.NewToken(op.id, op.role) if err != nil { return "", Operator{}, fmt.Errorf("auth: issue token: %w", err) } // Log successful login if s.recorder != nil { _ = s.recorder(ctx, "operator.login", "operator", op.id, map[string]interface{}{ "operator_name": op.name, "role": op.role, }) } return token, operator, nil } } // No match found - record failure for ALL operators (since PIN is global, not per-operator) // But we track per-operator to prevent brute-force enumeration // Actually, since the PIN is compared against all operators, we can't know // which operator was targeted. Record a global failure counter instead. // However, the spec says "keyed by operator ID" - this makes more sense // when PINs are unique per operator (which they should be in practice). // We'll record failure against a sentinel "global" key. s.recordFailure(ctx, "_global") globalLimited, _ := s.isRateLimited(ctx, "_global") if globalLimited { // Check if hard locked locked, _ := s.isLockedOut(ctx, "_global") if locked { return "", Operator{}, ErrOperatorLocked } return "", Operator{}, ErrTooManyAttempts } return "", Operator{}, ErrInvalidPIN } // operatorRow holds a full operator row for internal use during login. type operatorRow struct { id string name string pinHash string role string createdAt int64 updatedAt int64 } // loadOperators loads all operators from the database. func (s *AuthService) loadOperators(ctx context.Context) ([]operatorRow, error) { rows, err := s.db.QueryContext(ctx, "SELECT id, name, pin_hash, role, created_at, updated_at FROM operators") if err != nil { return nil, err } defer rows.Close() var operators []operatorRow for rows.Next() { var op operatorRow if err := rows.Scan(&op.id, &op.name, &op.pinHash, &op.role, &op.createdAt, &op.updatedAt); err != nil { return nil, err } operators = append(operators, op) } return operators, rows.Err() } // checkRateLimit checks if the operator is currently rate-limited. func (s *AuthService) checkRateLimit(ctx context.Context, operatorID string) error { limited, err := s.isRateLimited(ctx, operatorID) if err != nil { return err } if limited { locked, _ := s.isLockedOut(ctx, operatorID) if locked { return ErrOperatorLocked } return ErrTooManyAttempts } return nil } // isRateLimited returns true if the operator has exceeded the failure threshold. func (s *AuthService) isRateLimited(ctx context.Context, operatorID string) (bool, error) { var failures int var lockedUntil int64 err := s.db.QueryRowContext(ctx, "SELECT consecutive_failures, locked_until FROM login_attempts WHERE operator_id = ?", operatorID).Scan(&failures, &lockedUntil) if errors.Is(err, sql.ErrNoRows) { return false, nil } if err != nil { return false, err } now := time.Now().Unix() // Check lockout if lockedUntil > now { return true, nil } return false, nil } // isLockedOut returns true if the operator is in a hard lockout (10+ failures). func (s *AuthService) isLockedOut(ctx context.Context, operatorID string) (bool, error) { var failures int var lockedUntil int64 err := s.db.QueryRowContext(ctx, "SELECT consecutive_failures, locked_until FROM login_attempts WHERE operator_id = ?", operatorID).Scan(&failures, &lockedUntil) if errors.Is(err, sql.ErrNoRows) { return false, nil } if err != nil { return false, err } now := time.Now().Unix() return failures >= 10 && lockedUntil > now, nil } // recordFailure increments the failure counter and applies rate limiting. func (s *AuthService) recordFailure(ctx context.Context, operatorID string) { now := time.Now().Unix() // Upsert the failure record _, err := s.db.ExecContext(ctx, ` INSERT INTO login_attempts (operator_id, consecutive_failures, last_failure_at, locked_until) VALUES (?, 1, ?, 0) ON CONFLICT(operator_id) DO UPDATE SET consecutive_failures = consecutive_failures + 1, last_failure_at = ? `, operatorID, now, now) if err != nil { log.Printf("auth: record failure error: %v", err) return } // Get current failure count to apply rate limiting var failures int err = s.db.QueryRowContext(ctx, "SELECT consecutive_failures FROM login_attempts WHERE operator_id = ?", operatorID).Scan(&failures) if err != nil { log.Printf("auth: get failure count error: %v", err) return } // Apply rate limiting thresholds var lockDuration time.Duration switch { case failures >= 10: lockDuration = 30 * time.Minute case failures >= 8: lockDuration = 5 * time.Minute case failures >= 5: lockDuration = 30 * time.Second default: return } lockedUntil := now + int64(lockDuration.Seconds()) _, err = s.db.ExecContext(ctx, "UPDATE login_attempts SET locked_until = ? WHERE operator_id = ?", lockedUntil, operatorID) if err != nil { log.Printf("auth: set lockout error: %v", err) return } log.Printf("auth: rate limiting activated for %s (failures=%d, locked_for=%s)", operatorID, failures, lockDuration) // Emit audit entry on 5+ failures if s.recorder != nil && failures >= 5 { _ = s.recorder(ctx, "operator.login_rate_limited", "operator", operatorID, map[string]interface{}{ "failures": failures, "locked_until": lockedUntil, }) } } // resetFailures clears the failure counter for an operator. func (s *AuthService) resetFailures(ctx context.Context, operatorID string) { _, err := s.db.ExecContext(ctx, "DELETE FROM login_attempts WHERE operator_id = ?", operatorID) if err != nil { log.Printf("auth: reset failures error: %v", err) } // Also reset global counter on any successful login _, err = s.db.ExecContext(ctx, "DELETE FROM login_attempts WHERE operator_id = '_global'") if err != nil { log.Printf("auth: reset global failures error: %v", err) } } // GetFailureCount returns the current consecutive failure count for rate limiting tests. func (s *AuthService) GetFailureCount(ctx context.Context, operatorID string) (int, error) { var failures int err := s.db.QueryRowContext(ctx, "SELECT consecutive_failures FROM login_attempts WHERE operator_id = ?", operatorID).Scan(&failures) if errors.Is(err, sql.ErrNoRows) { return 0, nil } return failures, err } // HashPIN hashes a PIN using bcrypt with cost 12. func HashPIN(pin string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(pin), 12) if err != nil { return "", fmt.Errorf("auth: hash PIN: %w", err) } return string(hash), nil } // CreateOperator creates a new operator with a hashed PIN. func (s *AuthService) CreateOperator(ctx context.Context, name, pin, role string) (Operator, error) { if name == "" { return Operator{}, fmt.Errorf("auth: empty operator name") } if pin == "" { return Operator{}, fmt.Errorf("auth: empty PIN") } if role != "admin" && role != "floor" && role != "viewer" { return Operator{}, fmt.Errorf("auth: invalid role %q (must be admin, floor, or viewer)", role) } hash, err := HashPIN(pin) if err != nil { return Operator{}, err } id := generateUUID() now := time.Now().Unix() _, err = s.db.ExecContext(ctx, "INSERT INTO operators (id, name, pin_hash, role, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)", id, name, hash, role, now, now) if err != nil { return Operator{}, fmt.Errorf("auth: create operator: %w", err) } return Operator{ ID: id, Name: name, Role: role, CreatedAt: now, UpdatedAt: now, }, nil } // ListOperators returns all operators (without PIN hashes). func (s *AuthService) ListOperators(ctx context.Context) ([]Operator, error) { rows, err := s.db.QueryContext(ctx, "SELECT id, name, role, created_at, updated_at FROM operators ORDER BY name") if err != nil { return nil, fmt.Errorf("auth: list operators: %w", err) } defer rows.Close() var operators []Operator for rows.Next() { var op Operator if err := rows.Scan(&op.ID, &op.Name, &op.Role, &op.CreatedAt, &op.UpdatedAt); err != nil { return nil, err } operators = append(operators, op) } return operators, rows.Err() } // UpdateOperator updates an operator's name, PIN (optional), and/or role. // If pin is empty, the PIN is not changed. func (s *AuthService) UpdateOperator(ctx context.Context, id, name, pin, role string) error { if id == "" { return fmt.Errorf("auth: empty operator ID") } if role != "" && role != "admin" && role != "floor" && role != "viewer" { return fmt.Errorf("auth: invalid role %q", role) } now := time.Now().Unix() if pin != "" { hash, err := HashPIN(pin) if err != nil { return err } _, err = s.db.ExecContext(ctx, "UPDATE operators SET name = ?, pin_hash = ?, role = ?, updated_at = ? WHERE id = ?", name, hash, role, now, id) if err != nil { return fmt.Errorf("auth: update operator: %w", err) } } else { _, err := s.db.ExecContext(ctx, "UPDATE operators SET name = ?, role = ?, updated_at = ? WHERE id = ?", name, role, now, id) if err != nil { return fmt.Errorf("auth: update operator: %w", err) } } return nil } // generateUUID generates a v4 UUID. func generateUUID() string { b := make([]byte, 16) _, _ = rand.Read(b) b[6] = (b[6] & 0x0f) | 0x40 // Version 4 b[8] = (b[8] & 0x3f) | 0x80 // Variant 1 return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]) }