Rollback partial dm snapshot startup

Prevent partial VM startup failures from leaking loop devices and dm state on the host.

Move root snapshot setup into a rollback-safe helper that records loop and mapper handles incrementally, tears them down in reverse order on failure, and reuses the same dm/loop cleanup path during normal runtime teardown. Also switch the daemon runner field to a small command-runner interface so the snapshot path can be tested with injected failures.

Add failure-injection coverage for losetup, blockdev, dmsetup, partial teardown, and joined rollback errors. Validated with go test ./... and make build.
This commit is contained in:
Thales Maciel 2026-03-16 14:06:17 -03:00
parent 171009b30b
commit 375900cf65
No known key found for this signature in database
GPG key ID: 33112E6833C34679
5 changed files with 401 additions and 47 deletions

View file

@ -26,7 +26,7 @@ type Daemon struct {
layout paths.Layout layout paths.Layout
config model.DaemonConfig config model.DaemonConfig
store *store.Store store *store.Store
runner system.Runner runner system.CommandRunner
mu sync.Mutex mu sync.Mutex
closing chan struct{} closing chan struct{}
once sync.Once once sync.Once

View file

@ -0,0 +1,79 @@
package daemon
import (
"context"
"errors"
"fmt"
"strings"
)
type dmSnapshotHandles struct {
BaseLoop string
COWLoop string
DMName string
DMDev string
}
func (d *Daemon) createDMSnapshot(ctx context.Context, rootfsPath, cowPath, dmName string) (handles dmSnapshotHandles, err error) {
defer func() {
if err == nil {
return
}
if cleanupErr := d.cleanupDMSnapshot(context.Background(), handles); cleanupErr != nil {
err = errors.Join(err, cleanupErr)
}
}()
baseBytes, err := d.runner.RunSudo(ctx, "losetup", "-f", "--show", "--read-only", rootfsPath)
if err != nil {
return handles, err
}
handles.BaseLoop = strings.TrimSpace(string(baseBytes))
cowBytes, err := d.runner.RunSudo(ctx, "losetup", "-f", "--show", cowPath)
if err != nil {
return handles, err
}
handles.COWLoop = strings.TrimSpace(string(cowBytes))
sectorsBytes, err := d.runner.RunSudo(ctx, "blockdev", "--getsz", handles.BaseLoop)
if err != nil {
return handles, err
}
sectors := strings.TrimSpace(string(sectorsBytes))
if _, err := d.runner.RunSudo(ctx, "dmsetup", "create", dmName, "--table", fmt.Sprintf("0 %s snapshot %s %s P 8", sectors, handles.BaseLoop, handles.COWLoop)); err != nil {
return handles, err
}
handles.DMName = dmName
handles.DMDev = "/dev/mapper/" + dmName
return handles, nil
}
func (d *Daemon) cleanupDMSnapshot(ctx context.Context, handles dmSnapshotHandles) error {
var cleanupErr error
switch {
case handles.DMName != "":
if _, err := d.runner.RunSudo(ctx, "dmsetup", "remove", handles.DMName); err != nil {
cleanupErr = errors.Join(cleanupErr, err)
}
case handles.DMDev != "":
if _, err := d.runner.RunSudo(ctx, "dmsetup", "remove", handles.DMDev); err != nil {
cleanupErr = errors.Join(cleanupErr, err)
}
}
if handles.COWLoop != "" {
if _, err := d.runner.RunSudo(ctx, "losetup", "-d", handles.COWLoop); err != nil {
cleanupErr = errors.Join(cleanupErr, err)
}
}
if handles.BaseLoop != "" {
if _, err := d.runner.RunSudo(ctx, "losetup", "-d", handles.BaseLoop); err != nil {
cleanupErr = errors.Join(cleanupErr, err)
}
}
return cleanupErr
}

View file

@ -0,0 +1,294 @@
package daemon
import (
"context"
"errors"
"slices"
"testing"
)
type runnerCall struct {
sudo bool
name string
args []string
}
type runnerStep struct {
call runnerCall
out []byte
err error
}
type scriptedRunner struct {
t *testing.T
steps []runnerStep
calls []runnerCall
}
func (r *scriptedRunner) Run(ctx context.Context, name string, args ...string) ([]byte, error) {
return r.next(runnerCall{name: name, args: append([]string(nil), args...)})
}
func (r *scriptedRunner) RunSudo(ctx context.Context, args ...string) ([]byte, error) {
return r.next(runnerCall{sudo: true, args: append([]string(nil), args...)})
}
func (r *scriptedRunner) next(call runnerCall) ([]byte, error) {
r.t.Helper()
r.calls = append(r.calls, call)
if len(r.steps) == 0 {
r.t.Fatalf("unexpected call: %+v", call)
}
step := r.steps[0]
r.steps = r.steps[1:]
if step.call.sudo != call.sudo || step.call.name != call.name || !slices.Equal(step.call.args, call.args) {
r.t.Fatalf("call mismatch:\n got: %+v\n want: %+v", call, step.call)
}
return step.out, step.err
}
func (r *scriptedRunner) assertExhausted() {
r.t.Helper()
if len(r.steps) != 0 {
r.t.Fatalf("unconsumed steps: %+v", r.steps)
}
}
func sudoStep(out string, err error, args ...string) runnerStep {
return runnerStep{
call: runnerCall{sudo: true, args: append([]string(nil), args...)},
out: []byte(out),
err: err,
}
}
func TestCreateDMSnapshotFailsWithoutRollbackWhenBaseLoopSetupFails(t *testing.T) {
t.Parallel()
attachErr := errors.New("attach base loop")
runner := &scriptedRunner{
t: t,
steps: []runnerStep{
sudoStep("", attachErr, "losetup", "-f", "--show", "--read-only", "/rootfs.ext4"),
},
}
d := &Daemon{runner: runner}
_, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test")
if !errors.Is(err, attachErr) {
t.Fatalf("error = %v, want %v", err, attachErr)
}
runner.assertExhausted()
if len(runner.calls) != 1 {
t.Fatalf("call count = %d, want 1", len(runner.calls))
}
}
func TestCreateDMSnapshotRollsBackBaseLoopWhenCowLoopSetupFails(t *testing.T) {
t.Parallel()
attachErr := errors.New("attach cow loop")
runner := &scriptedRunner{
t: t,
steps: []runnerStep{
sudoStep("/dev/loop10\n", nil, "losetup", "-f", "--show", "--read-only", "/rootfs.ext4"),
sudoStep("", attachErr, "losetup", "-f", "--show", "/cow.ext4"),
sudoStep("", nil, "losetup", "-d", "/dev/loop10"),
},
}
d := &Daemon{runner: runner}
_, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test")
if !errors.Is(err, attachErr) {
t.Fatalf("error = %v, want %v", err, attachErr)
}
runner.assertExhausted()
}
func TestCreateDMSnapshotRollsBackBothLoopsWhenBlockdevFails(t *testing.T) {
t.Parallel()
blockdevErr := errors.New("read sectors")
runner := &scriptedRunner{
t: t,
steps: []runnerStep{
sudoStep("/dev/loop10\n", nil, "losetup", "-f", "--show", "--read-only", "/rootfs.ext4"),
sudoStep("/dev/loop11\n", nil, "losetup", "-f", "--show", "/cow.ext4"),
sudoStep("", blockdevErr, "blockdev", "--getsz", "/dev/loop10"),
sudoStep("", nil, "losetup", "-d", "/dev/loop11"),
sudoStep("", nil, "losetup", "-d", "/dev/loop10"),
},
}
d := &Daemon{runner: runner}
_, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test")
if !errors.Is(err, blockdevErr) {
t.Fatalf("error = %v, want %v", err, blockdevErr)
}
runner.assertExhausted()
}
func TestCreateDMSnapshotRollsBackLoopsWhenDMSetupFails(t *testing.T) {
t.Parallel()
dmErr := errors.New("create dm snapshot")
runner := &scriptedRunner{
t: t,
steps: []runnerStep{
sudoStep("/dev/loop10\n", nil, "losetup", "-f", "--show", "--read-only", "/rootfs.ext4"),
sudoStep("/dev/loop11\n", nil, "losetup", "-f", "--show", "/cow.ext4"),
sudoStep("12345\n", nil, "blockdev", "--getsz", "/dev/loop10"),
sudoStep("", dmErr, "dmsetup", "create", "fc-rootfs-test", "--table", "0 12345 snapshot /dev/loop10 /dev/loop11 P 8"),
sudoStep("", nil, "losetup", "-d", "/dev/loop11"),
sudoStep("", nil, "losetup", "-d", "/dev/loop10"),
},
}
d := &Daemon{runner: runner}
_, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test")
if !errors.Is(err, dmErr) {
t.Fatalf("error = %v, want %v", err, dmErr)
}
runner.assertExhausted()
for _, call := range runner.calls {
if call.sudo && len(call.args) >= 2 && call.args[0] == "dmsetup" && call.args[1] == "remove" {
t.Fatalf("unexpected dmsetup remove call: %+v", call)
}
}
}
func TestCreateDMSnapshotJoinsRollbackErrors(t *testing.T) {
t.Parallel()
blockdevErr := errors.New("read sectors")
detachErr := errors.New("detach cow loop")
runner := &scriptedRunner{
t: t,
steps: []runnerStep{
sudoStep("/dev/loop10\n", nil, "losetup", "-f", "--show", "--read-only", "/rootfs.ext4"),
sudoStep("/dev/loop11\n", nil, "losetup", "-f", "--show", "/cow.ext4"),
sudoStep("", blockdevErr, "blockdev", "--getsz", "/dev/loop10"),
sudoStep("", detachErr, "losetup", "-d", "/dev/loop11"),
sudoStep("", nil, "losetup", "-d", "/dev/loop10"),
},
}
d := &Daemon{runner: runner}
_, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test")
if err == nil {
t.Fatal("expected createDMSnapshot to return an error")
}
if !errors.Is(err, blockdevErr) || !errors.Is(err, detachErr) {
t.Fatalf("error = %v, want joined blockdev and rollback errors", err)
}
runner.assertExhausted()
}
func TestCreateDMSnapshotReturnsHandlesOnSuccess(t *testing.T) {
t.Parallel()
runner := &scriptedRunner{
t: t,
steps: []runnerStep{
sudoStep("/dev/loop10\n", nil, "losetup", "-f", "--show", "--read-only", "/rootfs.ext4"),
sudoStep("/dev/loop11\n", nil, "losetup", "-f", "--show", "/cow.ext4"),
sudoStep("12345\n", nil, "blockdev", "--getsz", "/dev/loop10"),
sudoStep("", nil, "dmsetup", "create", "fc-rootfs-test", "--table", "0 12345 snapshot /dev/loop10 /dev/loop11 P 8"),
},
}
d := &Daemon{runner: runner}
handles, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test")
if err != nil {
t.Fatalf("createDMSnapshot returned error: %v", err)
}
want := dmSnapshotHandles{
BaseLoop: "/dev/loop10",
COWLoop: "/dev/loop11",
DMName: "fc-rootfs-test",
DMDev: "/dev/mapper/fc-rootfs-test",
}
if handles != want {
t.Fatalf("handles = %+v, want %+v", handles, want)
}
runner.assertExhausted()
}
func TestCleanupDMSnapshotRemovesResourcesInReverseOrder(t *testing.T) {
t.Parallel()
runner := &scriptedRunner{
t: t,
steps: []runnerStep{
sudoStep("", nil, "dmsetup", "remove", "fc-rootfs-test"),
sudoStep("", nil, "losetup", "-d", "/dev/loop11"),
sudoStep("", nil, "losetup", "-d", "/dev/loop10"),
},
}
d := &Daemon{runner: runner}
err := d.cleanupDMSnapshot(context.Background(), dmSnapshotHandles{
BaseLoop: "/dev/loop10",
COWLoop: "/dev/loop11",
DMName: "fc-rootfs-test",
DMDev: "/dev/mapper/fc-rootfs-test",
})
if err != nil {
t.Fatalf("cleanupDMSnapshot returned error: %v", err)
}
runner.assertExhausted()
}
func TestCleanupDMSnapshotUsesPartialHandles(t *testing.T) {
t.Parallel()
runner := &scriptedRunner{
t: t,
steps: []runnerStep{
sudoStep("", nil, "dmsetup", "remove", "/dev/mapper/fc-rootfs-test"),
sudoStep("", nil, "losetup", "-d", "/dev/loop10"),
},
}
d := &Daemon{runner: runner}
err := d.cleanupDMSnapshot(context.Background(), dmSnapshotHandles{
BaseLoop: "/dev/loop10",
DMDev: "/dev/mapper/fc-rootfs-test",
})
if err != nil {
t.Fatalf("cleanupDMSnapshot returned error: %v", err)
}
runner.assertExhausted()
}
func TestCleanupDMSnapshotJoinsTeardownErrors(t *testing.T) {
t.Parallel()
dmErr := errors.New("remove dm")
cowErr := errors.New("detach cow")
baseErr := errors.New("detach base")
runner := &scriptedRunner{
t: t,
steps: []runnerStep{
sudoStep("", dmErr, "dmsetup", "remove", "fc-rootfs-test"),
sudoStep("", cowErr, "losetup", "-d", "/dev/loop11"),
sudoStep("", baseErr, "losetup", "-d", "/dev/loop10"),
},
}
d := &Daemon{runner: runner}
err := d.cleanupDMSnapshot(context.Background(), dmSnapshotHandles{
BaseLoop: "/dev/loop10",
COWLoop: "/dev/loop11",
DMName: "fc-rootfs-test",
})
if err == nil {
t.Fatal("expected cleanupDMSnapshot to return an error")
}
for _, expected := range []error{dmErr, cowErr, baseErr} {
if !errors.Is(err, expected) {
t.Fatalf("cleanup error %q not joined into %v", expected, err)
}
}
runner.assertExhausted()
}

View file

@ -153,14 +153,14 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
return model.VMRecord{}, err return model.VMRecord{}, err
} }
baseLoop, cowLoop, dmDev, err := d.createDMSnapshot(ctx, image.RootfsPath, vm.Runtime.SystemOverlay, dmName) handles, err := d.createDMSnapshot(ctx, image.RootfsPath, vm.Runtime.SystemOverlay, dmName)
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
vm.Runtime.BaseLoop = baseLoop vm.Runtime.BaseLoop = handles.BaseLoop
vm.Runtime.COWLoop = cowLoop vm.Runtime.COWLoop = handles.COWLoop
vm.Runtime.DMName = dmName vm.Runtime.DMName = handles.DMName
vm.Runtime.DMDev = dmDev vm.Runtime.DMDev = handles.DMDev
vm.Runtime.APISockPath = apiSock vm.Runtime.APISockPath = apiSock
vm.Runtime.TapDevice = tap vm.Runtime.TapDevice = tap
vm.Runtime.State = model.VMStateRunning vm.Runtime.State = model.VMStateRunning
@ -171,7 +171,9 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
vm.State = model.VMStateError vm.State = model.VMStateError
vm.Runtime.State = model.VMStateError vm.Runtime.State = model.VMStateError
vm.Runtime.LastError = err.Error() vm.Runtime.LastError = err.Error()
_ = d.cleanupRuntime(context.Background(), vm, true) if cleanupErr := d.cleanupRuntime(context.Background(), vm, true); cleanupErr != nil {
err = errors.Join(err, cleanupErr)
}
clearRuntimeHandles(&vm) clearRuntimeHandles(&vm)
_ = d.store.UpsertVM(context.Background(), vm) _ = d.store.UpsertVM(context.Background(), vm)
return model.VMRecord{}, err return model.VMRecord{}, err
@ -273,6 +275,7 @@ func (d *Daemon) StopVM(ctx context.Context, idOrName string) (model.VMRecord, e
return vm, nil return vm, nil
} }
func (d *Daemon) RestartVM(ctx context.Context, idOrName string) (model.VMRecord, error) { func (d *Daemon) RestartVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
vm, err := d.StopVM(ctx, idOrName) vm, err := d.StopVM(ctx, idOrName)
if err != nil { if err != nil {
@ -506,28 +509,6 @@ func (d *Daemon) ensureWorkDisk(ctx context.Context, vm *model.VMRecord) error {
return nil return nil
} }
func (d *Daemon) createDMSnapshot(ctx context.Context, rootfsPath, cowPath, dmName string) (baseLoop, cowLoop, dmDev string, err error) {
baseBytes, err := d.runner.RunSudo(ctx, "losetup", "-f", "--show", "--read-only", rootfsPath)
if err != nil {
return "", "", "", err
}
baseLoop = strings.TrimSpace(string(baseBytes))
cowBytes, err := d.runner.RunSudo(ctx, "losetup", "-f", "--show", cowPath)
if err != nil {
return "", "", "", err
}
cowLoop = strings.TrimSpace(string(cowBytes))
sectorsBytes, err := d.runner.RunSudo(ctx, "blockdev", "--getsz", baseLoop)
if err != nil {
return "", "", "", err
}
sectors := strings.TrimSpace(string(sectorsBytes))
if _, err := d.runner.RunSudo(ctx, "dmsetup", "create", dmName, "--table", fmt.Sprintf("0 %s snapshot %s %s P 8", sectors, baseLoop, cowLoop)); err != nil {
return "", "", "", err
}
return baseLoop, cowLoop, "/dev/mapper/" + dmName, nil
}
func (d *Daemon) ensureBridge(ctx context.Context) error { func (d *Daemon) ensureBridge(ctx context.Context) error {
if _, err := d.runner.Run(ctx, "ip", "link", "show", d.config.BridgeName); err == nil { if _, err := d.runner.Run(ctx, "ip", "link", "show", d.config.BridgeName); err == nil {
_, err = d.runner.RunSudo(ctx, "ip", "link", "set", d.config.BridgeName, "up") _, err = d.runner.RunSudo(ctx, "ip", "link", "set", d.config.BridgeName, "up")
@ -638,25 +619,20 @@ func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserve
if vm.Runtime.APISockPath != "" { if vm.Runtime.APISockPath != "" {
_ = os.Remove(vm.Runtime.APISockPath) _ = os.Remove(vm.Runtime.APISockPath)
} }
if vm.Runtime.DMName != "" { snapshotErr := d.cleanupDMSnapshot(ctx, dmSnapshotHandles{
_, _ = d.runner.RunSudo(ctx, "dmsetup", "remove", vm.Runtime.DMName) BaseLoop: vm.Runtime.BaseLoop,
} else if vm.Runtime.DMDev != "" { COWLoop: vm.Runtime.COWLoop,
_, _ = d.runner.RunSudo(ctx, "dmsetup", "remove", vm.Runtime.DMDev) DMName: vm.Runtime.DMName,
} DMDev: vm.Runtime.DMDev,
if vm.Runtime.COWLoop != "" { })
_, _ = d.runner.RunSudo(ctx, "losetup", "-d", vm.Runtime.COWLoop)
}
if vm.Runtime.BaseLoop != "" {
_, _ = d.runner.RunSudo(ctx, "losetup", "-d", vm.Runtime.BaseLoop)
}
if vm.Spec.NATEnabled { if vm.Spec.NATEnabled {
_ = d.ensureNAT(ctx, vm, false) _ = d.ensureNAT(ctx, vm, false)
} }
_ = d.removeDNS(ctx, vm.Runtime.DNSName) _ = d.removeDNS(ctx, vm.Runtime.DNSName)
if !preserveDisks && vm.Runtime.VMDir != "" { if !preserveDisks && vm.Runtime.VMDir != "" {
return os.RemoveAll(vm.Runtime.VMDir) return errors.Join(snapshotErr, os.RemoveAll(vm.Runtime.VMDir))
} }
return nil return snapshotErr
} }
func clearRuntimeHandles(vm *model.VMRecord) { func clearRuntimeHandles(vm *model.VMRecord) {

View file

@ -20,6 +20,11 @@ import (
type Runner struct{} type Runner struct{}
type CommandRunner interface {
Run(ctx context.Context, name string, args ...string) ([]byte, error)
RunSudo(ctx context.Context, args ...string) ([]byte, error)
}
func NewRunner() Runner { func NewRunner() Runner {
return Runner{} return Runner{}
} }
@ -162,7 +167,7 @@ func lastJSONLine(data []byte) []byte {
return last return last
} }
func CopyDirContents(ctx context.Context, runner Runner, sourceDir, targetDir string, useSudo bool) error { func CopyDirContents(ctx context.Context, runner CommandRunner, sourceDir, targetDir string, useSudo bool) error {
args := []string{"-a", filepath.Join(sourceDir, "."), targetDir + "/"} args := []string{"-a", filepath.Join(sourceDir, "."), targetDir + "/"}
var err error var err error
if useSudo { if useSudo {
@ -173,7 +178,7 @@ func CopyDirContents(ctx context.Context, runner Runner, sourceDir, targetDir st
return err return err
} }
func ResizeExt4Image(ctx context.Context, runner Runner, path string, bytes int64) error { func ResizeExt4Image(ctx context.Context, runner CommandRunner, path string, bytes int64) error {
if _, err := runner.Run(ctx, "truncate", "-s", strconv.FormatInt(bytes, 10), path); err != nil { if _, err := runner.Run(ctx, "truncate", "-s", strconv.FormatInt(bytes, 10), path); err != nil {
return err return err
} }
@ -184,7 +189,7 @@ func ResizeExt4Image(ctx context.Context, runner Runner, path string, bytes int6
return err return err
} }
func ReadDebugFSText(ctx context.Context, runner Runner, imagePath, guestPath string) (string, error) { func ReadDebugFSText(ctx context.Context, runner CommandRunner, imagePath, guestPath string) (string, error) {
out, err := runner.Run(ctx, "debugfs", "-R", "cat "+guestPath, imagePath) out, err := runner.Run(ctx, "debugfs", "-R", "cat "+guestPath, imagePath)
if err != nil { if err != nil {
return "", err return "", err
@ -192,7 +197,7 @@ func ReadDebugFSText(ctx context.Context, runner Runner, imagePath, guestPath st
return string(out), nil return string(out), nil
} }
func WriteExt4File(ctx context.Context, runner Runner, imagePath, guestPath string, data []byte) error { func WriteExt4File(ctx context.Context, runner CommandRunner, imagePath, guestPath string, data []byte) error {
tmp, err := os.CreateTemp("", "banger-ext4-*") tmp, err := os.CreateTemp("", "banger-ext4-*")
if err != nil { if err != nil {
return err return err
@ -210,7 +215,7 @@ func WriteExt4File(ctx context.Context, runner Runner, imagePath, guestPath stri
return err return err
} }
func MountTempDir(ctx context.Context, runner Runner, source string, readOnly bool) (string, func() error, error) { func MountTempDir(ctx context.Context, runner CommandRunner, source string, readOnly bool) (string, func() error, error) {
mountDir, err := os.MkdirTemp("", "banger-mnt-*") mountDir, err := os.MkdirTemp("", "banger-mnt-*")
if err != nil { if err != nil {
return "", nil, err return "", nil, err