diff --git a/internal/daemon/vm_test.go b/internal/daemon/vm_test.go new file mode 100644 index 0000000..9e0f9e9 --- /dev/null +++ b/internal/daemon/vm_test.go @@ -0,0 +1,394 @@ +package daemon + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "banger/internal/api" + "banger/internal/model" + "banger/internal/store" +) + +func TestFindVMPrefixResolution(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openDaemonStore(t) + d := &Daemon{store: db} + + for _, vm := range []model.VMRecord{ + testVM("alpha", "image-alpha", "172.16.0.2"), + testVM("alpine", "image-alpha", "172.16.0.3"), + testVM("bravo", "image-alpha", "172.16.0.4"), + } { + if err := db.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM(%s): %v", vm.Name, err) + } + } + + vm, err := d.FindVM(ctx, "alpha") + if err != nil || vm.Name != "alpha" { + t.Fatalf("FindVM(alpha) = %+v, %v", vm, err) + } + + vm, err = d.FindVM(ctx, "br") + if err != nil || vm.Name != "bravo" { + t.Fatalf("FindVM(br) = %+v, %v", vm, err) + } + + _, err = d.FindVM(ctx, "al") + if err == nil || !strings.Contains(err.Error(), "multiple VMs match") { + t.Fatalf("FindVM(al) error = %v, want ambiguity", err) + } + + _, err = d.FindVM(ctx, "missing") + if err == nil || !strings.Contains(err.Error(), `vm "missing" not found`) { + t.Fatalf("FindVM(missing) error = %v, want not-found", err) + } +} + +func TestFindImagePrefixResolution(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openDaemonStore(t) + d := &Daemon{store: db} + + for _, image := range []model.Image{ + testImage("base"), + testImage("basic"), + testImage("docker"), + } { + if err := db.UpsertImage(ctx, image); err != nil { + t.Fatalf("UpsertImage(%s): %v", image.Name, err) + } + } + + image, err := d.FindImage(ctx, "base") + if err != nil || image.Name != "base" { + t.Fatalf("FindImage(base) = %+v, %v", image, err) + } + + image, err = d.FindImage(ctx, "dock") + if err != nil || image.Name != "docker" { + t.Fatalf("FindImage(dock) = %+v, %v", image, err) + } + + _, err = d.FindImage(ctx, "ba") + if err == nil || !strings.Contains(err.Error(), "multiple images match") { + t.Fatalf("FindImage(ba) error = %v, want ambiguity", err) + } + + _, err = d.FindImage(ctx, "missing") + if err == nil || !strings.Contains(err.Error(), `image "missing" not found`) { + t.Fatalf("FindImage(missing) error = %v, want not-found", err) + } +} + +func TestReconcileStopsStaleRunningVMAndClearsRuntimeHandles(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openDaemonStore(t) + apiSock := filepath.Join(t.TempDir(), "fc.sock") + if err := os.WriteFile(apiSock, []byte{}, 0o644); err != nil { + t.Fatalf("WriteFile(api sock): %v", err) + } + vm := testVM("stale", "image-stale", "172.16.0.9") + vm.State = model.VMStateRunning + vm.Runtime.State = model.VMStateRunning + vm.Runtime.PID = 999999 + vm.Runtime.APISockPath = apiSock + vm.Runtime.DMName = "fc-rootfs-stale" + vm.Runtime.DMDev = "/dev/mapper/fc-rootfs-stale" + vm.Runtime.COWLoop = "/dev/loop11" + vm.Runtime.BaseLoop = "/dev/loop10" + vm.Runtime.DNSName = "" + if err := db.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM: %v", err) + } + + runner := &scriptedRunner{ + t: t, + steps: []runnerStep{ + sudoStep("", nil, "dmsetup", "remove", "fc-rootfs-stale"), + sudoStep("", nil, "losetup", "-d", "/dev/loop11"), + sudoStep("", nil, "losetup", "-d", "/dev/loop10"), + }, + } + d := &Daemon{store: db, runner: runner} + + if err := d.reconcile(ctx); err != nil { + t.Fatalf("reconcile: %v", err) + } + runner.assertExhausted() + + got, err := db.GetVM(ctx, vm.ID) + if err != nil { + t.Fatalf("GetVM: %v", err) + } + if got.State != model.VMStateStopped || got.Runtime.State != model.VMStateStopped { + t.Fatalf("vm state after reconcile = %s/%s, want stopped", got.State, got.Runtime.State) + } + if got.Runtime.PID != 0 || got.Runtime.APISockPath != "" || got.Runtime.DMName != "" || got.Runtime.COWLoop != "" || got.Runtime.BaseLoop != "" { + t.Fatalf("runtime handles not cleared after reconcile: %+v", got.Runtime) + } +} + +func TestSetVMRejectsStoppedOnlyChangesForRunningVM(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openDaemonStore(t) + apiSock := filepath.Join(t.TempDir(), "running.sock") + cmd := startFakeFirecrackerProcess(t, apiSock) + t.Cleanup(func() { + _ = cmd.Process.Kill() + _ = cmd.Wait() + }) + + vm := testVM("running", "image-run", "172.16.0.10") + vm.State = model.VMStateRunning + vm.Runtime.State = model.VMStateRunning + vm.Runtime.PID = cmd.Process.Pid + vm.Runtime.APISockPath = apiSock + if err := db.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM: %v", err) + } + + d := &Daemon{store: db} + tests := []struct { + name string + params api.VMSetParams + want string + }{ + { + name: "vcpu", + params: api.VMSetParams{IDOrName: vm.ID, VCPUCount: ptr(4)}, + want: "vcpu changes require the VM to be stopped", + }, + { + name: "memory", + params: api.VMSetParams{IDOrName: vm.ID, MemoryMiB: ptr(2048)}, + want: "memory changes require the VM to be stopped", + }, + { + name: "disk", + params: api.VMSetParams{IDOrName: vm.ID, WorkDiskSize: "16G"}, + want: "disk changes require the VM to be stopped", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := d.SetVM(ctx, tt.params) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("SetVM(%s) error = %v, want %q", tt.name, err, tt.want) + } + }) + } +} + +func TestSetVMDiskResizeFailsPreflightWhenToolsMissing(t *testing.T) { + ctx := context.Background() + db := openDaemonStore(t) + workDisk := filepath.Join(t.TempDir(), "root.ext4") + if err := os.WriteFile(workDisk, []byte("disk"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + vm := testVM("resize", "image-resize", "172.16.0.11") + vm.Runtime.WorkDiskPath = workDisk + vm.Spec.WorkDiskSizeBytes = 8 * 1024 * 1024 * 1024 + if err := db.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM: %v", err) + } + + t.Setenv("PATH", t.TempDir()) + d := &Daemon{store: db} + _, err := d.SetVM(ctx, api.VMSetParams{IDOrName: vm.ID, WorkDiskSize: "16G"}) + if err == nil || !strings.Contains(err.Error(), "work disk resize preflight failed") { + t.Fatalf("SetVM() error = %v, want preflight failure", err) + } +} + +func TestCollectStatsIgnoresMalformedMetricsFile(t *testing.T) { + t.Parallel() + + overlay := filepath.Join(t.TempDir(), "system.cow") + workDisk := filepath.Join(t.TempDir(), "root.ext4") + metrics := filepath.Join(t.TempDir(), "metrics.json") + for _, path := range []string{overlay, workDisk} { + if err := os.WriteFile(path, []byte("allocated"), 0o644); err != nil { + t.Fatalf("WriteFile(%s): %v", path, err) + } + } + if err := os.WriteFile(metrics, []byte("{not-json}\n"), 0o644); err != nil { + t.Fatalf("WriteFile(metrics): %v", err) + } + + d := &Daemon{} + stats, err := d.collectStats(context.Background(), model.VMRecord{ + Runtime: model.VMRuntime{ + SystemOverlay: overlay, + WorkDiskPath: workDisk, + MetricsPath: metrics, + }, + }) + if err != nil { + t.Fatalf("collectStats: %v", err) + } + if stats.MetricsRaw != nil { + t.Fatalf("MetricsRaw = %v, want nil for malformed metrics", stats.MetricsRaw) + } + if stats.SystemOverlayBytes == 0 || stats.WorkDiskBytes == 0 { + t.Fatalf("allocated bytes not captured: %+v", stats) + } +} + +func TestValidateStartPrereqsReportsNATUplinkFailure(t *testing.T) { + ctx := context.Background() + binDir := t.TempDir() + for _, name := range []string{ + "sudo", "ip", "dmsetup", "losetup", "blockdev", "truncate", "pgrep", "ps", + "chown", "chmod", "kill", "e2cp", "e2rm", "debugfs", "mkfs.ext4", "mount", + "umount", "cp", "iptables", "sysctl", "mapdns", + } { + writeFakeExecutable(t, filepath.Join(binDir, name)) + } + t.Setenv("PATH", binDir) + + firecrackerBin := filepath.Join(t.TempDir(), "firecracker") + rootfsPath := filepath.Join(t.TempDir(), "rootfs.ext4") + kernelPath := filepath.Join(t.TempDir(), "vmlinux") + for _, path := range []string{firecrackerBin, rootfsPath, kernelPath} { + writeFakeExecutable(t, path) + } + + runner := &scriptedRunner{ + t: t, + steps: []runnerStep{ + {call: runnerCall{name: "ip", args: []string{"route", "show", "default"}}, out: []byte("10.0.0.0/24 dev br-fc\n")}, + }, + } + d := &Daemon{ + runner: runner, + config: model.DaemonConfig{ + FirecrackerBin: firecrackerBin, + MapDNSBin: "mapdns", + }, + } + vm := testVM("nat", "image-nat", "172.16.0.12") + vm.Spec.NATEnabled = true + vm.Runtime.WorkDiskPath = filepath.Join(t.TempDir(), "missing-root.ext4") + image := testImage("image-nat") + image.RootfsPath = rootfsPath + image.KernelPath = kernelPath + + err := d.validateStartPrereqs(ctx, vm, image) + if err == nil || !strings.Contains(err.Error(), "uplink interface for NAT") { + t.Fatalf("validateStartPrereqs() error = %v, want NAT uplink failure", err) + } + runner.assertExhausted() +} + +func openDaemonStore(t *testing.T) *store.Store { + t.Helper() + db, err := store.Open(filepath.Join(t.TempDir(), "state.db")) + if err != nil { + t.Fatalf("store.Open: %v", err) + } + t.Cleanup(func() { + _ = db.Close() + }) + return db +} + +func testVM(name, imageID, guestIP string) model.VMRecord { + now := time.Date(2026, time.March, 16, 12, 0, 0, 0, time.UTC) + return model.VMRecord{ + ID: name + "-id", + Name: name, + ImageID: imageID, + State: model.VMStateStopped, + CreatedAt: now, + UpdatedAt: now, + LastTouchedAt: now, + Spec: model.VMSpec{ + VCPUCount: 2, + MemoryMiB: 1024, + SystemOverlaySizeByte: model.DefaultSystemOverlaySize, + WorkDiskSizeBytes: model.DefaultWorkDiskSize, + }, + Runtime: model.VMRuntime{ + State: model.VMStateStopped, + GuestIP: guestIP, + DNSName: name + ".vm", + VMDir: filepath.Join("/state", name), + SystemOverlay: filepath.Join("/state", name, "system.cow"), + WorkDiskPath: filepath.Join("/state", name, "root.ext4"), + }, + } +} + +func testImage(name string) model.Image { + now := time.Date(2026, time.March, 16, 12, 0, 0, 0, time.UTC) + return model.Image{ + ID: name + "-id", + Name: name, + RootfsPath: filepath.Join("/images", name+".ext4"), + KernelPath: filepath.Join("/kernels", name), + CreatedAt: now, + UpdatedAt: now, + } +} + +func startFakeFirecrackerProcess(t *testing.T, apiSock string) *exec.Cmd { + t.Helper() + + cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 30", "firecracker --api-sock "+apiSock)) + if err := cmd.Start(); err != nil { + t.Fatalf("start fake firecracker: %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if cmd.Process != nil && cmd.Process.Pid > 0 && systemProcessRunning(cmd.Process.Pid, apiSock) { + return cmd + } + time.Sleep(20 * time.Millisecond) + } + _ = cmd.Process.Kill() + _ = cmd.Wait() + t.Fatalf("fake firecracker process never looked running for %s", apiSock) + return nil +} + +func systemProcessRunning(pid int, apiSock string) bool { + data, err := os.ReadFile(filepath.Join("/proc", strconv.Itoa(pid), "cmdline")) + if err != nil { + return false + } + cmdline := strings.ReplaceAll(string(data), "\x00", " ") + return strings.Contains(cmdline, "firecracker") && strings.Contains(cmdline, apiSock) +} + +func writeFakeExecutable(t *testing.T, path string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("MkdirAll(%s): %v", filepath.Dir(path), err) + } + if err := os.WriteFile(path, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("WriteFile(%s): %v", path, err) + } +} + +func ptr[T any](value T) *T { + return &value +} diff --git a/internal/rpc/rpc_test.go b/internal/rpc/rpc_test.go new file mode 100644 index 0000000..b469fcd --- /dev/null +++ b/internal/rpc/rpc_test.go @@ -0,0 +1,175 @@ +package rpc + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "net" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestDecodeParams(t *testing.T) { + t.Parallel() + + type payload struct { + Name string `json:"name"` + } + + got, err := DecodeParams[payload](Request{}) + if err != nil { + t.Fatalf("DecodeParams(empty): %v", err) + } + if got.Name != "" { + t.Fatalf("DecodeParams(empty) = %+v, want zero value", got) + } + + _, err = DecodeParams[payload](Request{Params: json.RawMessage(`{"name":`)}) + if err == nil { + t.Fatal("DecodeParams(malformed) returned nil error") + } +} + +func TestCallRoundTripSuccess(t *testing.T) { + t.Parallel() + + socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) { + defer conn.Close() + var req Request + if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + if req.Version != Version || req.Method != "ping" { + t.Fatalf("unexpected request: %+v", req) + } + var params map[string]string + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + t.Fatalf("unmarshal params: %v", err) + } + if params["name"] != "devbox" { + t.Fatalf("params = %v, want name=devbox", params) + } + resp, err := NewResult(map[string]string{"status": "ok"}) + if err != nil { + t.Fatalf("NewResult: %v", err) + } + if err := json.NewEncoder(conn).Encode(resp); err != nil { + t.Fatalf("encode response: %v", err) + } + }) + defer cleanup() + + result, err := Call[map[string]string](context.Background(), socketPath, "ping", map[string]string{"name": "devbox"}) + if err != nil { + t.Fatalf("Call: %v", err) + } + if result["status"] != "ok" { + t.Fatalf("Call() result = %v, want status=ok", result) + } +} + +func TestCallReturnsRemoteError(t *testing.T) { + t.Parallel() + + socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) { + defer conn.Close() + var req Request + if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + if err := json.NewEncoder(conn).Encode(NewError("operation_failed", "boom")); err != nil { + t.Fatalf("encode error response: %v", err) + } + }) + defer cleanup() + + _, err := Call[map[string]string](context.Background(), socketPath, "ping", nil) + if err == nil || !strings.Contains(err.Error(), "operation_failed: boom") { + t.Fatalf("Call() error = %v, want remote error", err) + } +} + +func TestCallRejectsMalformedResponse(t *testing.T) { + t.Parallel() + + socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) { + defer conn.Close() + _, _ = conn.Write([]byte("{not-json}\n")) + }) + defer cleanup() + + _, err := Call[map[string]string](context.Background(), socketPath, "ping", nil) + if err == nil { + t.Fatal("Call() returned nil error for malformed response") + } +} + +func TestCallHonorsContextDeadline(t *testing.T) { + t.Parallel() + + socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) { + defer conn.Close() + var req Request + if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + time.Sleep(100 * time.Millisecond) + }) + defer cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + _, err := Call[map[string]string](ctx, socketPath, "ping", nil) + if err == nil { + t.Fatal("Call() returned nil error for deadline") + } +} + +func TestWaitForSocket(t *testing.T) { + t.Parallel() + + socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) { + _ = conn.Close() + }) + defer cleanup() + + if err := WaitForSocket(socketPath, 2*time.Second); err != nil { + t.Fatalf("WaitForSocket(success): %v", err) + } + + err := WaitForSocket(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond) + if err == nil || !strings.Contains(err.Error(), "not ready") { + t.Fatalf("WaitForSocket(timeout) error = %v, want timeout", err) + } +} + +func serveRPCOnce(t *testing.T, handler func(net.Conn)) (string, func()) { + t.Helper() + + socketPath := filepath.Join(t.TempDir(), "rpc.sock") + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("Listen: %v", err) + } + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := listener.Accept() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + t.Errorf("Accept: %v", err) + } + return + } + handler(conn) + }() + + cleanup := func() { + _ = listener.Close() + <-done + } + return socketPath, cleanup +} diff --git a/internal/store/store_test.go b/internal/store/store_test.go new file mode 100644 index 0000000..9efb813 --- /dev/null +++ b/internal/store/store_test.go @@ -0,0 +1,276 @@ +package store + +import ( + "context" + "database/sql" + "errors" + "path/filepath" + "reflect" + "strconv" + "strings" + "testing" + "time" + + "banger/internal/model" +) + +func TestStoreImageAndVMRoundTrip(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := openTestStore(t) + + image := sampleImage("image-one") + if err := store.UpsertImage(ctx, image); err != nil { + t.Fatalf("UpsertImage: %v", err) + } + + vm := sampleVM("vm-one", image.ID, "172.16.0.8") + if err := store.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM: %v", err) + } + + gotImage, err := store.GetImageByName(ctx, image.Name) + if err != nil { + t.Fatalf("GetImageByName: %v", err) + } + if !reflect.DeepEqual(gotImage, image) { + t.Fatalf("GetImageByName = %+v, want %+v", gotImage, image) + } + + gotVM, err := store.GetVM(ctx, vm.Name) + if err != nil { + t.Fatalf("GetVM: %v", err) + } + if !reflect.DeepEqual(gotVM, vm) { + t.Fatalf("GetVM = %+v, want %+v", gotVM, vm) + } + + images, err := store.ListImages(ctx) + if err != nil { + t.Fatalf("ListImages: %v", err) + } + if len(images) != 1 || !reflect.DeepEqual(images[0], image) { + t.Fatalf("ListImages = %+v, want [%+v]", images, image) + } + + vms, err := store.ListVMs(ctx) + if err != nil { + t.Fatalf("ListVMs: %v", err) + } + if len(vms) != 1 || !reflect.DeepEqual(vms[0], vm) { + t.Fatalf("ListVMs = %+v, want [%+v]", vms, vm) + } + + users, err := store.FindVMsUsingImage(ctx, image.ID) + if err != nil { + t.Fatalf("FindVMsUsingImage: %v", err) + } + if len(users) != 1 || users[0].ID != vm.ID { + t.Fatalf("FindVMsUsingImage = %+v, want vm %s", users, vm.ID) + } + + if err := store.DeleteVM(ctx, vm.ID); err != nil { + t.Fatalf("DeleteVM: %v", err) + } + if _, err := store.GetVM(ctx, vm.ID); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("GetVM after delete error = %v, want sql.ErrNoRows", err) + } + + if err := store.DeleteImage(ctx, image.ID); err != nil { + t.Fatalf("DeleteImage: %v", err) + } + if _, err := store.GetImageByID(ctx, image.ID); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("GetImageByID after delete error = %v, want sql.ErrNoRows", err) + } +} + +func TestNextGuestIPSkipsAllocatedAddresses(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := openTestStore(t) + image := sampleImage("image-next-ip") + if err := store.UpsertImage(ctx, image); err != nil { + t.Fatalf("UpsertImage: %v", err) + } + for i, ip := range []string{"172.16.0.2", "172.16.0.3", "172.16.0.5"} { + vm := sampleVM("vm-next-"+strconv.Itoa(i), image.ID, ip) + if err := store.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM(%s): %v", ip, err) + } + } + + got, err := store.NextGuestIP(ctx, "172.16.0") + if err != nil { + t.Fatalf("NextGuestIP: %v", err) + } + if got != "172.16.0.4" { + t.Fatalf("NextGuestIP = %q, want 172.16.0.4", got) + } +} + +func TestNextGuestIPReturnsErrorWhenRangeExhausted(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := openTestStore(t) + image := sampleImage("image-full") + if err := store.UpsertImage(ctx, image); err != nil { + t.Fatalf("UpsertImage: %v", err) + } + for i := 2; i < 255; i++ { + vm := sampleVM("vm-"+strconv.Itoa(i), image.ID, "172.16.0."+strconv.Itoa(i)) + if err := store.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM(%d): %v", i, err) + } + } + + _, err := store.NextGuestIP(ctx, "172.16.0") + if err == nil || !strings.Contains(err.Error(), "no guest IPs available") { + t.Fatalf("NextGuestIP() error = %v, want exhaustion error", err) + } +} + +func TestGetVMRejectsMalformedRuntimeJSON(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := openTestStore(t) + now := fixedTime() + _, err := store.db.ExecContext(ctx, ` + INSERT INTO vms ( + id, name, image_id, guest_ip, state, created_at, updated_at, last_touched_at, + spec_json, runtime_json, stats_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + "vm-malformed-runtime", + "vm-malformed-runtime", + "image-id", + "172.16.0.8", + string(model.VMStateCreated), + now.Format(time.RFC3339), + now.Format(time.RFC3339), + now.Format(time.RFC3339), + `{"vcpu_count":2}`, + `{"guest_ip":`, + `{}`, + ) + if err != nil { + t.Fatalf("insert malformed vm: %v", err) + } + + _, err = store.GetVM(ctx, "vm-malformed-runtime") + if err == nil || !strings.Contains(err.Error(), "unexpected end of JSON input") { + t.Fatalf("GetVM() error = %v, want runtime JSON failure", err) + } +} + +func TestGetImageRejectsMalformedTimestamp(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := openTestStore(t) + _, err := store.db.ExecContext(ctx, ` + INSERT INTO images ( + id, name, managed, artifact_dir, rootfs_path, kernel_path, initrd_path, + modules_dir, packages_path, build_size, docker, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + "image-bad-time", + "image-bad-time", + 0, + "", + "/rootfs.ext4", + "/vmlinux", + "", + "", + "", + "", + 0, + "not-a-time", + "not-a-time", + ) + if err != nil { + t.Fatalf("insert malformed image: %v", err) + } + + _, err = store.GetImageByName(ctx, "image-bad-time") + if err == nil || !strings.Contains(err.Error(), "cannot parse") { + t.Fatalf("GetImageByName() error = %v, want timestamp parse failure", err) + } +} + +func openTestStore(t *testing.T) *Store { + t.Helper() + store, err := Open(filepath.Join(t.TempDir(), "state.db")) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { + _ = store.Close() + }) + return store +} + +func sampleImage(name string) model.Image { + now := fixedTime() + return model.Image{ + ID: name + "-id", + Name: name, + Managed: true, + ArtifactDir: "/artifacts/" + name, + RootfsPath: "/images/" + name + ".ext4", + KernelPath: "/kernels/" + name, + InitrdPath: "/initrd/" + name, + ModulesDir: "/modules/" + name, + PackagesPath: "/packages/" + name + ".apt", + BuildSize: "8G", + Docker: true, + CreatedAt: now, + UpdatedAt: now, + } +} + +func sampleVM(name, imageID, guestIP string) model.VMRecord { + now := fixedTime() + return model.VMRecord{ + ID: name + "-id", + Name: name, + ImageID: imageID, + State: model.VMStateStopped, + CreatedAt: now, + UpdatedAt: now, + LastTouchedAt: now, + Spec: model.VMSpec{ + VCPUCount: 2, + MemoryMiB: 1024, + SystemOverlaySizeByte: 8 * 1024 * 1024 * 1024, + WorkDiskSizeBytes: 8 * 1024 * 1024 * 1024, + NATEnabled: true, + }, + Runtime: model.VMRuntime{ + State: model.VMStateStopped, + GuestIP: guestIP, + TapDevice: "tap-" + name, + APISockPath: "/tmp/" + name + ".sock", + LogPath: "/tmp/" + name + ".log", + MetricsPath: "/tmp/" + name + ".metrics", + DNSName: name + ".vm", + VMDir: "/state/" + name, + SystemOverlay: "/state/" + name + "/system.cow", + WorkDiskPath: "/state/" + name + "/root.ext4", + }, + Stats: model.VMStats{ + CPUPercent: 1.25, + RSSBytes: 1024, + VSZBytes: 2048, + SystemOverlayBytes: 4096, + WorkDiskBytes: 8192, + MetricsRaw: map[string]any{"uptime": 12.0}, + CollectedAt: now, + }, + } +} + +func fixedTime() time.Time { + return time.Date(2026, time.March, 16, 12, 0, 0, 0, time.UTC) +} diff --git a/internal/system/system_test.go b/internal/system/system_test.go new file mode 100644 index 0000000..d5a7f27 --- /dev/null +++ b/internal/system/system_test.go @@ -0,0 +1,312 @@ +package system + +import ( + "context" + "errors" + "os" + "path/filepath" + "reflect" + "strings" + "testing" +) + +type systemCall struct { + sudo bool + name string + args []string +} + +type systemStep struct { + call systemCall + out []byte + err error +} + +type scriptedRunner struct { + t *testing.T + steps []systemStep + calls []systemCall +} + +func (r *scriptedRunner) Run(ctx context.Context, name string, args ...string) ([]byte, error) { + return r.next(systemCall{name: name, args: append([]string(nil), args...)}) +} + +func (r *scriptedRunner) RunSudo(ctx context.Context, args ...string) ([]byte, error) { + return r.next(systemCall{sudo: true, args: append([]string(nil), args...)}) +} + +func (r *scriptedRunner) next(call systemCall) ([]byte, error) { + r.t.Helper() + r.calls = append(r.calls, call) + if len(r.steps) == 0 { + r.t.Fatalf("unexpected call: %+v", call) + } + step := r.steps[0] + r.steps = r.steps[1:] + if step.call.sudo != call.sudo || step.call.name != call.name || !reflect.DeepEqual(step.call.args, call.args) { + r.t.Fatalf("call mismatch:\n got: %+v\n want: %+v", call, step.call) + } + return step.out, step.err +} + +func (r *scriptedRunner) assertExhausted() { + r.t.Helper() + if len(r.steps) != 0 { + r.t.Fatalf("unconsumed steps: %+v", r.steps) + } +} + +type funcRunner struct { + run func(ctx context.Context, name string, args ...string) ([]byte, error) + runSudo func(ctx context.Context, args ...string) ([]byte, error) +} + +func (r funcRunner) Run(ctx context.Context, name string, args ...string) ([]byte, error) { + if r.run == nil { + return nil, errors.New("unexpected Run call") + } + return r.run(ctx, name, args...) +} + +func (r funcRunner) RunSudo(ctx context.Context, args ...string) ([]byte, error) { + if r.runSudo == nil { + return nil, errors.New("unexpected RunSudo call") + } + return r.runSudo(ctx, args...) +} + +func TestResizeExt4ImageStopsAtFirstFailure(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + steps []systemStep + wantErr string + wantCalls int + }{ + { + name: "truncate failure", + steps: []systemStep{ + {call: systemCall{name: "truncate", args: []string{"-s", "4096", "/tmp/root.ext4"}}, err: errors.New("truncate failed")}, + }, + wantErr: "truncate failed", + wantCalls: 1, + }, + { + name: "e2fsck failure", + steps: []systemStep{ + {call: systemCall{name: "truncate", args: []string{"-s", "4096", "/tmp/root.ext4"}}}, + {call: systemCall{name: "e2fsck", args: []string{"-p", "-f", "/tmp/root.ext4"}}, err: errors.New("e2fsck failed")}, + }, + wantErr: "e2fsck failed", + wantCalls: 2, + }, + { + name: "resize2fs failure", + steps: []systemStep{ + {call: systemCall{name: "truncate", args: []string{"-s", "4096", "/tmp/root.ext4"}}}, + {call: systemCall{name: "e2fsck", args: []string{"-p", "-f", "/tmp/root.ext4"}}}, + {call: systemCall{name: "resize2fs", args: []string{"/tmp/root.ext4"}}, err: errors.New("resize2fs failed")}, + }, + wantErr: "resize2fs failed", + wantCalls: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + runner := &scriptedRunner{t: t, steps: tt.steps} + err := ResizeExt4Image(context.Background(), runner, "/tmp/root.ext4", 4096) + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("ResizeExt4Image() error = %v, want %q", err, tt.wantErr) + } + if len(runner.calls) != tt.wantCalls { + t.Fatalf("call count = %d, want %d", len(runner.calls), tt.wantCalls) + } + runner.assertExhausted() + }) + } +} + +func TestWriteExt4FileRemovesTempFileAndReturnsCopyError(t *testing.T) { + t.Parallel() + + copyErr := errors.New("e2cp failed") + var tempPath string + runner := funcRunner{ + runSudo: func(ctx context.Context, args ...string) ([]byte, error) { + switch args[0] { + case "e2rm": + return nil, errors.New("ignore remove") + case "e2cp": + tempPath = args[1] + if _, err := os.Stat(tempPath); err != nil { + t.Fatalf("temp file missing during e2cp: %v", err) + } + return nil, copyErr + default: + t.Fatalf("unexpected sudo call: %v", args) + return nil, nil + } + }, + } + + err := WriteExt4File(context.Background(), runner, "/tmp/root.ext4", "/etc/hostname", []byte("devbox\n")) + if !errors.Is(err, copyErr) { + t.Fatalf("WriteExt4File() error = %v, want %v", err, copyErr) + } + if tempPath == "" { + t.Fatal("expected e2cp temp path to be recorded") + } + if _, err := os.Stat(tempPath); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("temp file still exists after WriteExt4File: %v", err) + } +} + +func TestMountTempDirUsesLoopForRegularFilesAndCleanupUsesBackgroundContext(t *testing.T) { + t.Parallel() + + source := filepath.Join(t.TempDir(), "root.ext4") + if err := os.WriteFile(source, []byte("rootfs"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + var mountDir string + calls := 0 + runner := funcRunner{ + runSudo: func(callCtx context.Context, args ...string) ([]byte, error) { + calls++ + switch calls { + case 1: + mountDir = args[len(args)-1] + want := []string{"mount", "-o", "ro,loop", source, mountDir} + if !reflect.DeepEqual(args, want) { + t.Fatalf("mount args = %v, want %v", args, want) + } + case 2: + if callCtx.Err() != nil { + t.Fatalf("cleanup context should not be canceled: %v", callCtx.Err()) + } + want := []string{"umount", mountDir} + if !reflect.DeepEqual(args, want) { + t.Fatalf("cleanup args = %v, want %v", args, want) + } + default: + t.Fatalf("unexpected RunSudo call %d: %v", calls, args) + } + return nil, nil + }, + } + + gotMountDir, cleanup, err := MountTempDir(ctx, runner, source, true) + if err != nil { + t.Fatalf("MountTempDir: %v", err) + } + if gotMountDir != mountDir { + t.Fatalf("mount dir = %q, want %q", gotMountDir, mountDir) + } + cancel() + if err := cleanup(); err != nil { + t.Fatalf("cleanup: %v", err) + } + if _, err := os.Stat(mountDir); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("mount dir still exists after cleanup: %v", err) + } +} + +func TestMountTempDirRemovesTempDirWhenMountFails(t *testing.T) { + t.Parallel() + + source := t.TempDir() + var mountDir string + runner := funcRunner{ + runSudo: func(ctx context.Context, args ...string) ([]byte, error) { + mountDir = args[len(args)-1] + return nil, errors.New("mount failed") + }, + } + + _, _, err := MountTempDir(context.Background(), runner, source, false) + if err == nil || !strings.Contains(err.Error(), "mount failed") { + t.Fatalf("MountTempDir() error = %v, want mount failure", err) + } + if mountDir == "" { + t.Fatal("expected mount path to be recorded") + } + if _, statErr := os.Stat(mountDir); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("mount dir still exists after failed mount: %v", statErr) + } +} + +func TestParseMetricsFileHandlesWholeAndLineDelimitedJSON(t *testing.T) { + t.Parallel() + + full := filepath.Join(t.TempDir(), "metrics-full.json") + if err := os.WriteFile(full, []byte(`{"uptime":1,"signals":{"sigterm":0}}`), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + got := ParseMetricsFile(full) + if got["uptime"] != float64(1) { + t.Fatalf("ParseMetricsFile(full) = %v", got) + } + + lines := filepath.Join(t.TempDir(), "metrics-lines.json") + if err := os.WriteFile(lines, []byte("junk\n{\"uptime\":2}\n"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + got = ParseMetricsFile(lines) + if got["uptime"] != float64(2) { + t.Fatalf("ParseMetricsFile(lines) = %v", got) + } +} + +func TestUpdateFSTabStripsLegacyMountsAndAddsDefaultsOnce(t *testing.T) { + t.Parallel() + + input := strings.Join([]string{ + "/dev/vda / ext4 defaults 0 1", + "/dev/vdb /home ext4 defaults 0 2", + "/dev/vdc /var ext4 defaults 0 2", + "/dev/vdb /root ext4 defaults 0 2", + "tmpfs /run tmpfs defaults,nodev,nosuid,mode=0755 0 0", + "", + }, "\n") + + got := UpdateFSTab(input) + if strings.Contains(got, "/home") || strings.Contains(got, "/var") { + t.Fatalf("UpdateFSTab() kept legacy mounts: %q", got) + } + if strings.Count(got, "/dev/vdb /root ext4 defaults 0 2") != 1 { + t.Fatalf("UpdateFSTab() duplicated /root mount: %q", got) + } + if strings.Count(got, "tmpfs /run tmpfs defaults,nodev,nosuid,mode=0755 0 0") != 1 { + t.Fatalf("UpdateFSTab() duplicated /run mount: %q", got) + } + if !strings.Contains(got, "tmpfs /tmp tmpfs defaults,nodev,nosuid,mode=1777 0 0") { + t.Fatalf("UpdateFSTab() missing /tmp mount: %q", got) + } +} + +func TestUseLoopMount(t *testing.T) { + t.Parallel() + + file := filepath.Join(t.TempDir(), "root.ext4") + if err := os.WriteFile(file, []byte("rootfs"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + dir := t.TempDir() + + if !useLoopMount(file) { + t.Fatalf("useLoopMount(%s) = false, want true", file) + } + if useLoopMount(dir) { + t.Fatalf("useLoopMount(%s) = true, want false", dir) + } + if useLoopMount(filepath.Join(dir, "missing")) { + t.Fatalf("useLoopMount(missing) = true, want false") + } +} diff --git a/verify.sh b/verify.sh index 6fcaaee..015e3f5 100755 --- a/verify.sh +++ b/verify.sh @@ -24,7 +24,7 @@ fi wait_for_ssh() { local guest_ip="$1" - local deadline=$((SECONDS + 60)) + local deadline="$2" while ((SECONDS < deadline)); do if ssh -i "$SSH_KEY" -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ @@ -37,6 +37,62 @@ wait_for_ssh() { return 1 } +refresh_vm_metadata() { + if ! VM_JSON="$(./banger vm show "$VM_NAME" 2>/dev/null)"; then + return 1 + fi + TAP="$(printf '%s\n' "$VM_JSON" | jq -r '.runtime.tap_device // empty')" + VM_DIR="$(printf '%s\n' "$VM_JSON" | jq -r '.runtime.vm_dir // empty')" + GUEST_IP="$(printf '%s\n' "$VM_JSON" | jq -r '.runtime.guest_ip // empty')" + API_SOCK="$(printf '%s\n' "$VM_JSON" | jq -r '.runtime.api_sock_path // empty')" + PID="$(printf '%s\n' "$VM_JSON" | jq -r '.runtime.pid // 0')" + VM_STATE="$(printf '%s\n' "$VM_JSON" | jq -r '.state // empty')" + LAST_ERROR="$(printf '%s\n' "$VM_JSON" | jq -r '.runtime.last_error // empty')" + return 0 +} + +wait_for_vm_ready() { + local deadline="$1" + + while ((SECONDS < deadline)); do + if ! refresh_vm_metadata; then + sleep 1 + continue + fi + if [[ "$VM_STATE" == "error" || -n "$LAST_ERROR" ]]; then + return 2 + fi + if [[ "$VM_STATE" == "running" && -n "$GUEST_IP" && -n "$TAP" && -n "$VM_DIR" && -n "$API_SOCK" && "${PID:-0}" -gt 0 ]]; then + if [[ -S "$API_SOCK" ]] && ip link show "$TAP" >/dev/null 2>&1; then + return 0 + fi + fi + sleep 1 + done + + return 1 +} + +dump_diagnostics() { + log "diagnostics for $VM_NAME" + ./banger vm show "$VM_NAME" || true + log "recent firecracker log" + ./banger vm logs "$VM_NAME" 2>/dev/null | tail -n 200 || true + if [[ -n "${TAP:-}" ]]; then + log "tap state for $TAP" + ip link show "$TAP" || true + fi + if [[ -n "${API_SOCK:-}" ]]; then + log "api socket $API_SOCK" + ls -l "$API_SOCK" 2>/dev/null || true + fi + if (( NAT_ENABLED )) && [[ -n "${UPLINK:-}" && -n "${GUEST_IP:-}" && -n "${TAP:-}" ]]; then + log "nat rules for ${GUEST_IP} via ${UPLINK}" + sudo iptables -t nat -S POSTROUTING | grep "${GUEST_IP}/32" || true + sudo iptables -S FORWARD | grep "$TAP" || true + fi +} + usage() { cat <<'EOF' Usage: ./verify.sh [--nat] @@ -47,6 +103,7 @@ EOF } NAT_ENABLED=0 +BOOT_TIMEOUT_SECS="${VERIFY_BOOT_TIMEOUT_SECS:-90}" if [[ "${1:-}" == "--nat" ]]; then NAT_ENABLED=1 shift @@ -62,6 +119,10 @@ TAP="" VM_DIR="" GUEST_IP="" UPLINK="" +API_SOCK="" +PID="0" +VM_STATE="" +LAST_ERROR="" cleanup() { if [[ -n "${VM_NAME:-}" ]]; then @@ -78,21 +139,15 @@ if (( NAT_ENABLED )); then fi "${CREATE_ARGS[@]}" >/dev/null -VM_JSON="$(./banger vm show "$VM_NAME")" -name="$(printf '%s\n' "$VM_JSON" | jq -r '.name // empty')" -guest_ip="$(printf '%s\n' "$VM_JSON" | jq -r '.runtime.guest_ip // empty')" -tap="$(printf '%s\n' "$VM_JSON" | jq -r '.runtime.tap_device // empty')" -vm_dir="$(printf '%s\n' "$VM_JSON" | jq -r '.runtime.vm_dir // empty')" +BOOT_DEADLINE=$((SECONDS + BOOT_TIMEOUT_SECS)) -if [[ -z "$name" || -z "$guest_ip" || -z "$tap" || -z "$vm_dir" ]]; then - log "missing VM metadata from banger vm show" +log "waiting for VM runtime readiness" +if ! wait_for_vm_ready "$BOOT_DEADLINE"; then + log "vm did not become ready before timeout" + dump_diagnostics exit 1 fi -TAP="$tap" -VM_DIR="$vm_dir" -GUEST_IP="$guest_ip" - if (( NAT_ENABLED )); then UPLINK="$(ip route show default 2>/dev/null | awk '/default/ {print $5; exit}')" if [[ -z "$UPLINK" ]]; then @@ -106,17 +161,18 @@ if (( NAT_ENABLED )); then fi log "asserting VM is reachable via SSH" -if ! wait_for_ssh "$guest_ip"; then - log "ssh did not become ready for ${guest_ip}" +if ! wait_for_ssh "$GUEST_IP" "$BOOT_DEADLINE"; then + log "ssh did not become ready for ${GUEST_IP}" + dump_diagnostics exit 1 fi ssh -i "$SSH_KEY" -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ - "root@${guest_ip}" "uname -a" >/dev/null + "root@${GUEST_IP}" "uname -a" >/dev/null if (( NAT_ENABLED )); then log "asserting VM has outbound network access" ssh -i "$SSH_KEY" -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ - "root@${guest_ip}" "curl -fsS https://example.com >/dev/null" >/dev/null + "root@${GUEST_IP}" "curl -fsS https://example.com >/dev/null" >/dev/null fi log "cleaning up VM"