diff --git a/internal/store/migrations.go b/internal/store/migrations.go new file mode 100644 index 0000000..68b4ad3 --- /dev/null +++ b/internal/store/migrations.go @@ -0,0 +1,197 @@ +package store + +import ( + "database/sql" + "fmt" + "sort" + "time" +) + +// migration is one ordered, atomic schema step. id must be unique and +// strictly increasing across the slice. name is a human-readable label +// stored alongside the id for debugging, and up receives a *sql.Tx so +// DDL + data backfills land atomically — either the migration fully +// applies and a schema_migrations row is written, or the whole thing +// rolls back and gets retried on next Open(). +type migration struct { + id int + name string + up func(*sql.Tx) error +} + +// migrations is the canonical ordered history. Append new migrations +// at the bottom with the next id. Never edit or reorder existing +// entries — installed DBs key off the id column. +var migrations = []migration{ + {id: 1, name: "baseline", up: migrateBaseline}, +} + +// runMigrations ensures schema_migrations exists, then applies every +// migration whose id hasn't been recorded yet, in id order. Existing +// dev databases (schema set up by the pre-versioning inline migrate() +// helper) see the baseline SQL as a no-op because every statement is +// `CREATE TABLE IF NOT EXISTS`; the row that records id=1 is what +// brings them into the new system. +func runMigrations(db *sql.DB) error { + if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + applied_at TEXT NOT NULL + )`); err != nil { + return fmt.Errorf("create schema_migrations: %w", err) + } + + applied, err := loadAppliedMigrations(db) + if err != nil { + return err + } + + sorted := make([]migration, len(migrations)) + copy(sorted, migrations) + sort.Slice(sorted, func(i, j int) bool { return sorted[i].id < sorted[j].id }) + seen := map[int]bool{} + for _, m := range sorted { + if seen[m.id] { + return fmt.Errorf("duplicate migration id %d (%q)", m.id, m.name) + } + seen[m.id] = true + } + + for _, m := range sorted { + if _, ok := applied[m.id]; ok { + continue + } + if err := applyMigration(db, m); err != nil { + return fmt.Errorf("migration %d (%s): %w", m.id, m.name, err) + } + } + return nil +} + +func loadAppliedMigrations(db *sql.DB) (map[int]struct{}, error) { + rows, err := db.Query("SELECT id FROM schema_migrations") + if err != nil { + return nil, fmt.Errorf("load schema_migrations: %w", err) + } + defer rows.Close() + applied := map[int]struct{}{} + for rows.Next() { + var id int + if err := rows.Scan(&id); err != nil { + return nil, err + } + applied[id] = struct{}{} + } + return applied, rows.Err() +} + +func applyMigration(db *sql.DB, m migration) error { + tx, err := db.Begin() + if err != nil { + return err + } + if err := m.up(tx); err != nil { + _ = tx.Rollback() + return err + } + if _, err := tx.Exec( + "INSERT INTO schema_migrations (id, name, applied_at) VALUES (?, ?, ?)", + m.id, m.name, time.Now().UTC().Format(time.RFC3339), + ); err != nil { + _ = tx.Rollback() + return fmt.Errorf("record migration: %w", err) + } + return tx.Commit() +} + +// migrateBaseline captures the schema as it stood when the versioned +// migration system was introduced. Uses IF NOT EXISTS on every object +// so existing dev databases — whose tables were set up by the old +// inline migrate() — pass through cleanly and only the +// schema_migrations row gets added. +func migrateBaseline(tx *sql.Tx) error { + stmts := []string{ + `CREATE TABLE IF NOT EXISTS images ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + managed INTEGER NOT NULL DEFAULT 0, + artifact_dir TEXT, + rootfs_path TEXT NOT NULL, + work_seed_path TEXT, + kernel_path TEXT NOT NULL, + initrd_path TEXT, + modules_dir TEXT, + packages_path TEXT, + build_size TEXT, + seeded_ssh_public_key_fingerprint TEXT, + docker INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + );`, + `CREATE TABLE IF NOT EXISTS vms ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + image_id TEXT NOT NULL, + guest_ip TEXT NOT NULL UNIQUE, + state TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + last_touched_at TEXT NOT NULL, + spec_json TEXT NOT NULL, + runtime_json TEXT NOT NULL, + stats_json TEXT NOT NULL DEFAULT '{}', + FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE RESTRICT + );`, + } + for _, stmt := range stmts { + if _, err := tx.Exec(stmt); err != nil { + return err + } + } + // Columns added to the images table across the pre-versioning + // lifetime of the project. New installs get them from the CREATE + // TABLE above; upgraders from an ancient snapshot (pre- + // ensureColumnExists) pick them up here. Idempotent either way. + for _, col := range []struct{ table, name, typ string }{ + {"images", "work_seed_path", "TEXT"}, + {"images", "seeded_ssh_public_key_fingerprint", "TEXT"}, + } { + if err := addColumnIfMissing(tx, col.table, col.name, col.typ); err != nil { + return err + } + } + return nil +} + +// addColumnIfMissing is SQLite's "ALTER TABLE ADD COLUMN IF NOT EXISTS" +// (which the dialect lacks) as a library function. Used inside +// migrations when a column needs to survive a database that went +// through some historical path where the column was added later. +func addColumnIfMissing(tx *sql.Tx, table, column, columnType string) error { + rows, err := tx.Query(fmt.Sprintf("PRAGMA table_info(%s)", table)) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var ( + cid int + name string + valueType string + notNull int + defaultV sql.NullString + pk int + ) + if err := rows.Scan(&cid, &name, &valueType, ¬Null, &defaultV, &pk); err != nil { + return err + } + if name == column { + return nil + } + } + if err := rows.Err(); err != nil { + return err + } + _, err = tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, columnType)) + return err +} diff --git a/internal/store/migrations_test.go b/internal/store/migrations_test.go new file mode 100644 index 0000000..30d72bf --- /dev/null +++ b/internal/store/migrations_test.go @@ -0,0 +1,156 @@ +package store + +import ( + "database/sql" + "errors" + "path/filepath" + "testing" + + _ "modernc.org/sqlite" +) + +// openRawDB opens a SQLite DB at a fresh tempfile without running any +// migrations, so tests can observe migration-runner behaviour directly. +func openRawDB(t *testing.T) *sql.DB { + t.Helper() + path := filepath.Join(t.TempDir(), "state.db") + dsn, err := sqliteDSN(path) + if err != nil { + t.Fatalf("sqliteDSN: %v", err) + } + db, err := sql.Open("sqlite", dsn) + if err != nil { + t.Fatalf("sql.Open: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db +} + +func TestRunMigrationsAppliesBaselineOnFreshDB(t *testing.T) { + db := openRawDB(t) + if err := runMigrations(db); err != nil { + t.Fatalf("runMigrations: %v", err) + } + // All declared migrations must be recorded. + for _, m := range migrations { + var got string + if err := db.QueryRow("SELECT name FROM schema_migrations WHERE id = ?", m.id).Scan(&got); err != nil { + t.Fatalf("migration %d not recorded: %v", m.id, err) + } + if got != m.name { + t.Errorf("migration %d name = %q, want %q", m.id, got, m.name) + } + } + // Baseline must have created the real tables. + for _, table := range []string{"images", "vms"} { + var name string + if err := db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&name); err != nil { + t.Fatalf("table %s missing: %v", table, err) + } + } +} + +func TestRunMigrationsIsIdempotent(t *testing.T) { + db := openRawDB(t) + if err := runMigrations(db); err != nil { + t.Fatalf("runMigrations first pass: %v", err) + } + if err := runMigrations(db); err != nil { + t.Fatalf("runMigrations second pass: %v", err) + } + // One row per migration, no duplicates. + var count int + if err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations").Scan(&count); err != nil { + t.Fatalf("count: %v", err) + } + if count != len(migrations) { + t.Errorf("schema_migrations rows = %d, want %d", count, len(migrations)) + } +} + +func TestRunMigrationsSkipsAlreadyApplied(t *testing.T) { + db := openRawDB(t) + + // Swap in a test-only migration whose body would error if invoked, + // pre-insert its id into schema_migrations, and confirm the runner + // recognises the marker and skips the body entirely. + orig := migrations + t.Cleanup(func() { migrations = orig }) + migrations = []migration{ + {id: 1, name: "baseline", up: migrateBaseline}, + {id: 99, name: "explodes-if-run", up: func(*sql.Tx) error { + return errors.New("must not execute") + }}, + } + + if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + applied_at TEXT NOT NULL + )`); err != nil { + t.Fatalf("seed schema_migrations table: %v", err) + } + if _, err := db.Exec( + "INSERT INTO schema_migrations (id, name, applied_at) VALUES (?, ?, ?)", + 99, "explodes-if-run", "2026-04-20T00:00:00Z", + ); err != nil { + t.Fatalf("seed applied row: %v", err) + } + + if err := runMigrations(db); err != nil { + t.Fatalf("runMigrations: %v", err) + } +} + +func TestApplyMigrationRollsBackOnBodyError(t *testing.T) { + db := openRawDB(t) + if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + applied_at TEXT NOT NULL + )`); err != nil { + t.Fatalf("seed schema_migrations: %v", err) + } + + err := applyMigration(db, migration{ + id: 7, + name: "creates-then-fails", + up: func(tx *sql.Tx) error { + if _, err := tx.Exec("CREATE TABLE transient (x INTEGER)"); err != nil { + return err + } + return errors.New("synthetic failure") + }, + }) + if err == nil { + t.Fatal("expected applyMigration to surface body error") + } + + // The transient table must NOT survive the failed migration. + var name string + if err := db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='transient'").Scan(&name); err == nil { + t.Fatal("transient table survived rollback") + } + // And no schema_migrations row for id=7. + var count int + if err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE id=7").Scan(&count); err != nil { + t.Fatalf("count: %v", err) + } + if count != 0 { + t.Fatalf("schema_migrations recorded failed migration: count=%d", count) + } +} + +func TestRunMigrationsRejectsDuplicateID(t *testing.T) { + db := openRawDB(t) + orig := migrations + t.Cleanup(func() { migrations = orig }) + migrations = []migration{ + {id: 1, name: "first", up: func(*sql.Tx) error { return nil }}, + {id: 1, name: "dupe", up: func(*sql.Tx) error { return nil }}, + } + err := runMigrations(db) + if err == nil { + t.Fatal("expected error for duplicate migration id") + } +} diff --git a/internal/store/store.go b/internal/store/store.go index 6ac5a31..7ddc941 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -31,7 +31,7 @@ func Open(path string) (*Store, error) { return nil, err } store := &Store{db: db} - if err := store.migrate(); err != nil { + if err := runMigrations(db); err != nil { _ = db.Close() return nil, err } @@ -66,54 +66,6 @@ func sqliteDSN(path string) (string, error) { }).String(), nil } -func (s *Store) migrate() error { - stmts := []string{ - `CREATE TABLE IF NOT EXISTS images ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL UNIQUE, - managed INTEGER NOT NULL DEFAULT 0, - artifact_dir TEXT, - rootfs_path TEXT NOT NULL, - work_seed_path TEXT, - kernel_path TEXT NOT NULL, - initrd_path TEXT, - modules_dir TEXT, - packages_path TEXT, - build_size TEXT, - seeded_ssh_public_key_fingerprint TEXT, - docker INTEGER NOT NULL DEFAULT 0, - created_at TEXT NOT NULL, - updated_at TEXT NOT NULL - );`, - `CREATE TABLE IF NOT EXISTS vms ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL UNIQUE, - image_id TEXT NOT NULL, - guest_ip TEXT NOT NULL UNIQUE, - state TEXT NOT NULL, - created_at TEXT NOT NULL, - updated_at TEXT NOT NULL, - last_touched_at TEXT NOT NULL, - spec_json TEXT NOT NULL, - runtime_json TEXT NOT NULL, - stats_json TEXT NOT NULL DEFAULT '{}', - FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE RESTRICT - );`, - } - for _, stmt := range stmts { - if _, err := s.db.Exec(stmt); err != nil { - return err - } - } - if err := ensureColumnExists(s.db, "images", "work_seed_path", "TEXT"); err != nil { - return err - } - if err := ensureColumnExists(s.db, "images", "seeded_ssh_public_key_fingerprint", "TEXT"); err != nil { - return err - } - return nil -} - func (s *Store) UpsertImage(ctx context.Context, image model.Image) error { s.writeMu.Lock() defer s.writeMu.Unlock() @@ -432,35 +384,6 @@ func scanVMInto(row scanner) (model.VMRecord, error) { return vm, nil } -func ensureColumnExists(db *sql.DB, table, column, columnType string) error { - rows, err := db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table)) - if err != nil { - return err - } - defer rows.Close() - for rows.Next() { - var ( - cid int - name string - valueType string - notNull int - defaultV sql.NullString - pk int - ) - if err := rows.Scan(&cid, &name, &valueType, ¬Null, &defaultV, &pk); err != nil { - return err - } - if name == column { - return nil - } - } - if err := rows.Err(); err != nil { - return err - } - _, err = db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, columnType)) - return err -} - func boolToInt(value bool) int { if value { return 1