diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 2a0eba3..489ad86 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -26,7 +26,7 @@ type Daemon struct { layout paths.Layout config model.DaemonConfig store *store.Store - runner system.Runner + runner system.CommandRunner mu sync.Mutex closing chan struct{} once sync.Once diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go new file mode 100644 index 0000000..1a3fda2 --- /dev/null +++ b/internal/daemon/snapshot.go @@ -0,0 +1,79 @@ +package daemon + +import ( + "context" + "errors" + "fmt" + "strings" +) + +type dmSnapshotHandles struct { + BaseLoop string + COWLoop string + DMName string + DMDev string +} + +func (d *Daemon) createDMSnapshot(ctx context.Context, rootfsPath, cowPath, dmName string) (handles dmSnapshotHandles, err error) { + defer func() { + if err == nil { + return + } + if cleanupErr := d.cleanupDMSnapshot(context.Background(), handles); cleanupErr != nil { + err = errors.Join(err, cleanupErr) + } + }() + + baseBytes, err := d.runner.RunSudo(ctx, "losetup", "-f", "--show", "--read-only", rootfsPath) + if err != nil { + return handles, err + } + handles.BaseLoop = strings.TrimSpace(string(baseBytes)) + + cowBytes, err := d.runner.RunSudo(ctx, "losetup", "-f", "--show", cowPath) + if err != nil { + return handles, err + } + handles.COWLoop = strings.TrimSpace(string(cowBytes)) + + sectorsBytes, err := d.runner.RunSudo(ctx, "blockdev", "--getsz", handles.BaseLoop) + if err != nil { + return handles, err + } + sectors := strings.TrimSpace(string(sectorsBytes)) + + if _, err := d.runner.RunSudo(ctx, "dmsetup", "create", dmName, "--table", fmt.Sprintf("0 %s snapshot %s %s P 8", sectors, handles.BaseLoop, handles.COWLoop)); err != nil { + return handles, err + } + handles.DMName = dmName + handles.DMDev = "/dev/mapper/" + dmName + return handles, nil +} + +func (d *Daemon) cleanupDMSnapshot(ctx context.Context, handles dmSnapshotHandles) error { + var cleanupErr error + + switch { + case handles.DMName != "": + if _, err := d.runner.RunSudo(ctx, "dmsetup", "remove", handles.DMName); err != nil { + cleanupErr = errors.Join(cleanupErr, err) + } + case handles.DMDev != "": + if _, err := d.runner.RunSudo(ctx, "dmsetup", "remove", handles.DMDev); err != nil { + cleanupErr = errors.Join(cleanupErr, err) + } + } + + if handles.COWLoop != "" { + if _, err := d.runner.RunSudo(ctx, "losetup", "-d", handles.COWLoop); err != nil { + cleanupErr = errors.Join(cleanupErr, err) + } + } + if handles.BaseLoop != "" { + if _, err := d.runner.RunSudo(ctx, "losetup", "-d", handles.BaseLoop); err != nil { + cleanupErr = errors.Join(cleanupErr, err) + } + } + + return cleanupErr +} diff --git a/internal/daemon/snapshot_test.go b/internal/daemon/snapshot_test.go new file mode 100644 index 0000000..2a24597 --- /dev/null +++ b/internal/daemon/snapshot_test.go @@ -0,0 +1,294 @@ +package daemon + +import ( + "context" + "errors" + "slices" + "testing" +) + +type runnerCall struct { + sudo bool + name string + args []string +} + +type runnerStep struct { + call runnerCall + out []byte + err error +} + +type scriptedRunner struct { + t *testing.T + steps []runnerStep + calls []runnerCall +} + +func (r *scriptedRunner) Run(ctx context.Context, name string, args ...string) ([]byte, error) { + return r.next(runnerCall{name: name, args: append([]string(nil), args...)}) +} + +func (r *scriptedRunner) RunSudo(ctx context.Context, args ...string) ([]byte, error) { + return r.next(runnerCall{sudo: true, args: append([]string(nil), args...)}) +} + +func (r *scriptedRunner) next(call runnerCall) ([]byte, error) { + r.t.Helper() + r.calls = append(r.calls, call) + if len(r.steps) == 0 { + r.t.Fatalf("unexpected call: %+v", call) + } + step := r.steps[0] + r.steps = r.steps[1:] + if step.call.sudo != call.sudo || step.call.name != call.name || !slices.Equal(step.call.args, call.args) { + r.t.Fatalf("call mismatch:\n got: %+v\n want: %+v", call, step.call) + } + return step.out, step.err +} + +func (r *scriptedRunner) assertExhausted() { + r.t.Helper() + if len(r.steps) != 0 { + r.t.Fatalf("unconsumed steps: %+v", r.steps) + } +} + +func sudoStep(out string, err error, args ...string) runnerStep { + return runnerStep{ + call: runnerCall{sudo: true, args: append([]string(nil), args...)}, + out: []byte(out), + err: err, + } +} + +func TestCreateDMSnapshotFailsWithoutRollbackWhenBaseLoopSetupFails(t *testing.T) { + t.Parallel() + + attachErr := errors.New("attach base loop") + runner := &scriptedRunner{ + t: t, + steps: []runnerStep{ + sudoStep("", attachErr, "losetup", "-f", "--show", "--read-only", "/rootfs.ext4"), + }, + } + d := &Daemon{runner: runner} + + _, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") + if !errors.Is(err, attachErr) { + t.Fatalf("error = %v, want %v", err, attachErr) + } + runner.assertExhausted() + if len(runner.calls) != 1 { + t.Fatalf("call count = %d, want 1", len(runner.calls)) + } +} + +func TestCreateDMSnapshotRollsBackBaseLoopWhenCowLoopSetupFails(t *testing.T) { + t.Parallel() + + attachErr := errors.New("attach cow loop") + runner := &scriptedRunner{ + t: t, + steps: []runnerStep{ + sudoStep("/dev/loop10\n", nil, "losetup", "-f", "--show", "--read-only", "/rootfs.ext4"), + sudoStep("", attachErr, "losetup", "-f", "--show", "/cow.ext4"), + sudoStep("", nil, "losetup", "-d", "/dev/loop10"), + }, + } + d := &Daemon{runner: runner} + + _, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") + if !errors.Is(err, attachErr) { + t.Fatalf("error = %v, want %v", err, attachErr) + } + runner.assertExhausted() +} + +func TestCreateDMSnapshotRollsBackBothLoopsWhenBlockdevFails(t *testing.T) { + t.Parallel() + + blockdevErr := errors.New("read sectors") + runner := &scriptedRunner{ + t: t, + steps: []runnerStep{ + sudoStep("/dev/loop10\n", nil, "losetup", "-f", "--show", "--read-only", "/rootfs.ext4"), + sudoStep("/dev/loop11\n", nil, "losetup", "-f", "--show", "/cow.ext4"), + sudoStep("", blockdevErr, "blockdev", "--getsz", "/dev/loop10"), + sudoStep("", nil, "losetup", "-d", "/dev/loop11"), + sudoStep("", nil, "losetup", "-d", "/dev/loop10"), + }, + } + d := &Daemon{runner: runner} + + _, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") + if !errors.Is(err, blockdevErr) { + t.Fatalf("error = %v, want %v", err, blockdevErr) + } + runner.assertExhausted() +} + +func TestCreateDMSnapshotRollsBackLoopsWhenDMSetupFails(t *testing.T) { + t.Parallel() + + dmErr := errors.New("create dm snapshot") + runner := &scriptedRunner{ + t: t, + steps: []runnerStep{ + sudoStep("/dev/loop10\n", nil, "losetup", "-f", "--show", "--read-only", "/rootfs.ext4"), + sudoStep("/dev/loop11\n", nil, "losetup", "-f", "--show", "/cow.ext4"), + sudoStep("12345\n", nil, "blockdev", "--getsz", "/dev/loop10"), + sudoStep("", dmErr, "dmsetup", "create", "fc-rootfs-test", "--table", "0 12345 snapshot /dev/loop10 /dev/loop11 P 8"), + sudoStep("", nil, "losetup", "-d", "/dev/loop11"), + sudoStep("", nil, "losetup", "-d", "/dev/loop10"), + }, + } + d := &Daemon{runner: runner} + + _, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") + if !errors.Is(err, dmErr) { + t.Fatalf("error = %v, want %v", err, dmErr) + } + runner.assertExhausted() + for _, call := range runner.calls { + if call.sudo && len(call.args) >= 2 && call.args[0] == "dmsetup" && call.args[1] == "remove" { + t.Fatalf("unexpected dmsetup remove call: %+v", call) + } + } +} + +func TestCreateDMSnapshotJoinsRollbackErrors(t *testing.T) { + t.Parallel() + + blockdevErr := errors.New("read sectors") + detachErr := errors.New("detach cow loop") + runner := &scriptedRunner{ + t: t, + steps: []runnerStep{ + sudoStep("/dev/loop10\n", nil, "losetup", "-f", "--show", "--read-only", "/rootfs.ext4"), + sudoStep("/dev/loop11\n", nil, "losetup", "-f", "--show", "/cow.ext4"), + sudoStep("", blockdevErr, "blockdev", "--getsz", "/dev/loop10"), + sudoStep("", detachErr, "losetup", "-d", "/dev/loop11"), + sudoStep("", nil, "losetup", "-d", "/dev/loop10"), + }, + } + d := &Daemon{runner: runner} + + _, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") + if err == nil { + t.Fatal("expected createDMSnapshot to return an error") + } + if !errors.Is(err, blockdevErr) || !errors.Is(err, detachErr) { + t.Fatalf("error = %v, want joined blockdev and rollback errors", err) + } + runner.assertExhausted() +} + +func TestCreateDMSnapshotReturnsHandlesOnSuccess(t *testing.T) { + t.Parallel() + + runner := &scriptedRunner{ + t: t, + steps: []runnerStep{ + sudoStep("/dev/loop10\n", nil, "losetup", "-f", "--show", "--read-only", "/rootfs.ext4"), + sudoStep("/dev/loop11\n", nil, "losetup", "-f", "--show", "/cow.ext4"), + sudoStep("12345\n", nil, "blockdev", "--getsz", "/dev/loop10"), + sudoStep("", nil, "dmsetup", "create", "fc-rootfs-test", "--table", "0 12345 snapshot /dev/loop10 /dev/loop11 P 8"), + }, + } + d := &Daemon{runner: runner} + + handles, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") + if err != nil { + t.Fatalf("createDMSnapshot returned error: %v", err) + } + want := dmSnapshotHandles{ + BaseLoop: "/dev/loop10", + COWLoop: "/dev/loop11", + DMName: "fc-rootfs-test", + DMDev: "/dev/mapper/fc-rootfs-test", + } + if handles != want { + t.Fatalf("handles = %+v, want %+v", handles, want) + } + runner.assertExhausted() +} + +func TestCleanupDMSnapshotRemovesResourcesInReverseOrder(t *testing.T) { + t.Parallel() + + runner := &scriptedRunner{ + t: t, + steps: []runnerStep{ + sudoStep("", nil, "dmsetup", "remove", "fc-rootfs-test"), + sudoStep("", nil, "losetup", "-d", "/dev/loop11"), + sudoStep("", nil, "losetup", "-d", "/dev/loop10"), + }, + } + d := &Daemon{runner: runner} + + err := d.cleanupDMSnapshot(context.Background(), dmSnapshotHandles{ + BaseLoop: "/dev/loop10", + COWLoop: "/dev/loop11", + DMName: "fc-rootfs-test", + DMDev: "/dev/mapper/fc-rootfs-test", + }) + if err != nil { + t.Fatalf("cleanupDMSnapshot returned error: %v", err) + } + runner.assertExhausted() +} + +func TestCleanupDMSnapshotUsesPartialHandles(t *testing.T) { + t.Parallel() + + runner := &scriptedRunner{ + t: t, + steps: []runnerStep{ + sudoStep("", nil, "dmsetup", "remove", "/dev/mapper/fc-rootfs-test"), + sudoStep("", nil, "losetup", "-d", "/dev/loop10"), + }, + } + d := &Daemon{runner: runner} + + err := d.cleanupDMSnapshot(context.Background(), dmSnapshotHandles{ + BaseLoop: "/dev/loop10", + DMDev: "/dev/mapper/fc-rootfs-test", + }) + if err != nil { + t.Fatalf("cleanupDMSnapshot returned error: %v", err) + } + runner.assertExhausted() +} + +func TestCleanupDMSnapshotJoinsTeardownErrors(t *testing.T) { + t.Parallel() + + dmErr := errors.New("remove dm") + cowErr := errors.New("detach cow") + baseErr := errors.New("detach base") + runner := &scriptedRunner{ + t: t, + steps: []runnerStep{ + sudoStep("", dmErr, "dmsetup", "remove", "fc-rootfs-test"), + sudoStep("", cowErr, "losetup", "-d", "/dev/loop11"), + sudoStep("", baseErr, "losetup", "-d", "/dev/loop10"), + }, + } + d := &Daemon{runner: runner} + + err := d.cleanupDMSnapshot(context.Background(), dmSnapshotHandles{ + BaseLoop: "/dev/loop10", + COWLoop: "/dev/loop11", + DMName: "fc-rootfs-test", + }) + if err == nil { + t.Fatal("expected cleanupDMSnapshot to return an error") + } + for _, expected := range []error{dmErr, cowErr, baseErr} { + if !errors.Is(err, expected) { + t.Fatalf("cleanup error %q not joined into %v", expected, err) + } + } + runner.assertExhausted() +} diff --git a/internal/daemon/vm.go b/internal/daemon/vm.go index 2bceb9a..5567b1f 100644 --- a/internal/daemon/vm.go +++ b/internal/daemon/vm.go @@ -153,14 +153,14 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod return model.VMRecord{}, err } - baseLoop, cowLoop, dmDev, err := d.createDMSnapshot(ctx, image.RootfsPath, vm.Runtime.SystemOverlay, dmName) + handles, err := d.createDMSnapshot(ctx, image.RootfsPath, vm.Runtime.SystemOverlay, dmName) if err != nil { return model.VMRecord{}, err } - vm.Runtime.BaseLoop = baseLoop - vm.Runtime.COWLoop = cowLoop - vm.Runtime.DMName = dmName - vm.Runtime.DMDev = dmDev + vm.Runtime.BaseLoop = handles.BaseLoop + vm.Runtime.COWLoop = handles.COWLoop + vm.Runtime.DMName = handles.DMName + vm.Runtime.DMDev = handles.DMDev vm.Runtime.APISockPath = apiSock vm.Runtime.TapDevice = tap vm.Runtime.State = model.VMStateRunning @@ -171,7 +171,9 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod vm.State = model.VMStateError vm.Runtime.State = model.VMStateError vm.Runtime.LastError = err.Error() - _ = d.cleanupRuntime(context.Background(), vm, true) + if cleanupErr := d.cleanupRuntime(context.Background(), vm, true); cleanupErr != nil { + err = errors.Join(err, cleanupErr) + } clearRuntimeHandles(&vm) _ = d.store.UpsertVM(context.Background(), vm) return model.VMRecord{}, err @@ -273,6 +275,7 @@ func (d *Daemon) StopVM(ctx context.Context, idOrName string) (model.VMRecord, e return vm, nil } + func (d *Daemon) RestartVM(ctx context.Context, idOrName string) (model.VMRecord, error) { vm, err := d.StopVM(ctx, idOrName) if err != nil { @@ -506,28 +509,6 @@ func (d *Daemon) ensureWorkDisk(ctx context.Context, vm *model.VMRecord) error { return nil } -func (d *Daemon) createDMSnapshot(ctx context.Context, rootfsPath, cowPath, dmName string) (baseLoop, cowLoop, dmDev string, err error) { - baseBytes, err := d.runner.RunSudo(ctx, "losetup", "-f", "--show", "--read-only", rootfsPath) - if err != nil { - return "", "", "", err - } - baseLoop = strings.TrimSpace(string(baseBytes)) - cowBytes, err := d.runner.RunSudo(ctx, "losetup", "-f", "--show", cowPath) - if err != nil { - return "", "", "", err - } - cowLoop = strings.TrimSpace(string(cowBytes)) - sectorsBytes, err := d.runner.RunSudo(ctx, "blockdev", "--getsz", baseLoop) - if err != nil { - return "", "", "", err - } - sectors := strings.TrimSpace(string(sectorsBytes)) - if _, err := d.runner.RunSudo(ctx, "dmsetup", "create", dmName, "--table", fmt.Sprintf("0 %s snapshot %s %s P 8", sectors, baseLoop, cowLoop)); err != nil { - return "", "", "", err - } - return baseLoop, cowLoop, "/dev/mapper/" + dmName, nil -} - func (d *Daemon) ensureBridge(ctx context.Context) error { if _, err := d.runner.Run(ctx, "ip", "link", "show", d.config.BridgeName); err == nil { _, err = d.runner.RunSudo(ctx, "ip", "link", "set", d.config.BridgeName, "up") @@ -638,25 +619,20 @@ func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserve if vm.Runtime.APISockPath != "" { _ = os.Remove(vm.Runtime.APISockPath) } - if vm.Runtime.DMName != "" { - _, _ = d.runner.RunSudo(ctx, "dmsetup", "remove", vm.Runtime.DMName) - } else if vm.Runtime.DMDev != "" { - _, _ = d.runner.RunSudo(ctx, "dmsetup", "remove", vm.Runtime.DMDev) - } - if vm.Runtime.COWLoop != "" { - _, _ = d.runner.RunSudo(ctx, "losetup", "-d", vm.Runtime.COWLoop) - } - if vm.Runtime.BaseLoop != "" { - _, _ = d.runner.RunSudo(ctx, "losetup", "-d", vm.Runtime.BaseLoop) - } + snapshotErr := d.cleanupDMSnapshot(ctx, dmSnapshotHandles{ + BaseLoop: vm.Runtime.BaseLoop, + COWLoop: vm.Runtime.COWLoop, + DMName: vm.Runtime.DMName, + DMDev: vm.Runtime.DMDev, + }) if vm.Spec.NATEnabled { _ = d.ensureNAT(ctx, vm, false) } _ = d.removeDNS(ctx, vm.Runtime.DNSName) if !preserveDisks && vm.Runtime.VMDir != "" { - return os.RemoveAll(vm.Runtime.VMDir) + return errors.Join(snapshotErr, os.RemoveAll(vm.Runtime.VMDir)) } - return nil + return snapshotErr } func clearRuntimeHandles(vm *model.VMRecord) { diff --git a/internal/system/system.go b/internal/system/system.go index 3f153cc..c10cee3 100644 --- a/internal/system/system.go +++ b/internal/system/system.go @@ -20,6 +20,11 @@ import ( type Runner struct{} +type CommandRunner interface { + Run(ctx context.Context, name string, args ...string) ([]byte, error) + RunSudo(ctx context.Context, args ...string) ([]byte, error) +} + func NewRunner() Runner { return Runner{} } @@ -162,7 +167,7 @@ func lastJSONLine(data []byte) []byte { return last } -func CopyDirContents(ctx context.Context, runner Runner, sourceDir, targetDir string, useSudo bool) error { +func CopyDirContents(ctx context.Context, runner CommandRunner, sourceDir, targetDir string, useSudo bool) error { args := []string{"-a", filepath.Join(sourceDir, "."), targetDir + "/"} var err error if useSudo { @@ -173,7 +178,7 @@ func CopyDirContents(ctx context.Context, runner Runner, sourceDir, targetDir st return err } -func ResizeExt4Image(ctx context.Context, runner Runner, path string, bytes int64) error { +func ResizeExt4Image(ctx context.Context, runner CommandRunner, path string, bytes int64) error { if _, err := runner.Run(ctx, "truncate", "-s", strconv.FormatInt(bytes, 10), path); err != nil { return err } @@ -184,7 +189,7 @@ func ResizeExt4Image(ctx context.Context, runner Runner, path string, bytes int6 return err } -func ReadDebugFSText(ctx context.Context, runner Runner, imagePath, guestPath string) (string, error) { +func ReadDebugFSText(ctx context.Context, runner CommandRunner, imagePath, guestPath string) (string, error) { out, err := runner.Run(ctx, "debugfs", "-R", "cat "+guestPath, imagePath) if err != nil { return "", err @@ -192,7 +197,7 @@ func ReadDebugFSText(ctx context.Context, runner Runner, imagePath, guestPath st return string(out), nil } -func WriteExt4File(ctx context.Context, runner Runner, imagePath, guestPath string, data []byte) error { +func WriteExt4File(ctx context.Context, runner CommandRunner, imagePath, guestPath string, data []byte) error { tmp, err := os.CreateTemp("", "banger-ext4-*") if err != nil { return err @@ -210,7 +215,7 @@ func WriteExt4File(ctx context.Context, runner Runner, imagePath, guestPath stri return err } -func MountTempDir(ctx context.Context, runner Runner, source string, readOnly bool) (string, func() error, error) { +func MountTempDir(ctx context.Context, runner CommandRunner, source string, readOnly bool) (string, func() error, error) { mountDir, err := os.MkdirTemp("", "banger-mnt-*") if err != nil { return "", nil, err