package store import ( "database/sql" "embed" "fmt" "log" "sort" "strings" "time" ) //go:embed migrations/*.sql var migrationsFS embed.FS // devOnlyMigrations lists migration filenames that should only be applied // in development mode (--dev flag). var devOnlyMigrations = map[string]bool{ "004_dev_seed.sql": true, } // RunMigrations applies all pending SQL migrations embedded in the binary. // Migrations are sorted by filename (numeric prefix ensures order) and // executed statement-by-statement. Each successful migration is recorded // in the _migrations table to prevent re-application. // // go-libsql does not support multi-statement Exec, so each SQL statement // is executed individually. // // If devMode is false, migrations listed in devOnlyMigrations are skipped. func RunMigrations(db *sql.DB, devMode bool) error { // Force single connection during migration to ensure all tables are // visible across migration steps. db.SetMaxOpenConns(1) defer db.SetMaxOpenConns(0) // restore default (unlimited) after migrations // Create migrations tracking table _, err := db.Exec(` CREATE TABLE IF NOT EXISTS _migrations ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL UNIQUE, applied_at INTEGER NOT NULL ) `) if err != nil { return fmt.Errorf("create _migrations table: %w", err) } // Read all migration files from embedded filesystem entries, err := migrationsFS.ReadDir("migrations") if err != nil { return fmt.Errorf("read migrations dir: %w", err) } // Sort by filename to ensure order sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() }) applied := 0 skipped := 0 for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") { continue } name := entry.Name() // Skip dev-only migrations unless in dev mode if devOnlyMigrations[name] && !devMode { log.Printf("store: skipping dev-only migration %s (not in dev mode)", name) continue } // Check if already applied var count int err := db.QueryRow("SELECT COUNT(*) FROM _migrations WHERE name = ?", name).Scan(&count) if err != nil { return fmt.Errorf("check migration %s: %w", name, err) } if count > 0 { skipped++ continue } // Read migration SQL sqlBytes, err := migrationsFS.ReadFile("migrations/" + name) if err != nil { return fmt.Errorf("read migration %s: %w", name, err) } // Split into individual statements and execute each one. // go-libsql does not support multi-statement Exec. stmts := splitStatements(string(sqlBytes)) for i, stmt := range stmts { if _, err := db.Exec(stmt); err != nil { return fmt.Errorf("execute migration %s (statement %d): %w\nSQL: %s", name, i+1, err, truncate(stmt, 200)) } } // Record migration as applied now := time.Now().Unix() if _, err := db.Exec( "INSERT INTO _migrations (name, applied_at) VALUES (?, ?)", name, now, ); err != nil { return fmt.Errorf("record migration %s: %w", name, err) } applied++ log.Printf("store: applied migration %s (%d statements)", name, len(stmts)) } if applied > 0 { log.Printf("store: %d migration(s) applied, %d already up-to-date", applied, skipped) } else if skipped > 0 { log.Printf("store: all %d migration(s) already applied", skipped) } return nil } // splitStatements splits a SQL file into individual statements. // It handles: // - Single-line comments (-- ...) // - Multi-statement files separated by semicolons // - String literals containing semicolons (won't split inside quotes) // - CREATE TRIGGER statements that contain semicolons inside BEGIN...END blocks func splitStatements(sql string) []string { var stmts []string var current strings.Builder inTrigger := false lines := strings.Split(sql, "\n") for _, line := range lines { trimmed := strings.TrimSpace(line) // Skip empty lines and pure comment lines if trimmed == "" || strings.HasPrefix(trimmed, "--") { continue } // Track BEGIN/END for trigger bodies upperTrimmed := strings.ToUpper(trimmed) if strings.Contains(upperTrimmed, "CREATE TRIGGER") { inTrigger = true } current.WriteString(line) current.WriteString("\n") // Check if this line ends a statement if strings.HasSuffix(trimmed, ";") { if inTrigger { // Inside a trigger, only END; terminates the trigger if strings.HasPrefix(upperTrimmed, "END;") || upperTrimmed == "END;" { stmt := strings.TrimSpace(current.String()) if stmt != "" { stmts = append(stmts, stmt) } current.Reset() inTrigger = false } } else { stmt := strings.TrimSpace(current.String()) if stmt != "" { stmts = append(stmts, stmt) } current.Reset() } } } // Handle any remaining content (statement without trailing semicolon) remaining := strings.TrimSpace(current.String()) if remaining != "" { stmts = append(stmts, remaining) } return stmts } // truncate returns at most n characters of s, appending "..." if truncated. func truncate(s string, n int) string { if len(s) <= n { return s } return s[:n] + "..." }