package store import ( "context" "database/sql" "errors" "path/filepath" "reflect" "strconv" "strings" "sync" "testing" "time" "banger/internal/model" ) func TestStoreImageAndVMRoundTrip(t *testing.T) { t.Parallel() ctx := context.Background() store := openTestStore(t) image := sampleImage("image-one") if err := store.UpsertImage(ctx, image); err != nil { t.Fatalf("UpsertImage: %v", err) } vm := sampleVM("vm-one", image.ID, "172.16.0.8") if err := store.UpsertVM(ctx, vm); err != nil { t.Fatalf("UpsertVM: %v", err) } gotImage, err := store.GetImageByName(ctx, image.Name) if err != nil { t.Fatalf("GetImageByName: %v", err) } if !reflect.DeepEqual(gotImage, image) { t.Fatalf("GetImageByName = %+v, want %+v", gotImage, image) } gotVM, err := store.GetVM(ctx, vm.Name) if err != nil { t.Fatalf("GetVM: %v", err) } if !reflect.DeepEqual(gotVM, vm) { t.Fatalf("GetVM = %+v, want %+v", gotVM, vm) } images, err := store.ListImages(ctx) if err != nil { t.Fatalf("ListImages: %v", err) } if len(images) != 1 || !reflect.DeepEqual(images[0], image) { t.Fatalf("ListImages = %+v, want [%+v]", images, image) } vms, err := store.ListVMs(ctx) if err != nil { t.Fatalf("ListVMs: %v", err) } if len(vms) != 1 || !reflect.DeepEqual(vms[0], vm) { t.Fatalf("ListVMs = %+v, want [%+v]", vms, vm) } users, err := store.FindVMsUsingImage(ctx, image.ID) if err != nil { t.Fatalf("FindVMsUsingImage: %v", err) } if len(users) != 1 || users[0].ID != vm.ID { t.Fatalf("FindVMsUsingImage = %+v, want vm %s", users, vm.ID) } if err := store.DeleteVM(ctx, vm.ID); err != nil { t.Fatalf("DeleteVM: %v", err) } if _, err := store.GetVM(ctx, vm.ID); !errors.Is(err, sql.ErrNoRows) { t.Fatalf("GetVM after delete error = %v, want sql.ErrNoRows", err) } if err := store.DeleteImage(ctx, image.ID); err != nil { t.Fatalf("DeleteImage: %v", err) } if _, err := store.GetImageByID(ctx, image.ID); !errors.Is(err, sql.ErrNoRows) { t.Fatalf("GetImageByID after delete error = %v, want sql.ErrNoRows", err) } } func TestNextGuestIPSkipsAllocatedAddresses(t *testing.T) { t.Parallel() ctx := context.Background() store := openTestStore(t) image := sampleImage("image-next-ip") if err := store.UpsertImage(ctx, image); err != nil { t.Fatalf("UpsertImage: %v", err) } for i, ip := range []string{"172.16.0.2", "172.16.0.3", "172.16.0.5"} { vm := sampleVM("vm-next-"+strconv.Itoa(i), image.ID, ip) if err := store.UpsertVM(ctx, vm); err != nil { t.Fatalf("UpsertVM(%s): %v", ip, err) } } got, err := store.NextGuestIP(ctx, "172.16.0") if err != nil { t.Fatalf("NextGuestIP: %v", err) } if got != "172.16.0.4" { t.Fatalf("NextGuestIP = %q, want 172.16.0.4", got) } } func TestNextGuestIPReturnsErrorWhenRangeExhausted(t *testing.T) { t.Parallel() ctx := context.Background() store := openTestStore(t) image := sampleImage("image-full") if err := store.UpsertImage(ctx, image); err != nil { t.Fatalf("UpsertImage: %v", err) } for i := 2; i < 255; i++ { vm := sampleVM("vm-"+strconv.Itoa(i), image.ID, "172.16.0."+strconv.Itoa(i)) if err := store.UpsertVM(ctx, vm); err != nil { t.Fatalf("UpsertVM(%d): %v", i, err) } } _, err := store.NextGuestIP(ctx, "172.16.0") if err == nil || !strings.Contains(err.Error(), "no guest IPs available") { t.Fatalf("NextGuestIP() error = %v, want exhaustion error", err) } } func TestGetVMRejectsMalformedRuntimeJSON(t *testing.T) { t.Parallel() ctx := context.Background() store := openTestStore(t) image := sampleImage("image-malformed-runtime") if err := store.UpsertImage(ctx, image); err != nil { t.Fatalf("UpsertImage: %v", err) } now := fixedTime() _, err := store.db.ExecContext(ctx, ` INSERT INTO vms ( id, name, image_id, guest_ip, state, created_at, updated_at, last_touched_at, spec_json, runtime_json, stats_json ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, "vm-malformed-runtime", "vm-malformed-runtime", image.ID, "172.16.0.8", string(model.VMStateCreated), now.Format(time.RFC3339), now.Format(time.RFC3339), now.Format(time.RFC3339), `{"vcpu_count":2}`, `{"guest_ip":`, `{}`, ) if err != nil { t.Fatalf("insert malformed vm: %v", err) } _, err = store.GetVM(ctx, "vm-malformed-runtime") if err == nil || !strings.Contains(err.Error(), "unexpected end of JSON input") { t.Fatalf("GetVM() error = %v, want runtime JSON failure", err) } } func TestGetImageRejectsMalformedTimestamp(t *testing.T) { t.Parallel() ctx := context.Background() store := openTestStore(t) _, err := store.db.ExecContext(ctx, ` INSERT INTO images ( id, name, managed, artifact_dir, rootfs_path, kernel_path, initrd_path, modules_dir, packages_path, build_size, docker, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, "image-bad-time", "image-bad-time", 0, "", "/rootfs.ext4", "/vmlinux", "", "", "", "", 0, "not-a-time", "not-a-time", ) if err != nil { t.Fatalf("insert malformed image: %v", err) } _, err = store.GetImageByName(ctx, "image-bad-time") if err == nil || !strings.Contains(err.Error(), "cannot parse") { t.Fatalf("GetImageByName() error = %v, want timestamp parse failure", err) } } func TestStoreSerializesConcurrentVMWrites(t *testing.T) { t.Parallel() ctx := context.Background() store := openTestStore(t) image := sampleImage("image-concurrent") if err := store.UpsertImage(ctx, image); err != nil { t.Fatalf("UpsertImage: %v", err) } vms := []model.VMRecord{ sampleVM("vm-a", image.ID, "172.16.0.20"), sampleVM("vm-b", image.ID, "172.16.0.21"), } for _, vm := range vms { if err := store.UpsertVM(ctx, vm); err != nil { t.Fatalf("UpsertVM(%s): %v", vm.Name, err) } } errCh := make(chan error, 32) var wg sync.WaitGroup for i := 0; i < 8; i++ { vm := vms[i%len(vms)] wg.Add(1) go func(iter int, vm model.VMRecord) { defer wg.Done() for j := 0; j < 25; j++ { vm.UpdatedAt = fixedTime().Add(time.Duration(iter*25+j) * time.Second) if err := store.UpsertVM(ctx, vm); err != nil { errCh <- err return } if err := store.DeleteVM(ctx, vm.ID); err != nil { errCh <- err return } if err := store.UpsertVM(ctx, vm); err != nil { errCh <- err return } } }(i, vm) } wg.Wait() close(errCh) for err := range errCh { if err != nil { t.Fatalf("concurrent write error: %v", err) } } } func TestStoreConfiguresSQLitePragmasOnPooledConnections(t *testing.T) { t.Parallel() store := openTestStore(t) store.db.SetMaxOpenConns(2) ctx := context.Background() conn1, err := store.db.Conn(ctx) if err != nil { t.Fatalf("db.Conn(1): %v", err) } defer conn1.Close() conn2, err := store.db.Conn(ctx) if err != nil { t.Fatalf("db.Conn(2): %v", err) } defer conn2.Close() for i, conn := range []*sql.Conn{conn1, conn2} { var mode string if err := conn.QueryRowContext(ctx, "PRAGMA journal_mode;").Scan(&mode); err != nil { t.Fatalf("conn %d PRAGMA journal_mode: %v", i+1, err) } if mode != "wal" { t.Fatalf("conn %d journal_mode = %q, want wal", i+1, mode) } var timeout int if err := conn.QueryRowContext(ctx, "PRAGMA busy_timeout;").Scan(&timeout); err != nil { t.Fatalf("conn %d PRAGMA busy_timeout: %v", i+1, err) } if timeout != 5000 { t.Fatalf("conn %d busy_timeout = %d, want 5000", i+1, timeout) } var foreignKeys int if err := conn.QueryRowContext(ctx, "PRAGMA foreign_keys;").Scan(&foreignKeys); err != nil { t.Fatalf("conn %d PRAGMA foreign_keys: %v", i+1, err) } if foreignKeys != 1 { t.Fatalf("conn %d foreign_keys = %d, want 1", i+1, foreignKeys) } var synchronous int if err := conn.QueryRowContext(ctx, "PRAGMA synchronous;").Scan(&synchronous); err != nil { t.Fatalf("conn %d PRAGMA synchronous: %v", i+1, err) } if synchronous != 1 { t.Fatalf("conn %d synchronous = %d, want 1 (NORMAL)", i+1, synchronous) } var tempStore int if err := conn.QueryRowContext(ctx, "PRAGMA temp_store;").Scan(&tempStore); err != nil { t.Fatalf("conn %d PRAGMA temp_store: %v", i+1, err) } if tempStore != 2 { t.Fatalf("conn %d temp_store = %d, want 2 (MEMORY)", i+1, tempStore) } } } func openTestStore(t *testing.T) *Store { t.Helper() store, err := Open(filepath.Join(t.TempDir(), "state.db")) if err != nil { t.Fatalf("Open: %v", err) } t.Cleanup(func() { _ = store.Close() }) return store } func sampleImage(name string) model.Image { now := fixedTime() return model.Image{ ID: name + "-id", Name: name, Managed: true, ArtifactDir: "/artifacts/" + name, RootfsPath: "/images/" + name + ".ext4", WorkSeedPath: "/images/" + name + ".work-seed.ext4", KernelPath: "/kernels/" + name, InitrdPath: "/initrd/" + name, ModulesDir: "/modules/" + name, BuildSize: "8G", SeededSSHPublicKeyFingerprint: "seeded-fingerprint", Docker: true, CreatedAt: now, UpdatedAt: now, } } func sampleVM(name, imageID, guestIP string) model.VMRecord { now := fixedTime() return model.VMRecord{ ID: name + "-id", Name: name, ImageID: imageID, State: model.VMStateStopped, CreatedAt: now, UpdatedAt: now, LastTouchedAt: now, Spec: model.VMSpec{ VCPUCount: 2, MemoryMiB: 1024, SystemOverlaySizeByte: 8 * 1024 * 1024 * 1024, WorkDiskSizeBytes: 8 * 1024 * 1024 * 1024, NATEnabled: true, }, Runtime: model.VMRuntime{ State: model.VMStateStopped, GuestIP: guestIP, TapDevice: "tap-" + name, APISockPath: "/tmp/" + name + ".sock", LogPath: "/tmp/" + name + ".log", MetricsPath: "/tmp/" + name + ".metrics", DNSName: name + ".vm", VMDir: "/state/" + name, SystemOverlay: "/state/" + name + "/system.cow", WorkDiskPath: "/state/" + name + "/root.ext4", }, Stats: model.VMStats{ CPUPercent: 1.25, RSSBytes: 1024, VSZBytes: 2048, SystemOverlayBytes: 4096, WorkDiskBytes: 8192, MetricsRaw: map[string]any{"uptime": 12.0}, CollectedAt: now, }, } } func fixedTime() time.Time { return time.Date(2026, time.March, 16, 12, 0, 0, 0, time.UTC) }