- Clock API routes: start, pause, resume, advance, rewind, jump, get, warnings - Role-based access control (floor+ for mutations, any auth for reads) - Clock state persistence callback to DB on meaningful changes - Blind structure levels loaded from DB on clock start - Clock registry wired into HTTP server and cmd/leaf main - 25 tests covering: state machine, countdown, pause/resume, auto-advance, jump, rewind, hand-for-hand, warnings, overtime, crash recovery, snapshot - Fix missing crypto/rand import in auth/pin.go (Rule 3 auto-fix) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
454 lines
12 KiB
Go
454 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/coder/websocket"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
|
|
feltauth "github.com/felt-app/felt/internal/auth"
|
|
"github.com/felt-app/felt/internal/clock"
|
|
feltnats "github.com/felt-app/felt/internal/nats"
|
|
"github.com/felt-app/felt/internal/server"
|
|
"github.com/felt-app/felt/internal/server/middleware"
|
|
"github.com/felt-app/felt/internal/server/ws"
|
|
"github.com/felt-app/felt/internal/store"
|
|
)
|
|
|
|
func setupTestServer(t *testing.T) (*httptest.Server, *store.DB, *feltnats.EmbeddedServer, []byte, *feltauth.AuthService) {
|
|
t.Helper()
|
|
ctx := context.Background()
|
|
tmpDir := t.TempDir()
|
|
|
|
// Open database
|
|
db, err := store.Open(tmpDir, true)
|
|
if err != nil {
|
|
t.Fatalf("open database: %v", err)
|
|
}
|
|
t.Cleanup(func() { db.Close() })
|
|
|
|
// Start NATS
|
|
ns, err := feltnats.Start(ctx, tmpDir)
|
|
if err != nil {
|
|
t.Fatalf("start nats: %v", err)
|
|
}
|
|
t.Cleanup(func() { ns.Shutdown() })
|
|
|
|
// Setup JWT signing
|
|
signingKey := []byte("test-signing-key-32-bytes-long!!")
|
|
|
|
// Create auth service
|
|
jwtService := feltauth.NewJWTService(signingKey, 7*24*time.Hour)
|
|
authService := feltauth.NewAuthService(db.DB, jwtService)
|
|
|
|
tokenValidator := func(tokenStr string) (string, string, error) {
|
|
return middleware.ValidateJWT(tokenStr, signingKey)
|
|
}
|
|
|
|
hub := ws.NewHub(tokenValidator, nil, nil)
|
|
t.Cleanup(func() { hub.Shutdown() })
|
|
|
|
// Clock registry
|
|
clockRegistry := clock.NewRegistry(hub)
|
|
t.Cleanup(func() { clockRegistry.Shutdown() })
|
|
|
|
// Create HTTP server
|
|
srv := server.New(server.Config{
|
|
Addr: ":0",
|
|
SigningKey: signingKey,
|
|
DevMode: true,
|
|
}, db.DB, ns.Server(), hub, authService, clockRegistry)
|
|
|
|
ts := httptest.NewServer(srv.Handler())
|
|
t.Cleanup(func() { ts.Close() })
|
|
|
|
return ts, db, ns, signingKey, authService
|
|
}
|
|
|
|
func makeToken(t *testing.T, signingKey []byte, operatorID, role string) string {
|
|
t.Helper()
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
"sub": operatorID,
|
|
"role": role,
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
})
|
|
tokenStr, err := token.SignedString(signingKey)
|
|
if err != nil {
|
|
t.Fatalf("sign token: %v", err)
|
|
}
|
|
return tokenStr
|
|
}
|
|
|
|
func TestHealthEndpoint(t *testing.T) {
|
|
ts, _, _, _, _ := setupTestServer(t)
|
|
|
|
resp, err := http.Get(ts.URL + "/api/v1/health")
|
|
if err != nil {
|
|
t.Fatalf("health request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
var health map[string]interface{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&health); err != nil {
|
|
t.Fatalf("decode health: %v", err)
|
|
}
|
|
|
|
if health["status"] != "ok" {
|
|
t.Fatalf("expected status ok, got %v", health["status"])
|
|
}
|
|
|
|
// Check subsystems
|
|
subsystems, ok := health["subsystems"].(map[string]interface{})
|
|
if !ok {
|
|
t.Fatal("missing subsystems in health response")
|
|
}
|
|
|
|
dbStatus, ok := subsystems["database"].(map[string]interface{})
|
|
if !ok || dbStatus["status"] != "ok" {
|
|
t.Fatalf("database not ok: %v", dbStatus)
|
|
}
|
|
|
|
natsStatus, ok := subsystems["nats"].(map[string]interface{})
|
|
if !ok || natsStatus["status"] != "ok" {
|
|
t.Fatalf("nats not ok: %v", natsStatus)
|
|
}
|
|
|
|
wsStatus, ok := subsystems["websocket"].(map[string]interface{})
|
|
if !ok || wsStatus["status"] != "ok" {
|
|
t.Fatalf("websocket not ok: %v", wsStatus)
|
|
}
|
|
}
|
|
|
|
func TestSPAFallback(t *testing.T) {
|
|
ts, _, _, _, _ := setupTestServer(t)
|
|
|
|
// Root path
|
|
resp, err := http.Get(ts.URL + "/")
|
|
if err != nil {
|
|
t.Fatalf("root request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("expected 200 for root, got %d", resp.StatusCode)
|
|
}
|
|
|
|
// Unknown path (SPA fallback)
|
|
resp2, err := http.Get(ts.URL + "/some/unknown/route")
|
|
if err != nil {
|
|
t.Fatalf("unknown path request: %v", err)
|
|
}
|
|
defer resp2.Body.Close()
|
|
if resp2.StatusCode != 200 {
|
|
t.Fatalf("expected 200 for SPA fallback, got %d", resp2.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestWebSocketRejectsMissingToken(t *testing.T) {
|
|
ts, _, _, _, _ := setupTestServer(t)
|
|
ctx := context.Background()
|
|
|
|
_, resp, err := websocket.Dial(ctx, "ws"+ts.URL[4:]+"/ws", nil)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing token")
|
|
}
|
|
if resp != nil && resp.StatusCode != 401 {
|
|
t.Fatalf("expected 401, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestWebSocketRejectsInvalidToken(t *testing.T) {
|
|
ts, _, _, _, _ := setupTestServer(t)
|
|
ctx := context.Background()
|
|
|
|
_, resp, err := websocket.Dial(ctx, "ws"+ts.URL[4:]+"/ws?token=invalid", nil)
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid token")
|
|
}
|
|
if resp != nil && resp.StatusCode != 401 {
|
|
t.Fatalf("expected 401, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestWebSocketAcceptsValidToken(t *testing.T) {
|
|
ts, _, _, signingKey, _ := setupTestServer(t)
|
|
ctx := context.Background()
|
|
|
|
tokenStr := makeToken(t, signingKey, "operator-123", "admin")
|
|
|
|
wsURL := "ws" + ts.URL[4:] + "/ws?token=" + tokenStr
|
|
conn, _, err := websocket.Dial(ctx, wsURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("websocket dial: %v", err)
|
|
}
|
|
defer conn.CloseNow()
|
|
|
|
// Should receive a connected message
|
|
readCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
|
defer cancel()
|
|
|
|
_, msgBytes, err := conn.Read(readCtx)
|
|
if err != nil {
|
|
t.Fatalf("read connected message: %v", err)
|
|
}
|
|
|
|
var msg map[string]interface{}
|
|
if err := json.Unmarshal(msgBytes, &msg); err != nil {
|
|
t.Fatalf("decode message: %v", err)
|
|
}
|
|
|
|
if msg["type"] != "connected" {
|
|
t.Fatalf("expected 'connected' message, got %v", msg["type"])
|
|
}
|
|
|
|
conn.Close(websocket.StatusNormalClosure, "test done")
|
|
}
|
|
|
|
func TestNATSStreamsExist(t *testing.T) {
|
|
_, _, ns, _, _ := setupTestServer(t)
|
|
ctx := context.Background()
|
|
|
|
js := ns.JetStream()
|
|
|
|
// Check AUDIT stream
|
|
stream, err := js.Stream(ctx, "AUDIT")
|
|
if err != nil {
|
|
t.Fatalf("get AUDIT stream: %v", err)
|
|
}
|
|
info, err := stream.Info(ctx)
|
|
if err != nil {
|
|
t.Fatalf("get AUDIT stream info: %v", err)
|
|
}
|
|
if info.Config.Name != "AUDIT" {
|
|
t.Fatalf("expected AUDIT stream, got %s", info.Config.Name)
|
|
}
|
|
|
|
// Check STATE stream
|
|
stream, err = js.Stream(ctx, "STATE")
|
|
if err != nil {
|
|
t.Fatalf("get STATE stream: %v", err)
|
|
}
|
|
info, err = stream.Info(ctx)
|
|
if err != nil {
|
|
t.Fatalf("get STATE stream info: %v", err)
|
|
}
|
|
if info.Config.Name != "STATE" {
|
|
t.Fatalf("expected STATE stream, got %s", info.Config.Name)
|
|
}
|
|
}
|
|
|
|
func TestPublisherUUIDValidation(t *testing.T) {
|
|
_, _, ns, _, _ := setupTestServer(t)
|
|
ctx := context.Background()
|
|
|
|
js := ns.JetStream()
|
|
pub := feltnats.NewPublisher(js)
|
|
|
|
// Empty UUID
|
|
_, err := pub.Publish(ctx, "", "audit", []byte("test"))
|
|
if err == nil {
|
|
t.Fatal("expected error for empty UUID")
|
|
}
|
|
|
|
// UUID with NATS wildcards
|
|
_, err = pub.Publish(ctx, "test.*.injection", "audit", []byte("test"))
|
|
if err == nil {
|
|
t.Fatal("expected error for UUID with wildcards")
|
|
}
|
|
|
|
// Invalid format
|
|
_, err = pub.Publish(ctx, "not-a-uuid", "audit", []byte("test"))
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid UUID format")
|
|
}
|
|
|
|
// Valid UUID should succeed
|
|
_, err = pub.Publish(ctx, "550e8400-e29b-41d4-a716-446655440000", "audit", []byte(`{"test":true}`))
|
|
if err != nil {
|
|
t.Fatalf("expected success for valid UUID, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestLibSQLWALMode(t *testing.T) {
|
|
_, db, _, _, _ := setupTestServer(t)
|
|
|
|
var mode string
|
|
err := db.QueryRow("PRAGMA journal_mode").Scan(&mode)
|
|
if err != nil {
|
|
t.Fatalf("query journal_mode: %v", err)
|
|
}
|
|
if mode != "wal" {
|
|
t.Fatalf("expected WAL mode, got %s", mode)
|
|
}
|
|
}
|
|
|
|
func TestLibSQLForeignKeys(t *testing.T) {
|
|
_, db, _, _, _ := setupTestServer(t)
|
|
|
|
var fk int
|
|
err := db.QueryRow("PRAGMA foreign_keys").Scan(&fk)
|
|
if err != nil {
|
|
t.Fatalf("query foreign_keys: %v", err)
|
|
}
|
|
if fk != 1 {
|
|
t.Fatalf("expected foreign_keys=1, got %d", fk)
|
|
}
|
|
}
|
|
|
|
// ---- Auth Tests (Plan C) ----
|
|
|
|
func TestLoginWithCorrectPIN(t *testing.T) {
|
|
ts, _, _, _, _ := setupTestServer(t)
|
|
|
|
// Dev seed creates admin with PIN 1234
|
|
body := bytes.NewBufferString(`{"pin":"1234"}`)
|
|
resp, err := http.Post(ts.URL+"/api/v1/auth/login", "application/json", body)
|
|
if err != nil {
|
|
t.Fatalf("login request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
var loginResp map[string]interface{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&loginResp); err != nil {
|
|
t.Fatalf("decode login response: %v", err)
|
|
}
|
|
|
|
if loginResp["token"] == nil || loginResp["token"] == "" {
|
|
t.Fatal("expected token in login response")
|
|
}
|
|
|
|
operator, ok := loginResp["operator"].(map[string]interface{})
|
|
if !ok {
|
|
t.Fatal("expected operator in login response")
|
|
}
|
|
|
|
if operator["name"] != "Admin" {
|
|
t.Fatalf("expected operator name Admin, got %v", operator["name"])
|
|
}
|
|
if operator["role"] != "admin" {
|
|
t.Fatalf("expected operator role admin, got %v", operator["role"])
|
|
}
|
|
}
|
|
|
|
func TestLoginWithWrongPIN(t *testing.T) {
|
|
ts, _, _, _, _ := setupTestServer(t)
|
|
|
|
body := bytes.NewBufferString(`{"pin":"9999"}`)
|
|
resp, err := http.Post(ts.URL+"/api/v1/auth/login", "application/json", body)
|
|
if err != nil {
|
|
t.Fatalf("login request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 401 {
|
|
t.Fatalf("expected 401, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestLoginTokenAccessesProtectedEndpoint(t *testing.T) {
|
|
ts, _, _, _, _ := setupTestServer(t)
|
|
|
|
// Login first
|
|
body := bytes.NewBufferString(`{"pin":"1234"}`)
|
|
resp, err := http.Post(ts.URL+"/api/v1/auth/login", "application/json", body)
|
|
if err != nil {
|
|
t.Fatalf("login request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var loginResp map[string]interface{}
|
|
json.NewDecoder(resp.Body).Decode(&loginResp)
|
|
token := loginResp["token"].(string)
|
|
|
|
// Use token to access /auth/me
|
|
req, _ := http.NewRequest("GET", ts.URL+"/api/v1/auth/me", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
|
|
meResp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("me request: %v", err)
|
|
}
|
|
defer meResp.Body.Close()
|
|
|
|
if meResp.StatusCode != 200 {
|
|
t.Fatalf("expected 200, got %d", meResp.StatusCode)
|
|
}
|
|
|
|
var me map[string]interface{}
|
|
json.NewDecoder(meResp.Body).Decode(&me)
|
|
|
|
if me["role"] != "admin" {
|
|
t.Fatalf("expected role admin, got %v", me["role"])
|
|
}
|
|
}
|
|
|
|
func TestProtectedEndpointWithoutToken(t *testing.T) {
|
|
ts, _, _, _, _ := setupTestServer(t)
|
|
|
|
resp, err := http.Get(ts.URL + "/api/v1/auth/me")
|
|
if err != nil {
|
|
t.Fatalf("me request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 401 {
|
|
t.Fatalf("expected 401, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestRoleMiddlewareBlocksInsufficientRole(t *testing.T) {
|
|
ts, _, _, signingKey, _ := setupTestServer(t)
|
|
|
|
// Create a viewer token -- viewers can't access admin endpoints
|
|
viewerToken := makeToken(t, signingKey, "viewer-op", "viewer")
|
|
|
|
// /tournaments requires auth but should be accessible by any role for now
|
|
req, _ := http.NewRequest("GET", ts.URL+"/api/v1/tournaments", nil)
|
|
req.Header.Set("Authorization", "Bearer "+viewerToken)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Currently all authenticated users can access stub endpoints
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestJWTValidationEnforcesHS256(t *testing.T) {
|
|
signingKey := []byte("test-signing-key-32-bytes-long!!")
|
|
|
|
// Create a valid HS256 token -- should work
|
|
_, _, err := middleware.ValidateJWT(makeToken(t, signingKey, "op-1", "admin"), signingKey)
|
|
if err != nil {
|
|
t.Fatalf("valid HS256 token should pass: %v", err)
|
|
}
|
|
|
|
// Create an expired token -- should fail
|
|
expiredToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
"sub": "op-1",
|
|
"role": "admin",
|
|
"exp": time.Now().Add(-time.Hour).Unix(),
|
|
})
|
|
expiredStr, _ := expiredToken.SignedString(signingKey)
|
|
_, _, err = middleware.ValidateJWT(expiredStr, signingKey)
|
|
if err == nil {
|
|
t.Fatal("expired token should fail validation")
|
|
}
|
|
}
|