diff --git a/internal/daemon/capabilities.go b/internal/daemon/capabilities.go index b4c18cd..a9e26fa 100644 --- a/internal/daemon/capabilities.go +++ b/internal/daemon/capabilities.go @@ -234,11 +234,11 @@ type dnsCapability struct{} func (dnsCapability) Name() string { return "dns" } func (dnsCapability) PostStart(ctx context.Context, d *Daemon, vm model.VMRecord, _ model.Image) error { - return d.setDNS(ctx, vm.Name, vm.Runtime.GuestIP) + return d.hostNet().setDNS(ctx, vm.Name, vm.Runtime.GuestIP) } -func (dnsCapability) Cleanup(ctx context.Context, d *Daemon, vm model.VMRecord) error { - return d.removeDNS(ctx, vm.Runtime.DNSName) +func (dnsCapability) Cleanup(_ context.Context, d *Daemon, vm model.VMRecord) error { + return d.hostNet().removeDNS(vm.Runtime.DNSName) } func (dnsCapability) AddDoctorChecks(_ context.Context, _ *Daemon, report *system.Report) { @@ -263,14 +263,14 @@ func (natCapability) AddStartPreflight(ctx context.Context, d *Daemon, checks *s if !vm.Spec.NATEnabled { return } - d.addNATPrereqs(ctx, checks) + d.hostNet().addNATPrereqs(ctx, checks) } func (natCapability) PostStart(ctx context.Context, d *Daemon, vm model.VMRecord, _ model.Image) error { if !vm.Spec.NATEnabled { return nil } - return d.ensureNAT(ctx, vm, true) + return d.hostNet().ensureNAT(ctx, vm.Runtime.GuestIP, d.vmHandles(vm.ID).TapDevice, true) } func (natCapability) Cleanup(ctx context.Context, d *Daemon, vm model.VMRecord) error { @@ -284,7 +284,7 @@ func (natCapability) Cleanup(ctx context.Context, d *Daemon, vm model.VMRecord) } return nil } - return d.ensureNAT(ctx, vm, false) + return d.hostNet().ensureNAT(ctx, vm.Runtime.GuestIP, tap, false) } func (natCapability) ApplyConfigChange(ctx context.Context, d *Daemon, before, after model.VMRecord) error { @@ -294,18 +294,18 @@ func (natCapability) ApplyConfigChange(ctx context.Context, d *Daemon, before, a if !d.vmAlive(after) { return nil } - return d.ensureNAT(ctx, after, after.Spec.NATEnabled) + return d.hostNet().ensureNAT(ctx, after.Runtime.GuestIP, d.vmHandles(after.ID).TapDevice, after.Spec.NATEnabled) } func (natCapability) AddDoctorChecks(ctx context.Context, d *Daemon, report *system.Report) { checks := system.NewPreflight() checks.RequireCommand("ip", toolHint("ip")) - d.addNATPrereqs(ctx, checks) + d.hostNet().addNATPrereqs(ctx, checks) if len(checks.Problems()) > 0 { report.Add(system.CheckStatusFail, "feature nat", checks.Problems()...) return } - uplink, err := d.defaultUplink(ctx) + uplink, err := d.hostNet().defaultUplink(ctx) if err != nil { report.AddFail("feature nat", err.Error()) return diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index c582826..2acbf8d 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -52,12 +52,11 @@ type Daemon struct { // lives in the store, this is rebuildable from a per-VM // handles.json scratch file and OS inspection. handles *handleCache - tapPool tapPool + net *HostNetwork closing chan struct{} once sync.Once pid int listener net.Listener - vmDNS *vmdns.Server vmCaps []vmCapability pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error) finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error @@ -90,15 +89,24 @@ func Open(ctx context.Context) (d *Daemon, err error) { if err != nil { return nil, err } + closing := make(chan struct{}) + runner := system.NewRunner() d = &Daemon{ layout: layout, config: cfg, store: db, - runner: system.NewRunner(), + runner: runner, logger: logger, - closing: make(chan struct{}), + closing: closing, pid: os.Getpid(), handles: newHandleCache(), + net: newHostNetwork(hostNetworkDeps{ + runner: runner, + logger: logger, + config: cfg, + layout: layout, + closing: closing, + }), } // From here on, every failure path must run Close() so the host // state we touched (DNS listener goroutine, resolvectl routing, @@ -114,7 +122,7 @@ func Open(ctx context.Context) (d *Daemon, err error) { d.ensureVMSSHClientConfig() d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "log_level", cfg.LogLevel) - if err = d.startVMDNS(vmdns.DefaultListenAddr); err != nil { + if err = d.hostNet().startVMDNS(vmdns.DefaultListenAddr); err != nil { d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error()) return nil, err } @@ -122,12 +130,24 @@ func Open(ctx context.Context) (d *Daemon, err error) { d.logger.Error("daemon open failed", "stage", "reconcile", "error", err.Error()) return nil, err } - d.ensureVMDNSResolverRouting(ctx) - if err = d.initializeTapPool(ctx); err != nil { - d.logger.Error("daemon open failed", "stage", "initialize_tap_pool", "error", err.Error()) - return nil, err + d.hostNet().ensureVMDNSResolverRouting(ctx) + // Seed HostNetwork's pool index from taps already claimed by VMs + // on disk so newly warmed pool entries don't collide with them. + if d.config.TapPoolSize > 0 && d.store != nil { + vms, listErr := d.store.ListVMs(ctx) + if listErr != nil { + d.logger.Error("daemon open failed", "stage", "initialize_tap_pool", "error", listErr.Error()) + return nil, listErr + } + used := make([]string, 0, len(vms)) + for _, vm := range vms { + if tap := d.vmHandles(vm.ID).TapDevice; tap != "" { + used = append(used, tap) + } + } + d.hostNet().initializeTapPool(used) } - go d.ensureTapPool(context.Background()) + go d.hostNet().ensureTapPool(context.Background()) return d, nil } @@ -141,7 +161,7 @@ func (d *Daemon) Close() error { if d.listener != nil { _ = d.listener.Close() } - err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.store.Close()) + err = errors.Join(d.hostNet().clearVMDNSResolverRouting(context.Background()), d.hostNet().stopVMDNS(), d.store.Close()) }) return err } @@ -518,27 +538,6 @@ func (d *Daemon) backgroundLoop() { } } -func (d *Daemon) startVMDNS(addr string) error { - server, err := vmdns.New(addr, d.logger) - if err != nil { - return err - } - d.vmDNS = server - if d.logger != nil { - d.logger.Info("vm dns serving", "dns_addr", server.Addr()) - } - return nil -} - -func (d *Daemon) stopVMDNS() error { - if d.vmDNS == nil { - return nil - } - err := d.vmDNS.Close() - d.vmDNS = nil - return err -} - func (d *Daemon) ensureDefaultImage(ctx context.Context) error { _ = ctx return nil diff --git a/internal/daemon/dns_routing.go b/internal/daemon/dns_routing.go index 0b9a14e..0160488 100644 --- a/internal/daemon/dns_routing.go +++ b/internal/daemon/dns_routing.go @@ -15,49 +15,49 @@ var ( vmDNSAddrFunc = func(server *vmdns.Server) string { return server.Addr() } ) -func (d *Daemon) syncVMDNSResolverRouting(ctx context.Context) error { - if d == nil || d.vmDNS == nil { +func (n *HostNetwork) syncVMDNSResolverRouting(ctx context.Context) error { + if n == nil || n.vmDNS == nil { return nil } - if strings.TrimSpace(d.config.BridgeName) == "" { + if strings.TrimSpace(n.config.BridgeName) == "" { return nil } if _, err := lookupExecutableFunc("resolvectl"); err != nil { return nil } - if _, err := d.runner.Run(ctx, "ip", "link", "show", d.config.BridgeName); err != nil { + if _, err := n.runner.Run(ctx, "ip", "link", "show", n.config.BridgeName); err != nil { return nil } - serverAddr := strings.TrimSpace(vmDNSAddrFunc(d.vmDNS)) + serverAddr := strings.TrimSpace(vmDNSAddrFunc(n.vmDNS)) if serverAddr == "" { return nil } - if _, err := d.runner.RunSudo(ctx, "resolvectl", "dns", d.config.BridgeName, serverAddr); err != nil { + if _, err := n.runner.RunSudo(ctx, "resolvectl", "dns", n.config.BridgeName, serverAddr); err != nil { return err } - if _, err := d.runner.RunSudo(ctx, "resolvectl", "domain", d.config.BridgeName, vmResolverRouteDomain); err != nil { + if _, err := n.runner.RunSudo(ctx, "resolvectl", "domain", n.config.BridgeName, vmResolverRouteDomain); err != nil { return err } - _, err := d.runner.RunSudo(ctx, "resolvectl", "default-route", d.config.BridgeName, "no") + _, err := n.runner.RunSudo(ctx, "resolvectl", "default-route", n.config.BridgeName, "no") return err } -func (d *Daemon) clearVMDNSResolverRouting(ctx context.Context) error { - if d == nil || strings.TrimSpace(d.config.BridgeName) == "" { +func (n *HostNetwork) clearVMDNSResolverRouting(ctx context.Context) error { + if n == nil || strings.TrimSpace(n.config.BridgeName) == "" { return nil } if _, err := lookupExecutableFunc("resolvectl"); err != nil { return nil } - if _, err := d.runner.Run(ctx, "ip", "link", "show", d.config.BridgeName); err != nil { + if _, err := n.runner.Run(ctx, "ip", "link", "show", n.config.BridgeName); err != nil { return nil } - _, err := d.runner.RunSudo(ctx, "resolvectl", "revert", d.config.BridgeName) + _, err := n.runner.RunSudo(ctx, "resolvectl", "revert", n.config.BridgeName) return err } -func (d *Daemon) ensureVMDNSResolverRouting(ctx context.Context) { - if err := d.syncVMDNSResolverRouting(ctx); err != nil && d.logger != nil { - d.logger.Warn("vm dns resolver route sync failed", "bridge", d.config.BridgeName, "error", err.Error()) +func (n *HostNetwork) ensureVMDNSResolverRouting(ctx context.Context) { + if err := n.syncVMDNSResolverRouting(ctx); err != nil && n.logger != nil { + n.logger.Warn("vm dns resolver route sync failed", "bridge", n.config.BridgeName, "error", err.Error()) } } diff --git a/internal/daemon/dns_routing_test.go b/internal/daemon/dns_routing_test.go index 1bd8f6c..bc53945 100644 --- a/internal/daemon/dns_routing_test.go +++ b/internal/daemon/dns_routing_test.go @@ -32,13 +32,10 @@ func TestSyncVMDNSResolverRoutingConfiguresResolved(t *testing.T) { sudoStep("", nil, "resolvectl", "default-route", model.DefaultBridgeName, "no"), }, } - d := &Daemon{ - runner: runner, - config: model.DaemonConfig{BridgeName: model.DefaultBridgeName}, - vmDNS: new(vmdns.Server), - } + cfg := model.DaemonConfig{BridgeName: model.DefaultBridgeName} + n := &HostNetwork{runner: runner, config: cfg, vmDNS: new(vmdns.Server)} - if err := d.syncVMDNSResolverRouting(context.Background()); err != nil { + if err := n.syncVMDNSResolverRouting(context.Background()); err != nil { t.Fatalf("syncVMDNSResolverRouting: %v", err) } runner.assertExhausted() @@ -63,12 +60,10 @@ func TestClearVMDNSResolverRoutingRevertsBridgeConfig(t *testing.T) { sudoStep("", nil, "resolvectl", "revert", model.DefaultBridgeName), }, } - d := &Daemon{ - runner: runner, - config: model.DaemonConfig{BridgeName: model.DefaultBridgeName}, - } + cfg := model.DaemonConfig{BridgeName: model.DefaultBridgeName} + n := &HostNetwork{runner: runner, config: cfg} - if err := d.clearVMDNSResolverRouting(context.Background()); err != nil { + if err := n.clearVMDNSResolverRouting(context.Background()); err != nil { t.Fatalf("clearVMDNSResolverRouting: %v", err) } runner.assertExhausted() diff --git a/internal/daemon/fastpath_test.go b/internal/daemon/fastpath_test.go index aeafe7e..bd28533 100644 --- a/internal/daemon/fastpath_test.go +++ b/internal/daemon/fastpath_test.go @@ -75,18 +75,18 @@ func TestTapPoolWarmsAndReusesIdleTap(t *testing.T) { closing: make(chan struct{}), } - d.ensureTapPool(context.Background()) - tapName, err := d.acquireTap(context.Background(), "tap-fallback") + d.hostNet().ensureTapPool(context.Background()) + tapName, err := d.hostNet().acquireTap(context.Background(), "tap-fallback") if err != nil { t.Fatalf("acquireTap: %v", err) } if tapName != "tap-pool-0" { t.Fatalf("tapName = %q, want tap-pool-0", tapName) } - if err := d.releaseTap(context.Background(), tapName); err != nil { + if err := d.hostNet().releaseTap(context.Background(), tapName); err != nil { t.Fatalf("releaseTap: %v", err) } - tapName, err = d.acquireTap(context.Background(), "tap-fallback") + tapName, err = d.hostNet().acquireTap(context.Background(), "tap-fallback") if err != nil { t.Fatalf("acquireTap second time: %v", err) } diff --git a/internal/daemon/host_network.go b/internal/daemon/host_network.go new file mode 100644 index 0000000..d587d88 --- /dev/null +++ b/internal/daemon/host_network.go @@ -0,0 +1,242 @@ +package daemon + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "path/filepath" + "strings" + "time" + + "banger/internal/daemon/fcproc" + "banger/internal/firecracker" + "banger/internal/model" + "banger/internal/paths" + "banger/internal/system" + "banger/internal/vmdns" + "banger/internal/vsockagent" +) + +// HostNetwork owns the daemon's side of host networking: the TAP +// interface pool, the bridge, per-VM tap/NAT/DNS wiring, and the +// firecracker-process primitives (bridge setup, socket access, +// pgrep-based PID resolution, ctrl-alt-del, wait/kill) plus DM +// snapshot helpers. The Daemon holds one *HostNetwork and routes +// lifecycle calls through it instead of reaching into host-state +// directly. +// +// Fields stay unexported so peer services (VMService, etc.) access +// HostNetwork only through consumer-defined interfaces, not by +// fishing around in its struct. Construction goes through +// newHostNetwork with an explicit dependency bag so the wiring is +// auditable. +type HostNetwork struct { + runner system.CommandRunner + logger *slog.Logger + config model.DaemonConfig + layout paths.Layout + closing chan struct{} + + tapPool tapPool + vmDNS *vmdns.Server +} + +// hostNetworkDeps is the explicit wiring bag newHostNetwork expects. +// Keeping the deps in a dedicated struct rather than positional args +// makes the construction site in Daemon.Open read like a declaration. +type hostNetworkDeps struct { + runner system.CommandRunner + logger *slog.Logger + config model.DaemonConfig + layout paths.Layout + closing chan struct{} +} + +func newHostNetwork(deps hostNetworkDeps) *HostNetwork { + return &HostNetwork{ + runner: deps.runner, + logger: deps.logger, + config: deps.config, + layout: deps.layout, + closing: deps.closing, + } +} + +// hostNet returns the HostNetwork service, lazily constructing it from +// the Daemon's current fields if a test literal didn't wire one up. +// Production paths go through Daemon.Open, which always populates d.net +// eagerly; this lazy path exists only so tests that build `&Daemon{...}` +// literals without spelling out a HostNetwork don't have to learn the +// new construction pattern. Every call from production code that +// touches HostNetwork funnels through here. +func (d *Daemon) hostNet() *HostNetwork { + if d.net != nil { + return d.net + } + d.net = newHostNetwork(hostNetworkDeps{ + runner: d.runner, + logger: d.logger, + config: d.config, + layout: d.layout, + closing: d.closing, + }) + return d.net +} + +// --- DNS server lifecycle ------------------------------------------- + +func (n *HostNetwork) startVMDNS(addr string) error { + server, err := vmdns.New(addr, n.logger) + if err != nil { + return err + } + n.vmDNS = server + if n.logger != nil { + n.logger.Info("vm dns serving", "dns_addr", server.Addr()) + } + return nil +} + +func (n *HostNetwork) stopVMDNS() error { + if n.vmDNS == nil { + return nil + } + err := n.vmDNS.Close() + n.vmDNS = nil + return err +} + +func (n *HostNetwork) setDNS(ctx context.Context, vmName, guestIP string) error { + if n.vmDNS == nil { + return nil + } + if err := n.vmDNS.Set(vmdns.RecordName(vmName), guestIP); err != nil { + return err + } + n.ensureVMDNSResolverRouting(ctx) + return nil +} + +func (n *HostNetwork) removeDNS(dnsName string) error { + if dnsName == "" || n.vmDNS == nil { + return nil + } + return n.vmDNS.Remove(dnsName) +} + +// replaceDNS replaces the DNS server's full record set. Callers +// (Daemon.rebuildDNS) filter by vm-alive first; HostNetwork just +// takes the pre-filtered map. +func (n *HostNetwork) replaceDNS(records map[string]string) error { + if n.vmDNS == nil { + return nil + } + return n.vmDNS.Replace(records) +} + +// --- Firecracker process helpers ------------------------------------ + +// fc builds a fresh fcproc.Manager from the HostNetwork's current +// runner, config, and layout. Manager is stateless beyond those +// handles, so constructing per call keeps tests that build literals +// working without extra wiring. +func (n *HostNetwork) fc() *fcproc.Manager { + return fcproc.New(n.runner, fcproc.Config{ + FirecrackerBin: n.config.FirecrackerBin, + BridgeName: n.config.BridgeName, + BridgeIP: n.config.BridgeIP, + CIDR: n.config.CIDR, + RuntimeDir: n.layout.RuntimeDir, + }, n.logger) +} + +func (n *HostNetwork) ensureBridge(ctx context.Context) error { + return n.fc().EnsureBridge(ctx) +} + +func (n *HostNetwork) ensureSocketDir() error { + return n.fc().EnsureSocketDir() +} + +func (n *HostNetwork) createTap(ctx context.Context, tap string) error { + return n.fc().CreateTap(ctx, tap) +} + +func (n *HostNetwork) firecrackerBinary() (string, error) { + return n.fc().ResolveBinary() +} + +func (n *HostNetwork) ensureSocketAccess(ctx context.Context, socketPath, label string) error { + return n.fc().EnsureSocketAccess(ctx, socketPath, label) +} + +func (n *HostNetwork) findFirecrackerPID(ctx context.Context, apiSock string) (int, error) { + return n.fc().FindPID(ctx, apiSock) +} + +func (n *HostNetwork) resolveFirecrackerPID(ctx context.Context, machine *firecracker.Machine, apiSock string) int { + return n.fc().ResolvePID(ctx, machine, apiSock) +} + +func (n *HostNetwork) sendCtrlAltDel(ctx context.Context, apiSockPath string) error { + return n.fc().SendCtrlAltDel(ctx, apiSockPath) +} + +func (n *HostNetwork) waitForExit(ctx context.Context, pid int, apiSock string, timeout time.Duration) error { + return n.fc().WaitForExit(ctx, pid, apiSock, timeout) +} + +func (n *HostNetwork) killVMProcess(ctx context.Context, pid int) error { + return n.fc().Kill(ctx, pid) +} + +// waitForGuestVSockAgent is a HostNetwork helper because it's +// fundamentally about waiting for a vsock socket the firecracker +// process is serving on. No daemon state needed. +func (n *HostNetwork) waitForGuestVSockAgent(ctx context.Context, socketPath string, timeout time.Duration) error { + if strings.TrimSpace(socketPath) == "" { + return errors.New("vsock path is required") + } + + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + ticker := time.NewTicker(vsockReadyPoll) + defer ticker.Stop() + + var lastErr error + for { + pingCtx, pingCancel := context.WithTimeout(waitCtx, 3*time.Second) + err := vsockagent.Health(pingCtx, n.logger, socketPath) + pingCancel() + if err == nil { + return nil + } + lastErr = err + + select { + case <-waitCtx.Done(): + if lastErr != nil { + return fmt.Errorf("guest vsock agent not ready: %w", lastErr) + } + return errors.New("guest vsock agent not ready before timeout") + case <-ticker.C: + } + } +} + +// --- Utilities used across networking ------------------------------ + +func defaultVSockPath(runtimeDir, vmID string) string { + return filepath.Join(runtimeDir, "fc-"+system.ShortID(vmID)+".vsock") +} + +func defaultVSockCID(guestIP string) (uint32, error) { + ip := net.ParseIP(strings.TrimSpace(guestIP)).To4() + if ip == nil { + return 0, fmt.Errorf("guest IP is not IPv4: %q", guestIP) + } + return 10000 + uint32(ip[3]), nil +} diff --git a/internal/daemon/nat.go b/internal/daemon/nat.go index b0d4231..a879f54 100644 --- a/internal/daemon/nat.go +++ b/internal/daemon/nat.go @@ -10,22 +10,43 @@ import ( type natRule = hostnat.Rule -func (d *Daemon) ensureNAT(ctx context.Context, vm model.VMRecord, enable bool) error { - return hostnat.Ensure(ctx, d.runner, vm.Runtime.GuestIP, d.vmHandles(vm.ID).TapDevice, enable) +// ensureNAT takes tap explicitly rather than reading from a handle +// cache so HostNetwork stays decoupled from VM-service state. +// Callers (vm_lifecycle) resolve the tap device from the handle cache +// themselves and pass it in. +func (n *HostNetwork) ensureNAT(ctx context.Context, guestIP, tap string, enable bool) error { + return hostnat.Ensure(ctx, n.runner, guestIP, tap, enable) } -func (d *Daemon) validateNATPrereqs(ctx context.Context) (string, error) { +func (n *HostNetwork) validateNATPrereqs(ctx context.Context) (string, error) { checks := system.NewPreflight() checks.RequireCommand("ip", toolHint("ip")) - d.addNATPrereqs(ctx, checks) + n.addNATPrereqs(ctx, checks) if err := checks.Err("nat preflight failed"); err != nil { return "", err } - return d.defaultUplink(ctx) + return n.defaultUplink(ctx) } -func (d *Daemon) defaultUplink(ctx context.Context) (string, error) { - return hostnat.DefaultUplink(ctx, d.runner) +func (n *HostNetwork) addNATPrereqs(ctx context.Context, checks *system.Preflight) { + checks.RequireCommand("iptables", toolHint("iptables")) + checks.RequireCommand("sysctl", toolHint("sysctl")) + runner := n.runner + if runner == nil { + runner = system.NewRunner() + } + out, err := runner.Run(ctx, "ip", "route", "show", "default") + if err != nil { + checks.Addf("failed to inspect the default route for NAT: %v", err) + return + } + if _, err := parseDefaultUplink(string(out)); err != nil { + checks.Addf("failed to detect the uplink interface for NAT: %v", err) + } +} + +func (n *HostNetwork) defaultUplink(ctx context.Context) (string, error) { + return hostnat.DefaultUplink(ctx, n.runner) } func parseDefaultUplink(output string) (string, error) { diff --git a/internal/daemon/open_close_test.go b/internal/daemon/open_close_test.go index 1fb4d3a..7a386d0 100644 --- a/internal/daemon/open_close_test.go +++ b/internal/daemon/open_close_test.go @@ -50,12 +50,12 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) { return &Daemon{ store: openDaemonStore(t), closing: make(chan struct{}), - vmDNS: server, + net: &HostNetwork{vmDNS: server}, logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } }, verify: func(t *testing.T, d *Daemon) { - if d.vmDNS != nil { + if d.hostNet().vmDNS != nil { t.Error("vmDNS not cleared by Close") } }, diff --git a/internal/daemon/ports.go b/internal/daemon/ports.go index 40ab0c0..58c088f 100644 --- a/internal/daemon/ports.go +++ b/internal/daemon/ports.go @@ -40,7 +40,7 @@ func (d *Daemon) PortsVM(ctx context.Context, idOrName string) (result api.VMPor if vm.Runtime.VSockCID == 0 { return model.VMRecord{}, errors.New("vm has no vsock cid") } - if err := d.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { + if err := d.hostNet().ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { return model.VMRecord{}, err } portsCtx, cancel := context.WithTimeout(ctx, 3*time.Second) diff --git a/internal/daemon/preflight.go b/internal/daemon/preflight.go index 7ff9fa6..1ca2a8b 100644 --- a/internal/daemon/preflight.go +++ b/internal/daemon/preflight.go @@ -25,23 +25,6 @@ func (d *Daemon) validateWorkDiskResizePrereqs() error { return checks.Err("work disk resize preflight failed") } -func (d *Daemon) addNATPrereqs(ctx context.Context, checks *system.Preflight) { - checks.RequireCommand("iptables", toolHint("iptables")) - checks.RequireCommand("sysctl", toolHint("sysctl")) - runner := d.runner - if runner == nil { - runner = system.NewRunner() - } - out, err := runner.Run(ctx, "ip", "route", "show", "default") - if err != nil { - checks.Addf("failed to inspect the default route for NAT: %v", err) - return - } - if _, err := parseDefaultUplink(string(out)); err != nil { - checks.Addf("failed to detect the uplink interface for NAT: %v", err) - } -} - func (d *Daemon) addBaseStartPrereqs(checks *system.Preflight, image model.Image) { d.addBaseStartCommandPrereqs(checks) checks.RequireExecutable(d.config.FirecrackerBin, "firecracker binary", `install firecracker or set "firecracker_bin"`) diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go index 78da1f9..5835197 100644 --- a/internal/daemon/snapshot.go +++ b/internal/daemon/snapshot.go @@ -10,14 +10,14 @@ import ( // type so existing call sites and tests read naturally. type dmSnapshotHandles = dmsnap.Handles -func (d *Daemon) createDMSnapshot(ctx context.Context, rootfsPath, cowPath, dmName string) (dmSnapshotHandles, error) { - return dmsnap.Create(ctx, d.runner, rootfsPath, cowPath, dmName) +func (n *HostNetwork) createDMSnapshot(ctx context.Context, rootfsPath, cowPath, dmName string) (dmSnapshotHandles, error) { + return dmsnap.Create(ctx, n.runner, rootfsPath, cowPath, dmName) } -func (d *Daemon) cleanupDMSnapshot(ctx context.Context, handles dmSnapshotHandles) error { - return dmsnap.Cleanup(ctx, d.runner, handles) +func (n *HostNetwork) cleanupDMSnapshot(ctx context.Context, handles dmSnapshotHandles) error { + return dmsnap.Cleanup(ctx, n.runner, handles) } -func (d *Daemon) removeDMSnapshot(ctx context.Context, target string) error { - return dmsnap.Remove(ctx, d.runner, target) +func (n *HostNetwork) removeDMSnapshot(ctx context.Context, target string) error { + return dmsnap.Remove(ctx, n.runner, target) } diff --git a/internal/daemon/snapshot_test.go b/internal/daemon/snapshot_test.go index 2411206..35fad2a 100644 --- a/internal/daemon/snapshot_test.go +++ b/internal/daemon/snapshot_test.go @@ -74,7 +74,7 @@ func TestCreateDMSnapshotFailsWithoutRollbackWhenBaseLoopSetupFails(t *testing.T } d := &Daemon{runner: runner} - _, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") + _, err := d.hostNet().createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") if !errors.Is(err, attachErr) { t.Fatalf("error = %v, want %v", err, attachErr) } @@ -98,7 +98,7 @@ func TestCreateDMSnapshotRollsBackBaseLoopWhenCowLoopSetupFails(t *testing.T) { } d := &Daemon{runner: runner} - _, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") + _, err := d.hostNet().createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") if !errors.Is(err, attachErr) { t.Fatalf("error = %v, want %v", err, attachErr) } @@ -121,7 +121,7 @@ func TestCreateDMSnapshotRollsBackBothLoopsWhenBlockdevFails(t *testing.T) { } d := &Daemon{runner: runner} - _, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") + _, err := d.hostNet().createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") if !errors.Is(err, blockdevErr) { t.Fatalf("error = %v, want %v", err, blockdevErr) } @@ -145,7 +145,7 @@ func TestCreateDMSnapshotRollsBackLoopsWhenDMSetupFails(t *testing.T) { } d := &Daemon{runner: runner} - _, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") + _, err := d.hostNet().createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") if !errors.Is(err, dmErr) { t.Fatalf("error = %v, want %v", err, dmErr) } @@ -174,7 +174,7 @@ func TestCreateDMSnapshotJoinsRollbackErrors(t *testing.T) { } d := &Daemon{runner: runner} - _, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") + _, err := d.hostNet().createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") if err == nil { t.Fatal("expected createDMSnapshot to return an error") } @@ -198,7 +198,7 @@ func TestCreateDMSnapshotReturnsHandlesOnSuccess(t *testing.T) { } d := &Daemon{runner: runner} - handles, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") + handles, err := d.hostNet().createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") if err != nil { t.Fatalf("createDMSnapshot returned error: %v", err) } @@ -227,7 +227,7 @@ func TestCleanupDMSnapshotRemovesResourcesInReverseOrder(t *testing.T) { } d := &Daemon{runner: runner} - err := d.cleanupDMSnapshot(context.Background(), dmSnapshotHandles{ + err := d.hostNet().cleanupDMSnapshot(context.Background(), dmSnapshotHandles{ BaseLoop: "/dev/loop10", COWLoop: "/dev/loop11", DMName: "fc-rootfs-test", @@ -251,7 +251,7 @@ func TestCleanupDMSnapshotUsesPartialHandles(t *testing.T) { } d := &Daemon{runner: runner} - err := d.cleanupDMSnapshot(context.Background(), dmSnapshotHandles{ + err := d.hostNet().cleanupDMSnapshot(context.Background(), dmSnapshotHandles{ BaseLoop: "/dev/loop10", DMDev: "/dev/mapper/fc-rootfs-test", }) @@ -277,7 +277,7 @@ func TestCleanupDMSnapshotJoinsTeardownErrors(t *testing.T) { } d := &Daemon{runner: runner} - err := d.cleanupDMSnapshot(context.Background(), dmSnapshotHandles{ + err := d.hostNet().cleanupDMSnapshot(context.Background(), dmSnapshotHandles{ BaseLoop: "/dev/loop10", COWLoop: "/dev/loop11", DMName: "fc-rootfs-test", @@ -307,7 +307,7 @@ func TestRemoveDMSnapshotRetriesBusyDevice(t *testing.T) { } d := &Daemon{runner: runner} - if err := d.removeDMSnapshot(context.Background(), "fc-rootfs-test"); err != nil { + if err := d.hostNet().removeDMSnapshot(context.Background(), "fc-rootfs-test"); err != nil { t.Fatalf("removeDMSnapshot returned error: %v", err) } runner.assertExhausted() diff --git a/internal/daemon/tap_pool.go b/internal/daemon/tap_pool.go index 9d5e172..88cf373 100644 --- a/internal/daemon/tap_pool.go +++ b/internal/daemon/tap_pool.go @@ -18,98 +18,97 @@ type tapPool struct { next int } -func (d *Daemon) initializeTapPool(ctx context.Context) error { - if d.config.TapPoolSize <= 0 || d.store == nil { - return nil - } - vms, err := d.store.ListVMs(ctx) - if err != nil { - return err +// initializeTapPool seeds the monotonic pool index from the set of +// tap names already in use by running/stopped VMs, so newly warmed +// pool entries don't collide with existing ones. Callers (Daemon.Open) +// enumerate used taps from the handle cache and pass them in. +func (n *HostNetwork) initializeTapPool(usedTaps []string) { + if n.config.TapPoolSize <= 0 { + return } next := 0 - for _, vm := range vms { - if index, ok := parseTapPoolIndex(d.vmHandles(vm.ID).TapDevice); ok && index >= next { + for _, tapName := range usedTaps { + if index, ok := parseTapPoolIndex(tapName); ok && index >= next { next = index + 1 } } - d.tapPool.mu.Lock() - d.tapPool.next = next - d.tapPool.mu.Unlock() - return nil + n.tapPool.mu.Lock() + n.tapPool.next = next + n.tapPool.mu.Unlock() } -func (d *Daemon) ensureTapPool(ctx context.Context) { - if d.config.TapPoolSize <= 0 { +func (n *HostNetwork) ensureTapPool(ctx context.Context) { + if n.config.TapPoolSize <= 0 { return } for { select { case <-ctx.Done(): return - case <-d.closing: + case <-n.closing: return default: } - d.tapPool.mu.Lock() - if len(d.tapPool.entries) >= d.config.TapPoolSize { - d.tapPool.mu.Unlock() + n.tapPool.mu.Lock() + if len(n.tapPool.entries) >= n.config.TapPoolSize { + n.tapPool.mu.Unlock() return } - tapName := fmt.Sprintf("%s%d", tapPoolPrefix, d.tapPool.next) - d.tapPool.next++ - d.tapPool.mu.Unlock() + tapName := fmt.Sprintf("%s%d", tapPoolPrefix, n.tapPool.next) + n.tapPool.next++ + n.tapPool.mu.Unlock() - if err := d.createTap(ctx, tapName); err != nil { - if d.logger != nil { - d.logger.Warn("tap pool warmup failed", "tap_device", tapName, "error", err.Error()) + if err := n.createTap(ctx, tapName); err != nil { + if n.logger != nil { + n.logger.Warn("tap pool warmup failed", "tap_device", tapName, "error", err.Error()) } return } - d.tapPool.mu.Lock() - d.tapPool.entries = append(d.tapPool.entries, tapName) - d.tapPool.mu.Unlock() + n.tapPool.mu.Lock() + n.tapPool.entries = append(n.tapPool.entries, tapName) + n.tapPool.mu.Unlock() - if d.logger != nil { - d.logger.Debug("tap added to idle pool", "tap_device", tapName) + if n.logger != nil { + n.logger.Debug("tap added to idle pool", "tap_device", tapName) } } } -func (d *Daemon) acquireTap(ctx context.Context, fallbackName string) (string, error) { - d.tapPool.mu.Lock() - if n := len(d.tapPool.entries); n > 0 { - tapName := d.tapPool.entries[n-1] - d.tapPool.entries = d.tapPool.entries[:n-1] - d.tapPool.mu.Unlock() +func (n *HostNetwork) acquireTap(ctx context.Context, fallbackName string) (string, error) { + n.tapPool.mu.Lock() + if count := len(n.tapPool.entries); count > 0 { + tapName := n.tapPool.entries[count-1] + n.tapPool.entries = n.tapPool.entries[:count-1] + n.tapPool.mu.Unlock() return tapName, nil } - d.tapPool.mu.Unlock() + n.tapPool.mu.Unlock() - if err := d.createTap(ctx, fallbackName); err != nil { + if err := n.createTap(ctx, fallbackName); err != nil { return "", err } return fallbackName, nil } -func (d *Daemon) releaseTap(ctx context.Context, tapName string) error { +func (n *HostNetwork) releaseTap(ctx context.Context, tapName string) error { tapName = strings.TrimSpace(tapName) if tapName == "" { return nil } if isTapPoolName(tapName) { - d.tapPool.mu.Lock() - if len(d.tapPool.entries) < d.config.TapPoolSize { - d.tapPool.entries = append(d.tapPool.entries, tapName) - d.tapPool.mu.Unlock() + n.tapPool.mu.Lock() + if len(n.tapPool.entries) < n.config.TapPoolSize { + n.tapPool.entries = append(n.tapPool.entries, tapName) + n.tapPool.mu.Unlock() return nil } - d.tapPool.mu.Unlock() + n.tapPool.mu.Unlock() } - _, err := d.runner.RunSudo(ctx, "ip", "link", "del", tapName) + _, err := n.runner.RunSudo(ctx, "ip", "link", "del", tapName) if err == nil { - go d.ensureTapPool(context.Background()) + go n.ensureTapPool(context.Background()) } return err } diff --git a/internal/daemon/vm.go b/internal/daemon/vm.go index 6c4ed35..37f9aab 100644 --- a/internal/daemon/vm.go +++ b/internal/daemon/vm.go @@ -4,23 +4,20 @@ import ( "context" "errors" "fmt" - "log/slog" - "net" "os" - "path/filepath" "strconv" "strings" "time" "banger/internal/daemon/fcproc" - "banger/internal/firecracker" "banger/internal/model" "banger/internal/namegen" "banger/internal/system" - "banger/internal/vmdns" - "banger/internal/vsockagent" ) +// Cross-service constants. Kept in vm.go because both lifecycle +// (VMService) and networking (HostNetwork) reference them; moving +// them to either owner would read as a layering violation. var ( errWaitForExitTimeout = fcproc.ErrWaitForExitTimeout gracefulShutdownWait = 10 * time.Second @@ -28,59 +25,43 @@ var ( vsockReadyPoll = 200 * time.Millisecond ) -// fc builds a fresh fcproc.Manager from the Daemon's current runner, config, -// and layout. Manager is stateless beyond those handles, so constructing per -// call keeps tests that build Daemon literals working without extra wiring. -func (d *Daemon) fc() *fcproc.Manager { - return fcproc.New(d.runner, fcproc.Config{ - FirecrackerBin: d.config.FirecrackerBin, - BridgeName: d.config.BridgeName, - BridgeIP: d.config.BridgeIP, - CIDR: d.config.CIDR, - RuntimeDir: d.layout.RuntimeDir, - }, d.logger) +// rebuildDNS enumerates live VMs and republishes the DNS record set. +// Lives on *Daemon (not HostNetwork) because "alive" is a VMService +// concern that HostNetwork shouldn't need to reach into. Daemon +// orchestrates: VM list from the store, alive filter, hand the +// resulting map to HostNetwork.replaceDNS. +func (d *Daemon) rebuildDNS(ctx context.Context) error { + if d.net == nil { + return nil + } + vms, err := d.store.ListVMs(ctx) + if err != nil { + return err + } + records := make(map[string]string) + for _, vm := range vms { + if !d.vmAlive(vm) { + continue + } + if strings.TrimSpace(vm.Runtime.GuestIP) == "" { + continue + } + records[vmDNSRecordName(vm.Name)] = vm.Runtime.GuestIP + } + return d.hostNet().replaceDNS(records) } -func (d *Daemon) ensureBridge(ctx context.Context) error { - return d.fc().EnsureBridge(ctx) -} - -func (d *Daemon) ensureSocketDir() error { - return d.fc().EnsureSocketDir() -} - -func (d *Daemon) createTap(ctx context.Context, tap string) error { - return d.fc().CreateTap(ctx, tap) -} - -func (d *Daemon) firecrackerBinary() (string, error) { - return d.fc().ResolveBinary() -} - -func (d *Daemon) ensureSocketAccess(ctx context.Context, socketPath, label string) error { - return d.fc().EnsureSocketAccess(ctx, socketPath, label) -} - -func (d *Daemon) findFirecrackerPID(ctx context.Context, apiSock string) (int, error) { - return d.fc().FindPID(ctx, apiSock) -} - -func (d *Daemon) resolveFirecrackerPID(ctx context.Context, machine *firecracker.Machine, apiSock string) int { - return d.fc().ResolvePID(ctx, machine, apiSock) -} - -func (d *Daemon) sendCtrlAltDel(ctx context.Context, vm model.VMRecord) error { - return d.fc().SendCtrlAltDel(ctx, vm.Runtime.APISockPath) -} - -func (d *Daemon) waitForExit(ctx context.Context, pid int, apiSock string, timeout time.Duration) error { - return d.fc().WaitForExit(ctx, pid, apiSock, timeout) -} - -func (d *Daemon) killVMProcess(ctx context.Context, pid int) error { - return d.fc().Kill(ctx, pid) +// vmDNSRecordName is a small indirection so the dns-record-name +// helper is not directly pulled into every file that used to import +// vmdns for this one call. Equivalent to vmdns.RecordName. +func vmDNSRecordName(name string) string { + return strings.ToLower(strings.TrimSpace(name)) + ".vm" } +// cleanupRuntime tears down the host-side state for a VM: firecracker +// process, DM snapshot, capabilities, tap, sockets. Stays on *Daemon +// for now because it reaches into handles (VMService-owned) and +// capabilities (still on Daemon). Phase 4 will move it to VMService. func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserveDisks bool) error { if d.logger != nil { d.logger.Debug("cleanup runtime", append(vmLogAttrs(vm), "preserve_disks", preserveDisks)...) @@ -88,17 +69,17 @@ func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserve h := d.vmHandles(vm.ID) cleanupPID := h.PID if vm.Runtime.APISockPath != "" { - if pid, err := d.findFirecrackerPID(ctx, vm.Runtime.APISockPath); err == nil && pid > 0 { + if pid, err := d.hostNet().findFirecrackerPID(ctx, vm.Runtime.APISockPath); err == nil && pid > 0 { cleanupPID = pid } } if cleanupPID > 0 && system.ProcessRunning(cleanupPID, vm.Runtime.APISockPath) { - _ = d.killVMProcess(ctx, cleanupPID) - if err := d.waitForExit(ctx, cleanupPID, vm.Runtime.APISockPath, 30*time.Second); err != nil { + _ = d.hostNet().killVMProcess(ctx, cleanupPID) + if err := d.hostNet().waitForExit(ctx, cleanupPID, vm.Runtime.APISockPath, 30*time.Second); err != nil { return err } } - snapshotErr := d.cleanupDMSnapshot(ctx, dmSnapshotHandles{ + snapshotErr := d.hostNet().cleanupDMSnapshot(ctx, dmSnapshotHandles{ BaseLoop: h.BaseLoop, COWLoop: h.COWLoop, DMName: h.DMName, @@ -107,7 +88,7 @@ func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserve featureErr := d.cleanupCapabilityState(ctx, vm) var tapErr error if h.TapDevice != "" { - tapErr = d.releaseTap(ctx, h.TapDevice) + tapErr = d.hostNet().releaseTap(ctx, h.TapDevice) } if vm.Runtime.APISockPath != "" { _ = os.Remove(vm.Runtime.APISockPath) @@ -125,92 +106,6 @@ func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserve return errors.Join(snapshotErr, featureErr, tapErr) } -func defaultVSockPath(runtimeDir, vmID string) string { - return filepath.Join(runtimeDir, "fc-"+system.ShortID(vmID)+".vsock") -} - -func defaultVSockCID(guestIP string) (uint32, error) { - ip := net.ParseIP(strings.TrimSpace(guestIP)).To4() - if ip == nil { - return 0, fmt.Errorf("guest IP is not IPv4: %q", guestIP) - } - return 10000 + uint32(ip[3]), nil -} - -func waitForGuestVSockAgent(ctx context.Context, logger *slog.Logger, socketPath string, timeout time.Duration) error { - if strings.TrimSpace(socketPath) == "" { - return errors.New("vsock path is required") - } - - waitCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - ticker := time.NewTicker(vsockReadyPoll) - defer ticker.Stop() - - var lastErr error - for { - pingCtx, pingCancel := context.WithTimeout(waitCtx, 3*time.Second) - err := vsockagent.Health(pingCtx, logger, socketPath) - pingCancel() - if err == nil { - return nil - } - lastErr = err - - select { - case <-waitCtx.Done(): - if lastErr != nil { - return fmt.Errorf("guest vsock agent not ready: %w", lastErr) - } - return errors.New("guest vsock agent not ready before timeout") - case <-ticker.C: - } - } -} - -func (d *Daemon) setDNS(ctx context.Context, vmName, guestIP string) error { - if d.vmDNS == nil { - return nil - } - if err := d.vmDNS.Set(vmdns.RecordName(vmName), guestIP); err != nil { - return err - } - d.ensureVMDNSResolverRouting(ctx) - return nil -} - -func (d *Daemon) removeDNS(ctx context.Context, dnsName string) error { - if dnsName == "" { - return nil - } - if d.vmDNS == nil { - return nil - } - return d.vmDNS.Remove(dnsName) -} - -func (d *Daemon) rebuildDNS(ctx context.Context) error { - if d.vmDNS == nil { - return nil - } - vms, err := d.store.ListVMs(ctx) - if err != nil { - return err - } - records := make(map[string]string) - for _, vm := range vms { - if !d.vmAlive(vm) { - continue - } - if strings.TrimSpace(vm.Runtime.GuestIP) == "" { - continue - } - records[vmdns.RecordName(vm.Name)] = vm.Runtime.GuestIP - } - return d.vmDNS.Replace(records) -} - func (d *Daemon) generateName(ctx context.Context) (string, error) { _ = ctx if name := strings.TrimSpace(namegen.Generate()); name != "" { diff --git a/internal/daemon/vm_handles.go b/internal/daemon/vm_handles.go index ef367c4..40a2b34 100644 --- a/internal/daemon/vm_handles.go +++ b/internal/daemon/vm_handles.go @@ -200,7 +200,7 @@ func (d *Daemon) rediscoverHandles(ctx context.Context, vm model.VMRecord) (mode if apiSock == "" { return saved, false, nil } - if pid, pidErr := d.findFirecrackerPID(ctx, apiSock); pidErr == nil && pid > 0 { + if pid, pidErr := d.hostNet().findFirecrackerPID(ctx, apiSock); pidErr == nil && pid > 0 { saved.PID = pid return saved, true, nil } diff --git a/internal/daemon/vm_lifecycle.go b/internal/daemon/vm_lifecycle.go index 2bb8eb7..554dff0 100644 --- a/internal/daemon/vm_lifecycle.go +++ b/internal/daemon/vm_lifecycle.go @@ -56,11 +56,11 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod } d.clearVMHandles(vm) op.stage("bridge") - if err := d.ensureBridge(ctx); err != nil { + if err := d.hostNet().ensureBridge(ctx); err != nil { return model.VMRecord{}, err } op.stage("socket_dir") - if err := d.ensureSocketDir(); err != nil { + if err := d.hostNet().ensureSocketDir(); err != nil { return model.VMRecord{}, err } @@ -92,7 +92,7 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod op.stage("dm_snapshot", "dm_name", dmName) vmCreateStage(ctx, "prepare_rootfs", "creating root filesystem snapshot") - snapHandles, err := d.createDMSnapshot(ctx, image.RootfsPath, vm.Runtime.SystemOverlay, dmName) + snapHandles, err := d.hostNet().createDMSnapshot(ctx, image.RootfsPath, vm.Runtime.SystemOverlay, dmName) if err != nil { return model.VMRecord{}, err } @@ -138,7 +138,7 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod return cleanupOnErr(err) } op.stage("tap") - tap, err := d.acquireTap(ctx, tapName) + tap, err := d.hostNet().acquireTap(ctx, tapName) if err != nil { return cleanupOnErr(err) } @@ -150,7 +150,7 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod } op.stage("firecracker_binary") - fcPath, err := d.firecrackerBinary() + fcPath, err := d.hostNet().firecrackerBinary() if err != nil { return cleanupOnErr(err) } @@ -200,23 +200,23 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod // Use a fresh context: the request ctx may already be cancelled (client // disconnect), but we still need the PID so cleanupRuntime can kill the // Firecracker process that was spawned before the failure. - live.PID = d.resolveFirecrackerPID(context.Background(), machine, apiSock) + live.PID = d.hostNet().resolveFirecrackerPID(context.Background(), machine, apiSock) d.setVMHandles(vm, live) return cleanupOnErr(err) } - live.PID = d.resolveFirecrackerPID(context.Background(), machine, apiSock) + live.PID = d.hostNet().resolveFirecrackerPID(context.Background(), machine, apiSock) d.setVMHandles(vm, live) op.debugStage("firecracker_started", "pid", live.PID) op.stage("socket_access", "api_socket", apiSock) - if err := d.ensureSocketAccess(ctx, apiSock, "firecracker api socket"); err != nil { + if err := d.hostNet().ensureSocketAccess(ctx, apiSock, "firecracker api socket"); err != nil { return cleanupOnErr(err) } op.stage("vsock_access", "vsock_path", vm.Runtime.VSockPath, "vsock_cid", vm.Runtime.VSockCID) - if err := d.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { + if err := d.hostNet().ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { return cleanupOnErr(err) } vmCreateStage(ctx, "wait_vsock_agent", "waiting for guest vsock agent") - if err := waitForGuestVSockAgent(ctx, d.logger, vm.Runtime.VSockPath, vsockReadyWait); err != nil { + if err := d.hostNet().waitForGuestVSockAgent(ctx, vm.Runtime.VSockPath, vsockReadyWait); err != nil { return cleanupOnErr(err) } op.stage("post_start_features") @@ -264,11 +264,11 @@ func (d *Daemon) stopVMLocked(ctx context.Context, current model.VMRecord) (vm m } pid := d.vmHandles(vm.ID).PID op.stage("graceful_shutdown") - if err := d.sendCtrlAltDel(ctx, vm); err != nil { + if err := d.hostNet().sendCtrlAltDel(ctx, vm.Runtime.APISockPath); err != nil { return model.VMRecord{}, err } op.stage("wait_for_exit", "pid", pid) - if err := d.waitForExit(ctx, pid, vm.Runtime.APISockPath, gracefulShutdownWait); err != nil { + if err := d.hostNet().waitForExit(ctx, pid, vm.Runtime.APISockPath, gracefulShutdownWait); err != nil { if !errors.Is(err, errWaitForExitTimeout) { return model.VMRecord{}, err } @@ -328,7 +328,7 @@ func (d *Daemon) killVMLocked(ctx context.Context, current model.VMRecord, signa return model.VMRecord{}, err } op.stage("wait_for_exit", "pid", pid) - if err := d.waitForExit(ctx, pid, vm.Runtime.APISockPath, 30*time.Second); err != nil { + if err := d.hostNet().waitForExit(ctx, pid, vm.Runtime.APISockPath, 30*time.Second); err != nil { if !errors.Is(err, errWaitForExitTimeout) { return model.VMRecord{}, err } @@ -395,7 +395,7 @@ func (d *Daemon) deleteVMLocked(ctx context.Context, current model.VMRecord) (vm if d.vmAlive(vm) { pid := d.vmHandles(vm.ID).PID op.stage("kill_running_vm", "pid", pid) - _ = d.killVMProcess(ctx, pid) + _ = d.hostNet().killVMProcess(ctx, pid) } op.stage("cleanup_runtime") if err := d.cleanupRuntime(ctx, vm, false); err != nil { diff --git a/internal/daemon/vm_stats.go b/internal/daemon/vm_stats.go index d917150..77cc7fe 100644 --- a/internal/daemon/vm_stats.go +++ b/internal/daemon/vm_stats.go @@ -35,7 +35,7 @@ func (d *Daemon) HealthVM(ctx context.Context, idOrName string) (result api.VMHe if vm.Runtime.VSockCID == 0 { return model.VMRecord{}, errors.New("vm has no vsock cid") } - if err := d.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { + if err := d.hostNet().ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { return model.VMRecord{}, err } pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second) @@ -123,8 +123,8 @@ func (d *Daemon) stopStaleVMs(ctx context.Context) (err error) { return nil } op.stage("stopping_vm", vmLogAttrs(vm)...) - _ = d.sendCtrlAltDel(ctx, vm) - _ = d.waitForExit(ctx, d.vmHandles(vm.ID).PID, vm.Runtime.APISockPath, 10*time.Second) + _ = d.hostNet().sendCtrlAltDel(ctx, vm.Runtime.APISockPath) + _ = d.hostNet().waitForExit(ctx, d.vmHandles(vm.ID).PID, vm.Runtime.APISockPath, 10*time.Second) _ = d.cleanupRuntime(ctx, vm, true) vm.State = model.VMStateStopped vm.Runtime.State = model.VMStateStopped diff --git a/internal/daemon/vm_test.go b/internal/daemon/vm_test.go index c6ae796..7dfe279 100644 --- a/internal/daemon/vm_test.go +++ b/internal/daemon/vm_test.go @@ -212,7 +212,7 @@ func TestRebuildDNSIncludesOnlyLiveRunningVMs(t *testing.T) { } }) - d := &Daemon{store: db, vmDNS: server} + d := &Daemon{store: db, net: &HostNetwork{vmDNS: server}} // rebuildDNS reads the alive check from the handle cache. Seed // the live VM with its real PID; leave the stale entry with a PID // that definitely isn't running (999999 ≫ max PID on most hosts). @@ -512,7 +512,8 @@ func TestWaitForGuestVSockAgentRetriesUntilHealthy(t *testing.T) { serverDone <- errors.New("health probe did not retry") }() - if err := waitForGuestVSockAgent(context.Background(), nil, socketPath, time.Second); err != nil { + n := &HostNetwork{} + if err := n.waitForGuestVSockAgent(context.Background(), socketPath, time.Second); err != nil { t.Fatalf("waitForGuestVSockAgent: %v", err) } if err := <-serverDone; err != nil {