package main import ( "context" "encoding/json" "net/http" "net/http/httptest" "testing" "time" "github.com/coder/websocket" "github.com/golang-jwt/jwt/v5" feltnats "github.com/felt-app/felt/internal/nats" "github.com/felt-app/felt/internal/server" "github.com/felt-app/felt/internal/server/middleware" "github.com/felt-app/felt/internal/server/ws" "github.com/felt-app/felt/internal/store" ) func setupTestServer(t *testing.T) (*httptest.Server, *store.DB, *feltnats.EmbeddedServer, []byte) { t.Helper() ctx := context.Background() tmpDir := t.TempDir() // Open database db, err := store.Open(tmpDir, true) if err != nil { t.Fatalf("open database: %v", err) } t.Cleanup(func() { db.Close() }) // Start NATS ns, err := feltnats.Start(ctx, tmpDir) if err != nil { t.Fatalf("start nats: %v", err) } t.Cleanup(func() { ns.Shutdown() }) // Setup JWT signing signingKey := []byte("test-signing-key-32-bytes-long!!") tokenValidator := func(tokenStr string) (string, string, error) { return middleware.ValidateJWT(tokenStr, signingKey) } hub := ws.NewHub(tokenValidator, nil, nil) t.Cleanup(func() { hub.Shutdown() }) // Create HTTP server srv := server.New(server.Config{ Addr: ":0", SigningKey: signingKey, DevMode: true, }, db.DB, ns.Server(), hub) ts := httptest.NewServer(srv.Handler()) t.Cleanup(func() { ts.Close() }) return ts, db, ns, signingKey } func makeToken(t *testing.T, signingKey []byte, operatorID, role string) string { t.Helper() token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": operatorID, "role": role, "exp": time.Now().Add(time.Hour).Unix(), }) tokenStr, err := token.SignedString(signingKey) if err != nil { t.Fatalf("sign token: %v", err) } return tokenStr } func TestHealthEndpoint(t *testing.T) { ts, _, _, _ := setupTestServer(t) resp, err := http.Get(ts.URL + "/api/v1/health") if err != nil { t.Fatalf("health request: %v", err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("expected 200, got %d", resp.StatusCode) } var health map[string]interface{} if err := json.NewDecoder(resp.Body).Decode(&health); err != nil { t.Fatalf("decode health: %v", err) } if health["status"] != "ok" { t.Fatalf("expected status ok, got %v", health["status"]) } // Check subsystems subsystems, ok := health["subsystems"].(map[string]interface{}) if !ok { t.Fatal("missing subsystems in health response") } dbStatus, ok := subsystems["database"].(map[string]interface{}) if !ok || dbStatus["status"] != "ok" { t.Fatalf("database not ok: %v", dbStatus) } natsStatus, ok := subsystems["nats"].(map[string]interface{}) if !ok || natsStatus["status"] != "ok" { t.Fatalf("nats not ok: %v", natsStatus) } wsStatus, ok := subsystems["websocket"].(map[string]interface{}) if !ok || wsStatus["status"] != "ok" { t.Fatalf("websocket not ok: %v", wsStatus) } } func TestSPAFallback(t *testing.T) { ts, _, _, _ := setupTestServer(t) // Root path resp, err := http.Get(ts.URL + "/") if err != nil { t.Fatalf("root request: %v", err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("expected 200 for root, got %d", resp.StatusCode) } // Unknown path (SPA fallback) resp2, err := http.Get(ts.URL + "/some/unknown/route") if err != nil { t.Fatalf("unknown path request: %v", err) } defer resp2.Body.Close() if resp2.StatusCode != 200 { t.Fatalf("expected 200 for SPA fallback, got %d", resp2.StatusCode) } } func TestWebSocketRejectsMissingToken(t *testing.T) { ts, _, _, _ := setupTestServer(t) ctx := context.Background() _, resp, err := websocket.Dial(ctx, "ws"+ts.URL[4:]+"/ws", nil) if err == nil { t.Fatal("expected error for missing token") } if resp != nil && resp.StatusCode != 401 { t.Fatalf("expected 401, got %d", resp.StatusCode) } } func TestWebSocketRejectsInvalidToken(t *testing.T) { ts, _, _, _ := setupTestServer(t) ctx := context.Background() _, resp, err := websocket.Dial(ctx, "ws"+ts.URL[4:]+"/ws?token=invalid", nil) if err == nil { t.Fatal("expected error for invalid token") } if resp != nil && resp.StatusCode != 401 { t.Fatalf("expected 401, got %d", resp.StatusCode) } } func TestWebSocketAcceptsValidToken(t *testing.T) { ts, _, _, signingKey := setupTestServer(t) ctx := context.Background() tokenStr := makeToken(t, signingKey, "operator-123", "admin") wsURL := "ws" + ts.URL[4:] + "/ws?token=" + tokenStr conn, _, err := websocket.Dial(ctx, wsURL, nil) if err != nil { t.Fatalf("websocket dial: %v", err) } defer conn.CloseNow() // Should receive a connected message readCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() _, msgBytes, err := conn.Read(readCtx) if err != nil { t.Fatalf("read connected message: %v", err) } var msg map[string]interface{} if err := json.Unmarshal(msgBytes, &msg); err != nil { t.Fatalf("decode message: %v", err) } if msg["type"] != "connected" { t.Fatalf("expected 'connected' message, got %v", msg["type"]) } conn.Close(websocket.StatusNormalClosure, "test done") } func TestNATSStreamsExist(t *testing.T) { _, _, ns, _ := setupTestServer(t) ctx := context.Background() js := ns.JetStream() // Check AUDIT stream stream, err := js.Stream(ctx, "AUDIT") if err != nil { t.Fatalf("get AUDIT stream: %v", err) } info, err := stream.Info(ctx) if err != nil { t.Fatalf("get AUDIT stream info: %v", err) } if info.Config.Name != "AUDIT" { t.Fatalf("expected AUDIT stream, got %s", info.Config.Name) } // Check STATE stream stream, err = js.Stream(ctx, "STATE") if err != nil { t.Fatalf("get STATE stream: %v", err) } info, err = stream.Info(ctx) if err != nil { t.Fatalf("get STATE stream info: %v", err) } if info.Config.Name != "STATE" { t.Fatalf("expected STATE stream, got %s", info.Config.Name) } } func TestPublisherUUIDValidation(t *testing.T) { _, _, ns, _ := setupTestServer(t) ctx := context.Background() js := ns.JetStream() pub := feltnats.NewPublisher(js) // Empty UUID _, err := pub.Publish(ctx, "", "audit", []byte("test")) if err == nil { t.Fatal("expected error for empty UUID") } // UUID with NATS wildcards _, err = pub.Publish(ctx, "test.*.injection", "audit", []byte("test")) if err == nil { t.Fatal("expected error for UUID with wildcards") } // Invalid format _, err = pub.Publish(ctx, "not-a-uuid", "audit", []byte("test")) if err == nil { t.Fatal("expected error for invalid UUID format") } // Valid UUID should succeed _, err = pub.Publish(ctx, "550e8400-e29b-41d4-a716-446655440000", "audit", []byte(`{"test":true}`)) if err != nil { t.Fatalf("expected success for valid UUID, got: %v", err) } } func TestLibSQLWALMode(t *testing.T) { _, db, _, _ := setupTestServer(t) var mode string err := db.QueryRow("PRAGMA journal_mode").Scan(&mode) if err != nil { t.Fatalf("query journal_mode: %v", err) } if mode != "wal" { t.Fatalf("expected WAL mode, got %s", mode) } } func TestLibSQLForeignKeys(t *testing.T) { _, db, _, _ := setupTestServer(t) var fk int err := db.QueryRow("PRAGMA foreign_keys").Scan(&fk) if err != nil { t.Fatalf("query foreign_keys: %v", err) } if fk != 1 { t.Fatalf("expected foreign_keys=1, got %d", fk) } }