From a14a80fd6bad0bb46f215e11a592c709b5330581 Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Wed, 18 Mar 2026 20:39:31 -0300 Subject: [PATCH] Harden VM delete cleanup and SQLite settings Multi-VM delete exposed two separate regressions: NAT teardown was still running after stopped VMs had already dropped their tap metadata, and the store was relying on one-off SQLite pragmas instead of configuring every pooled connection. Skip NAT cleanup when the runtime no longer has the network handles needed to identify rules, and move the SQLite profile into the DSN so WAL, busy timeouts, foreign keys, and the other connection-scoped settings apply consistently across the pool. Keep the write mutex in place for concurrent mutations, and update the daemon/store tests to use valid image fixtures now that foreign key enforcement is real. Validated with go test ./... and make build. --- internal/daemon/capabilities.go | 6 ++ internal/daemon/vm_test.go | 91 ++++++++++++++--------- internal/store/store.go | 45 +++++++++++- internal/store/store_test.go | 123 +++++++++++++++++++++++++++++++- 4 files changed, 228 insertions(+), 37 deletions(-) diff --git a/internal/daemon/capabilities.go b/internal/daemon/capabilities.go index 0cbcabd..bbf37cb 100644 --- a/internal/daemon/capabilities.go +++ b/internal/daemon/capabilities.go @@ -238,6 +238,12 @@ func (natCapability) Cleanup(ctx context.Context, d *Daemon, vm model.VMRecord) if !vm.Spec.NATEnabled { return nil } + if strings.TrimSpace(vm.Runtime.GuestIP) == "" || strings.TrimSpace(vm.Runtime.TapDevice) == "" { + if d.logger != nil { + d.logger.Debug("skipping nat cleanup without runtime network handles", append(vmLogAttrs(vm), "guest_ip", vm.Runtime.GuestIP, "tap_device", vm.Runtime.TapDevice)...) + } + return nil + } return d.ensureNAT(ctx, vm, false) } diff --git a/internal/daemon/vm_test.go b/internal/daemon/vm_test.go index 5b2441c..b4fec45 100644 --- a/internal/daemon/vm_test.go +++ b/internal/daemon/vm_test.go @@ -33,9 +33,7 @@ func TestFindVMPrefixResolution(t *testing.T) { testVM("alpine", "image-alpha", "172.16.0.3"), testVM("bravo", "image-alpha", "172.16.0.4"), } { - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM(%s): %v", vm.Name, err) - } + upsertDaemonVM(t, ctx, db, vm) } vm, err := d.FindVM(ctx, "alpha") @@ -116,9 +114,7 @@ func TestReconcileStopsStaleRunningVMAndClearsRuntimeHandles(t *testing.T) { vm.Runtime.COWLoop = "/dev/loop11" vm.Runtime.BaseLoop = "/dev/loop10" vm.Runtime.DNSName = "" - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) runner := &scriptedRunner{ t: t, @@ -176,9 +172,7 @@ func TestRebuildDNSIncludesOnlyLiveRunningVMs(t *testing.T) { stopped := testVM("stopped", "image-stopped", "172.16.0.23") for _, vm := range []model.VMRecord{live, stale, stopped} { - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM(%s): %v", vm.Name, err) - } + upsertDaemonVM(t, ctx, db, vm) } server, err := vmdns.New("127.0.0.1:0", nil) @@ -224,9 +218,7 @@ func TestSetVMRejectsStoppedOnlyChangesForRunningVM(t *testing.T) { vm.Runtime.State = model.VMStateRunning vm.Runtime.PID = cmd.Process.Pid vm.Runtime.APISockPath = apiSock - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) d := &Daemon{store: db} tests := []struct { @@ -324,9 +316,7 @@ func TestPingVMReturnsAliveForRunningGuest(t *testing.T) { vm.Runtime.APISockPath = apiSock vm.Runtime.VSockPath = vsockSock vm.Runtime.VSockCID = 10041 - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) runner := &scriptedRunner{ t: t, @@ -355,9 +345,7 @@ func TestPingVMReturnsFalseForStoppedVM(t *testing.T) { ctx := context.Background() db := openDaemonStore(t) vm := testVM("stopped-ping", "image-stopped", "172.16.0.42") - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) d := &Daemon{store: db} result, err := d.PingVM(ctx, vm.Name) @@ -379,9 +367,7 @@ func TestSetVMDiskResizeFailsPreflightWhenToolsMissing(t *testing.T) { vm := testVM("resize", "image-resize", "172.16.0.11") vm.Runtime.WorkDiskPath = workDisk vm.Spec.WorkDiskSizeBytes = 8 * 1024 * 1024 * 1024 - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) t.Setenv("PATH", t.TempDir()) d := &Daemon{store: db} @@ -464,9 +450,7 @@ func TestSetVMRejectsNonPositiveCPUAndMemory(t *testing.T) { ctx := context.Background() db := openDaemonStore(t) vm := testVM("validate", "image-validate", "172.16.0.13") - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) d := &Daemon{store: db} if _, err := d.SetVM(ctx, api.VMSetParams{IDOrName: vm.ID, VCPUCount: ptr(0)}); err == nil || !strings.Contains(err.Error(), "vcpu must be a positive integer") { @@ -590,6 +574,40 @@ func TestCleanupRuntimeRediscoversLiveFirecrackerPID(t *testing.T) { } } +func TestDeleteStoppedNATVMDoesNotFailWithoutTapDevice(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openDaemonStore(t) + vmDir := filepath.Join(t.TempDir(), "stopped-nat-vm") + if err := os.MkdirAll(vmDir, 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + vm := testVM("stopped-nat", "image-stopped-nat", "172.16.0.24") + vm.Spec.NATEnabled = true + vm.Runtime.VMDir = vmDir + vm.Runtime.TapDevice = "" + vm.State = model.VMStateStopped + vm.Runtime.State = model.VMStateStopped + upsertDaemonVM(t, ctx, db, vm) + + d := &Daemon{store: db} + deleted, err := d.DeleteVM(ctx, vm.Name) + if err != nil { + t.Fatalf("DeleteVM: %v", err) + } + if deleted.ID != vm.ID { + t.Fatalf("deleted VM = %+v, want %s", deleted, vm.ID) + } + if _, err := db.GetVMByID(ctx, vm.ID); err == nil { + t.Fatal("expected VM record to be deleted") + } + if _, err := os.Stat(vmDir); !os.IsNotExist(err) { + t.Fatalf("vm dir still exists or stat failed: %v", err) + } +} + func TestStopVMFallsBackToForcedCleanupAfterGracefulTimeout(t *testing.T) { ctx := context.Background() db := openDaemonStore(t) @@ -615,9 +633,7 @@ func TestStopVMFallsBackToForcedCleanupAfterGracefulTimeout(t *testing.T) { vm.Runtime.State = model.VMStateRunning vm.Runtime.PID = fake.Process.Pid vm.Runtime.APISockPath = apiSock - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) runner := &processKillingRunner{ scriptedRunner: &scriptedRunner{ @@ -650,9 +666,7 @@ func TestWithVMLockByIDSerializesSameVM(t *testing.T) { ctx := context.Background() db := openDaemonStore(t) vm := testVM("serial", "image-serial", "172.16.0.30") - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) d := &Daemon{store: db} firstEntered := make(chan struct{}) @@ -710,9 +724,7 @@ func TestWithVMLockByIDAllowsDifferentVMsConcurrently(t *testing.T) { vmA := testVM("alpha-lock", "image-alpha", "172.16.0.31") vmB := testVM("bravo-lock", "image-bravo", "172.16.0.32") for _, vm := range []model.VMRecord{vmA, vmB} { - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM(%s): %v", vm.Name, err) - } + upsertDaemonVM(t, ctx, db, vm) } d := &Daemon{store: db} @@ -760,6 +772,19 @@ func openDaemonStore(t *testing.T) *store.Store { return db } +func upsertDaemonVM(t *testing.T, ctx context.Context, db *store.Store, vm model.VMRecord) { + t.Helper() + image := testImage(vm.ImageID) + image.ID = vm.ImageID + image.Name = vm.ImageID + if err := db.UpsertImage(ctx, image); err != nil { + t.Fatalf("UpsertImage(%s): %v", image.Name, err) + } + if err := db.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM(%s): %v", vm.Name, err) + } +} + func testVM(name, imageID, guestIP string) model.VMRecord { now := time.Date(2026, time.March, 16, 12, 0, 0, 0, time.UTC) return model.VMRecord{ diff --git a/internal/store/store.go b/internal/store/store.go index 1696ee1..fa4ee27 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -6,6 +6,9 @@ import ( "encoding/json" "errors" "fmt" + "net/url" + "path/filepath" + "sync" "time" _ "modernc.org/sqlite" @@ -14,11 +17,16 @@ import ( ) type Store struct { - db *sql.DB + db *sql.DB + writeMu sync.Mutex } func Open(path string) (*Store, error) { - db, err := sql.Open("sqlite", path) + dsn, err := sqliteDSN(path) + if err != nil { + return nil, err + } + db, err := sql.Open("sqlite", dsn) if err != nil { return nil, err } @@ -34,9 +42,32 @@ func (s *Store) Close() error { return s.db.Close() } +func sqliteDSN(path string) (string, error) { + absPath, err := filepath.Abs(path) + if err != nil { + return "", fmt.Errorf("resolve sqlite path: %w", err) + } + query := url.Values{} + for _, pragma := range []string{ + "journal_mode(WAL)", + "synchronous(NORMAL)", + "foreign_keys(1)", + "busy_timeout(5000)", + "temp_store(MEMORY)", + "wal_autocheckpoint(1000)", + "journal_size_limit(67108864)", + } { + query.Add("_pragma", pragma) + } + return (&url.URL{ + Scheme: "file", + Path: filepath.ToSlash(absPath), + RawQuery: query.Encode(), + }).String(), nil +} + func (s *Store) migrate() error { stmts := []string{ - `PRAGMA journal_mode=WAL;`, `CREATE TABLE IF NOT EXISTS images ( id TEXT PRIMARY KEY, name TEXT NOT NULL UNIQUE, @@ -76,6 +107,8 @@ func (s *Store) migrate() error { } func (s *Store) UpsertImage(ctx context.Context, image model.Image) error { + s.writeMu.Lock() + defer s.writeMu.Unlock() const query = ` INSERT INTO images ( id, name, managed, artifact_dir, rootfs_path, kernel_path, initrd_path, @@ -137,11 +170,15 @@ func (s *Store) ListImages(ctx context.Context) ([]model.Image, error) { } func (s *Store) DeleteImage(ctx context.Context, id string) error { + s.writeMu.Lock() + defer s.writeMu.Unlock() _, err := s.db.ExecContext(ctx, "DELETE FROM images WHERE id = ?", id) return err } func (s *Store) UpsertVM(ctx context.Context, vm model.VMRecord) error { + s.writeMu.Lock() + defer s.writeMu.Unlock() specJSON, err := json.Marshal(vm.Spec) if err != nil { return err @@ -225,6 +262,8 @@ func (s *Store) ListVMs(ctx context.Context) ([]model.VMRecord, error) { } func (s *Store) DeleteVM(ctx context.Context, id string) error { + s.writeMu.Lock() + defer s.writeMu.Unlock() _, err := s.db.ExecContext(ctx, "DELETE FROM vms WHERE id = ?", id) return err } diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 9efb813..0cbc123 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -8,6 +8,7 @@ import ( "reflect" "strconv" "strings" + "sync" "testing" "time" @@ -137,6 +138,10 @@ func TestGetVMRejectsMalformedRuntimeJSON(t *testing.T) { ctx := context.Background() store := openTestStore(t) + image := sampleImage("image-malformed-runtime") + if err := store.UpsertImage(ctx, image); err != nil { + t.Fatalf("UpsertImage: %v", err) + } now := fixedTime() _, err := store.db.ExecContext(ctx, ` INSERT INTO vms ( @@ -145,7 +150,7 @@ func TestGetVMRejectsMalformedRuntimeJSON(t *testing.T) { ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, "vm-malformed-runtime", "vm-malformed-runtime", - "image-id", + image.ID, "172.16.0.8", string(model.VMStateCreated), now.Format(time.RFC3339), @@ -199,6 +204,122 @@ func TestGetImageRejectsMalformedTimestamp(t *testing.T) { } } +func TestStoreSerializesConcurrentVMWrites(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := openTestStore(t) + image := sampleImage("image-concurrent") + if err := store.UpsertImage(ctx, image); err != nil { + t.Fatalf("UpsertImage: %v", err) + } + + vms := []model.VMRecord{ + sampleVM("vm-a", image.ID, "172.16.0.20"), + sampleVM("vm-b", image.ID, "172.16.0.21"), + } + for _, vm := range vms { + if err := store.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM(%s): %v", vm.Name, err) + } + } + + errCh := make(chan error, 32) + var wg sync.WaitGroup + for i := 0; i < 8; i++ { + vm := vms[i%len(vms)] + wg.Add(1) + go func(iter int, vm model.VMRecord) { + defer wg.Done() + for j := 0; j < 25; j++ { + vm.UpdatedAt = fixedTime().Add(time.Duration(iter*25+j) * time.Second) + if err := store.UpsertVM(ctx, vm); err != nil { + errCh <- err + return + } + if err := store.DeleteVM(ctx, vm.ID); err != nil { + errCh <- err + return + } + if err := store.UpsertVM(ctx, vm); err != nil { + errCh <- err + return + } + } + }(i, vm) + } + wg.Wait() + close(errCh) + + for err := range errCh { + if err != nil { + t.Fatalf("concurrent write error: %v", err) + } + } +} + +func TestStoreConfiguresSQLitePragmasOnPooledConnections(t *testing.T) { + t.Parallel() + + store := openTestStore(t) + store.db.SetMaxOpenConns(2) + + ctx := context.Background() + conn1, err := store.db.Conn(ctx) + if err != nil { + t.Fatalf("db.Conn(1): %v", err) + } + defer conn1.Close() + + conn2, err := store.db.Conn(ctx) + if err != nil { + t.Fatalf("db.Conn(2): %v", err) + } + defer conn2.Close() + + for i, conn := range []*sql.Conn{conn1, conn2} { + var mode string + if err := conn.QueryRowContext(ctx, "PRAGMA journal_mode;").Scan(&mode); err != nil { + t.Fatalf("conn %d PRAGMA journal_mode: %v", i+1, err) + } + if mode != "wal" { + t.Fatalf("conn %d journal_mode = %q, want wal", i+1, mode) + } + + var timeout int + if err := conn.QueryRowContext(ctx, "PRAGMA busy_timeout;").Scan(&timeout); err != nil { + t.Fatalf("conn %d PRAGMA busy_timeout: %v", i+1, err) + } + if timeout != 5000 { + t.Fatalf("conn %d busy_timeout = %d, want 5000", i+1, timeout) + } + + var foreignKeys int + if err := conn.QueryRowContext(ctx, "PRAGMA foreign_keys;").Scan(&foreignKeys); err != nil { + t.Fatalf("conn %d PRAGMA foreign_keys: %v", i+1, err) + } + if foreignKeys != 1 { + t.Fatalf("conn %d foreign_keys = %d, want 1", i+1, foreignKeys) + } + + var synchronous int + if err := conn.QueryRowContext(ctx, "PRAGMA synchronous;").Scan(&synchronous); err != nil { + t.Fatalf("conn %d PRAGMA synchronous: %v", i+1, err) + } + if synchronous != 1 { + t.Fatalf("conn %d synchronous = %d, want 1 (NORMAL)", i+1, synchronous) + } + + var tempStore int + if err := conn.QueryRowContext(ctx, "PRAGMA temp_store;").Scan(&tempStore); err != nil { + t.Fatalf("conn %d PRAGMA temp_store: %v", i+1, err) + } + if tempStore != 2 { + t.Fatalf("conn %d temp_store = %d, want 2 (MEMORY)", i+1, tempStore) + } + } +} + func openTestStore(t *testing.T) *Store { t.Helper() store, err := Open(filepath.Join(t.TempDir(), "state.db"))