diff --git a/internal/daemon/capabilities.go b/internal/daemon/capabilities.go index 0cbcabd..bbf37cb 100644 --- a/internal/daemon/capabilities.go +++ b/internal/daemon/capabilities.go @@ -238,6 +238,12 @@ func (natCapability) Cleanup(ctx context.Context, d *Daemon, vm model.VMRecord) if !vm.Spec.NATEnabled { return nil } + if strings.TrimSpace(vm.Runtime.GuestIP) == "" || strings.TrimSpace(vm.Runtime.TapDevice) == "" { + if d.logger != nil { + d.logger.Debug("skipping nat cleanup without runtime network handles", append(vmLogAttrs(vm), "guest_ip", vm.Runtime.GuestIP, "tap_device", vm.Runtime.TapDevice)...) + } + return nil + } return d.ensureNAT(ctx, vm, false) } diff --git a/internal/daemon/vm_test.go b/internal/daemon/vm_test.go index 5b2441c..b4fec45 100644 --- a/internal/daemon/vm_test.go +++ b/internal/daemon/vm_test.go @@ -33,9 +33,7 @@ func TestFindVMPrefixResolution(t *testing.T) { testVM("alpine", "image-alpha", "172.16.0.3"), testVM("bravo", "image-alpha", "172.16.0.4"), } { - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM(%s): %v", vm.Name, err) - } + upsertDaemonVM(t, ctx, db, vm) } vm, err := d.FindVM(ctx, "alpha") @@ -116,9 +114,7 @@ func TestReconcileStopsStaleRunningVMAndClearsRuntimeHandles(t *testing.T) { vm.Runtime.COWLoop = "/dev/loop11" vm.Runtime.BaseLoop = "/dev/loop10" vm.Runtime.DNSName = "" - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) runner := &scriptedRunner{ t: t, @@ -176,9 +172,7 @@ func TestRebuildDNSIncludesOnlyLiveRunningVMs(t *testing.T) { stopped := testVM("stopped", "image-stopped", "172.16.0.23") for _, vm := range []model.VMRecord{live, stale, stopped} { - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM(%s): %v", vm.Name, err) - } + upsertDaemonVM(t, ctx, db, vm) } server, err := vmdns.New("127.0.0.1:0", nil) @@ -224,9 +218,7 @@ func TestSetVMRejectsStoppedOnlyChangesForRunningVM(t *testing.T) { vm.Runtime.State = model.VMStateRunning vm.Runtime.PID = cmd.Process.Pid vm.Runtime.APISockPath = apiSock - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) d := &Daemon{store: db} tests := []struct { @@ -324,9 +316,7 @@ func TestPingVMReturnsAliveForRunningGuest(t *testing.T) { vm.Runtime.APISockPath = apiSock vm.Runtime.VSockPath = vsockSock vm.Runtime.VSockCID = 10041 - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) runner := &scriptedRunner{ t: t, @@ -355,9 +345,7 @@ func TestPingVMReturnsFalseForStoppedVM(t *testing.T) { ctx := context.Background() db := openDaemonStore(t) vm := testVM("stopped-ping", "image-stopped", "172.16.0.42") - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) d := &Daemon{store: db} result, err := d.PingVM(ctx, vm.Name) @@ -379,9 +367,7 @@ func TestSetVMDiskResizeFailsPreflightWhenToolsMissing(t *testing.T) { vm := testVM("resize", "image-resize", "172.16.0.11") vm.Runtime.WorkDiskPath = workDisk vm.Spec.WorkDiskSizeBytes = 8 * 1024 * 1024 * 1024 - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) t.Setenv("PATH", t.TempDir()) d := &Daemon{store: db} @@ -464,9 +450,7 @@ func TestSetVMRejectsNonPositiveCPUAndMemory(t *testing.T) { ctx := context.Background() db := openDaemonStore(t) vm := testVM("validate", "image-validate", "172.16.0.13") - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) d := &Daemon{store: db} if _, err := d.SetVM(ctx, api.VMSetParams{IDOrName: vm.ID, VCPUCount: ptr(0)}); err == nil || !strings.Contains(err.Error(), "vcpu must be a positive integer") { @@ -590,6 +574,40 @@ func TestCleanupRuntimeRediscoversLiveFirecrackerPID(t *testing.T) { } } +func TestDeleteStoppedNATVMDoesNotFailWithoutTapDevice(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openDaemonStore(t) + vmDir := filepath.Join(t.TempDir(), "stopped-nat-vm") + if err := os.MkdirAll(vmDir, 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + vm := testVM("stopped-nat", "image-stopped-nat", "172.16.0.24") + vm.Spec.NATEnabled = true + vm.Runtime.VMDir = vmDir + vm.Runtime.TapDevice = "" + vm.State = model.VMStateStopped + vm.Runtime.State = model.VMStateStopped + upsertDaemonVM(t, ctx, db, vm) + + d := &Daemon{store: db} + deleted, err := d.DeleteVM(ctx, vm.Name) + if err != nil { + t.Fatalf("DeleteVM: %v", err) + } + if deleted.ID != vm.ID { + t.Fatalf("deleted VM = %+v, want %s", deleted, vm.ID) + } + if _, err := db.GetVMByID(ctx, vm.ID); err == nil { + t.Fatal("expected VM record to be deleted") + } + if _, err := os.Stat(vmDir); !os.IsNotExist(err) { + t.Fatalf("vm dir still exists or stat failed: %v", err) + } +} + func TestStopVMFallsBackToForcedCleanupAfterGracefulTimeout(t *testing.T) { ctx := context.Background() db := openDaemonStore(t) @@ -615,9 +633,7 @@ func TestStopVMFallsBackToForcedCleanupAfterGracefulTimeout(t *testing.T) { vm.Runtime.State = model.VMStateRunning vm.Runtime.PID = fake.Process.Pid vm.Runtime.APISockPath = apiSock - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) runner := &processKillingRunner{ scriptedRunner: &scriptedRunner{ @@ -650,9 +666,7 @@ func TestWithVMLockByIDSerializesSameVM(t *testing.T) { ctx := context.Background() db := openDaemonStore(t) vm := testVM("serial", "image-serial", "172.16.0.30") - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM: %v", err) - } + upsertDaemonVM(t, ctx, db, vm) d := &Daemon{store: db} firstEntered := make(chan struct{}) @@ -710,9 +724,7 @@ func TestWithVMLockByIDAllowsDifferentVMsConcurrently(t *testing.T) { vmA := testVM("alpha-lock", "image-alpha", "172.16.0.31") vmB := testVM("bravo-lock", "image-bravo", "172.16.0.32") for _, vm := range []model.VMRecord{vmA, vmB} { - if err := db.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM(%s): %v", vm.Name, err) - } + upsertDaemonVM(t, ctx, db, vm) } d := &Daemon{store: db} @@ -760,6 +772,19 @@ func openDaemonStore(t *testing.T) *store.Store { return db } +func upsertDaemonVM(t *testing.T, ctx context.Context, db *store.Store, vm model.VMRecord) { + t.Helper() + image := testImage(vm.ImageID) + image.ID = vm.ImageID + image.Name = vm.ImageID + if err := db.UpsertImage(ctx, image); err != nil { + t.Fatalf("UpsertImage(%s): %v", image.Name, err) + } + if err := db.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM(%s): %v", vm.Name, err) + } +} + func testVM(name, imageID, guestIP string) model.VMRecord { now := time.Date(2026, time.March, 16, 12, 0, 0, 0, time.UTC) return model.VMRecord{ diff --git a/internal/store/store.go b/internal/store/store.go index 1696ee1..fa4ee27 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -6,6 +6,9 @@ import ( "encoding/json" "errors" "fmt" + "net/url" + "path/filepath" + "sync" "time" _ "modernc.org/sqlite" @@ -14,11 +17,16 @@ import ( ) type Store struct { - db *sql.DB + db *sql.DB + writeMu sync.Mutex } func Open(path string) (*Store, error) { - db, err := sql.Open("sqlite", path) + dsn, err := sqliteDSN(path) + if err != nil { + return nil, err + } + db, err := sql.Open("sqlite", dsn) if err != nil { return nil, err } @@ -34,9 +42,32 @@ func (s *Store) Close() error { return s.db.Close() } +func sqliteDSN(path string) (string, error) { + absPath, err := filepath.Abs(path) + if err != nil { + return "", fmt.Errorf("resolve sqlite path: %w", err) + } + query := url.Values{} + for _, pragma := range []string{ + "journal_mode(WAL)", + "synchronous(NORMAL)", + "foreign_keys(1)", + "busy_timeout(5000)", + "temp_store(MEMORY)", + "wal_autocheckpoint(1000)", + "journal_size_limit(67108864)", + } { + query.Add("_pragma", pragma) + } + return (&url.URL{ + Scheme: "file", + Path: filepath.ToSlash(absPath), + RawQuery: query.Encode(), + }).String(), nil +} + func (s *Store) migrate() error { stmts := []string{ - `PRAGMA journal_mode=WAL;`, `CREATE TABLE IF NOT EXISTS images ( id TEXT PRIMARY KEY, name TEXT NOT NULL UNIQUE, @@ -76,6 +107,8 @@ func (s *Store) migrate() error { } func (s *Store) UpsertImage(ctx context.Context, image model.Image) error { + s.writeMu.Lock() + defer s.writeMu.Unlock() const query = ` INSERT INTO images ( id, name, managed, artifact_dir, rootfs_path, kernel_path, initrd_path, @@ -137,11 +170,15 @@ func (s *Store) ListImages(ctx context.Context) ([]model.Image, error) { } func (s *Store) DeleteImage(ctx context.Context, id string) error { + s.writeMu.Lock() + defer s.writeMu.Unlock() _, err := s.db.ExecContext(ctx, "DELETE FROM images WHERE id = ?", id) return err } func (s *Store) UpsertVM(ctx context.Context, vm model.VMRecord) error { + s.writeMu.Lock() + defer s.writeMu.Unlock() specJSON, err := json.Marshal(vm.Spec) if err != nil { return err @@ -225,6 +262,8 @@ func (s *Store) ListVMs(ctx context.Context) ([]model.VMRecord, error) { } func (s *Store) DeleteVM(ctx context.Context, id string) error { + s.writeMu.Lock() + defer s.writeMu.Unlock() _, err := s.db.ExecContext(ctx, "DELETE FROM vms WHERE id = ?", id) return err } diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 9efb813..0cbc123 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -8,6 +8,7 @@ import ( "reflect" "strconv" "strings" + "sync" "testing" "time" @@ -137,6 +138,10 @@ func TestGetVMRejectsMalformedRuntimeJSON(t *testing.T) { ctx := context.Background() store := openTestStore(t) + image := sampleImage("image-malformed-runtime") + if err := store.UpsertImage(ctx, image); err != nil { + t.Fatalf("UpsertImage: %v", err) + } now := fixedTime() _, err := store.db.ExecContext(ctx, ` INSERT INTO vms ( @@ -145,7 +150,7 @@ func TestGetVMRejectsMalformedRuntimeJSON(t *testing.T) { ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, "vm-malformed-runtime", "vm-malformed-runtime", - "image-id", + image.ID, "172.16.0.8", string(model.VMStateCreated), now.Format(time.RFC3339), @@ -199,6 +204,122 @@ func TestGetImageRejectsMalformedTimestamp(t *testing.T) { } } +func TestStoreSerializesConcurrentVMWrites(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := openTestStore(t) + image := sampleImage("image-concurrent") + if err := store.UpsertImage(ctx, image); err != nil { + t.Fatalf("UpsertImage: %v", err) + } + + vms := []model.VMRecord{ + sampleVM("vm-a", image.ID, "172.16.0.20"), + sampleVM("vm-b", image.ID, "172.16.0.21"), + } + for _, vm := range vms { + if err := store.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM(%s): %v", vm.Name, err) + } + } + + errCh := make(chan error, 32) + var wg sync.WaitGroup + for i := 0; i < 8; i++ { + vm := vms[i%len(vms)] + wg.Add(1) + go func(iter int, vm model.VMRecord) { + defer wg.Done() + for j := 0; j < 25; j++ { + vm.UpdatedAt = fixedTime().Add(time.Duration(iter*25+j) * time.Second) + if err := store.UpsertVM(ctx, vm); err != nil { + errCh <- err + return + } + if err := store.DeleteVM(ctx, vm.ID); err != nil { + errCh <- err + return + } + if err := store.UpsertVM(ctx, vm); err != nil { + errCh <- err + return + } + } + }(i, vm) + } + wg.Wait() + close(errCh) + + for err := range errCh { + if err != nil { + t.Fatalf("concurrent write error: %v", err) + } + } +} + +func TestStoreConfiguresSQLitePragmasOnPooledConnections(t *testing.T) { + t.Parallel() + + store := openTestStore(t) + store.db.SetMaxOpenConns(2) + + ctx := context.Background() + conn1, err := store.db.Conn(ctx) + if err != nil { + t.Fatalf("db.Conn(1): %v", err) + } + defer conn1.Close() + + conn2, err := store.db.Conn(ctx) + if err != nil { + t.Fatalf("db.Conn(2): %v", err) + } + defer conn2.Close() + + for i, conn := range []*sql.Conn{conn1, conn2} { + var mode string + if err := conn.QueryRowContext(ctx, "PRAGMA journal_mode;").Scan(&mode); err != nil { + t.Fatalf("conn %d PRAGMA journal_mode: %v", i+1, err) + } + if mode != "wal" { + t.Fatalf("conn %d journal_mode = %q, want wal", i+1, mode) + } + + var timeout int + if err := conn.QueryRowContext(ctx, "PRAGMA busy_timeout;").Scan(&timeout); err != nil { + t.Fatalf("conn %d PRAGMA busy_timeout: %v", i+1, err) + } + if timeout != 5000 { + t.Fatalf("conn %d busy_timeout = %d, want 5000", i+1, timeout) + } + + var foreignKeys int + if err := conn.QueryRowContext(ctx, "PRAGMA foreign_keys;").Scan(&foreignKeys); err != nil { + t.Fatalf("conn %d PRAGMA foreign_keys: %v", i+1, err) + } + if foreignKeys != 1 { + t.Fatalf("conn %d foreign_keys = %d, want 1", i+1, foreignKeys) + } + + var synchronous int + if err := conn.QueryRowContext(ctx, "PRAGMA synchronous;").Scan(&synchronous); err != nil { + t.Fatalf("conn %d PRAGMA synchronous: %v", i+1, err) + } + if synchronous != 1 { + t.Fatalf("conn %d synchronous = %d, want 1 (NORMAL)", i+1, synchronous) + } + + var tempStore int + if err := conn.QueryRowContext(ctx, "PRAGMA temp_store;").Scan(&tempStore); err != nil { + t.Fatalf("conn %d PRAGMA temp_store: %v", i+1, err) + } + if tempStore != 2 { + t.Fatalf("conn %d temp_store = %d, want 2 (MEMORY)", i+1, tempStore) + } + } +} + func openTestStore(t *testing.T) *Store { t.Helper() store, err := Open(filepath.Join(t.TempDir(), "state.db"))