daemon split (4/5): extract *VMService service

Phase 4 of the daemon god-struct refactor. VM lifecycle, create-op
registry, handle cache, disk provisioning, stats polling, ports
query, and the per-VM lock set all move off *Daemon onto *VMService.

Daemon keeps thin forwarders only for FindVM / TouchVM (dispatch
surface) and is otherwise out of VM lifecycle. Lazy-init via
d.vmSvc() mirrors the earlier services so test literals like
\`&Daemon{store: db, runner: r}\` still get a functional service
without spelling one out.

Three small cleanups along the way:

  * preflight helpers (validateStartPrereqs / addBaseStartPrereqs
    / addBaseStartCommandPrereqs / validateWorkDiskResizePrereqs)
    move with the VM methods that call them.
  * cleanupRuntime / rebuildDNS move to *VMService, with
    HostNetwork primitives (findFirecrackerPID, cleanupDMSnapshot,
    killVMProcess, releaseTap, waitForExit, sendCtrlAltDel)
    reached through s.net instead of the hostNet() facade.
  * vsockAgentBinary becomes a package-level function so both
    *Daemon (doctor) and *VMService (preflight) call one entry
    point instead of each owning a forwarder method.

WorkspaceService's peer deps switch from eager method values to
closures — vmSvc() constructs VMService with WorkspaceService as a
peer, so resolving d.vmSvc().FindVM at construction time recursed
through workspaceSvc() → vmSvc(). Closures defer the lookup to call
time.

Pure code motion: build + unit tests green, lint clean. No RPC
surface or lock-ordering changes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Thales Maciel 2026-04-20 20:57:05 -03:00
parent c0d456e734
commit 466a7c30c4
No known key found for this signature in database
GPG key ID: 33112E6833C34679
23 changed files with 655 additions and 463 deletions

View file

@ -38,7 +38,7 @@ func TestFindOrAutoPullImageReturnsLocalWithoutPulling(t *testing.T) {
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
image, err := d.findOrAutoPullImage(context.Background(), "my-local-image") image, err := d.vmSvc().findOrAutoPullImage(context.Background(), "my-local-image")
if err != nil { if err != nil {
t.Fatalf("findOrAutoPullImage: %v", err) t.Fatalf("findOrAutoPullImage: %v", err)
} }
@ -68,7 +68,7 @@ func TestFindOrAutoPullImagePullsFromCatalog(t *testing.T) {
}, },
} }
// "debian-bookworm" is in the embedded imagecat catalog. // "debian-bookworm" is in the embedded imagecat catalog.
image, err := d.findOrAutoPullImage(context.Background(), "debian-bookworm") image, err := d.vmSvc().findOrAutoPullImage(context.Background(), "debian-bookworm")
if err != nil { if err != nil {
t.Fatalf("findOrAutoPullImage: %v", err) t.Fatalf("findOrAutoPullImage: %v", err)
} }
@ -86,7 +86,7 @@ func TestFindOrAutoPullImageReturnsOriginalErrorWhenNotInCatalog(t *testing.T) {
store: openDaemonStore(t), store: openDaemonStore(t),
runner: system.NewRunner(), runner: system.NewRunner(),
} }
_, err := d.findOrAutoPullImage(context.Background(), "not-in-catalog-or-store") _, err := d.vmSvc().findOrAutoPullImage(context.Background(), "not-in-catalog-or-store")
if err == nil || !strings.Contains(err.Error(), "not found") { if err == nil || !strings.Contains(err.Error(), "not found") {
t.Fatalf("err = %v, want not-found", err) t.Fatalf("err = %v, want not-found", err)
} }

View file

@ -199,7 +199,7 @@ func (workDiskCapability) ContributeMachine(cfg *firecracker.MachineConfig, vm m
} }
func (workDiskCapability) PrepareHost(ctx context.Context, d *Daemon, vm *model.VMRecord, image model.Image) error { func (workDiskCapability) PrepareHost(ctx context.Context, d *Daemon, vm *model.VMRecord, image model.Image) error {
prep, err := d.ensureWorkDisk(ctx, vm, image) prep, err := d.vmSvc().ensureWorkDisk(ctx, vm, image)
if err != nil { if err != nil {
return err return err
} }
@ -270,14 +270,14 @@ func (natCapability) PostStart(ctx context.Context, d *Daemon, vm model.VMRecord
if !vm.Spec.NATEnabled { if !vm.Spec.NATEnabled {
return nil return nil
} }
return d.hostNet().ensureNAT(ctx, vm.Runtime.GuestIP, d.vmHandles(vm.ID).TapDevice, true) return d.hostNet().ensureNAT(ctx, vm.Runtime.GuestIP, d.vmSvc().vmHandles(vm.ID).TapDevice, true)
} }
func (natCapability) Cleanup(ctx context.Context, d *Daemon, vm model.VMRecord) error { func (natCapability) Cleanup(ctx context.Context, d *Daemon, vm model.VMRecord) error {
if !vm.Spec.NATEnabled { if !vm.Spec.NATEnabled {
return nil return nil
} }
tap := d.vmHandles(vm.ID).TapDevice tap := d.vmSvc().vmHandles(vm.ID).TapDevice
if strings.TrimSpace(vm.Runtime.GuestIP) == "" || strings.TrimSpace(tap) == "" { if strings.TrimSpace(vm.Runtime.GuestIP) == "" || strings.TrimSpace(tap) == "" {
if d.logger != nil { if d.logger != nil {
d.logger.Debug("skipping nat cleanup without runtime network handles", append(vmLogAttrs(vm), "guest_ip", vm.Runtime.GuestIP, "tap_device", tap)...) d.logger.Debug("skipping nat cleanup without runtime network handles", append(vmLogAttrs(vm), "guest_ip", vm.Runtime.GuestIP, "tap_device", tap)...)
@ -291,10 +291,10 @@ func (natCapability) ApplyConfigChange(ctx context.Context, d *Daemon, before, a
if before.Spec.NATEnabled == after.Spec.NATEnabled { if before.Spec.NATEnabled == after.Spec.NATEnabled {
return nil return nil
} }
if !d.vmAlive(after) { if !d.vmSvc().vmAlive(after) {
return nil return nil
} }
return d.hostNet().ensureNAT(ctx, after.Runtime.GuestIP, d.vmHandles(after.ID).TapDevice, after.Spec.NATEnabled) return d.hostNet().ensureNAT(ctx, after.Runtime.GuestIP, d.vmSvc().vmHandles(after.ID).TapDevice, after.Spec.NATEnabled)
} }
func (natCapability) AddDoctorChecks(ctx context.Context, d *Daemon, report *system.Report) { func (natCapability) AddDoctorChecks(ctx context.Context, d *Daemon, report *system.Report) {

View file

@ -3,21 +3,18 @@ package daemon
import ( import (
"bufio" "bufio"
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
"net" "net"
"os" "os"
"strings"
"sync" "sync"
"time" "time"
"banger/internal/api" "banger/internal/api"
"banger/internal/buildinfo" "banger/internal/buildinfo"
"banger/internal/config" "banger/internal/config"
"banger/internal/daemon/opstate"
"banger/internal/model" "banger/internal/model"
"banger/internal/paths" "banger/internal/paths"
"banger/internal/rpc" "banger/internal/rpc"
@ -26,31 +23,23 @@ import (
"banger/internal/vmdns" "banger/internal/vmdns"
) )
// Daemon is the composition root: shared infrastructure (store,
// runner, logger, layout, config) plus pointers to the four focused
// services that own behavior. Open wires the services; the dispatch
// loop forwards RPCs to them. No lifecycle / image / workspace /
// networking behavior lives on *Daemon itself — it's wiring.
type Daemon struct { type Daemon struct {
layout paths.Layout layout paths.Layout
config model.DaemonConfig config model.DaemonConfig
store *store.Store store *store.Store
runner system.CommandRunner runner system.CommandRunner
logger *slog.Logger logger *slog.Logger
createVMMu sync.Mutex
createOps opstate.Registry[*vmCreateOperationState] net *HostNetwork
vmLocks vmLockSet img *ImageService
// workspaceLocks serialises workspace.prepare / workspace.export ws *WorkspaceService
// calls on the same VM (two concurrent prepares would clobber each vm *VMService
// other's tar streams). It is a SEPARATE scope from vmLocks so
// slow guest I/O — SSH dial, tar upload, chmod — does not block
// vm stop/delete/restart. See ARCHITECTURE.md.
workspaceLocks vmLockSet
// handles caches per-VM transient kernel/process handles (PID,
// tap device, loop devices, DM name/device). Populated at vm
// start and at daemon startup reconcile; cleared on stop/delete.
// See internal/daemon/vm_handles.go — persistent durable state
// lives in the store, this is rebuildable from a per-VM
// handles.json scratch file and OS inspection.
handles *handleCache
net *HostNetwork
img *ImageService
ws *WorkspaceService
closing chan struct{} closing chan struct{}
once sync.Once once sync.Once
pid int pid int
@ -92,7 +81,6 @@ func Open(ctx context.Context) (d *Daemon, err error) {
logger: logger, logger: logger,
closing: closing, closing: closing,
pid: os.Getpid(), pid: os.Getpid(),
handles: newHandleCache(),
net: newHostNetwork(hostNetworkDeps{ net: newHostNetwork(hostNetworkDeps{
runner: runner, runner: runner,
logger: logger, logger: logger,
@ -134,7 +122,7 @@ func Open(ctx context.Context) (d *Daemon, err error) {
} }
used := make([]string, 0, len(vms)) used := make([]string, 0, len(vms))
for _, vm := range vms { for _, vm := range vms {
if tap := d.vmHandles(vm.ID).TapDevice; tap != "" { if tap := d.vmSvc().vmHandles(vm.ID).TapDevice; tap != "" {
used = append(used, tap) used = append(used, tap)
} }
} }
@ -294,28 +282,28 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
vm, err := d.CreateVM(ctx, params) vm, err := d.vmSvc().CreateVM(ctx, params)
return marshalResultOrError(api.VMShowResult{VM: vm}, err) return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.create.begin": case "vm.create.begin":
params, err := rpc.DecodeParams[api.VMCreateParams](req) params, err := rpc.DecodeParams[api.VMCreateParams](req)
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
op, err := d.BeginVMCreate(ctx, params) op, err := d.vmSvc().BeginVMCreate(ctx, params)
return marshalResultOrError(api.VMCreateBeginResult{Operation: op}, err) return marshalResultOrError(api.VMCreateBeginResult{Operation: op}, err)
case "vm.create.status": case "vm.create.status":
params, err := rpc.DecodeParams[api.VMCreateStatusParams](req) params, err := rpc.DecodeParams[api.VMCreateStatusParams](req)
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
op, err := d.VMCreateStatus(ctx, params.ID) op, err := d.vmSvc().VMCreateStatus(ctx, params.ID)
return marshalResultOrError(api.VMCreateStatusResult{Operation: op}, err) return marshalResultOrError(api.VMCreateStatusResult{Operation: op}, err)
case "vm.create.cancel": case "vm.create.cancel":
params, err := rpc.DecodeParams[api.VMCreateStatusParams](req) params, err := rpc.DecodeParams[api.VMCreateStatusParams](req)
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
err = d.CancelVMCreate(ctx, params.ID) err = d.vmSvc().CancelVMCreate(ctx, params.ID)
return marshalResultOrError(api.Empty{}, err) return marshalResultOrError(api.Empty{}, err)
case "vm.list": case "vm.list":
vms, err := d.store.ListVMs(ctx) vms, err := d.store.ListVMs(ctx)
@ -325,63 +313,63 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
vm, err := d.FindVM(ctx, params.IDOrName) vm, err := d.vmSvc().FindVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err) return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.start": case "vm.start":
params, err := rpc.DecodeParams[api.VMRefParams](req) params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
vm, err := d.StartVM(ctx, params.IDOrName) vm, err := d.vmSvc().StartVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err) return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.stop": case "vm.stop":
params, err := rpc.DecodeParams[api.VMRefParams](req) params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
vm, err := d.StopVM(ctx, params.IDOrName) vm, err := d.vmSvc().StopVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err) return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.kill": case "vm.kill":
params, err := rpc.DecodeParams[api.VMKillParams](req) params, err := rpc.DecodeParams[api.VMKillParams](req)
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
vm, err := d.KillVM(ctx, params) vm, err := d.vmSvc().KillVM(ctx, params)
return marshalResultOrError(api.VMShowResult{VM: vm}, err) return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.restart": case "vm.restart":
params, err := rpc.DecodeParams[api.VMRefParams](req) params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
vm, err := d.RestartVM(ctx, params.IDOrName) vm, err := d.vmSvc().RestartVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err) return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.delete": case "vm.delete":
params, err := rpc.DecodeParams[api.VMRefParams](req) params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
vm, err := d.DeleteVM(ctx, params.IDOrName) vm, err := d.vmSvc().DeleteVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err) return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.set": case "vm.set":
params, err := rpc.DecodeParams[api.VMSetParams](req) params, err := rpc.DecodeParams[api.VMSetParams](req)
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
vm, err := d.SetVM(ctx, params) vm, err := d.vmSvc().SetVM(ctx, params)
return marshalResultOrError(api.VMShowResult{VM: vm}, err) return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.stats": case "vm.stats":
params, err := rpc.DecodeParams[api.VMRefParams](req) params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
vm, stats, err := d.GetVMStats(ctx, params.IDOrName) vm, stats, err := d.vmSvc().GetVMStats(ctx, params.IDOrName)
return marshalResultOrError(api.VMStatsResult{VM: vm, Stats: stats}, err) return marshalResultOrError(api.VMStatsResult{VM: vm, Stats: stats}, err)
case "vm.logs": case "vm.logs":
params, err := rpc.DecodeParams[api.VMRefParams](req) params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
vm, err := d.FindVM(ctx, params.IDOrName) vm, err := d.vmSvc().FindVM(ctx, params.IDOrName)
if err != nil { if err != nil {
return rpc.NewError("not_found", err.Error()) return rpc.NewError("not_found", err.Error())
} }
@ -391,11 +379,11 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
vm, err := d.TouchVM(ctx, params.IDOrName) vm, err := d.vmSvc().TouchVM(ctx, params.IDOrName)
if err != nil { if err != nil {
return rpc.NewError("not_found", err.Error()) return rpc.NewError("not_found", err.Error())
} }
if !d.vmAlive(vm) { if !d.vmSvc().vmAlive(vm) {
return rpc.NewError("not_running", fmt.Sprintf("vm %s is not running", vm.Name)) return rpc.NewError("not_running", fmt.Sprintf("vm %s is not running", vm.Name))
} }
return marshalResultOrError(api.VMSSHResult{Name: vm.Name, GuestIP: vm.Runtime.GuestIP}, nil) return marshalResultOrError(api.VMSSHResult{Name: vm.Name, GuestIP: vm.Runtime.GuestIP}, nil)
@ -404,21 +392,21 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
result, err := d.HealthVM(ctx, params.IDOrName) result, err := d.vmSvc().HealthVM(ctx, params.IDOrName)
return marshalResultOrError(result, err) return marshalResultOrError(result, err)
case "vm.ping": case "vm.ping":
params, err := rpc.DecodeParams[api.VMRefParams](req) params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
result, err := d.PingVM(ctx, params.IDOrName) result, err := d.vmSvc().PingVM(ctx, params.IDOrName)
return marshalResultOrError(result, err) return marshalResultOrError(result, err)
case "vm.ports": case "vm.ports":
params, err := rpc.DecodeParams[api.VMRefParams](req) params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil { if err != nil {
return rpc.NewError("bad_request", err.Error()) return rpc.NewError("bad_request", err.Error())
} }
result, err := d.PortsVM(ctx, params.IDOrName) result, err := d.vmSvc().PortsVM(ctx, params.IDOrName)
return marshalResultOrError(result, err) return marshalResultOrError(result, err)
case "vm.workspace.prepare": case "vm.workspace.prepare":
params, err := rpc.DecodeParams[api.VMWorkspacePrepareParams](req) params, err := rpc.DecodeParams[api.VMWorkspacePrepareParams](req)
@ -519,14 +507,14 @@ func (d *Daemon) backgroundLoop() {
case <-d.closing: case <-d.closing:
return return
case <-statsTicker.C: case <-statsTicker.C:
if err := d.pollStats(context.Background()); err != nil && d.logger != nil { if err := d.vmSvc().pollStats(context.Background()); err != nil && d.logger != nil {
d.logger.Error("background stats poll failed", "error", err.Error()) d.logger.Error("background stats poll failed", "error", err.Error())
} }
case <-staleTicker.C: case <-staleTicker.C:
if err := d.stopStaleVMs(context.Background()); err != nil && d.logger != nil { if err := d.vmSvc().stopStaleVMs(context.Background()); err != nil && d.logger != nil {
d.logger.Error("background stale sweep failed", "error", err.Error()) d.logger.Error("background stale sweep failed", "error", err.Error())
} }
d.pruneVMCreateOperations(time.Now().Add(-10 * time.Minute)) d.vmSvc().pruneVMCreateOperations(time.Now().Add(-10 * time.Minute))
} }
} }
} }
@ -543,18 +531,18 @@ func (d *Daemon) reconcile(ctx context.Context) error {
return op.fail(err) return op.fail(err)
} }
for _, vm := range vms { for _, vm := range vms {
if err := d.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error { if err := d.vmSvc().withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error {
if vm.State != model.VMStateRunning { if vm.State != model.VMStateRunning {
// Belt-and-braces: a stopped VM should never have a // Belt-and-braces: a stopped VM should never have a
// scratch file or a cache entry. Clean up anything // scratch file or a cache entry. Clean up anything
// left by an ungraceful previous daemon crash. // left by an ungraceful previous daemon crash.
d.clearVMHandles(vm) d.vmSvc().clearVMHandles(vm)
return nil return nil
} }
// Rebuild the in-memory handle cache by loading the per-VM // Rebuild the in-memory handle cache by loading the per-VM
// scratch file and verifying the firecracker process is // scratch file and verifying the firecracker process is
// still alive. // still alive.
h, alive, err := d.rediscoverHandles(ctx, vm) h, alive, err := d.vmSvc().rediscoverHandles(ctx, vm)
if err != nil && d.logger != nil { if err != nil && d.logger != nil {
d.logger.Warn("rediscover handles failed", "vm_id", vm.ID, "error", err.Error()) d.logger.Warn("rediscover handles failed", "vm_id", vm.ID, "error", err.Error())
} }
@ -562,54 +550,33 @@ func (d *Daemon) reconcile(ctx context.Context) error {
// claimed. If alive, subsequent vmAlive() calls pass; if // claimed. If alive, subsequent vmAlive() calls pass; if
// not, cleanupRuntime needs these handles to know which // not, cleanupRuntime needs these handles to know which
// kernel resources (DM / loops / tap) to tear down. // kernel resources (DM / loops / tap) to tear down.
d.setVMHandlesInMemory(vm.ID, h) d.vmSvc().setVMHandlesInMemory(vm.ID, h)
if alive { if alive {
return nil return nil
} }
op.stage("stale_vm", vmLogAttrs(vm)...) op.stage("stale_vm", vmLogAttrs(vm)...)
_ = d.cleanupRuntime(ctx, vm, true) _ = d.vmSvc().cleanupRuntime(ctx, vm, true)
vm.State = model.VMStateStopped vm.State = model.VMStateStopped
vm.Runtime.State = model.VMStateStopped vm.Runtime.State = model.VMStateStopped
d.clearVMHandles(vm) d.vmSvc().clearVMHandles(vm)
vm.UpdatedAt = model.Now() vm.UpdatedAt = model.Now()
return d.store.UpsertVM(ctx, vm) return d.store.UpsertVM(ctx, vm)
}); err != nil { }); err != nil {
return op.fail(err, "vm_id", vm.ID) return op.fail(err, "vm_id", vm.ID)
} }
} }
if err := d.rebuildDNS(ctx); err != nil { if err := d.vmSvc().rebuildDNS(ctx); err != nil {
return op.fail(err) return op.fail(err)
} }
op.done() op.done()
return nil return nil
} }
// FindVM stays on Daemon as a thin forwarder to the VM service lookup.
// Dispatch code reads the facade directly; tests that pre-date the
// service split keep compiling.
func (d *Daemon) FindVM(ctx context.Context, idOrName string) (model.VMRecord, error) { func (d *Daemon) FindVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
if idOrName == "" { return d.vmSvc().FindVM(ctx, idOrName)
return model.VMRecord{}, errors.New("vm id or name is required")
}
if vm, err := d.store.GetVM(ctx, idOrName); err == nil {
return vm, nil
}
vms, err := d.store.ListVMs(ctx)
if err != nil {
return model.VMRecord{}, err
}
matchCount := 0
var match model.VMRecord
for _, vm := range vms {
if strings.HasPrefix(vm.ID, idOrName) || strings.HasPrefix(vm.Name, idOrName) {
match = vm
matchCount++
}
}
if matchCount == 1 {
return match, nil
}
if matchCount > 1 {
return model.VMRecord{}, fmt.Errorf("multiple VMs match %q", idOrName)
}
return model.VMRecord{}, fmt.Errorf("vm %q not found", idOrName)
} }
// FindImage stays on Daemon as a thin forwarder to the image service // FindImage stays on Daemon as a thin forwarder to the image service
@ -620,52 +587,7 @@ func (d *Daemon) FindImage(ctx context.Context, idOrName string) (model.Image, e
} }
func (d *Daemon) TouchVM(ctx context.Context, idOrName string) (model.VMRecord, error) { func (d *Daemon) TouchVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { return d.vmSvc().TouchVM(ctx, idOrName)
system.TouchNow(&vm)
if err := d.store.UpsertVM(ctx, vm); err != nil {
return model.VMRecord{}, err
}
return vm, nil
})
}
func (d *Daemon) withVMLockByRef(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) {
vm, err := d.FindVM(ctx, idOrName)
if err != nil {
return model.VMRecord{}, err
}
return d.withVMLockByID(ctx, vm.ID, fn)
}
func (d *Daemon) withVMLockByID(ctx context.Context, id string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) {
if strings.TrimSpace(id) == "" {
return model.VMRecord{}, errors.New("vm id is required")
}
unlock := d.lockVMID(id)
defer unlock()
vm, err := d.store.GetVMByID(ctx, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return model.VMRecord{}, fmt.Errorf("vm %q not found", id)
}
return model.VMRecord{}, err
}
return fn(vm)
}
func (d *Daemon) withVMLockByIDErr(ctx context.Context, id string, fn func(model.VMRecord) error) error {
_, err := d.withVMLockByID(ctx, id, func(vm model.VMRecord) (model.VMRecord, error) {
if err := fn(vm); err != nil {
return model.VMRecord{}, err
}
return vm, nil
})
return err
}
func (d *Daemon) lockVMID(id string) func() {
return d.vmLocks.lock(id)
} }
func marshalResultOrError(v any, err error) rpc.Response { func marshalResultOrError(v any, err error) rpc.Response {

View file

@ -133,7 +133,7 @@ func (d *Daemon) runtimeChecks() *system.Preflight {
checks := system.NewPreflight() checks := system.NewPreflight()
checks.RequireExecutable(d.config.FirecrackerBin, "firecracker binary", `install firecracker or set "firecracker_bin"`) checks.RequireExecutable(d.config.FirecrackerBin, "firecracker binary", `install firecracker or set "firecracker_bin"`)
checks.RequireFile(d.config.SSHKeyPath, "ssh private key", `set "ssh_key_path" or let banger create its default key`) checks.RequireFile(d.config.SSHKeyPath, "ssh private key", `set "ssh_key_path" or let banger create its default key`)
if helper, err := d.vsockAgentBinary(); err == nil { if helper, err := vsockAgentBinary(d.layout); err == nil {
checks.RequireExecutable(helper, "vsock agent helper", `run 'make build' or reinstall banger`) checks.RequireExecutable(helper, "vsock agent helper", `run 'make build' or reinstall banger`)
} else { } else {
checks.Addf("%v", err) checks.Addf("%v", err)
@ -167,13 +167,13 @@ func defaultImageInCatalog(name string) bool {
func (d *Daemon) coreVMLifecycleChecks() *system.Preflight { func (d *Daemon) coreVMLifecycleChecks() *system.Preflight {
checks := system.NewPreflight() checks := system.NewPreflight()
d.addBaseStartCommandPrereqs(checks) d.vmSvc().addBaseStartCommandPrereqs(checks)
return checks return checks
} }
func (d *Daemon) vsockChecks() *system.Preflight { func (d *Daemon) vsockChecks() *system.Preflight {
checks := system.NewPreflight() checks := system.NewPreflight()
if helper, err := d.vsockAgentBinary(); err == nil { if helper, err := vsockAgentBinary(d.layout); err == nil {
checks.RequireExecutable(helper, "vsock agent helper", `run 'make build' or reinstall banger`) checks.RequireExecutable(helper, "vsock agent helper", `run 'make build' or reinstall banger`)
} else { } else {
checks.Addf("%v", err) checks.Addf("%v", err)

View file

@ -39,7 +39,7 @@ func TestEnsureWorkDiskClonesSeedImageAndResizes(t *testing.T) {
image := testImage("image-seeded") image := testImage("image-seeded")
image.WorkSeedPath = seedPath image.WorkSeedPath = seedPath
if _, err := d.ensureWorkDisk(context.Background(), &vm, image); err != nil { if _, err := d.vmSvc().ensureWorkDisk(context.Background(), &vm, image); err != nil {
t.Fatalf("ensureWorkDisk: %v", err) t.Fatalf("ensureWorkDisk: %v", err)
} }
runner.assertExhausted() runner.assertExhausted()

View file

@ -115,7 +115,7 @@ func TestStartVMLockedLogsBridgeFailure(t *testing.T) {
logger: logger, logger: logger,
} }
_, err = d.startVMLocked(ctx, vm, image) _, err = d.vmSvc().startVMLocked(ctx, vm, image)
if err == nil || !strings.Contains(err.Error(), "bridge up failed") { if err == nil || !strings.Contains(err.Error(), "bridge up failed") {
t.Fatalf("startVMLocked() error = %v, want bridge failure", err) t.Fatalf("startVMLocked() error = %v, want bridge failure", err)
} }

View file

@ -21,14 +21,14 @@ import (
const httpProbeTimeout = 750 * time.Millisecond const httpProbeTimeout = 750 * time.Millisecond
func (d *Daemon) PortsVM(ctx context.Context, idOrName string) (result api.VMPortsResult, err error) { func (s *VMService) PortsVM(ctx context.Context, idOrName string) (result api.VMPortsResult, err error) {
_, err = d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { _, err = s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
result.Name = vm.Name result.Name = vm.Name
result.DNSName = strings.TrimSpace(vm.Runtime.DNSName) result.DNSName = strings.TrimSpace(vm.Runtime.DNSName)
if result.DNSName == "" && strings.TrimSpace(vm.Name) != "" { if result.DNSName == "" && strings.TrimSpace(vm.Name) != "" {
result.DNSName = vmdns.RecordName(vm.Name) result.DNSName = vmdns.RecordName(vm.Name)
} }
if !d.vmAlive(vm) { if !s.vmAlive(vm) {
return model.VMRecord{}, fmt.Errorf("vm %s is not running", vm.Name) return model.VMRecord{}, fmt.Errorf("vm %s is not running", vm.Name)
} }
if strings.TrimSpace(vm.Runtime.GuestIP) == "" { if strings.TrimSpace(vm.Runtime.GuestIP) == "" {
@ -40,12 +40,12 @@ func (d *Daemon) PortsVM(ctx context.Context, idOrName string) (result api.VMPor
if vm.Runtime.VSockCID == 0 { if vm.Runtime.VSockCID == 0 {
return model.VMRecord{}, errors.New("vm has no vsock cid") return model.VMRecord{}, errors.New("vm has no vsock cid")
} }
if err := d.hostNet().ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { if err := s.net.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
portsCtx, cancel := context.WithTimeout(ctx, 3*time.Second) portsCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel() defer cancel()
listeners, err := vsockagent.Ports(portsCtx, d.logger, vm.Runtime.VSockPath) listeners, err := vsockagent.Ports(portsCtx, s.logger, vm.Runtime.VSockPath)
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }

View file

@ -10,14 +10,14 @@ import (
var vsockHostDevicePath = "/dev/vhost-vsock" var vsockHostDevicePath = "/dev/vhost-vsock"
func (d *Daemon) validateStartPrereqs(ctx context.Context, vm model.VMRecord, image model.Image) error { func (s *VMService) validateStartPrereqs(ctx context.Context, vm model.VMRecord, image model.Image) error {
checks := system.NewPreflight() checks := system.NewPreflight()
d.addBaseStartPrereqs(checks, image) s.addBaseStartPrereqs(checks, image)
d.addCapabilityStartPrereqs(ctx, checks, vm, image) s.capHooks.addStartPrereqs(ctx, checks, vm, image)
return checks.Err("vm start preflight failed") return checks.Err("vm start preflight failed")
} }
func (d *Daemon) validateWorkDiskResizePrereqs() error { func (s *VMService) validateWorkDiskResizePrereqs() error {
checks := system.NewPreflight() checks := system.NewPreflight()
checks.RequireCommand("truncate", toolHint("truncate")) checks.RequireCommand("truncate", toolHint("truncate"))
checks.RequireCommand("e2fsck", `install e2fsprogs`) checks.RequireCommand("e2fsck", `install e2fsprogs`)
@ -25,10 +25,10 @@ func (d *Daemon) validateWorkDiskResizePrereqs() error {
return checks.Err("work disk resize preflight failed") return checks.Err("work disk resize preflight failed")
} }
func (d *Daemon) addBaseStartPrereqs(checks *system.Preflight, image model.Image) { func (s *VMService) addBaseStartPrereqs(checks *system.Preflight, image model.Image) {
d.addBaseStartCommandPrereqs(checks) s.addBaseStartCommandPrereqs(checks)
checks.RequireExecutable(d.config.FirecrackerBin, "firecracker binary", `install firecracker or set "firecracker_bin"`) checks.RequireExecutable(s.config.FirecrackerBin, "firecracker binary", `install firecracker or set "firecracker_bin"`)
if helper, err := d.vsockAgentBinary(); err == nil { if helper, err := vsockAgentBinary(s.layout); err == nil {
checks.RequireExecutable(helper, "vsock agent helper", `run 'make build' or reinstall banger`) checks.RequireExecutable(helper, "vsock agent helper", `run 'make build' or reinstall banger`)
} else { } else {
checks.Addf("%v", err) checks.Addf("%v", err)
@ -41,7 +41,7 @@ func (d *Daemon) addBaseStartPrereqs(checks *system.Preflight, image model.Image
} }
} }
func (d *Daemon) addBaseStartCommandPrereqs(checks *system.Preflight) { func (s *VMService) addBaseStartCommandPrereqs(checks *system.Preflight) {
for _, command := range []string{"sudo", "ip", "dmsetup", "losetup", "blockdev", "truncate", "pgrep", "chown", "chmod", "kill", "e2cp", "e2rm", "debugfs"} { for _, command := range []string{"sudo", "ip", "dmsetup", "losetup", "blockdev", "truncate", "pgrep", "chown", "chmod", "kill", "e2cp", "e2rm", "debugfs"} {
checks.RequireCommand(command, toolHint(command)) checks.RequireCommand(command, toolHint(command))
} }

View file

@ -6,7 +6,11 @@ import (
"banger/internal/paths" "banger/internal/paths"
) )
func (d *Daemon) vsockAgentBinary() (string, error) { // vsockAgentBinary resolves the companion helper the daemon ships
// alongside its own binary. It's stateless — the signature takes no
// argument so callers on *Daemon / *VMService / doctor all share one
// entry point instead of each owning a forwarder method.
func vsockAgentBinary(_ paths.Layout) (string, error) {
path, err := paths.CompanionBinaryPath("banger-vsock-agent") path, err := paths.CompanionBinaryPath("banger-vsock-agent")
if err != nil { if err != nil {
return "", fmt.Errorf("vsock agent helper not available: %w", err) return "", fmt.Errorf("vsock agent helper not available: %w", err)

View file

@ -26,21 +26,21 @@ var (
) )
// rebuildDNS enumerates live VMs and republishes the DNS record set. // rebuildDNS enumerates live VMs and republishes the DNS record set.
// Lives on *Daemon (not HostNetwork) because "alive" is a VMService // Lives on VMService because "alive" is a VM-state concern that
// concern that HostNetwork shouldn't need to reach into. Daemon // HostNetwork shouldn't need to reach into. VMService orchestrates:
// orchestrates: VM list from the store, alive filter, hand the // VM list from the store, alive filter, hand the resulting map to
// resulting map to HostNetwork.replaceDNS. // HostNetwork.replaceDNS.
func (d *Daemon) rebuildDNS(ctx context.Context) error { func (s *VMService) rebuildDNS(ctx context.Context) error {
if d.net == nil { if s.net == nil {
return nil return nil
} }
vms, err := d.store.ListVMs(ctx) vms, err := s.store.ListVMs(ctx)
if err != nil { if err != nil {
return err return err
} }
records := make(map[string]string) records := make(map[string]string)
for _, vm := range vms { for _, vm := range vms {
if !d.vmAlive(vm) { if !s.vmAlive(vm) {
continue continue
} }
if strings.TrimSpace(vm.Runtime.GuestIP) == "" { if strings.TrimSpace(vm.Runtime.GuestIP) == "" {
@ -48,7 +48,7 @@ func (d *Daemon) rebuildDNS(ctx context.Context) error {
} }
records[vmDNSRecordName(vm.Name)] = vm.Runtime.GuestIP records[vmDNSRecordName(vm.Name)] = vm.Runtime.GuestIP
} }
return d.hostNet().replaceDNS(records) return s.net.replaceDNS(records)
} }
// vmDNSRecordName is a small indirection so the dns-record-name // vmDNSRecordName is a small indirection so the dns-record-name
@ -59,36 +59,37 @@ func vmDNSRecordName(name string) string {
} }
// cleanupRuntime tears down the host-side state for a VM: firecracker // cleanupRuntime tears down the host-side state for a VM: firecracker
// process, DM snapshot, capabilities, tap, sockets. Stays on *Daemon // process, DM snapshot, capabilities, tap, sockets. Lives on VMService
// for now because it reaches into handles (VMService-owned) and // because it reaches into handles (VMService-owned); the capability
// capabilities (still on Daemon). Phase 4 will move it to VMService. // teardown goes through the capHooks seam to keep Daemon out of the
func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserveDisks bool) error { // dependency chain.
if d.logger != nil { func (s *VMService) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserveDisks bool) error {
d.logger.Debug("cleanup runtime", append(vmLogAttrs(vm), "preserve_disks", preserveDisks)...) if s.logger != nil {
s.logger.Debug("cleanup runtime", append(vmLogAttrs(vm), "preserve_disks", preserveDisks)...)
} }
h := d.vmHandles(vm.ID) h := s.vmHandles(vm.ID)
cleanupPID := h.PID cleanupPID := h.PID
if vm.Runtime.APISockPath != "" { if vm.Runtime.APISockPath != "" {
if pid, err := d.hostNet().findFirecrackerPID(ctx, vm.Runtime.APISockPath); err == nil && pid > 0 { if pid, err := s.net.findFirecrackerPID(ctx, vm.Runtime.APISockPath); err == nil && pid > 0 {
cleanupPID = pid cleanupPID = pid
} }
} }
if cleanupPID > 0 && system.ProcessRunning(cleanupPID, vm.Runtime.APISockPath) { if cleanupPID > 0 && system.ProcessRunning(cleanupPID, vm.Runtime.APISockPath) {
_ = d.hostNet().killVMProcess(ctx, cleanupPID) _ = s.net.killVMProcess(ctx, cleanupPID)
if err := d.hostNet().waitForExit(ctx, cleanupPID, vm.Runtime.APISockPath, 30*time.Second); err != nil { if err := s.net.waitForExit(ctx, cleanupPID, vm.Runtime.APISockPath, 30*time.Second); err != nil {
return err return err
} }
} }
snapshotErr := d.hostNet().cleanupDMSnapshot(ctx, dmSnapshotHandles{ snapshotErr := s.net.cleanupDMSnapshot(ctx, dmSnapshotHandles{
BaseLoop: h.BaseLoop, BaseLoop: h.BaseLoop,
COWLoop: h.COWLoop, COWLoop: h.COWLoop,
DMName: h.DMName, DMName: h.DMName,
DMDev: h.DMDev, DMDev: h.DMDev,
}) })
featureErr := d.cleanupCapabilityState(ctx, vm) featureErr := s.capHooks.cleanupState(ctx, vm)
var tapErr error var tapErr error
if h.TapDevice != "" { if h.TapDevice != "" {
tapErr = d.hostNet().releaseTap(ctx, h.TapDevice) tapErr = s.net.releaseTap(ctx, h.TapDevice)
} }
if vm.Runtime.APISockPath != "" { if vm.Runtime.APISockPath != "" {
_ = os.Remove(vm.Runtime.APISockPath) _ = os.Remove(vm.Runtime.APISockPath)
@ -99,14 +100,14 @@ func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserve
// The handles are only meaningful while the kernel objects exist; // The handles are only meaningful while the kernel objects exist;
// dropping them here keeps the cache in sync with reality even // dropping them here keeps the cache in sync with reality even
// when the caller forgets to call clearVMHandles explicitly. // when the caller forgets to call clearVMHandles explicitly.
d.clearVMHandles(vm) s.clearVMHandles(vm)
if !preserveDisks && vm.Runtime.VMDir != "" { if !preserveDisks && vm.Runtime.VMDir != "" {
return errors.Join(snapshotErr, featureErr, tapErr, os.RemoveAll(vm.Runtime.VMDir)) return errors.Join(snapshotErr, featureErr, tapErr, os.RemoveAll(vm.Runtime.VMDir))
} }
return errors.Join(snapshotErr, featureErr, tapErr) return errors.Join(snapshotErr, featureErr, tapErr)
} }
func (d *Daemon) generateName(ctx context.Context) (string, error) { func (s *VMService) generateName(ctx context.Context) (string, error) {
_ = ctx _ = ctx
if name := strings.TrimSpace(namegen.Generate()); name != "" { if name := strings.TrimSpace(namegen.Generate()); name != "" {
return name, nil return name, nil

View file

@ -27,8 +27,8 @@ import (
// won. // won.
// 3. Boot. Only the per-VM lock is held — parallel creates against // 3. Boot. Only the per-VM lock is held — parallel creates against
// different VMs fully overlap. // different VMs fully overlap.
func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm model.VMRecord, err error) { func (s *VMService) CreateVM(ctx context.Context, params api.VMCreateParams) (vm model.VMRecord, err error) {
op := d.beginOperation("vm.create") op := s.beginOperation("vm.create")
defer func() { defer func() {
if err != nil { if err != nil {
op.fail(err) op.fail(err)
@ -45,10 +45,10 @@ func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm mo
imageName := params.ImageName imageName := params.ImageName
if imageName == "" { if imageName == "" {
imageName = d.config.DefaultImageName imageName = s.config.DefaultImageName
} }
vmCreateStage(ctx, "resolve_image", "resolving image") vmCreateStage(ctx, "resolve_image", "resolving image")
image, err := d.findOrAutoPullImage(ctx, imageName) image, err := s.findOrAutoPullImage(ctx, imageName)
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
@ -77,7 +77,7 @@ func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm mo
NATEnabled: params.NATEnabled, NATEnabled: params.NATEnabled,
} }
vm, err = d.reserveVM(ctx, strings.TrimSpace(params.Name), image, spec) vm, err = s.reserveVM(ctx, strings.TrimSpace(params.Name), image, spec)
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
@ -85,31 +85,31 @@ func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm mo
vmCreateBindVM(ctx, vm) vmCreateBindVM(ctx, vm)
vmCreateStage(ctx, "reserve_vm", fmt.Sprintf("allocated %s (%s)", vm.Name, vm.Runtime.GuestIP)) vmCreateStage(ctx, "reserve_vm", fmt.Sprintf("allocated %s (%s)", vm.Name, vm.Runtime.GuestIP))
unlockVM := d.lockVMID(vm.ID) unlockVM := s.lockVMID(vm.ID)
defer unlockVM() defer unlockVM()
if params.NoStart { if params.NoStart {
vm.State = model.VMStateStopped vm.State = model.VMStateStopped
vm.Runtime.State = model.VMStateStopped vm.Runtime.State = model.VMStateStopped
if err := d.store.UpsertVM(ctx, vm); err != nil { if err := s.store.UpsertVM(ctx, vm); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
return vm, nil return vm, nil
} }
return d.startVMLocked(ctx, vm, image) return s.startVMLocked(ctx, vm, image)
} }
// reserveVM holds createVMMu only long enough to verify the name is // reserveVM holds createVMMu only long enough to verify the name is
// free, allocate a guest IP from the store, and persist the "created" // free, allocate a guest IP from the store, and persist the "created"
// reservation row. Everything else (image resolution upstream, boot // reservation row. Everything else (image resolution upstream, boot
// downstream) runs outside this lock. // downstream) runs outside this lock.
func (d *Daemon) reserveVM(ctx context.Context, requestedName string, image model.Image, spec model.VMSpec) (model.VMRecord, error) { func (s *VMService) reserveVM(ctx context.Context, requestedName string, image model.Image, spec model.VMSpec) (model.VMRecord, error) {
d.createVMMu.Lock() s.createVMMu.Lock()
defer d.createVMMu.Unlock() defer s.createVMMu.Unlock()
name := requestedName name := requestedName
if name == "" { if name == "" {
generated, err := d.generateName(ctx) generated, err := s.generateName(ctx)
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
@ -118,7 +118,7 @@ func (d *Daemon) reserveVM(ctx context.Context, requestedName string, image mode
// Exact-name lookup. Using FindVM here would also match a new name // Exact-name lookup. Using FindVM here would also match a new name
// that merely prefixes some existing VM's id or another VM's name, // that merely prefixes some existing VM's id or another VM's name,
// falsely rejecting perfectly valid names. // falsely rejecting perfectly valid names.
if _, err := d.store.GetVMByName(ctx, name); err == nil { if _, err := s.store.GetVMByName(ctx, name); err == nil {
return model.VMRecord{}, fmt.Errorf("vm name already exists: %s", name) return model.VMRecord{}, fmt.Errorf("vm name already exists: %s", name)
} else if !errors.Is(err, sql.ErrNoRows) { } else if !errors.Is(err, sql.ErrNoRows) {
return model.VMRecord{}, err return model.VMRecord{}, err
@ -128,11 +128,11 @@ func (d *Daemon) reserveVM(ctx context.Context, requestedName string, image mode
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
guestIP, err := d.store.NextGuestIP(ctx, bridgePrefix(d.config.BridgeIP)) guestIP, err := s.store.NextGuestIP(ctx, bridgePrefix(s.config.BridgeIP))
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
vmDir := filepath.Join(d.layout.VMsDir, id) vmDir := filepath.Join(s.layout.VMsDir, id)
if err := os.MkdirAll(vmDir, 0o755); err != nil { if err := os.MkdirAll(vmDir, 0o755); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
@ -155,7 +155,7 @@ func (d *Daemon) reserveVM(ctx context.Context, requestedName string, image mode
GuestIP: guestIP, GuestIP: guestIP,
DNSName: vmdns.RecordName(name), DNSName: vmdns.RecordName(name),
VMDir: vmDir, VMDir: vmDir,
VSockPath: defaultVSockPath(d.layout.RuntimeDir, id), VSockPath: defaultVSockPath(s.layout.RuntimeDir, id),
VSockCID: vsockCID, VSockCID: vsockCID,
SystemOverlay: filepath.Join(vmDir, "system.cow"), SystemOverlay: filepath.Join(vmDir, "system.cow"),
WorkDiskPath: filepath.Join(vmDir, "root.ext4"), WorkDiskPath: filepath.Join(vmDir, "root.ext4"),
@ -163,7 +163,7 @@ func (d *Daemon) reserveVM(ctx context.Context, requestedName string, image mode
MetricsPath: filepath.Join(vmDir, "metrics.json"), MetricsPath: filepath.Join(vmDir, "metrics.json"),
}, },
} }
if err := d.store.UpsertVM(ctx, vm); err != nil { if err := s.store.UpsertVM(ctx, vm); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
return vm, nil return vm, nil
@ -174,8 +174,8 @@ func (d *Daemon) reserveVM(ctx context.Context, requestedName string, image mode
// catalog, it auto-pulls the bundle so `vm create --image foo` (and // catalog, it auto-pulls the bundle so `vm create --image foo` (and
// therefore `vm run`) works on a fresh host without the user having // therefore `vm run`) works on a fresh host without the user having
// to run `image pull` first. // to run `image pull` first.
func (d *Daemon) findOrAutoPullImage(ctx context.Context, idOrName string) (model.Image, error) { func (s *VMService) findOrAutoPullImage(ctx context.Context, idOrName string) (model.Image, error) {
image, err := d.imageSvc().FindImage(ctx, idOrName) image, err := s.img.FindImage(ctx, idOrName)
if err == nil { if err == nil {
return image, nil return image, nil
} }
@ -189,8 +189,8 @@ func (d *Daemon) findOrAutoPullImage(ctx context.Context, idOrName string) (mode
return model.Image{}, err return model.Image{}, err
} }
vmCreateStage(ctx, "auto_pull_image", fmt.Sprintf("pulling %s from image catalog", entry.Name)) vmCreateStage(ctx, "auto_pull_image", fmt.Sprintf("pulling %s from image catalog", entry.Name))
if _, pullErr := d.imageSvc().PullImage(ctx, api.ImagePullParams{Ref: entry.Name}); pullErr != nil { if _, pullErr := s.img.PullImage(ctx, api.ImagePullParams{Ref: entry.Name}); pullErr != nil {
return model.Image{}, fmt.Errorf("auto-pull image %q: %w", entry.Name, pullErr) return model.Image{}, fmt.Errorf("auto-pull image %q: %w", entry.Name, pullErr)
} }
return d.imageSvc().FindImage(ctx, idOrName) return s.img.FindImage(ctx, idOrName)
} }

View file

@ -146,20 +146,20 @@ func (op *vmCreateOperationState) cancelOperation() {
} }
} }
func (d *Daemon) BeginVMCreate(_ context.Context, params api.VMCreateParams) (api.VMCreateOperation, error) { func (s *VMService) BeginVMCreate(_ context.Context, params api.VMCreateParams) (api.VMCreateOperation, error) {
op, err := newVMCreateOperationState() op, err := newVMCreateOperationState()
if err != nil { if err != nil {
return api.VMCreateOperation{}, err return api.VMCreateOperation{}, err
} }
createCtx, cancel := context.WithCancel(context.Background()) createCtx, cancel := context.WithCancel(context.Background())
op.setCancel(cancel) op.setCancel(cancel)
d.createOps.Insert(op) s.createOps.Insert(op)
go d.runVMCreateOperation(withVMCreateProgress(createCtx, op), op, params) go s.runVMCreateOperation(withVMCreateProgress(createCtx, op), op, params)
return op.snapshot(), nil return op.snapshot(), nil
} }
func (d *Daemon) runVMCreateOperation(ctx context.Context, op *vmCreateOperationState, params api.VMCreateParams) { func (s *VMService) runVMCreateOperation(ctx context.Context, op *vmCreateOperationState, params api.VMCreateParams) {
vm, err := d.CreateVM(ctx, params) vm, err := s.CreateVM(ctx, params)
if err != nil { if err != nil {
op.fail(err) op.fail(err)
return return
@ -167,16 +167,16 @@ func (d *Daemon) runVMCreateOperation(ctx context.Context, op *vmCreateOperation
op.done(vm) op.done(vm)
} }
func (d *Daemon) VMCreateStatus(_ context.Context, id string) (api.VMCreateOperation, error) { func (s *VMService) VMCreateStatus(_ context.Context, id string) (api.VMCreateOperation, error) {
op, ok := d.createOps.Get(strings.TrimSpace(id)) op, ok := s.createOps.Get(strings.TrimSpace(id))
if !ok { if !ok {
return api.VMCreateOperation{}, fmt.Errorf("vm create operation not found: %s", id) return api.VMCreateOperation{}, fmt.Errorf("vm create operation not found: %s", id)
} }
return op.snapshot(), nil return op.snapshot(), nil
} }
func (d *Daemon) CancelVMCreate(_ context.Context, id string) error { func (s *VMService) CancelVMCreate(_ context.Context, id string) error {
op, ok := d.createOps.Get(strings.TrimSpace(id)) op, ok := s.createOps.Get(strings.TrimSpace(id))
if !ok { if !ok {
return fmt.Errorf("vm create operation not found: %s", id) return fmt.Errorf("vm create operation not found: %s", id)
} }
@ -184,6 +184,6 @@ func (d *Daemon) CancelVMCreate(_ context.Context, id string) error {
return nil return nil
} }
func (d *Daemon) pruneVMCreateOperations(olderThan time.Time) { func (s *VMService) pruneVMCreateOperations(olderThan time.Time) {
d.createOps.Prune(olderThan) s.createOps.Prune(olderThan)
} }

View file

@ -41,14 +41,14 @@ func TestReserveVMAllowsNameThatPrefixesExistingVM(t *testing.T) {
// New VM name is a prefix of the existing id (which is // New VM name is a prefix of the existing id (which is
// "longname-sandbox-foobar-id" per testVM). Old FindVM-based check // "longname-sandbox-foobar-id" per testVM). Old FindVM-based check
// would reject this. // would reject this.
if vm, err := d.reserveVM(ctx, "longname", image, model.VMSpec{VCPUCount: 1, MemoryMiB: 128}); err != nil { if vm, err := d.vmSvc().reserveVM(ctx, "longname", image, model.VMSpec{VCPUCount: 1, MemoryMiB: 128}); err != nil {
t.Fatalf("reserveVM(prefix of id): %v", err) t.Fatalf("reserveVM(prefix of id): %v", err)
} else if vm.Name != "longname" { } else if vm.Name != "longname" {
t.Fatalf("reserveVM returned name=%q, want longname", vm.Name) t.Fatalf("reserveVM returned name=%q, want longname", vm.Name)
} }
// Prefix of the existing name ("longname-sandbox") must also work. // Prefix of the existing name ("longname-sandbox") must also work.
if vm, err := d.reserveVM(ctx, "longname-sandbox", image, model.VMSpec{VCPUCount: 1, MemoryMiB: 128}); err != nil { if vm, err := d.vmSvc().reserveVM(ctx, "longname-sandbox", image, model.VMSpec{VCPUCount: 1, MemoryMiB: 128}); err != nil {
t.Fatalf("reserveVM(prefix of name): %v", err) t.Fatalf("reserveVM(prefix of name): %v", err)
} else if vm.Name != "longname-sandbox" { } else if vm.Name != "longname-sandbox" {
t.Fatalf("reserveVM returned name=%q, want longname-sandbox", vm.Name) t.Fatalf("reserveVM returned name=%q, want longname-sandbox", vm.Name)
@ -76,7 +76,7 @@ func TestReserveVMRejectsExactDuplicateName(t *testing.T) {
t.Fatalf("UpsertImage: %v", err) t.Fatalf("UpsertImage: %v", err)
} }
_, err := d.reserveVM(ctx, "sandbox", image, model.VMSpec{VCPUCount: 1, MemoryMiB: 128}) _, err := d.vmSvc().reserveVM(ctx, "sandbox", image, model.VMSpec{VCPUCount: 1, MemoryMiB: 128})
if err == nil { if err == nil {
t.Fatal("reserveVM with duplicate name should have failed") t.Fatal("reserveVM with duplicate name should have failed")
} }

View file

@ -18,11 +18,11 @@ type workDiskPreparation struct {
ClonedFromSeed bool ClonedFromSeed bool
} }
func (d *Daemon) ensureSystemOverlay(ctx context.Context, vm *model.VMRecord) error { func (s *VMService) ensureSystemOverlay(ctx context.Context, vm *model.VMRecord) error {
if exists(vm.Runtime.SystemOverlay) { if exists(vm.Runtime.SystemOverlay) {
return nil return nil
} }
_, err := d.runner.Run(ctx, "truncate", "-s", strconv.FormatInt(vm.Spec.SystemOverlaySizeByte, 10), vm.Runtime.SystemOverlay) _, err := s.runner.Run(ctx, "truncate", "-s", strconv.FormatInt(vm.Spec.SystemOverlaySizeByte, 10), vm.Runtime.SystemOverlay)
return err return err
} }
@ -30,16 +30,16 @@ func (d *Daemon) ensureSystemOverlay(ctx context.Context, vm *model.VMRecord) er
// hostname, hosts, sshd drop-in, network bootstrap, fstab) into the // hostname, hosts, sshd drop-in, network bootstrap, fstab) into the
// rootfs overlay. Reads the DM device path from the handle cache, // rootfs overlay. Reads the DM device path from the handle cache,
// which the start flow populates before calling this. // which the start flow populates before calling this.
func (d *Daemon) patchRootOverlay(ctx context.Context, vm model.VMRecord, image model.Image) error { func (s *VMService) patchRootOverlay(ctx context.Context, vm model.VMRecord, image model.Image) error {
dmDev := d.vmHandles(vm.ID).DMDev dmDev := s.vmHandles(vm.ID).DMDev
if dmDev == "" { if dmDev == "" {
return fmt.Errorf("vm %q: DM device not in handle cache — start flow out of order?", vm.ID) return fmt.Errorf("vm %q: DM device not in handle cache — start flow out of order?", vm.ID)
} }
resolv := []byte(fmt.Sprintf("nameserver %s\n", d.config.DefaultDNS)) resolv := []byte(fmt.Sprintf("nameserver %s\n", s.config.DefaultDNS))
hostname := []byte(vm.Name + "\n") hostname := []byte(vm.Name + "\n")
hosts := []byte(fmt.Sprintf("127.0.0.1 localhost\n127.0.1.1 %s\n", vm.Name)) hosts := []byte(fmt.Sprintf("127.0.0.1 localhost\n127.0.1.1 %s\n", vm.Name))
sshdConfig := []byte(sshdGuestConfig()) sshdConfig := []byte(sshdGuestConfig())
fstab, err := system.ReadDebugFSText(ctx, d.runner, dmDev, "/etc/fstab") fstab, err := system.ReadDebugFSText(ctx, s.runner, dmDev, "/etc/fstab")
if err != nil { if err != nil {
fstab = "" fstab = ""
} }
@ -47,7 +47,7 @@ func (d *Daemon) patchRootOverlay(ctx context.Context, vm model.VMRecord, image
builder.WriteFile("/etc/resolv.conf", resolv) builder.WriteFile("/etc/resolv.conf", resolv)
builder.WriteFile("/etc/hostname", hostname) builder.WriteFile("/etc/hostname", hostname)
builder.WriteFile("/etc/hosts", hosts) builder.WriteFile("/etc/hosts", hosts)
builder.WriteFile(guestnet.ConfigPath, guestnet.ConfigFile(vm.Runtime.GuestIP, d.config.BridgeIP, d.config.DefaultDNS)) builder.WriteFile(guestnet.ConfigPath, guestnet.ConfigFile(vm.Runtime.GuestIP, s.config.BridgeIP, s.config.DefaultDNS))
builder.WriteFile(guestnet.GuestScriptPath, []byte(guestnet.BootstrapScript())) builder.WriteFile(guestnet.GuestScriptPath, []byte(guestnet.BootstrapScript()))
builder.WriteFile("/etc/ssh/sshd_config.d/99-banger.conf", sshdConfig) builder.WriteFile("/etc/ssh/sshd_config.d/99-banger.conf", sshdConfig)
builder.DropMountTarget("/home") builder.DropMountTarget("/home")
@ -68,25 +68,25 @@ func (d *Daemon) patchRootOverlay(ctx context.Context, vm model.VMRecord, image
Dump: 0, Dump: 0,
Pass: 0, Pass: 0,
}) })
d.contributeGuestConfig(builder, vm, image) s.capHooks.contributeGuest(builder, vm, image)
builder.WriteFile("/etc/fstab", []byte(builder.RenderFSTab(fstab))) builder.WriteFile("/etc/fstab", []byte(builder.RenderFSTab(fstab)))
files := builder.Files() files := builder.Files()
for _, guestPath := range builder.FilePaths() { for _, guestPath := range builder.FilePaths() {
data := files[guestPath] data := files[guestPath]
if guestPath == guestnet.GuestScriptPath { if guestPath == guestnet.GuestScriptPath {
if err := system.WriteExt4FileMode(ctx, d.runner, dmDev, guestPath, 0o755, data); err != nil { if err := system.WriteExt4FileMode(ctx, s.runner, dmDev, guestPath, 0o755, data); err != nil {
return err return err
} }
continue continue
} }
if err := system.WriteExt4File(ctx, d.runner, dmDev, guestPath, data); err != nil { if err := system.WriteExt4File(ctx, s.runner, dmDev, guestPath, data); err != nil {
return err return err
} }
} }
return nil return nil
} }
func (d *Daemon) ensureWorkDisk(ctx context.Context, vm *model.VMRecord, image model.Image) (workDiskPreparation, error) { func (s *VMService) ensureWorkDisk(ctx context.Context, vm *model.VMRecord, image model.Image) (workDiskPreparation, error) {
if exists(vm.Runtime.WorkDiskPath) { if exists(vm.Runtime.WorkDiskPath) {
return workDiskPreparation{}, nil return workDiskPreparation{}, nil
} }
@ -104,38 +104,38 @@ func (d *Daemon) ensureWorkDisk(ctx context.Context, vm *model.VMRecord, image m
} }
if vm.Spec.WorkDiskSizeBytes > seedInfo.Size() { if vm.Spec.WorkDiskSizeBytes > seedInfo.Size() {
vmCreateStage(ctx, "prepare_work_disk", "resizing work disk") vmCreateStage(ctx, "prepare_work_disk", "resizing work disk")
if err := system.ResizeExt4Image(ctx, d.runner, vm.Runtime.WorkDiskPath, vm.Spec.WorkDiskSizeBytes); err != nil { if err := system.ResizeExt4Image(ctx, s.runner, vm.Runtime.WorkDiskPath, vm.Spec.WorkDiskSizeBytes); err != nil {
return workDiskPreparation{}, err return workDiskPreparation{}, err
} }
} }
return workDiskPreparation{ClonedFromSeed: true}, nil return workDiskPreparation{ClonedFromSeed: true}, nil
} }
vmCreateStage(ctx, "prepare_work_disk", "creating empty work disk") vmCreateStage(ctx, "prepare_work_disk", "creating empty work disk")
if _, err := d.runner.Run(ctx, "truncate", "-s", strconv.FormatInt(vm.Spec.WorkDiskSizeBytes, 10), vm.Runtime.WorkDiskPath); err != nil { if _, err := s.runner.Run(ctx, "truncate", "-s", strconv.FormatInt(vm.Spec.WorkDiskSizeBytes, 10), vm.Runtime.WorkDiskPath); err != nil {
return workDiskPreparation{}, err return workDiskPreparation{}, err
} }
if _, err := d.runner.Run(ctx, "mkfs.ext4", "-F", vm.Runtime.WorkDiskPath); err != nil { if _, err := s.runner.Run(ctx, "mkfs.ext4", "-F", vm.Runtime.WorkDiskPath); err != nil {
return workDiskPreparation{}, err return workDiskPreparation{}, err
} }
dmDev := d.vmHandles(vm.ID).DMDev dmDev := s.vmHandles(vm.ID).DMDev
if dmDev == "" { if dmDev == "" {
return workDiskPreparation{}, fmt.Errorf("vm %q: DM device not in handle cache", vm.ID) return workDiskPreparation{}, fmt.Errorf("vm %q: DM device not in handle cache", vm.ID)
} }
rootMount, cleanupRoot, err := system.MountTempDir(ctx, d.runner, dmDev, true) rootMount, cleanupRoot, err := system.MountTempDir(ctx, s.runner, dmDev, true)
if err != nil { if err != nil {
return workDiskPreparation{}, err return workDiskPreparation{}, err
} }
defer cleanupRoot() defer cleanupRoot()
workMount, cleanupWork, err := system.MountTempDir(ctx, d.runner, vm.Runtime.WorkDiskPath, false) workMount, cleanupWork, err := system.MountTempDir(ctx, s.runner, vm.Runtime.WorkDiskPath, false)
if err != nil { if err != nil {
return workDiskPreparation{}, err return workDiskPreparation{}, err
} }
defer cleanupWork() defer cleanupWork()
vmCreateStage(ctx, "prepare_work_disk", "copying /root into work disk") vmCreateStage(ctx, "prepare_work_disk", "copying /root into work disk")
if err := system.CopyDirContents(ctx, d.runner, filepath.Join(rootMount, "root"), workMount, true); err != nil { if err := system.CopyDirContents(ctx, s.runner, filepath.Join(rootMount, "root"), workMount, true); err != nil {
return workDiskPreparation{}, err return workDiskPreparation{}, err
} }
if err := d.flattenNestedWorkHome(ctx, workMount); err != nil { if err := flattenNestedWorkHome(ctx, s.runner, workMount); err != nil {
return workDiskPreparation{}, err return workDiskPreparation{}, err
} }
return workDiskPreparation{}, nil return workDiskPreparation{}, nil
@ -214,10 +214,3 @@ func flattenNestedWorkHome(ctx context.Context, runner system.CommandRunner, wor
_, err = runner.RunSudo(ctx, "rm", "-rf", nestedHome) _, err = runner.RunSudo(ctx, "rm", "-rf", nestedHome)
return err return err
} }
// Deprecated forwarder: until every caller learns the package-level
// helper, Daemon keeps a receiver-method form. Will be deleted once
// the last caller is rewritten.
func (d *Daemon) flattenNestedWorkHome(ctx context.Context, workMount string) error {
return flattenNestedWorkHome(ctx, d.runner, workMount)
}

View file

@ -105,57 +105,57 @@ func removeHandlesFile(vmDir string) {
// ensureHandleCache lazily constructs the cache so direct // ensureHandleCache lazily constructs the cache so direct
// `&Daemon{}` literals (common in tests) don't have to initialise // `&Daemon{}` literals (common in tests) don't have to initialise
// it. Production code goes through Open(), which also builds it. // it. Production code goes through Open(), which also builds it.
func (d *Daemon) ensureHandleCache() { func (s *VMService) ensureHandleCache() {
if d.handles == nil { if s.handles == nil {
d.handles = newHandleCache() s.handles = newHandleCache()
} }
} }
// setVMHandlesInMemory is a test-only cache seed that skips the // setVMHandlesInMemory is a test-only cache seed that skips the
// scratch-file write. Production callers should use setVMHandles so // scratch-file write. Production callers should use setVMHandles so
// the filesystem survives a daemon restart. // the filesystem survives a daemon restart.
func (d *Daemon) setVMHandlesInMemory(vmID string, h model.VMHandles) { func (s *VMService) setVMHandlesInMemory(vmID string, h model.VMHandles) {
if d == nil { if s == nil {
return return
} }
d.ensureHandleCache() s.ensureHandleCache()
d.handles.set(vmID, h) s.handles.set(vmID, h)
} }
// vmHandles returns the cached handles for vm (zero-value if no // vmHandles returns the cached handles for vm (zero-value if no
// entry). Call sites that previously read `vm.Runtime.{PID,...}` // entry). Call sites that previously read `vm.Runtime.{PID,...}`
// should read through this instead. // should read through this instead.
func (d *Daemon) vmHandles(vmID string) model.VMHandles { func (s *VMService) vmHandles(vmID string) model.VMHandles {
if d == nil { if s == nil {
return model.VMHandles{} return model.VMHandles{}
} }
d.ensureHandleCache() s.ensureHandleCache()
h, _ := d.handles.get(vmID) h, _ := s.handles.get(vmID)
return h return h
} }
// setVMHandles updates the in-memory cache AND the per-VM scratch // setVMHandles updates the in-memory cache AND the per-VM scratch
// file. Scratch-file errors are logged but not returned; the cache // file. Scratch-file errors are logged but not returned; the cache
// write is authoritative while the daemon is alive. // write is authoritative while the daemon is alive.
func (d *Daemon) setVMHandles(vm model.VMRecord, h model.VMHandles) { func (s *VMService) setVMHandles(vm model.VMRecord, h model.VMHandles) {
if d == nil { if s == nil {
return return
} }
d.ensureHandleCache() s.ensureHandleCache()
d.handles.set(vm.ID, h) s.handles.set(vm.ID, h)
if err := writeHandlesFile(vm.Runtime.VMDir, h); err != nil && d.logger != nil { if err := writeHandlesFile(vm.Runtime.VMDir, h); err != nil && s.logger != nil {
d.logger.Warn("persist handles.json failed", "vm_id", vm.ID, "error", err.Error()) s.logger.Warn("persist handles.json failed", "vm_id", vm.ID, "error", err.Error())
} }
} }
// clearVMHandles drops the cache entry and removes the scratch // clearVMHandles drops the cache entry and removes the scratch
// file. Called on stop / delete / after a failed start. // file. Called on stop / delete / after a failed start.
func (d *Daemon) clearVMHandles(vm model.VMRecord) { func (s *VMService) clearVMHandles(vm model.VMRecord) {
if d == nil { if s == nil {
return return
} }
d.ensureHandleCache() s.ensureHandleCache()
d.handles.clear(vm.ID) s.handles.clear(vm.ID)
removeHandlesFile(vm.Runtime.VMDir) removeHandlesFile(vm.Runtime.VMDir)
} }
@ -164,11 +164,11 @@ func (d *Daemon) clearVMHandles(vm model.VMRecord) {
// pattern, this reads the PID from the handle cache — which is // pattern, this reads the PID from the handle cache — which is
// authoritative in-process — and verifies the PID against the api // authoritative in-process — and verifies the PID against the api
// socket so a recycled PID can't false-positive. // socket so a recycled PID can't false-positive.
func (d *Daemon) vmAlive(vm model.VMRecord) bool { func (s *VMService) vmAlive(vm model.VMRecord) bool {
if vm.State != model.VMStateRunning { if vm.State != model.VMStateRunning {
return false return false
} }
h := d.vmHandles(vm.ID) h := s.vmHandles(vm.ID)
if h.PID <= 0 { if h.PID <= 0 {
return false return false
} }
@ -191,7 +191,7 @@ func (d *Daemon) vmAlive(vm model.VMRecord) bool {
// the daemon crashed but the PID changed on respawn — unlikely for // the daemon crashed but the PID changed on respawn — unlikely for
// firecracker, but cheap insurance); fall back to verifying the // firecracker, but cheap insurance); fall back to verifying the
// scratch file's PID directly. // scratch file's PID directly.
func (d *Daemon) rediscoverHandles(ctx context.Context, vm model.VMRecord) (model.VMHandles, bool, error) { func (s *VMService) rediscoverHandles(ctx context.Context, vm model.VMRecord) (model.VMHandles, bool, error) {
saved, _, err := readHandlesFile(vm.Runtime.VMDir) saved, _, err := readHandlesFile(vm.Runtime.VMDir)
if err != nil { if err != nil {
return model.VMHandles{}, false, err return model.VMHandles{}, false, err
@ -200,7 +200,7 @@ func (d *Daemon) rediscoverHandles(ctx context.Context, vm model.VMRecord) (mode
if apiSock == "" { if apiSock == "" {
return saved, false, nil return saved, false, nil
} }
if pid, pidErr := d.hostNet().findFirecrackerPID(ctx, apiSock); pidErr == nil && pid > 0 { if pid, pidErr := s.net.findFirecrackerPID(ctx, apiSock); pidErr == nil && pid > 0 {
saved.PID = pid saved.PID = pid
return saved, true, nil return saved, true, nil
} }

View file

@ -115,7 +115,7 @@ func TestRediscoverHandlesLoadsScratchWhenProcessDead(t *testing.T) {
vm.Runtime.APISockPath = apiSock vm.Runtime.APISockPath = apiSock
vm.Runtime.VMDir = vmDir vm.Runtime.VMDir = vmDir
got, alive, err := d.rediscoverHandles(context.Background(), vm) got, alive, err := d.vmSvc().rediscoverHandles(context.Background(), vm)
if err != nil { if err != nil {
t.Fatalf("rediscoverHandles: %v", err) t.Fatalf("rediscoverHandles: %v", err)
} }
@ -152,7 +152,7 @@ func TestRediscoverHandlesPrefersLivePIDOverScratch(t *testing.T) {
vm.Runtime.APISockPath = apiSock vm.Runtime.APISockPath = apiSock
vm.Runtime.VMDir = vmDir vm.Runtime.VMDir = vmDir
got, alive, err := d.rediscoverHandles(context.Background(), vm) got, alive, err := d.vmSvc().rediscoverHandles(context.Background(), vm)
if err != nil { if err != nil {
t.Fatalf("rediscoverHandles: %v", err) t.Fatalf("rediscoverHandles: %v", err)
} }
@ -179,13 +179,13 @@ func TestClearVMHandlesRemovesScratchFile(t *testing.T) {
d := &Daemon{} d := &Daemon{}
vm := testVM("sweep", "image-sweep", "172.16.0.252") vm := testVM("sweep", "image-sweep", "172.16.0.252")
vm.Runtime.VMDir = vmDir vm.Runtime.VMDir = vmDir
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: 42}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: 42})
d.clearVMHandles(vm) d.vmSvc().clearVMHandles(vm)
if _, err := os.Stat(handlesFilePath(vmDir)); !os.IsNotExist(err) { if _, err := os.Stat(handlesFilePath(vmDir)); !os.IsNotExist(err) {
t.Fatalf("scratch file still present: %v", err) t.Fatalf("scratch file still present: %v", err)
} }
if h, ok := d.handles.get(vm.ID); ok && !h.IsZero() { if h, ok := d.vmSvc().handles.get(vm.ID); ok && !h.IsZero() {
t.Fatalf("cache entry survives clear: %+v", h) t.Fatalf("cache entry survives clear: %+v", h)
} }
} }

View file

@ -16,24 +16,24 @@ import (
"banger/internal/system" "banger/internal/system"
) )
func (d *Daemon) StartVM(ctx context.Context, idOrName string) (model.VMRecord, error) { func (s *VMService) StartVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { return s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
image, err := d.store.GetImageByID(ctx, vm.ImageID) image, err := s.store.GetImageByID(ctx, vm.ImageID)
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
if d.vmAlive(vm) { if s.vmAlive(vm) {
if d.logger != nil { if s.logger != nil {
d.logger.Info("vm already running", vmLogAttrs(vm)...) s.logger.Info("vm already running", vmLogAttrs(vm)...)
} }
return vm, nil return vm, nil
} }
return d.startVMLocked(ctx, vm, image) return s.startVMLocked(ctx, vm, image)
}) })
} }
func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image model.Image) (_ model.VMRecord, err error) { func (s *VMService) startVMLocked(ctx context.Context, vm model.VMRecord, image model.Image) (_ model.VMRecord, err error) {
op := d.beginOperation("vm.start", append(vmLogAttrs(vm), imageLogAttrs(image)...)...) op := s.beginOperation("vm.start", append(vmLogAttrs(vm), imageLogAttrs(image)...)...)
defer func() { defer func() {
if err != nil { if err != nil {
err = annotateLogPath(err, vm.Runtime.LogPath) err = annotateLogPath(err, vm.Runtime.LogPath)
@ -44,32 +44,32 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
}() }()
op.stage("preflight") op.stage("preflight")
vmCreateStage(ctx, "preflight", "checking host prerequisites") vmCreateStage(ctx, "preflight", "checking host prerequisites")
if err := d.validateStartPrereqs(ctx, vm, image); err != nil { if err := s.validateStartPrereqs(ctx, vm, image); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
if err := os.MkdirAll(vm.Runtime.VMDir, 0o755); err != nil { if err := os.MkdirAll(vm.Runtime.VMDir, 0o755); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
op.stage("cleanup_runtime") op.stage("cleanup_runtime")
if err := d.cleanupRuntime(ctx, vm, true); err != nil { if err := s.cleanupRuntime(ctx, vm, true); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
d.clearVMHandles(vm) s.clearVMHandles(vm)
op.stage("bridge") op.stage("bridge")
if err := d.hostNet().ensureBridge(ctx); err != nil { if err := s.net.ensureBridge(ctx); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
op.stage("socket_dir") op.stage("socket_dir")
if err := d.hostNet().ensureSocketDir(); err != nil { if err := s.net.ensureSocketDir(); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
shortID := system.ShortID(vm.ID) shortID := system.ShortID(vm.ID)
apiSock := filepath.Join(d.layout.RuntimeDir, "fc-"+shortID+".sock") apiSock := filepath.Join(s.layout.RuntimeDir, "fc-"+shortID+".sock")
dmName := "fc-rootfs-" + shortID dmName := "fc-rootfs-" + shortID
tapName := "tap-fc-" + shortID tapName := "tap-fc-" + shortID
if strings.TrimSpace(vm.Runtime.VSockPath) == "" { if strings.TrimSpace(vm.Runtime.VSockPath) == "" {
vm.Runtime.VSockPath = defaultVSockPath(d.layout.RuntimeDir, vm.ID) vm.Runtime.VSockPath = defaultVSockPath(s.layout.RuntimeDir, vm.ID)
} }
if vm.Runtime.VSockCID == 0 { if vm.Runtime.VSockCID == 0 {
vm.Runtime.VSockCID, err = defaultVSockCID(vm.Runtime.GuestIP) vm.Runtime.VSockCID, err = defaultVSockCID(vm.Runtime.GuestIP)
@ -86,13 +86,13 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
op.stage("system_overlay", "overlay_path", vm.Runtime.SystemOverlay) op.stage("system_overlay", "overlay_path", vm.Runtime.SystemOverlay)
vmCreateStage(ctx, "prepare_rootfs", "preparing system overlay") vmCreateStage(ctx, "prepare_rootfs", "preparing system overlay")
if err := d.ensureSystemOverlay(ctx, &vm); err != nil { if err := s.ensureSystemOverlay(ctx, &vm); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
op.stage("dm_snapshot", "dm_name", dmName) op.stage("dm_snapshot", "dm_name", dmName)
vmCreateStage(ctx, "prepare_rootfs", "creating root filesystem snapshot") vmCreateStage(ctx, "prepare_rootfs", "creating root filesystem snapshot")
snapHandles, err := d.hostNet().createDMSnapshot(ctx, image.RootfsPath, vm.Runtime.SystemOverlay, dmName) snapHandles, err := s.net.createDMSnapshot(ctx, image.RootfsPath, vm.Runtime.SystemOverlay, dmName)
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
@ -107,7 +107,7 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
DMName: snapHandles.DMName, DMName: snapHandles.DMName,
DMDev: snapHandles.DMDev, DMDev: snapHandles.DMDev,
} }
d.setVMHandles(vm, live) s.setVMHandles(vm, live)
vm.Runtime.APISockPath = apiSock vm.Runtime.APISockPath = apiSock
vm.Runtime.State = model.VMStateRunning vm.Runtime.State = model.VMStateRunning
@ -119,38 +119,38 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
vm.Runtime.State = model.VMStateError vm.Runtime.State = model.VMStateError
vm.Runtime.LastError = err.Error() vm.Runtime.LastError = err.Error()
op.stage("cleanup_after_failure", "error", err.Error()) op.stage("cleanup_after_failure", "error", err.Error())
if cleanupErr := d.cleanupRuntime(context.Background(), vm, true); cleanupErr != nil { if cleanupErr := s.cleanupRuntime(context.Background(), vm, true); cleanupErr != nil {
err = errors.Join(err, cleanupErr) err = errors.Join(err, cleanupErr)
} }
d.clearVMHandles(vm) s.clearVMHandles(vm)
_ = d.store.UpsertVM(context.Background(), vm) _ = s.store.UpsertVM(context.Background(), vm)
return model.VMRecord{}, err return model.VMRecord{}, err
} }
op.stage("patch_root_overlay") op.stage("patch_root_overlay")
vmCreateStage(ctx, "prepare_rootfs", "writing guest configuration") vmCreateStage(ctx, "prepare_rootfs", "writing guest configuration")
if err := d.patchRootOverlay(ctx, vm, image); err != nil { if err := s.patchRootOverlay(ctx, vm, image); err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
op.stage("prepare_host_features") op.stage("prepare_host_features")
vmCreateStage(ctx, "prepare_host_features", "preparing host-side vm features") vmCreateStage(ctx, "prepare_host_features", "preparing host-side vm features")
if err := d.prepareCapabilityHosts(ctx, &vm, image); err != nil { if err := s.capHooks.prepareHosts(ctx, &vm, image); err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
op.stage("tap") op.stage("tap")
tap, err := d.hostNet().acquireTap(ctx, tapName) tap, err := s.net.acquireTap(ctx, tapName)
if err != nil { if err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
live.TapDevice = tap live.TapDevice = tap
d.setVMHandles(vm, live) s.setVMHandles(vm, live)
op.stage("metrics_file", "metrics_path", vm.Runtime.MetricsPath) op.stage("metrics_file", "metrics_path", vm.Runtime.MetricsPath)
if err := os.WriteFile(vm.Runtime.MetricsPath, nil, 0o644); err != nil { if err := os.WriteFile(vm.Runtime.MetricsPath, nil, 0o644); err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
op.stage("firecracker_binary") op.stage("firecracker_binary")
fcPath, err := d.hostNet().firecrackerBinary() fcPath, err := s.net.firecrackerBinary()
if err != nil { if err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
@ -165,7 +165,7 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
// 2. init= pointing at our universal wrapper which installs // 2. init= pointing at our universal wrapper which installs
// systemd+sshd on first boot if missing. // systemd+sshd on first boot if missing.
kernelArgs = system.BuildBootArgsWithKernelIP( kernelArgs = system.BuildBootArgsWithKernelIP(
vm.Name, vm.Runtime.GuestIP, d.config.BridgeIP, d.config.DefaultDNS, vm.Name, vm.Runtime.GuestIP, s.config.BridgeIP, s.config.DefaultDNS,
) + " init=" + imagepull.FirstBootScriptPath ) + " init=" + imagepull.FirstBootScriptPath
} }
@ -189,9 +189,9 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
VSockCID: vm.Runtime.VSockCID, VSockCID: vm.Runtime.VSockCID,
VCPUCount: vm.Spec.VCPUCount, VCPUCount: vm.Spec.VCPUCount,
MemoryMiB: vm.Spec.MemoryMiB, MemoryMiB: vm.Spec.MemoryMiB,
Logger: d.logger, Logger: s.logger,
} }
d.contributeMachineConfig(&machineConfig, vm, image) s.capHooks.contributeMachine(&machineConfig, vm, image)
machine, err := firecracker.NewMachine(ctx, machineConfig) machine, err := firecracker.NewMachine(ctx, machineConfig)
if err != nil { if err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
@ -200,48 +200,48 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
// Use a fresh context: the request ctx may already be cancelled (client // Use a fresh context: the request ctx may already be cancelled (client
// disconnect), but we still need the PID so cleanupRuntime can kill the // disconnect), but we still need the PID so cleanupRuntime can kill the
// Firecracker process that was spawned before the failure. // Firecracker process that was spawned before the failure.
live.PID = d.hostNet().resolveFirecrackerPID(context.Background(), machine, apiSock) live.PID = s.net.resolveFirecrackerPID(context.Background(), machine, apiSock)
d.setVMHandles(vm, live) s.setVMHandles(vm, live)
return cleanupOnErr(err) return cleanupOnErr(err)
} }
live.PID = d.hostNet().resolveFirecrackerPID(context.Background(), machine, apiSock) live.PID = s.net.resolveFirecrackerPID(context.Background(), machine, apiSock)
d.setVMHandles(vm, live) s.setVMHandles(vm, live)
op.debugStage("firecracker_started", "pid", live.PID) op.debugStage("firecracker_started", "pid", live.PID)
op.stage("socket_access", "api_socket", apiSock) op.stage("socket_access", "api_socket", apiSock)
if err := d.hostNet().ensureSocketAccess(ctx, apiSock, "firecracker api socket"); err != nil { if err := s.net.ensureSocketAccess(ctx, apiSock, "firecracker api socket"); err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
op.stage("vsock_access", "vsock_path", vm.Runtime.VSockPath, "vsock_cid", vm.Runtime.VSockCID) op.stage("vsock_access", "vsock_path", vm.Runtime.VSockPath, "vsock_cid", vm.Runtime.VSockCID)
if err := d.hostNet().ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { if err := s.net.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
vmCreateStage(ctx, "wait_vsock_agent", "waiting for guest vsock agent") vmCreateStage(ctx, "wait_vsock_agent", "waiting for guest vsock agent")
if err := d.hostNet().waitForGuestVSockAgent(ctx, vm.Runtime.VSockPath, vsockReadyWait); err != nil { if err := s.net.waitForGuestVSockAgent(ctx, vm.Runtime.VSockPath, vsockReadyWait); err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
op.stage("post_start_features") op.stage("post_start_features")
vmCreateStage(ctx, "wait_guest_ready", "waiting for guest services") vmCreateStage(ctx, "wait_guest_ready", "waiting for guest services")
if err := d.postStartCapabilities(ctx, vm, image); err != nil { if err := s.capHooks.postStart(ctx, vm, image); err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
system.TouchNow(&vm) system.TouchNow(&vm)
op.stage("persist") op.stage("persist")
vmCreateStage(ctx, "finalize", "saving vm state") vmCreateStage(ctx, "finalize", "saving vm state")
if err := d.store.UpsertVM(ctx, vm); err != nil { if err := s.store.UpsertVM(ctx, vm); err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
return vm, nil return vm, nil
} }
func (d *Daemon) StopVM(ctx context.Context, idOrName string) (model.VMRecord, error) { func (s *VMService) StopVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { return s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
return d.stopVMLocked(ctx, vm) return s.stopVMLocked(ctx, vm)
}) })
} }
func (d *Daemon) stopVMLocked(ctx context.Context, current model.VMRecord) (vm model.VMRecord, err error) { func (s *VMService) stopVMLocked(ctx context.Context, current model.VMRecord) (vm model.VMRecord, err error) {
vm = current vm = current
op := d.beginOperation("vm.stop", "vm_ref", vm.ID) op := s.beginOperation("vm.stop", "vm_ref", vm.ID)
defer func() { defer func() {
if err != nil { if err != nil {
op.fail(err, vmLogAttrs(vm)...) op.fail(err, vmLogAttrs(vm)...)
@ -249,54 +249,54 @@ func (d *Daemon) stopVMLocked(ctx context.Context, current model.VMRecord) (vm m
} }
op.done(vmLogAttrs(vm)...) op.done(vmLogAttrs(vm)...)
}() }()
if !d.vmAlive(vm) { if !s.vmAlive(vm) {
op.stage("cleanup_stale_runtime") op.stage("cleanup_stale_runtime")
if err := d.cleanupRuntime(ctx, vm, true); err != nil { if err := s.cleanupRuntime(ctx, vm, true); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
vm.State = model.VMStateStopped vm.State = model.VMStateStopped
vm.Runtime.State = model.VMStateStopped vm.Runtime.State = model.VMStateStopped
d.clearVMHandles(vm) s.clearVMHandles(vm)
if err := d.store.UpsertVM(ctx, vm); err != nil { if err := s.store.UpsertVM(ctx, vm); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
return vm, nil return vm, nil
} }
pid := d.vmHandles(vm.ID).PID pid := s.vmHandles(vm.ID).PID
op.stage("graceful_shutdown") op.stage("graceful_shutdown")
if err := d.hostNet().sendCtrlAltDel(ctx, vm.Runtime.APISockPath); err != nil { if err := s.net.sendCtrlAltDel(ctx, vm.Runtime.APISockPath); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
op.stage("wait_for_exit", "pid", pid) op.stage("wait_for_exit", "pid", pid)
if err := d.hostNet().waitForExit(ctx, pid, vm.Runtime.APISockPath, gracefulShutdownWait); err != nil { if err := s.net.waitForExit(ctx, pid, vm.Runtime.APISockPath, gracefulShutdownWait); err != nil {
if !errors.Is(err, errWaitForExitTimeout) { if !errors.Is(err, errWaitForExitTimeout) {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
op.stage("graceful_shutdown_timeout", "pid", pid) op.stage("graceful_shutdown_timeout", "pid", pid)
} }
op.stage("cleanup_runtime") op.stage("cleanup_runtime")
if err := d.cleanupRuntime(ctx, vm, true); err != nil { if err := s.cleanupRuntime(ctx, vm, true); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
vm.State = model.VMStateStopped vm.State = model.VMStateStopped
vm.Runtime.State = model.VMStateStopped vm.Runtime.State = model.VMStateStopped
d.clearVMHandles(vm) s.clearVMHandles(vm)
system.TouchNow(&vm) system.TouchNow(&vm)
if err := d.store.UpsertVM(ctx, vm); err != nil { if err := s.store.UpsertVM(ctx, vm); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
return vm, nil return vm, nil
} }
func (d *Daemon) KillVM(ctx context.Context, params api.VMKillParams) (model.VMRecord, error) { func (s *VMService) KillVM(ctx context.Context, params api.VMKillParams) (model.VMRecord, error) {
return d.withVMLockByRef(ctx, params.IDOrName, func(vm model.VMRecord) (model.VMRecord, error) { return s.withVMLockByRef(ctx, params.IDOrName, func(vm model.VMRecord) (model.VMRecord, error) {
return d.killVMLocked(ctx, vm, params.Signal) return s.killVMLocked(ctx, vm, params.Signal)
}) })
} }
func (d *Daemon) killVMLocked(ctx context.Context, current model.VMRecord, signalValue string) (vm model.VMRecord, err error) { func (s *VMService) killVMLocked(ctx context.Context, current model.VMRecord, signalValue string) (vm model.VMRecord, err error) {
vm = current vm = current
op := d.beginOperation("vm.kill", "vm_ref", vm.ID, "signal", signalValue) op := s.beginOperation("vm.kill", "vm_ref", vm.ID, "signal", signalValue)
defer func() { defer func() {
if err != nil { if err != nil {
op.fail(err, vmLogAttrs(vm)...) op.fail(err, vmLogAttrs(vm)...)
@ -304,15 +304,15 @@ func (d *Daemon) killVMLocked(ctx context.Context, current model.VMRecord, signa
} }
op.done(vmLogAttrs(vm)...) op.done(vmLogAttrs(vm)...)
}() }()
if !d.vmAlive(vm) { if !s.vmAlive(vm) {
op.stage("cleanup_stale_runtime") op.stage("cleanup_stale_runtime")
if err := d.cleanupRuntime(ctx, vm, true); err != nil { if err := s.cleanupRuntime(ctx, vm, true); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
vm.State = model.VMStateStopped vm.State = model.VMStateStopped
vm.Runtime.State = model.VMStateStopped vm.Runtime.State = model.VMStateStopped
d.clearVMHandles(vm) s.clearVMHandles(vm)
if err := d.store.UpsertVM(ctx, vm); err != nil { if err := s.store.UpsertVM(ctx, vm); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
return vm, nil return vm, nil
@ -322,34 +322,34 @@ func (d *Daemon) killVMLocked(ctx context.Context, current model.VMRecord, signa
if signal == "" { if signal == "" {
signal = "TERM" signal = "TERM"
} }
pid := d.vmHandles(vm.ID).PID pid := s.vmHandles(vm.ID).PID
op.stage("send_signal", "pid", pid, "signal", signal) op.stage("send_signal", "pid", pid, "signal", signal)
if _, err := d.runner.RunSudo(ctx, "kill", "-"+signal, strconv.Itoa(pid)); err != nil { if _, err := s.runner.RunSudo(ctx, "kill", "-"+signal, strconv.Itoa(pid)); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
op.stage("wait_for_exit", "pid", pid) op.stage("wait_for_exit", "pid", pid)
if err := d.hostNet().waitForExit(ctx, pid, vm.Runtime.APISockPath, 30*time.Second); err != nil { if err := s.net.waitForExit(ctx, pid, vm.Runtime.APISockPath, 30*time.Second); err != nil {
if !errors.Is(err, errWaitForExitTimeout) { if !errors.Is(err, errWaitForExitTimeout) {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
op.stage("signal_timeout", "pid", pid, "signal", signal) op.stage("signal_timeout", "pid", pid, "signal", signal)
} }
op.stage("cleanup_runtime") op.stage("cleanup_runtime")
if err := d.cleanupRuntime(ctx, vm, true); err != nil { if err := s.cleanupRuntime(ctx, vm, true); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
vm.State = model.VMStateStopped vm.State = model.VMStateStopped
vm.Runtime.State = model.VMStateStopped vm.Runtime.State = model.VMStateStopped
d.clearVMHandles(vm) s.clearVMHandles(vm)
system.TouchNow(&vm) system.TouchNow(&vm)
if err := d.store.UpsertVM(ctx, vm); err != nil { if err := s.store.UpsertVM(ctx, vm); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
return vm, nil return vm, nil
} }
func (d *Daemon) RestartVM(ctx context.Context, idOrName string) (vm model.VMRecord, err error) { func (s *VMService) RestartVM(ctx context.Context, idOrName string) (vm model.VMRecord, err error) {
op := d.beginOperation("vm.restart", "vm_ref", idOrName) op := s.beginOperation("vm.restart", "vm_ref", idOrName)
defer func() { defer func() {
if err != nil { if err != nil {
op.fail(err, vmLogAttrs(vm)...) op.fail(err, vmLogAttrs(vm)...)
@ -357,34 +357,34 @@ func (d *Daemon) RestartVM(ctx context.Context, idOrName string) (vm model.VMRec
} }
op.done(vmLogAttrs(vm)...) op.done(vmLogAttrs(vm)...)
}() }()
resolved, err := d.FindVM(ctx, idOrName) resolved, err := s.FindVM(ctx, idOrName)
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
return d.withVMLockByID(ctx, resolved.ID, func(vm model.VMRecord) (model.VMRecord, error) { return s.withVMLockByID(ctx, resolved.ID, func(vm model.VMRecord) (model.VMRecord, error) {
op.stage("stop") op.stage("stop")
vm, err = d.stopVMLocked(ctx, vm) vm, err = s.stopVMLocked(ctx, vm)
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
image, err := d.store.GetImageByID(ctx, vm.ImageID) image, err := s.store.GetImageByID(ctx, vm.ImageID)
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
op.stage("start", vmLogAttrs(vm)...) op.stage("start", vmLogAttrs(vm)...)
return d.startVMLocked(ctx, vm, image) return s.startVMLocked(ctx, vm, image)
}) })
} }
func (d *Daemon) DeleteVM(ctx context.Context, idOrName string) (model.VMRecord, error) { func (s *VMService) DeleteVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { return s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
return d.deleteVMLocked(ctx, vm) return s.deleteVMLocked(ctx, vm)
}) })
} }
func (d *Daemon) deleteVMLocked(ctx context.Context, current model.VMRecord) (vm model.VMRecord, err error) { func (s *VMService) deleteVMLocked(ctx context.Context, current model.VMRecord) (vm model.VMRecord, err error) {
vm = current vm = current
op := d.beginOperation("vm.delete", "vm_ref", vm.ID) op := s.beginOperation("vm.delete", "vm_ref", vm.ID)
defer func() { defer func() {
if err != nil { if err != nil {
op.fail(err, vmLogAttrs(vm)...) op.fail(err, vmLogAttrs(vm)...)
@ -392,17 +392,17 @@ func (d *Daemon) deleteVMLocked(ctx context.Context, current model.VMRecord) (vm
} }
op.done(vmLogAttrs(vm)...) op.done(vmLogAttrs(vm)...)
}() }()
if d.vmAlive(vm) { if s.vmAlive(vm) {
pid := d.vmHandles(vm.ID).PID pid := s.vmHandles(vm.ID).PID
op.stage("kill_running_vm", "pid", pid) op.stage("kill_running_vm", "pid", pid)
_ = d.hostNet().killVMProcess(ctx, pid) _ = s.net.killVMProcess(ctx, pid)
} }
op.stage("cleanup_runtime") op.stage("cleanup_runtime")
if err := d.cleanupRuntime(ctx, vm, false); err != nil { if err := s.cleanupRuntime(ctx, vm, false); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
op.stage("delete_store_record") op.stage("delete_store_record")
if err := d.store.DeleteVM(ctx, vm.ID); err != nil { if err := s.store.DeleteVM(ctx, vm.ID); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
if vm.Runtime.VMDir != "" { if vm.Runtime.VMDir != "" {
@ -414,6 +414,6 @@ func (d *Daemon) deleteVMLocked(ctx context.Context, current model.VMRecord) (vm
// Drop any host-key pins. A future VM reusing this IP or name // Drop any host-key pins. A future VM reusing this IP or name
// would otherwise trip the TOFU mismatch branch in // would otherwise trip the TOFU mismatch branch in
// TOFUHostKeyCallback and fail to connect. // TOFUHostKeyCallback and fail to connect.
removeVMKnownHosts(d.layout.KnownHostsPath, vm, d.logger) removeVMKnownHosts(s.layout.KnownHostsPath, vm, s.logger)
return vm, nil return vm, nil
} }

View file

@ -0,0 +1,256 @@
package daemon
import (
"context"
"database/sql"
"errors"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"banger/internal/daemon/opstate"
"banger/internal/firecracker"
"banger/internal/guestconfig"
"banger/internal/model"
"banger/internal/paths"
"banger/internal/store"
"banger/internal/system"
)
// VMService owns VM lifecycle — create / start / stop / restart /
// kill / delete / set — plus the handle cache, create-operation
// registry, stats polling, disk provisioning, ports query, and the
// SSH-client test seams.
//
// It holds pointers to its peer services (HostNetwork, ImageService,
// WorkspaceService) because VM lifecycle really does orchestrate
// across them (start needs bridge + tap + firecracker + auth sync +
// boot). Defining narrow function-typed interfaces for every peer
// method VMService calls would balloon the diff for no real win —
// services remain unexported within the package so nothing outside
// the daemon can see them.
//
// Capability invocation still runs through Daemon because the hook
// interfaces take *Daemon directly. VMService calls back via the
// capHooks seam rather than holding a *Daemon pointer, to keep the
// dependency graph acyclic.
type VMService struct {
runner system.CommandRunner
logger *slog.Logger
config model.DaemonConfig
layout paths.Layout
store *store.Store
// vmLocks is the per-VM mutex set. Held across entire lifecycle
// ops (start, stop, delete, set) — not just the validation window.
// Workspace.prepare intentionally splits off onto its own lock
// scope; see WorkspaceService.
vmLocks vmLockSet
createVMMu sync.Mutex
createOps opstate.Registry[*vmCreateOperationState]
// handles caches per-VM transient kernel/process state (PID, tap,
// loop devices, DM name/device). Rebuildable at daemon startup
// from a per-VM handles.json scratch file plus OS inspection.
handles *handleCache
// Peer services. VMService orchestrates across all three during
// start/stop/delete; pointer fields keep call sites direct without
// promoting the peer API to package-level interfaces.
net *HostNetwork
img *ImageService
ws *WorkspaceService
// Test seams.
guestWaitForSSH func(context.Context, string, string, time.Duration) error
guestDial func(context.Context, string, string) (guestSSHClient, error)
// Capability hook dispatch. Capabilities themselves live on
// *Daemon (their interface takes *Daemon as receiver); VMService
// invokes them via these seams so it doesn't need a *Daemon
// pointer.
capHooks capabilityHooks
beginOperation func(name string, attrs ...any) *operationLog
}
// capabilityHooks bundles the capability-dispatch entry points that
// VMService needs. Populated by Daemon.buildCapabilityHooks() at
// service construction; stubbable in tests that don't care about
// capability side effects.
type capabilityHooks struct {
addStartPrereqs func(ctx context.Context, checks *system.Preflight, vm model.VMRecord, image model.Image)
contributeGuest func(builder *guestconfig.Builder, vm model.VMRecord, image model.Image)
contributeMachine func(cfg *firecracker.MachineConfig, vm model.VMRecord, image model.Image)
prepareHosts func(ctx context.Context, vm *model.VMRecord, image model.Image) error
postStart func(ctx context.Context, vm model.VMRecord, image model.Image) error
cleanupState func(ctx context.Context, vm model.VMRecord) error
applyConfigChanges func(ctx context.Context, before, after model.VMRecord) error
}
type vmServiceDeps struct {
runner system.CommandRunner
logger *slog.Logger
config model.DaemonConfig
layout paths.Layout
store *store.Store
net *HostNetwork
img *ImageService
ws *WorkspaceService
guestWaitForSSH func(context.Context, string, string, time.Duration) error
guestDial func(context.Context, string, string) (guestSSHClient, error)
capHooks capabilityHooks
beginOperation func(name string, attrs ...any) *operationLog
}
func newVMService(deps vmServiceDeps) *VMService {
return &VMService{
runner: deps.runner,
logger: deps.logger,
config: deps.config,
layout: deps.layout,
store: deps.store,
net: deps.net,
img: deps.img,
ws: deps.ws,
guestWaitForSSH: deps.guestWaitForSSH,
guestDial: deps.guestDial,
capHooks: deps.capHooks,
beginOperation: deps.beginOperation,
handles: newHandleCache(),
}
}
// vmSvc is Daemon's lazy-init getter. Mirrors hostNet() / imageSvc() /
// workspaceSvc() so test literals like `&Daemon{store: db, runner: r}`
// still get a functional VMService without spelling one out.
func (d *Daemon) vmSvc() *VMService {
if d.vm != nil {
return d.vm
}
d.vm = newVMService(vmServiceDeps{
runner: d.runner,
logger: d.logger,
config: d.config,
layout: d.layout,
store: d.store,
net: d.hostNet(),
img: d.imageSvc(),
ws: d.workspaceSvc(),
guestWaitForSSH: d.guestWaitForSSH,
guestDial: d.guestDial,
capHooks: d.buildCapabilityHooks(),
beginOperation: d.beginOperation,
})
return d.vm
}
// buildCapabilityHooks adapts Daemon's existing capability-dispatch
// methods into the capabilityHooks bag VMService takes. Keeps the
// registry + capability types on *Daemon while letting VMService call
// into them through explicit function seams.
func (d *Daemon) buildCapabilityHooks() capabilityHooks {
return capabilityHooks{
addStartPrereqs: d.addCapabilityStartPrereqs,
contributeGuest: d.contributeGuestConfig,
contributeMachine: d.contributeMachineConfig,
prepareHosts: d.prepareCapabilityHosts,
postStart: d.postStartCapabilities,
cleanupState: d.cleanupCapabilityState,
applyConfigChanges: d.applyCapabilityConfigChanges,
}
}
// FindVM resolves an ID-or-name against the store with the historical
// precedence: exact-ID / exact-name first, then unambiguous prefix
// match. Returns an error when no match is found or when a prefix
// matches more than one record.
func (s *VMService) FindVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
if idOrName == "" {
return model.VMRecord{}, errors.New("vm id or name is required")
}
if vm, err := s.store.GetVM(ctx, idOrName); err == nil {
return vm, nil
}
vms, err := s.store.ListVMs(ctx)
if err != nil {
return model.VMRecord{}, err
}
matchCount := 0
var match model.VMRecord
for _, vm := range vms {
if strings.HasPrefix(vm.ID, idOrName) || strings.HasPrefix(vm.Name, idOrName) {
match = vm
matchCount++
}
}
if matchCount == 1 {
return match, nil
}
if matchCount > 1 {
return model.VMRecord{}, fmt.Errorf("multiple VMs match %q", idOrName)
}
return model.VMRecord{}, fmt.Errorf("vm %q not found", idOrName)
}
// TouchVM bumps a VM's updated-at timestamp under the per-VM lock.
func (s *VMService) TouchVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
return s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
system.TouchNow(&vm)
if err := s.store.UpsertVM(ctx, vm); err != nil {
return model.VMRecord{}, err
}
return vm, nil
})
}
// withVMLockByRef resolves idOrName then serialises fn under the
// per-VM lock. Every mutating VM operation funnels through here.
func (s *VMService) withVMLockByRef(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) {
vm, err := s.FindVM(ctx, idOrName)
if err != nil {
return model.VMRecord{}, err
}
return s.withVMLockByID(ctx, vm.ID, fn)
}
// withVMLockByID locks on the stable VM ID (so a rename mid-flight
// doesn't drop the lock) and re-reads the record under the lock so
// fn sees the committed state.
func (s *VMService) withVMLockByID(ctx context.Context, id string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) {
if strings.TrimSpace(id) == "" {
return model.VMRecord{}, errors.New("vm id is required")
}
unlock := s.lockVMID(id)
defer unlock()
vm, err := s.store.GetVMByID(ctx, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return model.VMRecord{}, fmt.Errorf("vm %q not found", id)
}
return model.VMRecord{}, err
}
return fn(vm)
}
// withVMLockByIDErr is the error-only variant of withVMLockByID for
// callers that don't need the returned record.
func (s *VMService) withVMLockByIDErr(ctx context.Context, id string, fn func(model.VMRecord) error) error {
_, err := s.withVMLockByID(ctx, id, func(vm model.VMRecord) (model.VMRecord, error) {
if err := fn(vm); err != nil {
return model.VMRecord{}, err
}
return vm, nil
})
return err
}
// lockVMID exposes the per-VM mutex for callers that need to hold it
// outside the usual withVMLockByRef/withVMLockByID helpers
// (workspace prepare, for example).
func (s *VMService) lockVMID(id string) func() {
return s.vmLocks.lock(id)
}

View file

@ -9,15 +9,15 @@ import (
"banger/internal/system" "banger/internal/system"
) )
func (d *Daemon) SetVM(ctx context.Context, params api.VMSetParams) (model.VMRecord, error) { func (s *VMService) SetVM(ctx context.Context, params api.VMSetParams) (model.VMRecord, error) {
return d.withVMLockByRef(ctx, params.IDOrName, func(vm model.VMRecord) (model.VMRecord, error) { return s.withVMLockByRef(ctx, params.IDOrName, func(vm model.VMRecord) (model.VMRecord, error) {
return d.setVMLocked(ctx, vm, params) return s.setVMLocked(ctx, vm, params)
}) })
} }
func (d *Daemon) setVMLocked(ctx context.Context, current model.VMRecord, params api.VMSetParams) (vm model.VMRecord, err error) { func (s *VMService) setVMLocked(ctx context.Context, current model.VMRecord, params api.VMSetParams) (vm model.VMRecord, err error) {
vm = current vm = current
op := d.beginOperation("vm.set", "vm_ref", vm.ID) op := s.beginOperation("vm.set", "vm_ref", vm.ID)
defer func() { defer func() {
if err != nil { if err != nil {
op.fail(err, vmLogAttrs(vm)...) op.fail(err, vmLogAttrs(vm)...)
@ -25,7 +25,7 @@ func (d *Daemon) setVMLocked(ctx context.Context, current model.VMRecord, params
} }
op.done(vmLogAttrs(vm)...) op.done(vmLogAttrs(vm)...)
}() }()
running := d.vmAlive(vm) running := s.vmAlive(vm)
if params.VCPUCount != nil { if params.VCPUCount != nil {
if err := validateOptionalPositiveSetting("vcpu", params.VCPUCount); err != nil { if err := validateOptionalPositiveSetting("vcpu", params.VCPUCount); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
@ -60,10 +60,10 @@ func (d *Daemon) setVMLocked(ctx context.Context, current model.VMRecord, params
if size > vm.Spec.WorkDiskSizeBytes { if size > vm.Spec.WorkDiskSizeBytes {
if exists(vm.Runtime.WorkDiskPath) { if exists(vm.Runtime.WorkDiskPath) {
op.stage("resize_work_disk", "from_bytes", vm.Spec.WorkDiskSizeBytes, "to_bytes", size) op.stage("resize_work_disk", "from_bytes", vm.Spec.WorkDiskSizeBytes, "to_bytes", size)
if err := d.validateWorkDiskResizePrereqs(); err != nil { if err := s.validateWorkDiskResizePrereqs(); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
if err := system.ResizeExt4Image(ctx, d.runner, vm.Runtime.WorkDiskPath, size); err != nil { if err := system.ResizeExt4Image(ctx, s.runner, vm.Runtime.WorkDiskPath, size); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
} }
@ -75,12 +75,12 @@ func (d *Daemon) setVMLocked(ctx context.Context, current model.VMRecord, params
vm.Spec.NATEnabled = *params.NATEnabled vm.Spec.NATEnabled = *params.NATEnabled
} }
if running { if running {
if err := d.applyCapabilityConfigChanges(ctx, current, vm); err != nil { if err := s.capHooks.applyConfigChanges(ctx, current, vm); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
} }
system.TouchNow(&vm) system.TouchNow(&vm)
if err := d.store.UpsertVM(ctx, vm); err != nil { if err := s.store.UpsertVM(ctx, vm); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
return vm, nil return vm, nil

View file

@ -12,9 +12,9 @@ import (
"banger/internal/vsockagent" "banger/internal/vsockagent"
) )
func (d *Daemon) GetVMStats(ctx context.Context, idOrName string) (model.VMRecord, model.VMStats, error) { func (s *VMService) GetVMStats(ctx context.Context, idOrName string) (model.VMRecord, model.VMStats, error) {
vm, err := d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { vm, err := s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
return d.getVMStatsLocked(ctx, vm) return s.getVMStatsLocked(ctx, vm)
}) })
if err != nil { if err != nil {
return model.VMRecord{}, model.VMStats{}, err return model.VMRecord{}, model.VMStats{}, err
@ -22,10 +22,10 @@ func (d *Daemon) GetVMStats(ctx context.Context, idOrName string) (model.VMRecor
return vm, vm.Stats, nil return vm, vm.Stats, nil
} }
func (d *Daemon) HealthVM(ctx context.Context, idOrName string) (result api.VMHealthResult, err error) { func (s *VMService) HealthVM(ctx context.Context, idOrName string) (result api.VMHealthResult, err error) {
_, err = d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { _, err = s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
result.Name = vm.Name result.Name = vm.Name
if !d.vmAlive(vm) { if !s.vmAlive(vm) {
result.Healthy = false result.Healthy = false
return vm, nil return vm, nil
} }
@ -35,12 +35,12 @@ func (d *Daemon) HealthVM(ctx context.Context, idOrName string) (result api.VMHe
if vm.Runtime.VSockCID == 0 { if vm.Runtime.VSockCID == 0 {
return model.VMRecord{}, errors.New("vm has no vsock cid") return model.VMRecord{}, errors.New("vm has no vsock cid")
} }
if err := d.hostNet().ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { if err := s.net.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second) pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel() defer cancel()
if err := vsockagent.Health(pingCtx, d.logger, vm.Runtime.VSockPath); err != nil { if err := vsockagent.Health(pingCtx, s.logger, vm.Runtime.VSockPath); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
result.Healthy = true result.Healthy = true
@ -49,47 +49,47 @@ func (d *Daemon) HealthVM(ctx context.Context, idOrName string) (result api.VMHe
return result, err return result, err
} }
func (d *Daemon) PingVM(ctx context.Context, idOrName string) (result api.VMPingResult, err error) { func (s *VMService) PingVM(ctx context.Context, idOrName string) (result api.VMPingResult, err error) {
health, err := d.HealthVM(ctx, idOrName) health, err := s.HealthVM(ctx, idOrName)
if err != nil { if err != nil {
return api.VMPingResult{}, err return api.VMPingResult{}, err
} }
return api.VMPingResult{Name: health.Name, Alive: health.Healthy}, nil return api.VMPingResult{Name: health.Name, Alive: health.Healthy}, nil
} }
func (d *Daemon) getVMStatsLocked(ctx context.Context, vm model.VMRecord) (model.VMRecord, error) { func (s *VMService) getVMStatsLocked(ctx context.Context, vm model.VMRecord) (model.VMRecord, error) {
stats, err := d.collectStats(ctx, vm) stats, err := s.collectStats(ctx, vm)
if err == nil { if err == nil {
vm.Stats = stats vm.Stats = stats
vm.UpdatedAt = model.Now() vm.UpdatedAt = model.Now()
_ = d.store.UpsertVM(ctx, vm) _ = s.store.UpsertVM(ctx, vm)
if d.logger != nil { if s.logger != nil {
d.logger.Debug("vm stats collected", append(vmLogAttrs(vm), "rss_bytes", stats.RSSBytes, "vsz_bytes", stats.VSZBytes, "cpu_percent", stats.CPUPercent)...) s.logger.Debug("vm stats collected", append(vmLogAttrs(vm), "rss_bytes", stats.RSSBytes, "vsz_bytes", stats.VSZBytes, "cpu_percent", stats.CPUPercent)...)
} }
} }
return vm, nil return vm, nil
} }
func (d *Daemon) pollStats(ctx context.Context) error { func (s *VMService) pollStats(ctx context.Context) error {
vms, err := d.store.ListVMs(ctx) vms, err := s.store.ListVMs(ctx)
if err != nil { if err != nil {
return err return err
} }
for _, vm := range vms { for _, vm := range vms {
if err := d.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error { if err := s.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error {
if !d.vmAlive(vm) { if !s.vmAlive(vm) {
return nil return nil
} }
stats, err := d.collectStats(ctx, vm) stats, err := s.collectStats(ctx, vm)
if err != nil { if err != nil {
if d.logger != nil { if s.logger != nil {
d.logger.Debug("vm stats collection failed", append(vmLogAttrs(vm), "error", err.Error())...) s.logger.Debug("vm stats collection failed", append(vmLogAttrs(vm), "error", err.Error())...)
} }
return nil return nil
} }
vm.Stats = stats vm.Stats = stats
vm.UpdatedAt = model.Now() vm.UpdatedAt = model.Now()
return d.store.UpsertVM(ctx, vm) return s.store.UpsertVM(ctx, vm)
}); err != nil { }); err != nil {
return err return err
} }
@ -97,11 +97,11 @@ func (d *Daemon) pollStats(ctx context.Context) error {
return nil return nil
} }
func (d *Daemon) stopStaleVMs(ctx context.Context) (err error) { func (s *VMService) stopStaleVMs(ctx context.Context) (err error) {
if d.config.AutoStopStaleAfter <= 0 { if s.config.AutoStopStaleAfter <= 0 {
return nil return nil
} }
op := d.beginOperation("vm.stop_stale") op := s.beginOperation("vm.stop_stale")
defer func() { defer func() {
if err != nil { if err != nil {
op.fail(err) op.fail(err)
@ -109,28 +109,28 @@ func (d *Daemon) stopStaleVMs(ctx context.Context) (err error) {
} }
op.done() op.done()
}() }()
vms, err := d.store.ListVMs(ctx) vms, err := s.store.ListVMs(ctx)
if err != nil { if err != nil {
return err return err
} }
now := model.Now() now := model.Now()
for _, vm := range vms { for _, vm := range vms {
if err := d.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error { if err := s.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error {
if !d.vmAlive(vm) { if !s.vmAlive(vm) {
return nil return nil
} }
if now.Sub(vm.LastTouchedAt) < d.config.AutoStopStaleAfter { if now.Sub(vm.LastTouchedAt) < s.config.AutoStopStaleAfter {
return nil return nil
} }
op.stage("stopping_vm", vmLogAttrs(vm)...) op.stage("stopping_vm", vmLogAttrs(vm)...)
_ = d.hostNet().sendCtrlAltDel(ctx, vm.Runtime.APISockPath) _ = s.net.sendCtrlAltDel(ctx, vm.Runtime.APISockPath)
_ = d.hostNet().waitForExit(ctx, d.vmHandles(vm.ID).PID, vm.Runtime.APISockPath, 10*time.Second) _ = s.net.waitForExit(ctx, s.vmHandles(vm.ID).PID, vm.Runtime.APISockPath, 10*time.Second)
_ = d.cleanupRuntime(ctx, vm, true) _ = s.cleanupRuntime(ctx, vm, true)
vm.State = model.VMStateStopped vm.State = model.VMStateStopped
vm.Runtime.State = model.VMStateStopped vm.Runtime.State = model.VMStateStopped
d.clearVMHandles(vm) s.clearVMHandles(vm)
vm.UpdatedAt = model.Now() vm.UpdatedAt = model.Now()
return d.store.UpsertVM(ctx, vm) return s.store.UpsertVM(ctx, vm)
}); err != nil { }); err != nil {
return err return err
} }
@ -138,15 +138,15 @@ func (d *Daemon) stopStaleVMs(ctx context.Context) (err error) {
return nil return nil
} }
func (d *Daemon) collectStats(ctx context.Context, vm model.VMRecord) (model.VMStats, error) { func (s *VMService) collectStats(ctx context.Context, vm model.VMRecord) (model.VMStats, error) {
stats := model.VMStats{ stats := model.VMStats{
CollectedAt: model.Now(), CollectedAt: model.Now(),
SystemOverlayBytes: system.AllocatedBytes(vm.Runtime.SystemOverlay), SystemOverlayBytes: system.AllocatedBytes(vm.Runtime.SystemOverlay),
WorkDiskBytes: system.AllocatedBytes(vm.Runtime.WorkDiskPath), WorkDiskBytes: system.AllocatedBytes(vm.Runtime.WorkDiskPath),
MetricsRaw: system.ParseMetricsFile(vm.Runtime.MetricsPath), MetricsRaw: system.ParseMetricsFile(vm.Runtime.MetricsPath),
} }
if d.vmAlive(vm) { if s.vmAlive(vm) {
if ps, err := system.ReadProcessStats(ctx, d.vmHandles(vm.ID).PID); err == nil { if ps, err := system.ReadProcessStats(ctx, s.vmHandles(vm.ID).PID); err == nil {
stats.CPUPercent = ps.CPUPercent stats.CPUPercent = ps.CPUPercent
stats.RSSBytes = ps.RSSBytes stats.RSSBytes = ps.RSSBytes
stats.VSZBytes = ps.VSZBytes stats.VSZBytes = ps.VSZBytes

View file

@ -167,7 +167,7 @@ func TestReconcileStopsStaleRunningVMAndClearsRuntimeHandles(t *testing.T) {
t.Fatalf("handles.json still present after reconcile: %v", err) t.Fatalf("handles.json still present after reconcile: %v", err)
} }
// And the in-memory cache must be empty. // And the in-memory cache must be empty.
if h, ok := d.handles.get(vm.ID); ok && !h.IsZero() { if h, ok := d.vmSvc().handles.get(vm.ID); ok && !h.IsZero() {
t.Fatalf("handle cache not cleared after reconcile: %+v", h) t.Fatalf("handle cache not cleared after reconcile: %+v", h)
} }
} }
@ -216,9 +216,9 @@ func TestRebuildDNSIncludesOnlyLiveRunningVMs(t *testing.T) {
// rebuildDNS reads the alive check from the handle cache. Seed // rebuildDNS reads the alive check from the handle cache. Seed
// the live VM with its real PID; leave the stale entry with a PID // the live VM with its real PID; leave the stale entry with a PID
// that definitely isn't running (999999 ≫ max PID on most hosts). // that definitely isn't running (999999 ≫ max PID on most hosts).
d.setVMHandlesInMemory(live.ID, model.VMHandles{PID: liveCmd.Process.Pid}) d.vmSvc().setVMHandlesInMemory(live.ID, model.VMHandles{PID: liveCmd.Process.Pid})
d.setVMHandlesInMemory(stale.ID, model.VMHandles{PID: 999999}) d.vmSvc().setVMHandlesInMemory(stale.ID, model.VMHandles{PID: 999999})
if err := d.rebuildDNS(ctx); err != nil { if err := d.vmSvc().rebuildDNS(ctx); err != nil {
t.Fatalf("rebuildDNS: %v", err) t.Fatalf("rebuildDNS: %v", err)
} }
@ -252,7 +252,7 @@ func TestSetVMRejectsStoppedOnlyChangesForRunningVM(t *testing.T) {
upsertDaemonVM(t, ctx, db, vm) upsertDaemonVM(t, ctx, db, vm)
d := &Daemon{store: db} d := &Daemon{store: db}
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: cmd.Process.Pid}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: cmd.Process.Pid})
tests := []struct { tests := []struct {
name string name string
params api.VMSetParams params api.VMSetParams
@ -277,7 +277,7 @@ func TestSetVMRejectsStoppedOnlyChangesForRunningVM(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := d.SetVM(ctx, tt.params) _, err := d.vmSvc().SetVM(ctx, tt.params)
if err == nil || !strings.Contains(err.Error(), tt.want) { if err == nil || !strings.Contains(err.Error(), tt.want) {
t.Fatalf("SetVM(%s) error = %v, want %q", tt.name, err, tt.want) t.Fatalf("SetVM(%s) error = %v, want %q", tt.name, err, tt.want)
} }
@ -367,8 +367,8 @@ func TestHealthVMReturnsHealthyForRunningGuest(t *testing.T) {
}, },
} }
d := &Daemon{store: db, runner: runner} d := &Daemon{store: db, runner: runner}
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: handlePID}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: handlePID})
result, err := d.HealthVM(ctx, vm.Name) result, err := d.vmSvc().HealthVM(ctx, vm.Name)
if err != nil { if err != nil {
t.Fatalf("HealthVM: %v", err) t.Fatalf("HealthVM: %v", err)
} }
@ -430,8 +430,8 @@ func TestPingVMAliasReturnsAliveForHealthyVM(t *testing.T) {
}, },
} }
d := &Daemon{store: db, runner: runner} d := &Daemon{store: db, runner: runner}
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: fake.Process.Pid}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: fake.Process.Pid})
result, err := d.PingVM(ctx, vm.Name) result, err := d.vmSvc().PingVM(ctx, vm.Name)
if err != nil { if err != nil {
t.Fatalf("PingVM: %v", err) t.Fatalf("PingVM: %v", err)
} }
@ -530,7 +530,7 @@ func TestHealthVMReturnsFalseForStoppedVM(t *testing.T) {
upsertDaemonVM(t, ctx, db, vm) upsertDaemonVM(t, ctx, db, vm)
d := &Daemon{store: db} d := &Daemon{store: db}
result, err := d.HealthVM(ctx, vm.Name) result, err := d.vmSvc().HealthVM(ctx, vm.Name)
if err != nil { if err != nil {
t.Fatalf("HealthVM: %v", err) t.Fatalf("HealthVM: %v", err)
} }
@ -628,9 +628,9 @@ func TestPortsVMReturnsEnrichedPortsAndWebSchemes(t *testing.T) {
}, },
} }
d := &Daemon{store: db, runner: runner} d := &Daemon{store: db, runner: runner}
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: fake.Process.Pid}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: fake.Process.Pid})
result, err := d.PortsVM(ctx, vm.Name) result, err := d.vmSvc().PortsVM(ctx, vm.Name)
if err != nil { if err != nil {
t.Fatalf("PortsVM: %v", err) t.Fatalf("PortsVM: %v", err)
} }
@ -677,7 +677,7 @@ func TestPortsVMReturnsErrorForStoppedVM(t *testing.T) {
upsertDaemonVM(t, ctx, db, vm) upsertDaemonVM(t, ctx, db, vm)
d := &Daemon{store: db} d := &Daemon{store: db}
_, err := d.PortsVM(ctx, vm.Name) _, err := d.vmSvc().PortsVM(ctx, vm.Name)
if err == nil || !strings.Contains(err.Error(), "is not running") { if err == nil || !strings.Contains(err.Error(), "is not running") {
t.Fatalf("PortsVM error = %v, want not running", err) t.Fatalf("PortsVM error = %v, want not running", err)
} }
@ -740,7 +740,7 @@ func TestSetVMDiskResizeFailsPreflightWhenToolsMissing(t *testing.T) {
t.Setenv("PATH", t.TempDir()) t.Setenv("PATH", t.TempDir())
d := &Daemon{store: db} d := &Daemon{store: db}
_, err := d.SetVM(ctx, api.VMSetParams{IDOrName: vm.ID, WorkDiskSize: "16G"}) _, err := d.vmSvc().SetVM(ctx, api.VMSetParams{IDOrName: vm.ID, WorkDiskSize: "16G"})
if err == nil || !strings.Contains(err.Error(), "work disk resize preflight failed") { if err == nil || !strings.Contains(err.Error(), "work disk resize preflight failed") {
t.Fatalf("SetVM() error = %v, want preflight failure", err) t.Fatalf("SetVM() error = %v, want preflight failure", err)
} }
@ -769,7 +769,7 @@ func TestFlattenNestedWorkHomeCopiesEntriesIndividually(t *testing.T) {
} }
d := &Daemon{runner: runner} d := &Daemon{runner: runner}
if err := d.flattenNestedWorkHome(context.Background(), workMount); err != nil { if err := flattenNestedWorkHome(context.Background(), d.runner, workMount); err != nil {
t.Fatalf("flattenNestedWorkHome: %v", err) t.Fatalf("flattenNestedWorkHome: %v", err)
} }
runner.assertExhausted() runner.assertExhausted()
@ -1157,10 +1157,10 @@ func TestRunFileSyncCopiesDirectoryRecursively(t *testing.T) {
func TestCreateVMRejectsNonPositiveCPUAndMemory(t *testing.T) { func TestCreateVMRejectsNonPositiveCPUAndMemory(t *testing.T) {
d := &Daemon{} d := &Daemon{}
if _, err := d.CreateVM(context.Background(), api.VMCreateParams{VCPUCount: ptr(0)}); err == nil || !strings.Contains(err.Error(), "vcpu must be a positive integer") { if _, err := d.vmSvc().CreateVM(context.Background(), api.VMCreateParams{VCPUCount: ptr(0)}); err == nil || !strings.Contains(err.Error(), "vcpu must be a positive integer") {
t.Fatalf("CreateVM(vcpu=0) error = %v", err) t.Fatalf("CreateVM(vcpu=0) error = %v", err)
} }
if _, err := d.CreateVM(context.Background(), api.VMCreateParams{MemoryMiB: ptr(-1)}); err == nil || !strings.Contains(err.Error(), "memory must be a positive integer") { if _, err := d.vmSvc().CreateVM(context.Background(), api.VMCreateParams{MemoryMiB: ptr(-1)}); err == nil || !strings.Contains(err.Error(), "memory must be a positive integer") {
t.Fatalf("CreateVM(memory=-1) error = %v", err) t.Fatalf("CreateVM(memory=-1) error = %v", err)
} }
} }
@ -1188,7 +1188,7 @@ func TestBeginVMCreateCompletesAndReturnsStatus(t *testing.T) {
}, },
} }
op, err := d.BeginVMCreate(ctx, api.VMCreateParams{Name: "queued", NoStart: true}) op, err := d.vmSvc().BeginVMCreate(ctx, api.VMCreateParams{Name: "queued", NoStart: true})
if err != nil { if err != nil {
t.Fatalf("BeginVMCreate: %v", err) t.Fatalf("BeginVMCreate: %v", err)
} }
@ -1198,7 +1198,7 @@ func TestBeginVMCreateCompletesAndReturnsStatus(t *testing.T) {
deadline := time.Now().Add(2 * time.Second) deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) { for time.Now().Before(deadline) {
status, err := d.VMCreateStatus(ctx, op.ID) status, err := d.vmSvc().VMCreateStatus(ctx, op.ID)
if err != nil { if err != nil {
t.Fatalf("VMCreateStatus: %v", err) t.Fatalf("VMCreateStatus: %v", err)
} }
@ -1238,7 +1238,7 @@ func TestCreateVMUsesDefaultsWhenCPUAndMemoryOmitted(t *testing.T) {
}, },
} }
vm, err := d.CreateVM(ctx, api.VMCreateParams{Name: "defaults", ImageName: image.Name, NoStart: true}) vm, err := d.vmSvc().CreateVM(ctx, api.VMCreateParams{Name: "defaults", ImageName: image.Name, NoStart: true})
if err != nil { if err != nil {
t.Fatalf("CreateVM: %v", err) t.Fatalf("CreateVM: %v", err)
} }
@ -1257,10 +1257,10 @@ func TestSetVMRejectsNonPositiveCPUAndMemory(t *testing.T) {
upsertDaemonVM(t, ctx, db, vm) upsertDaemonVM(t, ctx, db, vm)
d := &Daemon{store: db} d := &Daemon{store: db}
if _, err := d.SetVM(ctx, api.VMSetParams{IDOrName: vm.ID, VCPUCount: ptr(0)}); err == nil || !strings.Contains(err.Error(), "vcpu must be a positive integer") { if _, err := d.vmSvc().SetVM(ctx, api.VMSetParams{IDOrName: vm.ID, VCPUCount: ptr(0)}); err == nil || !strings.Contains(err.Error(), "vcpu must be a positive integer") {
t.Fatalf("SetVM(vcpu=0) error = %v", err) t.Fatalf("SetVM(vcpu=0) error = %v", err)
} }
if _, err := d.SetVM(ctx, api.VMSetParams{IDOrName: vm.ID, MemoryMiB: ptr(0)}); err == nil || !strings.Contains(err.Error(), "memory must be a positive integer") { if _, err := d.vmSvc().SetVM(ctx, api.VMSetParams{IDOrName: vm.ID, MemoryMiB: ptr(0)}); err == nil || !strings.Contains(err.Error(), "memory must be a positive integer") {
t.Fatalf("SetVM(memory=0) error = %v", err) t.Fatalf("SetVM(memory=0) error = %v", err)
} }
} }
@ -1281,7 +1281,7 @@ func TestCollectStatsIgnoresMalformedMetricsFile(t *testing.T) {
} }
d := &Daemon{} d := &Daemon{}
stats, err := d.collectStats(context.Background(), model.VMRecord{ stats, err := d.vmSvc().collectStats(context.Background(), model.VMRecord{
Runtime: model.VMRuntime{ Runtime: model.VMRuntime{
SystemOverlay: overlay, SystemOverlay: overlay,
WorkDiskPath: workDisk, WorkDiskPath: workDisk,
@ -1337,7 +1337,7 @@ func TestValidateStartPrereqsReportsNATUplinkFailure(t *testing.T) {
image.RootfsPath = rootfsPath image.RootfsPath = rootfsPath
image.KernelPath = kernelPath image.KernelPath = kernelPath
err := d.validateStartPrereqs(ctx, vm, image) err := d.vmSvc().validateStartPrereqs(ctx, vm, image)
if err == nil || !strings.Contains(err.Error(), "uplink interface for NAT") { if err == nil || !strings.Contains(err.Error(), "uplink interface for NAT") {
t.Fatalf("validateStartPrereqs() error = %v, want NAT uplink failure", err) t.Fatalf("validateStartPrereqs() error = %v, want NAT uplink failure", err)
} }
@ -1369,9 +1369,9 @@ func TestCleanupRuntimeRediscoversLiveFirecrackerPID(t *testing.T) {
vm.Runtime.APISockPath = apiSock vm.Runtime.APISockPath = apiSock
// Seed a stale PID so cleanupRuntime's findFirecrackerPID pgrep // Seed a stale PID so cleanupRuntime's findFirecrackerPID pgrep
// fallback wins — it rediscovers fake.Process.Pid from apiSock. // fallback wins — it rediscovers fake.Process.Pid from apiSock.
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: fake.Process.Pid + 999}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: fake.Process.Pid + 999})
if err := d.cleanupRuntime(context.Background(), vm, true); err != nil { if err := d.vmSvc().cleanupRuntime(context.Background(), vm, true); err != nil {
t.Fatalf("cleanupRuntime returned error: %v", err) t.Fatalf("cleanupRuntime returned error: %v", err)
} }
runner.assertExhausted() runner.assertExhausted()
@ -1398,7 +1398,7 @@ func TestDeleteStoppedNATVMDoesNotFailWithoutTapDevice(t *testing.T) {
upsertDaemonVM(t, ctx, db, vm) upsertDaemonVM(t, ctx, db, vm)
d := &Daemon{store: db} d := &Daemon{store: db}
deleted, err := d.DeleteVM(ctx, vm.Name) deleted, err := d.vmSvc().DeleteVM(ctx, vm.Name)
if err != nil { if err != nil {
t.Fatalf("DeleteVM: %v", err) t.Fatalf("DeleteVM: %v", err)
} }
@ -1452,9 +1452,9 @@ func TestStopVMFallsBackToForcedCleanupAfterGracefulTimeout(t *testing.T) {
proc: fake, proc: fake,
} }
d := &Daemon{store: db, runner: runner} d := &Daemon{store: db, runner: runner}
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: fake.Process.Pid}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: fake.Process.Pid})
got, err := d.StopVM(ctx, vm.ID) got, err := d.vmSvc().StopVM(ctx, vm.ID)
if err != nil { if err != nil {
t.Fatalf("StopVM returned error: %v", err) t.Fatalf("StopVM returned error: %v", err)
} }
@ -1465,7 +1465,7 @@ func TestStopVMFallsBackToForcedCleanupAfterGracefulTimeout(t *testing.T) {
// APISockPath + VSock paths are deterministic — they stay on the // APISockPath + VSock paths are deterministic — they stay on the
// record for debugging and next-start reuse even after stop. The // record for debugging and next-start reuse even after stop. The
// post-stop invariant is that the in-memory cache is empty. // post-stop invariant is that the in-memory cache is empty.
if h, ok := d.handles.get(vm.ID); ok && !h.IsZero() { if h, ok := d.vmSvc().handles.get(vm.ID); ok && !h.IsZero() {
t.Fatalf("handle cache not cleared: %+v", h) t.Fatalf("handle cache not cleared: %+v", h)
} }
} }
@ -1483,7 +1483,7 @@ func TestWithVMLockByIDSerializesSameVM(t *testing.T) {
errCh := make(chan error, 2) errCh := make(chan error, 2)
go func() { go func() {
_, err := d.withVMLockByID(ctx, vm.ID, func(vm model.VMRecord) (model.VMRecord, error) { _, err := d.vmSvc().withVMLockByID(ctx, vm.ID, func(vm model.VMRecord) (model.VMRecord, error) {
close(firstEntered) close(firstEntered)
<-releaseFirst <-releaseFirst
return vm, nil return vm, nil
@ -1498,7 +1498,7 @@ func TestWithVMLockByIDSerializesSameVM(t *testing.T) {
} }
go func() { go func() {
_, err := d.withVMLockByID(ctx, vm.ID, func(vm model.VMRecord) (model.VMRecord, error) { _, err := d.vmSvc().withVMLockByID(ctx, vm.ID, func(vm model.VMRecord) (model.VMRecord, error) {
close(secondEntered) close(secondEntered)
return vm, nil return vm, nil
}) })
@ -1540,7 +1540,7 @@ func TestWithVMLockByIDAllowsDifferentVMsConcurrently(t *testing.T) {
release := make(chan struct{}) release := make(chan struct{})
errCh := make(chan error, 2) errCh := make(chan error, 2)
run := func(id string) { run := func(id string) {
_, err := d.withVMLockByID(ctx, id, func(vm model.VMRecord) (model.VMRecord, error) { _, err := d.vmSvc().withVMLockByID(ctx, id, func(vm model.VMRecord) (model.VMRecord, error) {
started <- vm.ID started <- vm.ID
<-release <-release
return vm, nil return vm, nil

View file

@ -91,20 +91,36 @@ func (d *Daemon) workspaceSvc() *WorkspaceService {
if d.ws != nil { if d.ws != nil {
return d.ws return d.ws
} }
// Peer seams capture d by closure instead of pointing to
// d.vmSvc() / d.imageSvc() directly. vmSvc() constructs VMService
// with WorkspaceService as a peer, so resolving the peer service
// eagerly here would recurse. Closures defer the lookup to call
// time, by which point the cycle is broken because d.vm / d.img
// are already populated.
d.ws = newWorkspaceService(workspaceServiceDeps{ d.ws = newWorkspaceService(workspaceServiceDeps{
runner: d.runner, runner: d.runner,
logger: d.logger, logger: d.logger,
config: d.config, config: d.config,
layout: d.layout, layout: d.layout,
store: d.store, store: d.store,
vmResolver: d.FindVM, vmResolver: func(ctx context.Context, idOrName string) (model.VMRecord, error) {
aliveChecker: d.vmAlive, return d.vmSvc().FindVM(ctx, idOrName)
waitGuestSSH: d.waitForGuestSSH, },
dialGuest: d.dialGuest, aliveChecker: func(vm model.VMRecord) bool {
imageResolver: d.FindImage, return d.vmSvc().vmAlive(vm)
imageWorkSeed: d.imageSvc().refreshManagedWorkSeedFingerprint, },
withVMLockByRef: d.withVMLockByRef, waitGuestSSH: d.waitForGuestSSH,
beginOperation: d.beginOperation, dialGuest: d.dialGuest,
imageResolver: func(ctx context.Context, idOrName string) (model.Image, error) {
return d.FindImage(ctx, idOrName)
},
imageWorkSeed: func(ctx context.Context, image model.Image, fingerprint string) error {
return d.imageSvc().refreshManagedWorkSeedFingerprint(ctx, image, fingerprint)
},
withVMLockByRef: func(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) {
return d.vmSvc().withVMLockByRef(ctx, idOrName, fn)
},
beginOperation: d.beginOperation,
}) })
return d.ws return d.ws
} }

View file

@ -94,7 +94,7 @@ func TestExportVMWorkspace_HappyPath(t *testing.T) {
} }
d := newExportTestDaemonStore(t, fake) d := newExportTestDaemonStore(t, fake)
upsertDaemonVM(t, ctx, d.store, vm) upsertDaemonVM(t, ctx, d.store, vm)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
result, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{ result, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{
IDOrName: vm.Name, IDOrName: vm.Name,
@ -155,7 +155,7 @@ func TestExportVMWorkspace_WithBaseCommit(t *testing.T) {
} }
d := newExportTestDaemonStore(t, fake) d := newExportTestDaemonStore(t, fake)
upsertDaemonVM(t, ctx, d.store, vm) upsertDaemonVM(t, ctx, d.store, vm)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
const prepareCommit = "abc1234deadbeef" const prepareCommit = "abc1234deadbeef"
result, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{ result, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{
@ -202,7 +202,7 @@ func TestExportVMWorkspace_BaseCommitFallsBackToHEAD(t *testing.T) {
} }
d := newExportTestDaemonStore(t, fake) d := newExportTestDaemonStore(t, fake)
upsertDaemonVM(t, ctx, d.store, vm) upsertDaemonVM(t, ctx, d.store, vm)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
result, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{ result, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{
IDOrName: vm.Name, IDOrName: vm.Name,
@ -242,7 +242,7 @@ func TestExportVMWorkspace_NoChanges(t *testing.T) {
} }
d := newExportTestDaemonStore(t, fake) d := newExportTestDaemonStore(t, fake)
upsertDaemonVM(t, ctx, d.store, vm) upsertDaemonVM(t, ctx, d.store, vm)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
result, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{ result, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{
IDOrName: vm.Name, IDOrName: vm.Name,
@ -281,7 +281,7 @@ func TestExportVMWorkspace_DefaultGuestPath(t *testing.T) {
} }
d := newExportTestDaemonStore(t, fake) d := newExportTestDaemonStore(t, fake)
upsertDaemonVM(t, ctx, d.store, vm) upsertDaemonVM(t, ctx, d.store, vm)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
// GuestPath omitted — should default to /root/repo. // GuestPath omitted — should default to /root/repo.
result, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{ result, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{
@ -341,7 +341,7 @@ func TestExportVMWorkspace_MultipleChangedFiles(t *testing.T) {
} }
d := newExportTestDaemonStore(t, fake) d := newExportTestDaemonStore(t, fake)
upsertDaemonVM(t, ctx, d.store, vm) upsertDaemonVM(t, ctx, d.store, vm)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
result, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{ result, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{
IDOrName: vm.Name, IDOrName: vm.Name,
@ -391,7 +391,7 @@ func TestPrepareVMWorkspace_ReleasesVMLockDuringGuestIO(t *testing.T) {
return &exportGuestClient{}, nil return &exportGuestClient{}, nil
} }
upsertDaemonVM(t, ctx, d.store, vm) upsertDaemonVM(t, ctx, d.store, vm)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
// Install the workspace seams on this daemon instance. InspectRepo // Install the workspace seams on this daemon instance. InspectRepo
// returns a trivial spec so the real filesystem isn't touched; // returns a trivial spec so the real filesystem isn't touched;
@ -429,7 +429,7 @@ func TestPrepareVMWorkspace_ReleasesVMLockDuringGuestIO(t *testing.T) {
// import is in flight. Acquiring it must not wait. // import is in flight. Acquiring it must not wait.
acquired := make(chan struct{}) acquired := make(chan struct{})
go func() { go func() {
unlock := d.lockVMID(vm.ID) unlock := d.vmSvc().lockVMID(vm.ID)
close(acquired) close(acquired)
unlock() unlock()
}() }()
@ -478,7 +478,7 @@ func TestPrepareVMWorkspace_SerialisesConcurrentPreparesOnSameVM(t *testing.T) {
return &exportGuestClient{}, nil return &exportGuestClient{}, nil
} }
upsertDaemonVM(t, ctx, d.store, vm) upsertDaemonVM(t, ctx, d.store, vm)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
d.workspaceSvc().workspaceInspectRepo = func(context.Context, string, string, string) (workspace.RepoSpec, error) { d.workspaceSvc().workspaceInspectRepo = func(context.Context, string, string, string) (workspace.RepoSpec, error) {
return workspace.RepoSpec{RepoName: "fake", RepoRoot: "/tmp/fake"}, nil return workspace.RepoSpec{RepoName: "fake", RepoRoot: "/tmp/fake"}, nil
@ -565,7 +565,7 @@ func TestExportVMWorkspace_DoesNotMutateRealIndex(t *testing.T) {
} }
d := newExportTestDaemonStore(t, fake) d := newExportTestDaemonStore(t, fake)
upsertDaemonVM(t, ctx, d.store, vm) upsertDaemonVM(t, ctx, d.store, vm)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) d.vmSvc().setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
if _, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{IDOrName: vm.Name}); err != nil { if _, err := d.workspaceSvc().ExportVMWorkspace(ctx, api.WorkspaceExportParams{IDOrName: vm.Name}); err != nil {
t.Fatalf("ExportVMWorkspace: %v", err) t.Fatalf("ExportVMWorkspace: %v", err)