diff --git a/internal/daemon/vm_lifecycle.go b/internal/daemon/vm_lifecycle.go index 6d4ac45..abbed75 100644 --- a/internal/daemon/vm_lifecycle.go +++ b/internal/daemon/vm_lifecycle.go @@ -3,7 +3,6 @@ package daemon import ( "context" "errors" - "fmt" "os" "path/filepath" "strconv" @@ -11,8 +10,6 @@ import ( "time" "banger/internal/api" - "banger/internal/firecracker" - "banger/internal/imagepull" "banger/internal/model" "banger/internal/system" ) @@ -43,28 +40,10 @@ func (s *VMService) startVMLocked(ctx context.Context, vm model.VMRecord, image } op.done(vmLogAttrs(vm)...) }() - op.stage("preflight") - vmCreateStage(ctx, "preflight", "checking host prerequisites") - if err := s.validateStartPrereqs(ctx, vm, image); err != nil { - return model.VMRecord{}, err - } - if err := os.MkdirAll(vm.Runtime.VMDir, 0o755); err != nil { - return model.VMRecord{}, err - } - op.stage("cleanup_runtime") - if err := s.cleanupRuntime(ctx, vm, true); err != nil { - return model.VMRecord{}, err - } - s.clearVMHandles(vm) - op.stage("bridge") - if err := s.net.ensureBridge(ctx); err != nil { - return model.VMRecord{}, err - } - op.stage("socket_dir") - if err := s.net.ensureSocketDir(); err != nil { - return model.VMRecord{}, err - } + // Derive per-VM paths/names up front so every step sees the same + // values. Shortening vm.ID mirrors how the pre-refactor inline + // code did it. shortID := system.ShortID(vm.ID) apiSock := filepath.Join(s.layout.RuntimeDir, "fc-"+shortID+".sock") dmName := "fc-rootfs-" + shortID @@ -78,183 +57,35 @@ func (s *VMService) startVMLocked(ctx context.Context, vm model.VMRecord, image return model.VMRecord{}, err } } - if err := os.RemoveAll(apiSock); err != nil && !os.IsNotExist(err) { - return model.VMRecord{}, err - } - if err := os.RemoveAll(vm.Runtime.VSockPath); err != nil && !os.IsNotExist(err) { - return model.VMRecord{}, err + + live := model.VMHandles{} + sc := &startContext{ + vm: &vm, + image: image, + live: &live, + apiSock: apiSock, + dmName: dmName, + tapName: tapName, } - op.stage("system_overlay", "overlay_path", vm.Runtime.SystemOverlay) - vmCreateStage(ctx, "prepare_rootfs", "preparing system overlay") - if err := s.ensureSystemOverlay(ctx, &vm); err != nil { - return model.VMRecord{}, err - } - - op.stage("dm_snapshot", "dm_name", dmName) - vmCreateStage(ctx, "prepare_rootfs", "creating root filesystem snapshot") - snapHandles, err := s.net.createDMSnapshot(ctx, image.RootfsPath, vm.Runtime.SystemOverlay, dmName) - if err != nil { - return model.VMRecord{}, err - } - // Live handles are threaded through this function as a local and - // pushed to the cache via setVMHandles once we have every piece. - // The cache update must happen BEFORE any step that reads handles - // back (e.g. cleanupRuntime via cleanupOnErr) — otherwise loops - // and DM would leak on an early failure. - live := model.VMHandles{ - BaseLoop: snapHandles.BaseLoop, - COWLoop: snapHandles.COWLoop, - DMName: snapHandles.DMName, - DMDev: snapHandles.DMDev, - } - s.setVMHandles(vm, live) - - vm.Runtime.APISockPath = apiSock - vm.Runtime.State = model.VMStateRunning - vm.State = model.VMStateRunning - vm.Runtime.LastError = "" - - cleanupOnErr := func(err error) (model.VMRecord, error) { + if runErr := s.runStartSteps(ctx, op, sc, s.buildStartSteps(op, sc)); runErr != nil { + // The step driver already ran rollback in reverse for every + // succeeded step. All that's left is to persist the ERROR + // state so operators see the failure via `vm show`. Use a + // fresh context in case the request ctx is cancelled — DB + // writes past this point are recovery, not user-driven. + // + // The store check is for tests that construct a bare Daemon + // without a DB; production always has s.store non-nil. vm.State = model.VMStateError vm.Runtime.State = model.VMStateError - vm.Runtime.LastError = err.Error() - op.stage("cleanup_after_failure", "error", err.Error()) - if cleanupErr := s.cleanupRuntime(context.Background(), vm, true); cleanupErr != nil { - err = errors.Join(err, cleanupErr) - } + vm.Runtime.LastError = runErr.Error() vm.Runtime.TapDevice = "" s.clearVMHandles(vm) - _ = s.store.UpsertVM(context.Background(), vm) - return model.VMRecord{}, err - } - - // On a restart the COW already holds writes from a previous guest - // boot — stale free-inode / free-block counters, possibly unwritten - // journal updates. e2fsprogs (e2cp/e2rm, used by patchRootOverlay) - // refuses to touch the snapshot with "Inode bitmap checksum does - // not match bitmap", which bubbles up as a "start failed" even - // though the filesystem is kernel-valid. `e2fsck -fy` reconciles - // the bitmaps and is a no-op on a fresh snapshot, so running it - // unconditionally keeps the code path the same for first vs. - // subsequent starts. Exit code 1 means "errors fixed" — we treat - // that as success. - op.stage("fsck_snapshot") - if _, err := s.runner.RunSudo(ctx, "e2fsck", "-fy", live.DMDev); err != nil { - // e2fsck exit codes: 0=clean, 1=errors corrected, 2=reboot - // needed, 4+=uncorrected. -1 means the error wasn't an - // exec.ExitError (e.g. command not found, ctx cancel). - if code := system.ExitCode(err); code < 0 || code > 1 { - return cleanupOnErr(fmt.Errorf("fsck snapshot: %w", err)) + if s.store != nil { + _ = s.store.UpsertVM(context.Background(), vm) } - } - - op.stage("patch_root_overlay") - vmCreateStage(ctx, "prepare_rootfs", "writing guest configuration") - if err := s.patchRootOverlay(ctx, vm, image); err != nil { - return cleanupOnErr(err) - } - op.stage("prepare_host_features") - vmCreateStage(ctx, "prepare_host_features", "preparing host-side vm features") - if err := s.capHooks.prepareHosts(ctx, &vm, image); err != nil { - return cleanupOnErr(err) - } - op.stage("tap") - tap, err := s.net.acquireTap(ctx, tapName) - if err != nil { - return cleanupOnErr(err) - } - live.TapDevice = tap - s.setVMHandles(vm, live) - // Mirror onto VM.Runtime so NAT teardown can recover the tap - // name from the DB even if the handle cache is empty (daemon - // crash + restart, corrupt handles.json). - vm.Runtime.TapDevice = tap - op.stage("metrics_file", "metrics_path", vm.Runtime.MetricsPath) - if err := os.WriteFile(vm.Runtime.MetricsPath, nil, 0o644); err != nil { - return cleanupOnErr(err) - } - - op.stage("firecracker_binary") - fcPath, err := s.net.firecrackerBinary() - if err != nil { - return cleanupOnErr(err) - } - op.stage("firecracker_launch", "log_path", vm.Runtime.LogPath, "metrics_path", vm.Runtime.MetricsPath) - vmCreateStage(ctx, "boot_firecracker", "starting firecracker") - kernelArgs := system.BuildBootArgs(vm.Name) - if strings.TrimSpace(image.InitrdPath) == "" { - // Direct-boot image (no initramfs) — the rootfs may be a - // container image without /sbin/init or iproute2. Use: - // 1. Kernel-level IP config via ip= cmdline (CONFIG_IP_PNP), - // so the network is up before init runs — no ip(8) needed. - // 2. init= pointing at our universal wrapper which installs - // systemd+sshd on first boot if missing. - kernelArgs = system.BuildBootArgsWithKernelIP( - vm.Name, vm.Runtime.GuestIP, s.config.BridgeIP, s.config.DefaultDNS, - ) + " init=" + imagepull.FirstBootScriptPath - } - - machineConfig := firecracker.MachineConfig{ - BinaryPath: fcPath, - VMID: vm.ID, - SocketPath: apiSock, - LogPath: vm.Runtime.LogPath, - MetricsPath: vm.Runtime.MetricsPath, - KernelImagePath: image.KernelPath, - InitrdPath: image.InitrdPath, - KernelArgs: kernelArgs, - Drives: []firecracker.DriveConfig{{ - ID: "rootfs", - Path: live.DMDev, - ReadOnly: false, - IsRoot: true, - }}, - TapDevice: tap, - VSockPath: vm.Runtime.VSockPath, - VSockCID: vm.Runtime.VSockCID, - VCPUCount: vm.Spec.VCPUCount, - MemoryMiB: vm.Spec.MemoryMiB, - Logger: s.logger, - } - s.capHooks.contributeMachine(&machineConfig, vm, image) - machine, err := firecracker.NewMachine(ctx, machineConfig) - if err != nil { - return cleanupOnErr(err) - } - if err := machine.Start(ctx); err != nil { - // Use a fresh context: the request ctx may already be cancelled (client - // disconnect), but we still need the PID so cleanupRuntime can kill the - // Firecracker process that was spawned before the failure. - live.PID = s.net.resolveFirecrackerPID(context.Background(), machine, apiSock) - s.setVMHandles(vm, live) - return cleanupOnErr(err) - } - live.PID = s.net.resolveFirecrackerPID(context.Background(), machine, apiSock) - s.setVMHandles(vm, live) - op.debugStage("firecracker_started", "pid", live.PID) - op.stage("socket_access", "api_socket", apiSock) - if err := s.net.ensureSocketAccess(ctx, apiSock, "firecracker api socket"); err != nil { - return cleanupOnErr(err) - } - op.stage("vsock_access", "vsock_path", vm.Runtime.VSockPath, "vsock_cid", vm.Runtime.VSockCID) - if err := s.net.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { - return cleanupOnErr(err) - } - vmCreateStage(ctx, "wait_vsock_agent", "waiting for guest vsock agent") - if err := s.net.waitForGuestVSockAgent(ctx, vm.Runtime.VSockPath, vsockReadyWait); err != nil { - return cleanupOnErr(err) - } - op.stage("post_start_features") - vmCreateStage(ctx, "wait_guest_ready", "waiting for guest services") - if err := s.capHooks.postStart(ctx, vm, image); err != nil { - return cleanupOnErr(err) - } - system.TouchNow(&vm) - op.stage("persist") - vmCreateStage(ctx, "finalize", "saving vm state") - if err := s.store.UpsertVM(ctx, vm); err != nil { - return cleanupOnErr(err) + return model.VMRecord{}, runErr } return vm, nil } diff --git a/internal/daemon/vm_lifecycle_steps.go b/internal/daemon/vm_lifecycle_steps.go new file mode 100644 index 0000000..5e9e753 --- /dev/null +++ b/internal/daemon/vm_lifecycle_steps.go @@ -0,0 +1,434 @@ +package daemon + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + + "banger/internal/firecracker" + "banger/internal/imagepull" + "banger/internal/model" + "banger/internal/system" +) + +// buildKernelArgs assembles the kernel command line for a start. +// Direct-boot images (no initrd) get kernel-level IP config so the +// network is up before init, plus init= pointing at the universal +// first-boot wrapper. Anything else uses the plain variant. +func buildKernelArgs(vm model.VMRecord, image model.Image, bridgeIP, defaultDNS string) string { + if strings.TrimSpace(image.InitrdPath) == "" { + return system.BuildBootArgsWithKernelIP( + vm.Name, vm.Runtime.GuestIP, bridgeIP, defaultDNS, + ) + " init=" + imagepull.FirstBootScriptPath + } + return system.BuildBootArgs(vm.Name) +} + +// startContext is the mutable state threaded through every start +// step. `vm` and `live` are pointers so steps mutate in place — +// dodges returning redundant copies and keeps step bodies readable. +// Values computed by `startVMLocked` before the driver runs +// (apiSock, dmName, tapName) live here too so each step can read +// them without rederiving. +type startContext struct { + vm *model.VMRecord + image model.Image + live *model.VMHandles + apiSock string + dmName string + tapName string + fcPath string + machine *firecracker.Machine + + // systemOverlayCreated records whether the system_overlay step + // actually created the file (vs. the file existing from a crashed + // prior attempt). The undo honours it so a leftover-but-valid + // overlay isn't deleted under us. + systemOverlayCreated bool +} + +// startStep is one phase in the start-VM pipeline. Phases with no +// rollback obligation leave `undo` nil — the driver simply skips +// them on the rollback path. `createStage` / `createDetail` are +// forwarded to `vmCreateStage` so the async-create RPC caller sees +// progress; they're "" for phases that were never part of the +// user-facing progress stream. +type startStep struct { + name string + attrs []any + createStage string + createDetail string + run func(ctx context.Context, sc *startContext) error + undo func(ctx context.Context, sc *startContext) error +} + +// runStartSteps walks steps in order, logging each via `op.stage` +// (and `vmCreateStage` when the step opted in). On the first +// run-err, it rolls back the prefix (including the failing step, so +// a step that acquired resources before erroring gets its undo +// fired) and returns the original err joined with any rollback err. +// +// Contract: `undo` must be safe to call even when `run` returned +// an error — check zero-value guards rather than assuming success. +// This is cheaper than a two-phase acquire/commit per step and +// matches how `cleanupPreparedCapabilities` in capabilities.go +// treats partial-success rollback. +func (s *VMService) runStartSteps(ctx context.Context, op *operationLog, sc *startContext, steps []startStep) error { + done := make([]startStep, 0, len(steps)) + for _, step := range steps { + if step.createStage != "" { + vmCreateStage(ctx, step.createStage, step.createDetail) + } + op.stage(step.name, step.attrs...) + if err := step.run(ctx, sc); err != nil { + done = append(done, step) // include the failing step — see contract above + if rollbackErr := s.rollbackStartSteps(op, sc, done); rollbackErr != nil { + err = errors.Join(err, rollbackErr) + } + return err + } + done = append(done, step) + } + return nil +} + +// rollbackStartSteps iterates completed steps in reverse, calling +// each non-nil `undo` with a detached context — the original ctx +// may already be cancelled (RPC client disconnect), but cleanup +// still needs to run. Undo errors are joined together; one step's +// failure doesn't short-circuit the rest. +func (s *VMService) rollbackStartSteps(op *operationLog, sc *startContext, done []startStep) error { + var err error + for i := len(done) - 1; i >= 0; i-- { + step := done[i] + if step.undo == nil { + continue + } + op.stage("rollback_" + step.name) + if undoErr := step.undo(context.Background(), sc); undoErr != nil { + err = errors.Join(err, fmt.Errorf("rollback %s: %w", step.name, undoErr)) + } + } + return err +} + +// buildStartSteps returns the ordered list of phases startVMLocked +// drives. Keeping the list as data (vs. a long linear method body) +// makes the phase inventory diff-readable and lets a test driver +// substitute its own step slice. +// +// Phase names MUST stay 1:1 with the prior inline version — they +// appear in daemon logs, smoke-log greps, and the async-create +// progress stream that clients read. +func (s *VMService) buildStartSteps(op *operationLog, sc *startContext) []startStep { + return []startStep{ + { + name: "preflight", + createStage: "preflight", + createDetail: "checking host prerequisites", + run: func(ctx context.Context, sc *startContext) error { + if err := s.validateStartPrereqs(ctx, *sc.vm, sc.image); err != nil { + return err + } + return os.MkdirAll(sc.vm.Runtime.VMDir, 0o755) + }, + }, + { + name: "cleanup_runtime", + run: func(ctx context.Context, sc *startContext) error { + if err := s.cleanupRuntime(ctx, *sc.vm, true); err != nil { + return err + } + s.clearVMHandles(*sc.vm) + return nil + }, + }, + { + name: "bridge", + run: func(ctx context.Context, _ *startContext) error { + return s.net.ensureBridge(ctx) + }, + }, + { + name: "socket_dir", + run: func(_ context.Context, _ *startContext) error { + return s.net.ensureSocketDir() + }, + }, + { + // prepare_sockets is a new op.stage label — the prior + // inline code ran these `os.RemoveAll` calls before the + // system_overlay stage without a stage marker. Keeping a + // distinct name makes the log trace and rollback (if any + // later step fails) unambiguous. + name: "prepare_sockets", + run: func(_ context.Context, sc *startContext) error { + if err := os.RemoveAll(sc.apiSock); err != nil && !os.IsNotExist(err) { + return err + } + if err := os.RemoveAll(sc.vm.Runtime.VSockPath); err != nil && !os.IsNotExist(err) { + return err + } + return nil + }, + }, + { + name: "system_overlay", + attrs: []any{"overlay_path", sc.vm.Runtime.SystemOverlay}, + createStage: "prepare_rootfs", + createDetail: "preparing system overlay", + run: func(ctx context.Context, sc *startContext) error { + // Record ownership BEFORE the call so a partial-truncate + // failure still triggers cleanup of the half-created file. + if !exists(sc.vm.Runtime.SystemOverlay) { + sc.systemOverlayCreated = true + } + return s.ensureSystemOverlay(ctx, sc.vm) + }, + undo: func(_ context.Context, sc *startContext) error { + if !sc.systemOverlayCreated { + return nil + } + if err := os.Remove(sc.vm.Runtime.SystemOverlay); err != nil && !os.IsNotExist(err) { + return err + } + return nil + }, + }, + { + name: "dm_snapshot", + attrs: []any{"dm_name", sc.dmName}, + createStage: "prepare_rootfs", + createDetail: "creating root filesystem snapshot", + run: func(ctx context.Context, sc *startContext) error { + snapHandles, err := s.net.createDMSnapshot(ctx, sc.image.RootfsPath, sc.vm.Runtime.SystemOverlay, sc.dmName) + if err != nil { + // createDMSnapshot cleans up its own partial state on + // err; leave sc.live zero so the undo is a no-op. + return err + } + sc.live.BaseLoop = snapHandles.BaseLoop + sc.live.COWLoop = snapHandles.COWLoop + sc.live.DMName = snapHandles.DMName + sc.live.DMDev = snapHandles.DMDev + s.setVMHandles(*sc.vm, *sc.live) + // Fields that used to land next to the (now-deleted) + // cleanupOnErr closure. They belong with the DM + // snapshot because that's the first step producing + // runtime identity the downstream code reads back. + sc.vm.Runtime.APISockPath = sc.apiSock + sc.vm.Runtime.State = model.VMStateRunning + sc.vm.State = model.VMStateRunning + sc.vm.Runtime.LastError = "" + return nil + }, + undo: func(ctx context.Context, sc *startContext) error { + if sc.live.DMName == "" && sc.live.BaseLoop == "" && sc.live.COWLoop == "" { + return nil + } + return s.net.cleanupDMSnapshot(ctx, dmSnapshotHandles{ + BaseLoop: sc.live.BaseLoop, + COWLoop: sc.live.COWLoop, + DMName: sc.live.DMName, + DMDev: sc.live.DMDev, + }) + }, + }, + { + // See the comment in the prior inline version: stale + // bitmaps from a reused COW make e2cp/e2rm refuse to + // touch the snapshot. e2fsck -fy is a no-op on a fresh + // snapshot. Exit codes 0 + 1 are both "ok" here. + name: "fsck_snapshot", + run: func(ctx context.Context, sc *startContext) error { + if _, err := s.runner.RunSudo(ctx, "e2fsck", "-fy", sc.live.DMDev); err != nil { + if code := system.ExitCode(err); code < 0 || code > 1 { + return fmt.Errorf("fsck snapshot: %w", err) + } + } + return nil + }, + }, + { + name: "patch_root_overlay", + createStage: "prepare_rootfs", + createDetail: "writing guest configuration", + run: func(ctx context.Context, sc *startContext) error { + return s.patchRootOverlay(ctx, *sc.vm, sc.image) + }, + }, + { + name: "prepare_host_features", + createStage: "prepare_host_features", + createDetail: "preparing host-side vm features", + run: func(ctx context.Context, sc *startContext) error { + return s.capHooks.prepareHosts(ctx, sc.vm, sc.image) + }, + // On err, prepareHosts already cleaned up the prefix that + // succeeded before the failing capability. On success, any + // LATER step failure triggers this undo, which tears down + // ALL prepared caps via their Cleanup hooks. + undo: func(ctx context.Context, sc *startContext) error { + return s.capHooks.cleanupState(ctx, *sc.vm) + }, + }, + { + name: "tap", + run: func(ctx context.Context, sc *startContext) error { + tap, err := s.net.acquireTap(ctx, sc.tapName) + if err != nil { + return err + } + sc.live.TapDevice = tap + s.setVMHandles(*sc.vm, *sc.live) + // Mirror onto VM.Runtime for NAT teardown resilience + // across daemon crashes — see vm.Runtime.TapDevice docs. + sc.vm.Runtime.TapDevice = tap + return nil + }, + undo: func(ctx context.Context, sc *startContext) error { + if sc.live.TapDevice == "" { + return nil + } + return s.net.releaseTap(ctx, sc.live.TapDevice) + }, + }, + { + name: "metrics_file", + attrs: []any{"metrics_path", sc.vm.Runtime.MetricsPath}, + run: func(_ context.Context, sc *startContext) error { + return os.WriteFile(sc.vm.Runtime.MetricsPath, nil, 0o644) + }, + undo: func(_ context.Context, sc *startContext) error { + if err := os.Remove(sc.vm.Runtime.MetricsPath); err != nil && !os.IsNotExist(err) { + return err + } + return nil + }, + }, + { + name: "firecracker_binary", + run: func(_ context.Context, sc *startContext) error { + fcPath, err := s.net.firecrackerBinary() + if err != nil { + return err + } + sc.fcPath = fcPath + return nil + }, + }, + { + name: "firecracker_launch", + attrs: []any{"log_path", sc.vm.Runtime.LogPath, "metrics_path", sc.vm.Runtime.MetricsPath}, + createStage: "boot_firecracker", + createDetail: "starting firecracker", + run: func(ctx context.Context, sc *startContext) error { + kernelArgs := buildKernelArgs(*sc.vm, sc.image, s.config.BridgeIP, s.config.DefaultDNS) + machineConfig := firecracker.MachineConfig{ + BinaryPath: sc.fcPath, + VMID: sc.vm.ID, + SocketPath: sc.apiSock, + LogPath: sc.vm.Runtime.LogPath, + MetricsPath: sc.vm.Runtime.MetricsPath, + KernelImagePath: sc.image.KernelPath, + InitrdPath: sc.image.InitrdPath, + KernelArgs: kernelArgs, + Drives: []firecracker.DriveConfig{{ + ID: "rootfs", + Path: sc.live.DMDev, + ReadOnly: false, + IsRoot: true, + }}, + TapDevice: sc.live.TapDevice, + VSockPath: sc.vm.Runtime.VSockPath, + VSockCID: sc.vm.Runtime.VSockCID, + VCPUCount: sc.vm.Spec.VCPUCount, + MemoryMiB: sc.vm.Spec.MemoryMiB, + Logger: s.logger, + } + s.capHooks.contributeMachine(&machineConfig, *sc.vm, sc.image) + machine, err := firecracker.NewMachine(ctx, machineConfig) + if err != nil { + return err + } + sc.machine = machine + if err := machine.Start(ctx); err != nil { + // machine.Start can fail AFTER the firecracker process + // is already spawned (HTTP config phase). Record the + // PID so the undo can kill it; use a fresh ctx since + // the request ctx may be cancelled by now. + sc.live.PID = s.net.resolveFirecrackerPID(context.Background(), machine, sc.apiSock) + s.setVMHandles(*sc.vm, *sc.live) + return err + } + sc.live.PID = s.net.resolveFirecrackerPID(context.Background(), machine, sc.apiSock) + s.setVMHandles(*sc.vm, *sc.live) + op.debugStage("firecracker_started", "pid", sc.live.PID) + return nil + }, + undo: func(ctx context.Context, sc *startContext) error { + var errs []error + if sc.live.PID > 0 { + if err := s.net.killVMProcess(ctx, sc.live.PID); err != nil { + errs = append(errs, err) + } + } + if err := os.Remove(sc.apiSock); err != nil && !os.IsNotExist(err) { + errs = append(errs, err) + } + if err := os.Remove(sc.vm.Runtime.VSockPath); err != nil && !os.IsNotExist(err) { + errs = append(errs, err) + } + return errors.Join(errs...) + }, + }, + { + name: "socket_access", + attrs: []any{"api_socket", sc.apiSock}, + run: func(ctx context.Context, sc *startContext) error { + return s.net.ensureSocketAccess(ctx, sc.apiSock, "firecracker api socket") + }, + }, + { + name: "vsock_access", + attrs: []any{"vsock_path", sc.vm.Runtime.VSockPath, "vsock_cid", sc.vm.Runtime.VSockCID}, + run: func(ctx context.Context, sc *startContext) error { + return s.net.ensureSocketAccess(ctx, sc.vm.Runtime.VSockPath, "firecracker vsock socket") + }, + }, + { + name: "wait_vsock_agent", + createStage: "wait_vsock_agent", + createDetail: "waiting for guest vsock agent", + run: func(ctx context.Context, sc *startContext) error { + return s.net.waitForGuestVSockAgent(ctx, sc.vm.Runtime.VSockPath, vsockReadyWait) + }, + }, + { + name: "post_start_features", + createStage: "wait_guest_ready", + createDetail: "waiting for guest services", + run: func(ctx context.Context, sc *startContext) error { + return s.capHooks.postStart(ctx, *sc.vm, sc.image) + }, + // Capability Cleanup hooks are designed to be idempotent + // (check feature-enabled flag, no-op if nothing to undo), + // so calling cleanupState here is safe whether postStart + // reached every cap or bailed midway. + undo: func(ctx context.Context, sc *startContext) error { + return s.capHooks.cleanupState(ctx, *sc.vm) + }, + }, + { + name: "persist", + createStage: "finalize", + createDetail: "saving vm state", + run: func(ctx context.Context, sc *startContext) error { + system.TouchNow(sc.vm) + return s.store.UpsertVM(ctx, *sc.vm) + }, + }, + } +} diff --git a/internal/daemon/vm_lifecycle_steps_test.go b/internal/daemon/vm_lifecycle_steps_test.go new file mode 100644 index 0000000..f6998a6 --- /dev/null +++ b/internal/daemon/vm_lifecycle_steps_test.go @@ -0,0 +1,164 @@ +package daemon + +import ( + "context" + "errors" + "io" + "log/slog" + "strings" + "testing" +) + +// TestRunStartSteps_RollsBackInReverseOnFailure pins the driver +// contract at the heart of commit 1's refactor: on a step failure +// (a) every step that succeeded BEFORE the failing one gets its +// undo fired in reverse order; (b) the failing step's undo also +// fires, because steps may acquire partial state before returning +// err; (c) the final error wraps both the run error and any +// rollback errors via errors.Join. +func TestRunStartSteps_RollsBackInReverseOnFailure(t *testing.T) { + s := &VMService{} + op := &operationLog{logger: slog.New(slog.NewTextHandler(io.Discard, nil))} + sc := &startContext{} + + var events []string + record := func(label string) func(context.Context, *startContext) error { + return func(context.Context, *startContext) error { + events = append(events, label) + return nil + } + } + recordErr := func(label string, err error) func(context.Context, *startContext) error { + return func(context.Context, *startContext) error { + events = append(events, label) + return err + } + } + + steps := []startStep{ + {name: "first", run: record("run-first"), undo: record("undo-first")}, + {name: "second", run: record("run-second"), undo: record("undo-second")}, + {name: "third", run: recordErr("run-third", errors.New("boom")), undo: record("undo-third")}, + {name: "fourth", run: record("run-fourth"), undo: record("undo-fourth")}, + } + + err := s.runStartSteps(context.Background(), op, sc, steps) + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("runStartSteps err = %v, want containing 'boom'", err) + } + + want := []string{ + // Forward run: first, second, third (fails — fourth never runs). + "run-first", "run-second", "run-third", + // Reverse undo: third, second, first. Fourth never ran so no undo-fourth. + "undo-third", "undo-second", "undo-first", + } + if len(events) != len(want) { + t.Fatalf("events length = %d, want %d:\n got: %v\n want: %v", len(events), len(want), events, want) + } + for i := range want { + if events[i] != want[i] { + t.Fatalf("events[%d] = %q, want %q\n got: %v\n want: %v", i, events[i], want[i], events, want) + } + } +} + +// TestRunStartSteps_SkipsNilUndos proves the optional-undo contract: +// steps without teardown obligations leave `undo` nil and the driver +// must silently skip them during rollback rather than panicking. +func TestRunStartSteps_SkipsNilUndos(t *testing.T) { + s := &VMService{} + op := &operationLog{logger: slog.New(slog.NewTextHandler(io.Discard, nil))} + sc := &startContext{} + + var undoCalls []string + undo := func(label string) func(context.Context, *startContext) error { + return func(context.Context, *startContext) error { + undoCalls = append(undoCalls, label) + return nil + } + } + noop := func(context.Context, *startContext) error { return nil } + + steps := []startStep{ + {name: "has-undo", run: noop, undo: undo("has-undo")}, + {name: "no-undo", run: noop}, // undo nil intentionally + {name: "failing", run: func(context.Context, *startContext) error { return errors.New("x") }, undo: undo("failing")}, + } + + if err := s.runStartSteps(context.Background(), op, sc, steps); err == nil { + t.Fatal("runStartSteps err = nil, want failure") + } + + // Rollback order: failing (acquired state, so its undo runs), no-undo + // (skipped — nil), has-undo. + want := []string{"failing", "has-undo"} + if len(undoCalls) != len(want) || undoCalls[0] != want[0] || undoCalls[1] != want[1] { + t.Fatalf("undo calls = %v, want %v", undoCalls, want) + } +} + +// TestRunStartSteps_JoinsRollbackErrors asserts that undo errors are +// joined onto the original run error rather than hiding it — the +// caller must always see the root cause ("boom") even when the +// rollback path itself is messy. +func TestRunStartSteps_JoinsRollbackErrors(t *testing.T) { + s := &VMService{} + op := &operationLog{logger: slog.New(slog.NewTextHandler(io.Discard, nil))} + sc := &startContext{} + + rootErr := errors.New("boom") + undoErr := errors.New("undo-fail") + + steps := []startStep{ + { + name: "ok", + run: func(context.Context, *startContext) error { return nil }, + undo: func(context.Context, *startContext) error { return undoErr }, + }, + { + name: "fail", + run: func(context.Context, *startContext) error { return rootErr }, + }, + } + + err := s.runStartSteps(context.Background(), op, sc, steps) + if err == nil { + t.Fatal("err = nil, want joined error") + } + if !errors.Is(err, rootErr) { + t.Fatalf("err does not wrap rootErr; got: %v", err) + } + if !errors.Is(err, undoErr) { + t.Fatalf("err does not wrap undoErr; got: %v", err) + } +} + +// TestRunStartSteps_HappyPathNoRollback confirms that when every +// step's run returns nil, no undo fires — rollback is strictly a +// failure-path concern. +func TestRunStartSteps_HappyPathNoRollback(t *testing.T) { + s := &VMService{} + op := &operationLog{logger: slog.New(slog.NewTextHandler(io.Discard, nil))} + sc := &startContext{} + + var undoCalled bool + steps := []startStep{ + { + name: "a", + run: func(context.Context, *startContext) error { return nil }, + undo: func(context.Context, *startContext) error { undoCalled = true; return nil }, + }, + { + name: "b", + run: func(context.Context, *startContext) error { return nil }, + }, + } + + if err := s.runStartSteps(context.Background(), op, sc, steps); err != nil { + t.Fatalf("runStartSteps err = %v, want nil", err) + } + if undoCalled { + t.Fatal("undo fired on happy path — rollback must only run on failure") + } +}