package main import ( "bytes" "context" "encoding/json" "net/http" "net/http/httptest" "testing" "time" "github.com/coder/websocket" "github.com/golang-jwt/jwt/v5" feltauth "github.com/felt-app/felt/internal/auth" "github.com/felt-app/felt/internal/clock" feltnats "github.com/felt-app/felt/internal/nats" "github.com/felt-app/felt/internal/server" "github.com/felt-app/felt/internal/server/middleware" "github.com/felt-app/felt/internal/server/ws" "github.com/felt-app/felt/internal/store" ) func setupTestServer(t *testing.T) (*httptest.Server, *store.DB, *feltnats.EmbeddedServer, []byte, *feltauth.AuthService) { t.Helper() ctx := context.Background() tmpDir := t.TempDir() // Open database db, err := store.Open(tmpDir, true) if err != nil { t.Fatalf("open database: %v", err) } t.Cleanup(func() { db.Close() }) // Start NATS ns, err := feltnats.Start(ctx, tmpDir) if err != nil { t.Fatalf("start nats: %v", err) } t.Cleanup(func() { ns.Shutdown() }) // Setup JWT signing signingKey := []byte("test-signing-key-32-bytes-long!!") // Create auth service jwtService := feltauth.NewJWTService(signingKey, 7*24*time.Hour) authService := feltauth.NewAuthService(db.DB, jwtService) tokenValidator := func(tokenStr string) (string, string, error) { return middleware.ValidateJWT(tokenStr, signingKey) } hub := ws.NewHub(tokenValidator, nil, nil) t.Cleanup(func() { hub.Shutdown() }) // Clock registry clockRegistry := clock.NewRegistry(hub) t.Cleanup(func() { clockRegistry.Shutdown() }) // Create HTTP server srv := server.New(server.Config{ Addr: ":0", SigningKey: signingKey, DevMode: true, }, db.DB, ns.Server(), hub, authService, clockRegistry) ts := httptest.NewServer(srv.Handler()) t.Cleanup(func() { ts.Close() }) return ts, db, ns, signingKey, authService } func makeToken(t *testing.T, signingKey []byte, operatorID, role string) string { t.Helper() token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": operatorID, "role": role, "exp": time.Now().Add(time.Hour).Unix(), }) tokenStr, err := token.SignedString(signingKey) if err != nil { t.Fatalf("sign token: %v", err) } return tokenStr } func TestHealthEndpoint(t *testing.T) { ts, _, _, _, _ := setupTestServer(t) resp, err := http.Get(ts.URL + "/api/v1/health") if err != nil { t.Fatalf("health request: %v", err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("expected 200, got %d", resp.StatusCode) } var health map[string]interface{} if err := json.NewDecoder(resp.Body).Decode(&health); err != nil { t.Fatalf("decode health: %v", err) } if health["status"] != "ok" { t.Fatalf("expected status ok, got %v", health["status"]) } // Check subsystems subsystems, ok := health["subsystems"].(map[string]interface{}) if !ok { t.Fatal("missing subsystems in health response") } dbStatus, ok := subsystems["database"].(map[string]interface{}) if !ok || dbStatus["status"] != "ok" { t.Fatalf("database not ok: %v", dbStatus) } natsStatus, ok := subsystems["nats"].(map[string]interface{}) if !ok || natsStatus["status"] != "ok" { t.Fatalf("nats not ok: %v", natsStatus) } wsStatus, ok := subsystems["websocket"].(map[string]interface{}) if !ok || wsStatus["status"] != "ok" { t.Fatalf("websocket not ok: %v", wsStatus) } } func TestSPAFallback(t *testing.T) { ts, _, _, _, _ := setupTestServer(t) // Root path resp, err := http.Get(ts.URL + "/") if err != nil { t.Fatalf("root request: %v", err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("expected 200 for root, got %d", resp.StatusCode) } // Unknown path (SPA fallback) resp2, err := http.Get(ts.URL + "/some/unknown/route") if err != nil { t.Fatalf("unknown path request: %v", err) } defer resp2.Body.Close() if resp2.StatusCode != 200 { t.Fatalf("expected 200 for SPA fallback, got %d", resp2.StatusCode) } } func TestWebSocketRejectsMissingToken(t *testing.T) { ts, _, _, _, _ := setupTestServer(t) ctx := context.Background() _, resp, err := websocket.Dial(ctx, "ws"+ts.URL[4:]+"/ws", nil) if err == nil { t.Fatal("expected error for missing token") } if resp != nil && resp.StatusCode != 401 { t.Fatalf("expected 401, got %d", resp.StatusCode) } } func TestWebSocketRejectsInvalidToken(t *testing.T) { ts, _, _, _, _ := setupTestServer(t) ctx := context.Background() _, resp, err := websocket.Dial(ctx, "ws"+ts.URL[4:]+"/ws?token=invalid", nil) if err == nil { t.Fatal("expected error for invalid token") } if resp != nil && resp.StatusCode != 401 { t.Fatalf("expected 401, got %d", resp.StatusCode) } } func TestWebSocketAcceptsValidToken(t *testing.T) { ts, _, _, signingKey, _ := setupTestServer(t) ctx := context.Background() tokenStr := makeToken(t, signingKey, "operator-123", "admin") wsURL := "ws" + ts.URL[4:] + "/ws?token=" + tokenStr conn, _, err := websocket.Dial(ctx, wsURL, nil) if err != nil { t.Fatalf("websocket dial: %v", err) } defer conn.CloseNow() // Should receive a connected message readCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() _, msgBytes, err := conn.Read(readCtx) if err != nil { t.Fatalf("read connected message: %v", err) } var msg map[string]interface{} if err := json.Unmarshal(msgBytes, &msg); err != nil { t.Fatalf("decode message: %v", err) } if msg["type"] != "connected" { t.Fatalf("expected 'connected' message, got %v", msg["type"]) } conn.Close(websocket.StatusNormalClosure, "test done") } func TestNATSStreamsExist(t *testing.T) { _, _, ns, _, _ := setupTestServer(t) ctx := context.Background() js := ns.JetStream() // Check AUDIT stream stream, err := js.Stream(ctx, "AUDIT") if err != nil { t.Fatalf("get AUDIT stream: %v", err) } info, err := stream.Info(ctx) if err != nil { t.Fatalf("get AUDIT stream info: %v", err) } if info.Config.Name != "AUDIT" { t.Fatalf("expected AUDIT stream, got %s", info.Config.Name) } // Check STATE stream stream, err = js.Stream(ctx, "STATE") if err != nil { t.Fatalf("get STATE stream: %v", err) } info, err = stream.Info(ctx) if err != nil { t.Fatalf("get STATE stream info: %v", err) } if info.Config.Name != "STATE" { t.Fatalf("expected STATE stream, got %s", info.Config.Name) } } func TestPublisherUUIDValidation(t *testing.T) { _, _, ns, _, _ := setupTestServer(t) ctx := context.Background() js := ns.JetStream() pub := feltnats.NewPublisher(js) // Empty UUID _, err := pub.Publish(ctx, "", "audit", []byte("test")) if err == nil { t.Fatal("expected error for empty UUID") } // UUID with NATS wildcards _, err = pub.Publish(ctx, "test.*.injection", "audit", []byte("test")) if err == nil { t.Fatal("expected error for UUID with wildcards") } // Invalid format _, err = pub.Publish(ctx, "not-a-uuid", "audit", []byte("test")) if err == nil { t.Fatal("expected error for invalid UUID format") } // Valid UUID should succeed _, err = pub.Publish(ctx, "550e8400-e29b-41d4-a716-446655440000", "audit", []byte(`{"test":true}`)) if err != nil { t.Fatalf("expected success for valid UUID, got: %v", err) } } func TestLibSQLWALMode(t *testing.T) { _, db, _, _, _ := setupTestServer(t) var mode string err := db.QueryRow("PRAGMA journal_mode").Scan(&mode) if err != nil { t.Fatalf("query journal_mode: %v", err) } if mode != "wal" { t.Fatalf("expected WAL mode, got %s", mode) } } func TestLibSQLForeignKeys(t *testing.T) { _, db, _, _, _ := setupTestServer(t) var fk int err := db.QueryRow("PRAGMA foreign_keys").Scan(&fk) if err != nil { t.Fatalf("query foreign_keys: %v", err) } if fk != 1 { t.Fatalf("expected foreign_keys=1, got %d", fk) } } // ---- Auth Tests (Plan C) ---- func TestLoginWithCorrectPIN(t *testing.T) { ts, _, _, _, _ := setupTestServer(t) // Dev seed creates admin with PIN 1234 body := bytes.NewBufferString(`{"pin":"1234"}`) resp, err := http.Post(ts.URL+"/api/v1/auth/login", "application/json", body) if err != nil { t.Fatalf("login request: %v", err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("expected 200, got %d", resp.StatusCode) } var loginResp map[string]interface{} if err := json.NewDecoder(resp.Body).Decode(&loginResp); err != nil { t.Fatalf("decode login response: %v", err) } if loginResp["token"] == nil || loginResp["token"] == "" { t.Fatal("expected token in login response") } operator, ok := loginResp["operator"].(map[string]interface{}) if !ok { t.Fatal("expected operator in login response") } if operator["name"] != "Admin" { t.Fatalf("expected operator name Admin, got %v", operator["name"]) } if operator["role"] != "admin" { t.Fatalf("expected operator role admin, got %v", operator["role"]) } } func TestLoginWithWrongPIN(t *testing.T) { ts, _, _, _, _ := setupTestServer(t) body := bytes.NewBufferString(`{"pin":"9999"}`) resp, err := http.Post(ts.URL+"/api/v1/auth/login", "application/json", body) if err != nil { t.Fatalf("login request: %v", err) } defer resp.Body.Close() if resp.StatusCode != 401 { t.Fatalf("expected 401, got %d", resp.StatusCode) } } func TestLoginTokenAccessesProtectedEndpoint(t *testing.T) { ts, _, _, _, _ := setupTestServer(t) // Login first body := bytes.NewBufferString(`{"pin":"1234"}`) resp, err := http.Post(ts.URL+"/api/v1/auth/login", "application/json", body) if err != nil { t.Fatalf("login request: %v", err) } defer resp.Body.Close() var loginResp map[string]interface{} json.NewDecoder(resp.Body).Decode(&loginResp) token := loginResp["token"].(string) // Use token to access /auth/me req, _ := http.NewRequest("GET", ts.URL+"/api/v1/auth/me", nil) req.Header.Set("Authorization", "Bearer "+token) meResp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("me request: %v", err) } defer meResp.Body.Close() if meResp.StatusCode != 200 { t.Fatalf("expected 200, got %d", meResp.StatusCode) } var me map[string]interface{} json.NewDecoder(meResp.Body).Decode(&me) if me["role"] != "admin" { t.Fatalf("expected role admin, got %v", me["role"]) } } func TestProtectedEndpointWithoutToken(t *testing.T) { ts, _, _, _, _ := setupTestServer(t) resp, err := http.Get(ts.URL + "/api/v1/auth/me") if err != nil { t.Fatalf("me request: %v", err) } defer resp.Body.Close() if resp.StatusCode != 401 { t.Fatalf("expected 401, got %d", resp.StatusCode) } } func TestRoleMiddlewareBlocksInsufficientRole(t *testing.T) { ts, _, _, signingKey, _ := setupTestServer(t) // Create a viewer token -- viewers can't access admin endpoints viewerToken := makeToken(t, signingKey, "viewer-op", "viewer") // /tournaments requires auth but should be accessible by any role for now req, _ := http.NewRequest("GET", ts.URL+"/api/v1/tournaments", nil) req.Header.Set("Authorization", "Bearer "+viewerToken) resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("request: %v", err) } defer resp.Body.Close() // Currently all authenticated users can access stub endpoints if resp.StatusCode != 200 { t.Fatalf("expected 200, got %d", resp.StatusCode) } } func TestJWTValidationEnforcesHS256(t *testing.T) { signingKey := []byte("test-signing-key-32-bytes-long!!") // Create a valid HS256 token -- should work _, _, err := middleware.ValidateJWT(makeToken(t, signingKey, "op-1", "admin"), signingKey) if err != nil { t.Fatalf("valid HS256 token should pass: %v", err) } // Create an expired token -- should fail expiredToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": "op-1", "role": "admin", "exp": time.Now().Add(-time.Hour).Unix(), }) expiredStr, _ := expiredToken.SignedString(signingKey) _, _, err = middleware.ValidateJWT(expiredStr, signingKey) if err == nil { t.Fatal("expired token should fail validation") } }