daemon split (1/5): extract *HostNetwork service

First phase of splitting the daemon god-struct into focused services
with explicit ownership.

HostNetwork now owns everything host-networking: the TAP interface
pool (initializeTapPool / ensureTapPool / acquireTap / releaseTap /
createTap), bridge + socket dir setup, firecracker process primitives
(find/resolve/kill/wait/ensureSocketAccess/sendCtrlAltDel), DM
snapshot lifecycle, NAT rule enforcement, guest DNS server lifecycle
+ routing setup, and the vsock-agent readiness probe. That's 7 files
whose receivers flipped from *Daemon to *HostNetwork, plus a new
host_network.go that declares the struct, its hostNetworkDeps, and
the factored firecracker + DNS helpers that used to live in vm.go.

Daemon gives up the tapPool and vmDNS fields entirely; they're now
HostNetwork's business. Construction goes through newHostNetwork in
Daemon.Open with an explicit dependency bag (runner, logger, config,
layout, closing). A lazy-init hostNet() helper on Daemon supports
test literals that don't wire net explicitly — production always
populates it eagerly.

Signature tightenings where the old receiver reached into VM-service
state:
 - ensureNAT(ctx, vm, enable) → ensureNAT(ctx, guestIP, tap, enable).
   Callers resolve tap from the handle cache themselves.
 - initializeTapPool(ctx) → initializeTapPool(usedTaps []string).
   Daemon.Open enumerates VMs, collects taps from handles, hands the
   slice in.

rebuildDNS stays on *Daemon as the orchestrator — it filters by
vm-alive (a VMService concern handles will move to in phase 4) then
calls HostNetwork.replaceDNS with the already-filtered map.

Capability hooks continue to take *Daemon; they now use it as a
facade to reach services (d.net.ensureNAT, d.hostNet().*). Planned
CapabilityHost interface extraction is orthogonal, left for later.

Tests: dns_routing_test.go + fastpath_test.go + nat_test.go +
snapshot_test.go + open_close_test.go were touched to construct
HostNetwork literals where they exercise its methods directly, or
route through d.hostNet() where they exercise the Daemon entry
points.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Thales Maciel 2026-04-20 20:11:46 -03:00
parent eba9a553bf
commit 362009d747
No known key found for this signature in database
GPG key ID: 33112E6833C34679
18 changed files with 461 additions and 326 deletions

View file

@ -234,11 +234,11 @@ type dnsCapability struct{}
func (dnsCapability) Name() string { return "dns" } func (dnsCapability) Name() string { return "dns" }
func (dnsCapability) PostStart(ctx context.Context, d *Daemon, vm model.VMRecord, _ model.Image) error { func (dnsCapability) PostStart(ctx context.Context, d *Daemon, vm model.VMRecord, _ model.Image) error {
return d.setDNS(ctx, vm.Name, vm.Runtime.GuestIP) return d.hostNet().setDNS(ctx, vm.Name, vm.Runtime.GuestIP)
} }
func (dnsCapability) Cleanup(ctx context.Context, d *Daemon, vm model.VMRecord) error { func (dnsCapability) Cleanup(_ context.Context, d *Daemon, vm model.VMRecord) error {
return d.removeDNS(ctx, vm.Runtime.DNSName) return d.hostNet().removeDNS(vm.Runtime.DNSName)
} }
func (dnsCapability) AddDoctorChecks(_ context.Context, _ *Daemon, report *system.Report) { func (dnsCapability) AddDoctorChecks(_ context.Context, _ *Daemon, report *system.Report) {
@ -263,14 +263,14 @@ func (natCapability) AddStartPreflight(ctx context.Context, d *Daemon, checks *s
if !vm.Spec.NATEnabled { if !vm.Spec.NATEnabled {
return return
} }
d.addNATPrereqs(ctx, checks) d.hostNet().addNATPrereqs(ctx, checks)
} }
func (natCapability) PostStart(ctx context.Context, d *Daemon, vm model.VMRecord, _ model.Image) error { func (natCapability) PostStart(ctx context.Context, d *Daemon, vm model.VMRecord, _ model.Image) error {
if !vm.Spec.NATEnabled { if !vm.Spec.NATEnabled {
return nil return nil
} }
return d.ensureNAT(ctx, vm, true) return d.hostNet().ensureNAT(ctx, vm.Runtime.GuestIP, d.vmHandles(vm.ID).TapDevice, true)
} }
func (natCapability) Cleanup(ctx context.Context, d *Daemon, vm model.VMRecord) error { func (natCapability) Cleanup(ctx context.Context, d *Daemon, vm model.VMRecord) error {
@ -284,7 +284,7 @@ func (natCapability) Cleanup(ctx context.Context, d *Daemon, vm model.VMRecord)
} }
return nil return nil
} }
return d.ensureNAT(ctx, vm, false) return d.hostNet().ensureNAT(ctx, vm.Runtime.GuestIP, tap, false)
} }
func (natCapability) ApplyConfigChange(ctx context.Context, d *Daemon, before, after model.VMRecord) error { func (natCapability) ApplyConfigChange(ctx context.Context, d *Daemon, before, after model.VMRecord) error {
@ -294,18 +294,18 @@ func (natCapability) ApplyConfigChange(ctx context.Context, d *Daemon, before, a
if !d.vmAlive(after) { if !d.vmAlive(after) {
return nil return nil
} }
return d.ensureNAT(ctx, after, after.Spec.NATEnabled) return d.hostNet().ensureNAT(ctx, after.Runtime.GuestIP, d.vmHandles(after.ID).TapDevice, after.Spec.NATEnabled)
} }
func (natCapability) AddDoctorChecks(ctx context.Context, d *Daemon, report *system.Report) { func (natCapability) AddDoctorChecks(ctx context.Context, d *Daemon, report *system.Report) {
checks := system.NewPreflight() checks := system.NewPreflight()
checks.RequireCommand("ip", toolHint("ip")) checks.RequireCommand("ip", toolHint("ip"))
d.addNATPrereqs(ctx, checks) d.hostNet().addNATPrereqs(ctx, checks)
if len(checks.Problems()) > 0 { if len(checks.Problems()) > 0 {
report.Add(system.CheckStatusFail, "feature nat", checks.Problems()...) report.Add(system.CheckStatusFail, "feature nat", checks.Problems()...)
return return
} }
uplink, err := d.defaultUplink(ctx) uplink, err := d.hostNet().defaultUplink(ctx)
if err != nil { if err != nil {
report.AddFail("feature nat", err.Error()) report.AddFail("feature nat", err.Error())
return return

View file

@ -52,12 +52,11 @@ type Daemon struct {
// lives in the store, this is rebuildable from a per-VM // lives in the store, this is rebuildable from a per-VM
// handles.json scratch file and OS inspection. // handles.json scratch file and OS inspection.
handles *handleCache handles *handleCache
tapPool tapPool net *HostNetwork
closing chan struct{} closing chan struct{}
once sync.Once once sync.Once
pid int pid int
listener net.Listener listener net.Listener
vmDNS *vmdns.Server
vmCaps []vmCapability vmCaps []vmCapability
pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error) pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error)
finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error
@ -90,15 +89,24 @@ func Open(ctx context.Context) (d *Daemon, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
closing := make(chan struct{})
runner := system.NewRunner()
d = &Daemon{ d = &Daemon{
layout: layout, layout: layout,
config: cfg, config: cfg,
store: db, store: db,
runner: system.NewRunner(), runner: runner,
logger: logger, logger: logger,
closing: make(chan struct{}), closing: closing,
pid: os.Getpid(), pid: os.Getpid(),
handles: newHandleCache(), handles: newHandleCache(),
net: newHostNetwork(hostNetworkDeps{
runner: runner,
logger: logger,
config: cfg,
layout: layout,
closing: closing,
}),
} }
// From here on, every failure path must run Close() so the host // From here on, every failure path must run Close() so the host
// state we touched (DNS listener goroutine, resolvectl routing, // state we touched (DNS listener goroutine, resolvectl routing,
@ -114,7 +122,7 @@ func Open(ctx context.Context) (d *Daemon, err error) {
d.ensureVMSSHClientConfig() d.ensureVMSSHClientConfig()
d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "log_level", cfg.LogLevel) d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "log_level", cfg.LogLevel)
if err = d.startVMDNS(vmdns.DefaultListenAddr); err != nil { if err = d.hostNet().startVMDNS(vmdns.DefaultListenAddr); err != nil {
d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error()) d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error())
return nil, err return nil, err
} }
@ -122,12 +130,24 @@ func Open(ctx context.Context) (d *Daemon, err error) {
d.logger.Error("daemon open failed", "stage", "reconcile", "error", err.Error()) d.logger.Error("daemon open failed", "stage", "reconcile", "error", err.Error())
return nil, err return nil, err
} }
d.ensureVMDNSResolverRouting(ctx) d.hostNet().ensureVMDNSResolverRouting(ctx)
if err = d.initializeTapPool(ctx); err != nil { // Seed HostNetwork's pool index from taps already claimed by VMs
d.logger.Error("daemon open failed", "stage", "initialize_tap_pool", "error", err.Error()) // on disk so newly warmed pool entries don't collide with them.
return nil, err if d.config.TapPoolSize > 0 && d.store != nil {
vms, listErr := d.store.ListVMs(ctx)
if listErr != nil {
d.logger.Error("daemon open failed", "stage", "initialize_tap_pool", "error", listErr.Error())
return nil, listErr
} }
go d.ensureTapPool(context.Background()) used := make([]string, 0, len(vms))
for _, vm := range vms {
if tap := d.vmHandles(vm.ID).TapDevice; tap != "" {
used = append(used, tap)
}
}
d.hostNet().initializeTapPool(used)
}
go d.hostNet().ensureTapPool(context.Background())
return d, nil return d, nil
} }
@ -141,7 +161,7 @@ func (d *Daemon) Close() error {
if d.listener != nil { if d.listener != nil {
_ = d.listener.Close() _ = d.listener.Close()
} }
err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.store.Close()) err = errors.Join(d.hostNet().clearVMDNSResolverRouting(context.Background()), d.hostNet().stopVMDNS(), d.store.Close())
}) })
return err return err
} }
@ -518,27 +538,6 @@ func (d *Daemon) backgroundLoop() {
} }
} }
func (d *Daemon) startVMDNS(addr string) error {
server, err := vmdns.New(addr, d.logger)
if err != nil {
return err
}
d.vmDNS = server
if d.logger != nil {
d.logger.Info("vm dns serving", "dns_addr", server.Addr())
}
return nil
}
func (d *Daemon) stopVMDNS() error {
if d.vmDNS == nil {
return nil
}
err := d.vmDNS.Close()
d.vmDNS = nil
return err
}
func (d *Daemon) ensureDefaultImage(ctx context.Context) error { func (d *Daemon) ensureDefaultImage(ctx context.Context) error {
_ = ctx _ = ctx
return nil return nil

View file

@ -15,49 +15,49 @@ var (
vmDNSAddrFunc = func(server *vmdns.Server) string { return server.Addr() } vmDNSAddrFunc = func(server *vmdns.Server) string { return server.Addr() }
) )
func (d *Daemon) syncVMDNSResolverRouting(ctx context.Context) error { func (n *HostNetwork) syncVMDNSResolverRouting(ctx context.Context) error {
if d == nil || d.vmDNS == nil { if n == nil || n.vmDNS == nil {
return nil return nil
} }
if strings.TrimSpace(d.config.BridgeName) == "" { if strings.TrimSpace(n.config.BridgeName) == "" {
return nil return nil
} }
if _, err := lookupExecutableFunc("resolvectl"); err != nil { if _, err := lookupExecutableFunc("resolvectl"); err != nil {
return nil return nil
} }
if _, err := d.runner.Run(ctx, "ip", "link", "show", d.config.BridgeName); err != nil { if _, err := n.runner.Run(ctx, "ip", "link", "show", n.config.BridgeName); err != nil {
return nil return nil
} }
serverAddr := strings.TrimSpace(vmDNSAddrFunc(d.vmDNS)) serverAddr := strings.TrimSpace(vmDNSAddrFunc(n.vmDNS))
if serverAddr == "" { if serverAddr == "" {
return nil return nil
} }
if _, err := d.runner.RunSudo(ctx, "resolvectl", "dns", d.config.BridgeName, serverAddr); err != nil { if _, err := n.runner.RunSudo(ctx, "resolvectl", "dns", n.config.BridgeName, serverAddr); err != nil {
return err return err
} }
if _, err := d.runner.RunSudo(ctx, "resolvectl", "domain", d.config.BridgeName, vmResolverRouteDomain); err != nil { if _, err := n.runner.RunSudo(ctx, "resolvectl", "domain", n.config.BridgeName, vmResolverRouteDomain); err != nil {
return err return err
} }
_, err := d.runner.RunSudo(ctx, "resolvectl", "default-route", d.config.BridgeName, "no") _, err := n.runner.RunSudo(ctx, "resolvectl", "default-route", n.config.BridgeName, "no")
return err return err
} }
func (d *Daemon) clearVMDNSResolverRouting(ctx context.Context) error { func (n *HostNetwork) clearVMDNSResolverRouting(ctx context.Context) error {
if d == nil || strings.TrimSpace(d.config.BridgeName) == "" { if n == nil || strings.TrimSpace(n.config.BridgeName) == "" {
return nil return nil
} }
if _, err := lookupExecutableFunc("resolvectl"); err != nil { if _, err := lookupExecutableFunc("resolvectl"); err != nil {
return nil return nil
} }
if _, err := d.runner.Run(ctx, "ip", "link", "show", d.config.BridgeName); err != nil { if _, err := n.runner.Run(ctx, "ip", "link", "show", n.config.BridgeName); err != nil {
return nil return nil
} }
_, err := d.runner.RunSudo(ctx, "resolvectl", "revert", d.config.BridgeName) _, err := n.runner.RunSudo(ctx, "resolvectl", "revert", n.config.BridgeName)
return err return err
} }
func (d *Daemon) ensureVMDNSResolverRouting(ctx context.Context) { func (n *HostNetwork) ensureVMDNSResolverRouting(ctx context.Context) {
if err := d.syncVMDNSResolverRouting(ctx); err != nil && d.logger != nil { if err := n.syncVMDNSResolverRouting(ctx); err != nil && n.logger != nil {
d.logger.Warn("vm dns resolver route sync failed", "bridge", d.config.BridgeName, "error", err.Error()) n.logger.Warn("vm dns resolver route sync failed", "bridge", n.config.BridgeName, "error", err.Error())
} }
} }

View file

@ -32,13 +32,10 @@ func TestSyncVMDNSResolverRoutingConfiguresResolved(t *testing.T) {
sudoStep("", nil, "resolvectl", "default-route", model.DefaultBridgeName, "no"), sudoStep("", nil, "resolvectl", "default-route", model.DefaultBridgeName, "no"),
}, },
} }
d := &Daemon{ cfg := model.DaemonConfig{BridgeName: model.DefaultBridgeName}
runner: runner, n := &HostNetwork{runner: runner, config: cfg, vmDNS: new(vmdns.Server)}
config: model.DaemonConfig{BridgeName: model.DefaultBridgeName},
vmDNS: new(vmdns.Server),
}
if err := d.syncVMDNSResolverRouting(context.Background()); err != nil { if err := n.syncVMDNSResolverRouting(context.Background()); err != nil {
t.Fatalf("syncVMDNSResolverRouting: %v", err) t.Fatalf("syncVMDNSResolverRouting: %v", err)
} }
runner.assertExhausted() runner.assertExhausted()
@ -63,12 +60,10 @@ func TestClearVMDNSResolverRoutingRevertsBridgeConfig(t *testing.T) {
sudoStep("", nil, "resolvectl", "revert", model.DefaultBridgeName), sudoStep("", nil, "resolvectl", "revert", model.DefaultBridgeName),
}, },
} }
d := &Daemon{ cfg := model.DaemonConfig{BridgeName: model.DefaultBridgeName}
runner: runner, n := &HostNetwork{runner: runner, config: cfg}
config: model.DaemonConfig{BridgeName: model.DefaultBridgeName},
}
if err := d.clearVMDNSResolverRouting(context.Background()); err != nil { if err := n.clearVMDNSResolverRouting(context.Background()); err != nil {
t.Fatalf("clearVMDNSResolverRouting: %v", err) t.Fatalf("clearVMDNSResolverRouting: %v", err)
} }
runner.assertExhausted() runner.assertExhausted()

View file

@ -75,18 +75,18 @@ func TestTapPoolWarmsAndReusesIdleTap(t *testing.T) {
closing: make(chan struct{}), closing: make(chan struct{}),
} }
d.ensureTapPool(context.Background()) d.hostNet().ensureTapPool(context.Background())
tapName, err := d.acquireTap(context.Background(), "tap-fallback") tapName, err := d.hostNet().acquireTap(context.Background(), "tap-fallback")
if err != nil { if err != nil {
t.Fatalf("acquireTap: %v", err) t.Fatalf("acquireTap: %v", err)
} }
if tapName != "tap-pool-0" { if tapName != "tap-pool-0" {
t.Fatalf("tapName = %q, want tap-pool-0", tapName) t.Fatalf("tapName = %q, want tap-pool-0", tapName)
} }
if err := d.releaseTap(context.Background(), tapName); err != nil { if err := d.hostNet().releaseTap(context.Background(), tapName); err != nil {
t.Fatalf("releaseTap: %v", err) t.Fatalf("releaseTap: %v", err)
} }
tapName, err = d.acquireTap(context.Background(), "tap-fallback") tapName, err = d.hostNet().acquireTap(context.Background(), "tap-fallback")
if err != nil { if err != nil {
t.Fatalf("acquireTap second time: %v", err) t.Fatalf("acquireTap second time: %v", err)
} }

View file

@ -0,0 +1,242 @@
package daemon
import (
"context"
"errors"
"fmt"
"log/slog"
"net"
"path/filepath"
"strings"
"time"
"banger/internal/daemon/fcproc"
"banger/internal/firecracker"
"banger/internal/model"
"banger/internal/paths"
"banger/internal/system"
"banger/internal/vmdns"
"banger/internal/vsockagent"
)
// HostNetwork owns the daemon's side of host networking: the TAP
// interface pool, the bridge, per-VM tap/NAT/DNS wiring, and the
// firecracker-process primitives (bridge setup, socket access,
// pgrep-based PID resolution, ctrl-alt-del, wait/kill) plus DM
// snapshot helpers. The Daemon holds one *HostNetwork and routes
// lifecycle calls through it instead of reaching into host-state
// directly.
//
// Fields stay unexported so peer services (VMService, etc.) access
// HostNetwork only through consumer-defined interfaces, not by
// fishing around in its struct. Construction goes through
// newHostNetwork with an explicit dependency bag so the wiring is
// auditable.
type HostNetwork struct {
runner system.CommandRunner
logger *slog.Logger
config model.DaemonConfig
layout paths.Layout
closing chan struct{}
tapPool tapPool
vmDNS *vmdns.Server
}
// hostNetworkDeps is the explicit wiring bag newHostNetwork expects.
// Keeping the deps in a dedicated struct rather than positional args
// makes the construction site in Daemon.Open read like a declaration.
type hostNetworkDeps struct {
runner system.CommandRunner
logger *slog.Logger
config model.DaemonConfig
layout paths.Layout
closing chan struct{}
}
func newHostNetwork(deps hostNetworkDeps) *HostNetwork {
return &HostNetwork{
runner: deps.runner,
logger: deps.logger,
config: deps.config,
layout: deps.layout,
closing: deps.closing,
}
}
// hostNet returns the HostNetwork service, lazily constructing it from
// the Daemon's current fields if a test literal didn't wire one up.
// Production paths go through Daemon.Open, which always populates d.net
// eagerly; this lazy path exists only so tests that build `&Daemon{...}`
// literals without spelling out a HostNetwork don't have to learn the
// new construction pattern. Every call from production code that
// touches HostNetwork funnels through here.
func (d *Daemon) hostNet() *HostNetwork {
if d.net != nil {
return d.net
}
d.net = newHostNetwork(hostNetworkDeps{
runner: d.runner,
logger: d.logger,
config: d.config,
layout: d.layout,
closing: d.closing,
})
return d.net
}
// --- DNS server lifecycle -------------------------------------------
func (n *HostNetwork) startVMDNS(addr string) error {
server, err := vmdns.New(addr, n.logger)
if err != nil {
return err
}
n.vmDNS = server
if n.logger != nil {
n.logger.Info("vm dns serving", "dns_addr", server.Addr())
}
return nil
}
func (n *HostNetwork) stopVMDNS() error {
if n.vmDNS == nil {
return nil
}
err := n.vmDNS.Close()
n.vmDNS = nil
return err
}
func (n *HostNetwork) setDNS(ctx context.Context, vmName, guestIP string) error {
if n.vmDNS == nil {
return nil
}
if err := n.vmDNS.Set(vmdns.RecordName(vmName), guestIP); err != nil {
return err
}
n.ensureVMDNSResolverRouting(ctx)
return nil
}
func (n *HostNetwork) removeDNS(dnsName string) error {
if dnsName == "" || n.vmDNS == nil {
return nil
}
return n.vmDNS.Remove(dnsName)
}
// replaceDNS replaces the DNS server's full record set. Callers
// (Daemon.rebuildDNS) filter by vm-alive first; HostNetwork just
// takes the pre-filtered map.
func (n *HostNetwork) replaceDNS(records map[string]string) error {
if n.vmDNS == nil {
return nil
}
return n.vmDNS.Replace(records)
}
// --- Firecracker process helpers ------------------------------------
// fc builds a fresh fcproc.Manager from the HostNetwork's current
// runner, config, and layout. Manager is stateless beyond those
// handles, so constructing per call keeps tests that build literals
// working without extra wiring.
func (n *HostNetwork) fc() *fcproc.Manager {
return fcproc.New(n.runner, fcproc.Config{
FirecrackerBin: n.config.FirecrackerBin,
BridgeName: n.config.BridgeName,
BridgeIP: n.config.BridgeIP,
CIDR: n.config.CIDR,
RuntimeDir: n.layout.RuntimeDir,
}, n.logger)
}
func (n *HostNetwork) ensureBridge(ctx context.Context) error {
return n.fc().EnsureBridge(ctx)
}
func (n *HostNetwork) ensureSocketDir() error {
return n.fc().EnsureSocketDir()
}
func (n *HostNetwork) createTap(ctx context.Context, tap string) error {
return n.fc().CreateTap(ctx, tap)
}
func (n *HostNetwork) firecrackerBinary() (string, error) {
return n.fc().ResolveBinary()
}
func (n *HostNetwork) ensureSocketAccess(ctx context.Context, socketPath, label string) error {
return n.fc().EnsureSocketAccess(ctx, socketPath, label)
}
func (n *HostNetwork) findFirecrackerPID(ctx context.Context, apiSock string) (int, error) {
return n.fc().FindPID(ctx, apiSock)
}
func (n *HostNetwork) resolveFirecrackerPID(ctx context.Context, machine *firecracker.Machine, apiSock string) int {
return n.fc().ResolvePID(ctx, machine, apiSock)
}
func (n *HostNetwork) sendCtrlAltDel(ctx context.Context, apiSockPath string) error {
return n.fc().SendCtrlAltDel(ctx, apiSockPath)
}
func (n *HostNetwork) waitForExit(ctx context.Context, pid int, apiSock string, timeout time.Duration) error {
return n.fc().WaitForExit(ctx, pid, apiSock, timeout)
}
func (n *HostNetwork) killVMProcess(ctx context.Context, pid int) error {
return n.fc().Kill(ctx, pid)
}
// waitForGuestVSockAgent is a HostNetwork helper because it's
// fundamentally about waiting for a vsock socket the firecracker
// process is serving on. No daemon state needed.
func (n *HostNetwork) waitForGuestVSockAgent(ctx context.Context, socketPath string, timeout time.Duration) error {
if strings.TrimSpace(socketPath) == "" {
return errors.New("vsock path is required")
}
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
ticker := time.NewTicker(vsockReadyPoll)
defer ticker.Stop()
var lastErr error
for {
pingCtx, pingCancel := context.WithTimeout(waitCtx, 3*time.Second)
err := vsockagent.Health(pingCtx, n.logger, socketPath)
pingCancel()
if err == nil {
return nil
}
lastErr = err
select {
case <-waitCtx.Done():
if lastErr != nil {
return fmt.Errorf("guest vsock agent not ready: %w", lastErr)
}
return errors.New("guest vsock agent not ready before timeout")
case <-ticker.C:
}
}
}
// --- Utilities used across networking ------------------------------
func defaultVSockPath(runtimeDir, vmID string) string {
return filepath.Join(runtimeDir, "fc-"+system.ShortID(vmID)+".vsock")
}
func defaultVSockCID(guestIP string) (uint32, error) {
ip := net.ParseIP(strings.TrimSpace(guestIP)).To4()
if ip == nil {
return 0, fmt.Errorf("guest IP is not IPv4: %q", guestIP)
}
return 10000 + uint32(ip[3]), nil
}

View file

@ -10,22 +10,43 @@ import (
type natRule = hostnat.Rule type natRule = hostnat.Rule
func (d *Daemon) ensureNAT(ctx context.Context, vm model.VMRecord, enable bool) error { // ensureNAT takes tap explicitly rather than reading from a handle
return hostnat.Ensure(ctx, d.runner, vm.Runtime.GuestIP, d.vmHandles(vm.ID).TapDevice, enable) // cache so HostNetwork stays decoupled from VM-service state.
// Callers (vm_lifecycle) resolve the tap device from the handle cache
// themselves and pass it in.
func (n *HostNetwork) ensureNAT(ctx context.Context, guestIP, tap string, enable bool) error {
return hostnat.Ensure(ctx, n.runner, guestIP, tap, enable)
} }
func (d *Daemon) validateNATPrereqs(ctx context.Context) (string, error) { func (n *HostNetwork) validateNATPrereqs(ctx context.Context) (string, error) {
checks := system.NewPreflight() checks := system.NewPreflight()
checks.RequireCommand("ip", toolHint("ip")) checks.RequireCommand("ip", toolHint("ip"))
d.addNATPrereqs(ctx, checks) n.addNATPrereqs(ctx, checks)
if err := checks.Err("nat preflight failed"); err != nil { if err := checks.Err("nat preflight failed"); err != nil {
return "", err return "", err
} }
return d.defaultUplink(ctx) return n.defaultUplink(ctx)
} }
func (d *Daemon) defaultUplink(ctx context.Context) (string, error) { func (n *HostNetwork) addNATPrereqs(ctx context.Context, checks *system.Preflight) {
return hostnat.DefaultUplink(ctx, d.runner) checks.RequireCommand("iptables", toolHint("iptables"))
checks.RequireCommand("sysctl", toolHint("sysctl"))
runner := n.runner
if runner == nil {
runner = system.NewRunner()
}
out, err := runner.Run(ctx, "ip", "route", "show", "default")
if err != nil {
checks.Addf("failed to inspect the default route for NAT: %v", err)
return
}
if _, err := parseDefaultUplink(string(out)); err != nil {
checks.Addf("failed to detect the uplink interface for NAT: %v", err)
}
}
func (n *HostNetwork) defaultUplink(ctx context.Context) (string, error) {
return hostnat.DefaultUplink(ctx, n.runner)
} }
func parseDefaultUplink(output string) (string, error) { func parseDefaultUplink(output string) (string, error) {

View file

@ -50,12 +50,12 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) {
return &Daemon{ return &Daemon{
store: openDaemonStore(t), store: openDaemonStore(t),
closing: make(chan struct{}), closing: make(chan struct{}),
vmDNS: server, net: &HostNetwork{vmDNS: server},
logger: slog.New(slog.NewTextHandler(io.Discard, nil)), logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
} }
}, },
verify: func(t *testing.T, d *Daemon) { verify: func(t *testing.T, d *Daemon) {
if d.vmDNS != nil { if d.hostNet().vmDNS != nil {
t.Error("vmDNS not cleared by Close") t.Error("vmDNS not cleared by Close")
} }
}, },

View file

@ -40,7 +40,7 @@ func (d *Daemon) PortsVM(ctx context.Context, idOrName string) (result api.VMPor
if vm.Runtime.VSockCID == 0 { if vm.Runtime.VSockCID == 0 {
return model.VMRecord{}, errors.New("vm has no vsock cid") return model.VMRecord{}, errors.New("vm has no vsock cid")
} }
if err := d.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { if err := d.hostNet().ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
portsCtx, cancel := context.WithTimeout(ctx, 3*time.Second) portsCtx, cancel := context.WithTimeout(ctx, 3*time.Second)

View file

@ -25,23 +25,6 @@ func (d *Daemon) validateWorkDiskResizePrereqs() error {
return checks.Err("work disk resize preflight failed") return checks.Err("work disk resize preflight failed")
} }
func (d *Daemon) addNATPrereqs(ctx context.Context, checks *system.Preflight) {
checks.RequireCommand("iptables", toolHint("iptables"))
checks.RequireCommand("sysctl", toolHint("sysctl"))
runner := d.runner
if runner == nil {
runner = system.NewRunner()
}
out, err := runner.Run(ctx, "ip", "route", "show", "default")
if err != nil {
checks.Addf("failed to inspect the default route for NAT: %v", err)
return
}
if _, err := parseDefaultUplink(string(out)); err != nil {
checks.Addf("failed to detect the uplink interface for NAT: %v", err)
}
}
func (d *Daemon) addBaseStartPrereqs(checks *system.Preflight, image model.Image) { func (d *Daemon) addBaseStartPrereqs(checks *system.Preflight, image model.Image) {
d.addBaseStartCommandPrereqs(checks) d.addBaseStartCommandPrereqs(checks)
checks.RequireExecutable(d.config.FirecrackerBin, "firecracker binary", `install firecracker or set "firecracker_bin"`) checks.RequireExecutable(d.config.FirecrackerBin, "firecracker binary", `install firecracker or set "firecracker_bin"`)

View file

@ -10,14 +10,14 @@ import (
// type so existing call sites and tests read naturally. // type so existing call sites and tests read naturally.
type dmSnapshotHandles = dmsnap.Handles type dmSnapshotHandles = dmsnap.Handles
func (d *Daemon) createDMSnapshot(ctx context.Context, rootfsPath, cowPath, dmName string) (dmSnapshotHandles, error) { func (n *HostNetwork) createDMSnapshot(ctx context.Context, rootfsPath, cowPath, dmName string) (dmSnapshotHandles, error) {
return dmsnap.Create(ctx, d.runner, rootfsPath, cowPath, dmName) return dmsnap.Create(ctx, n.runner, rootfsPath, cowPath, dmName)
} }
func (d *Daemon) cleanupDMSnapshot(ctx context.Context, handles dmSnapshotHandles) error { func (n *HostNetwork) cleanupDMSnapshot(ctx context.Context, handles dmSnapshotHandles) error {
return dmsnap.Cleanup(ctx, d.runner, handles) return dmsnap.Cleanup(ctx, n.runner, handles)
} }
func (d *Daemon) removeDMSnapshot(ctx context.Context, target string) error { func (n *HostNetwork) removeDMSnapshot(ctx context.Context, target string) error {
return dmsnap.Remove(ctx, d.runner, target) return dmsnap.Remove(ctx, n.runner, target)
} }

View file

@ -74,7 +74,7 @@ func TestCreateDMSnapshotFailsWithoutRollbackWhenBaseLoopSetupFails(t *testing.T
} }
d := &Daemon{runner: runner} d := &Daemon{runner: runner}
_, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") _, err := d.hostNet().createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test")
if !errors.Is(err, attachErr) { if !errors.Is(err, attachErr) {
t.Fatalf("error = %v, want %v", err, attachErr) t.Fatalf("error = %v, want %v", err, attachErr)
} }
@ -98,7 +98,7 @@ func TestCreateDMSnapshotRollsBackBaseLoopWhenCowLoopSetupFails(t *testing.T) {
} }
d := &Daemon{runner: runner} d := &Daemon{runner: runner}
_, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") _, err := d.hostNet().createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test")
if !errors.Is(err, attachErr) { if !errors.Is(err, attachErr) {
t.Fatalf("error = %v, want %v", err, attachErr) t.Fatalf("error = %v, want %v", err, attachErr)
} }
@ -121,7 +121,7 @@ func TestCreateDMSnapshotRollsBackBothLoopsWhenBlockdevFails(t *testing.T) {
} }
d := &Daemon{runner: runner} d := &Daemon{runner: runner}
_, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") _, err := d.hostNet().createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test")
if !errors.Is(err, blockdevErr) { if !errors.Is(err, blockdevErr) {
t.Fatalf("error = %v, want %v", err, blockdevErr) t.Fatalf("error = %v, want %v", err, blockdevErr)
} }
@ -145,7 +145,7 @@ func TestCreateDMSnapshotRollsBackLoopsWhenDMSetupFails(t *testing.T) {
} }
d := &Daemon{runner: runner} d := &Daemon{runner: runner}
_, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") _, err := d.hostNet().createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test")
if !errors.Is(err, dmErr) { if !errors.Is(err, dmErr) {
t.Fatalf("error = %v, want %v", err, dmErr) t.Fatalf("error = %v, want %v", err, dmErr)
} }
@ -174,7 +174,7 @@ func TestCreateDMSnapshotJoinsRollbackErrors(t *testing.T) {
} }
d := &Daemon{runner: runner} d := &Daemon{runner: runner}
_, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") _, err := d.hostNet().createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test")
if err == nil { if err == nil {
t.Fatal("expected createDMSnapshot to return an error") t.Fatal("expected createDMSnapshot to return an error")
} }
@ -198,7 +198,7 @@ func TestCreateDMSnapshotReturnsHandlesOnSuccess(t *testing.T) {
} }
d := &Daemon{runner: runner} d := &Daemon{runner: runner}
handles, err := d.createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test") handles, err := d.hostNet().createDMSnapshot(context.Background(), "/rootfs.ext4", "/cow.ext4", "fc-rootfs-test")
if err != nil { if err != nil {
t.Fatalf("createDMSnapshot returned error: %v", err) t.Fatalf("createDMSnapshot returned error: %v", err)
} }
@ -227,7 +227,7 @@ func TestCleanupDMSnapshotRemovesResourcesInReverseOrder(t *testing.T) {
} }
d := &Daemon{runner: runner} d := &Daemon{runner: runner}
err := d.cleanupDMSnapshot(context.Background(), dmSnapshotHandles{ err := d.hostNet().cleanupDMSnapshot(context.Background(), dmSnapshotHandles{
BaseLoop: "/dev/loop10", BaseLoop: "/dev/loop10",
COWLoop: "/dev/loop11", COWLoop: "/dev/loop11",
DMName: "fc-rootfs-test", DMName: "fc-rootfs-test",
@ -251,7 +251,7 @@ func TestCleanupDMSnapshotUsesPartialHandles(t *testing.T) {
} }
d := &Daemon{runner: runner} d := &Daemon{runner: runner}
err := d.cleanupDMSnapshot(context.Background(), dmSnapshotHandles{ err := d.hostNet().cleanupDMSnapshot(context.Background(), dmSnapshotHandles{
BaseLoop: "/dev/loop10", BaseLoop: "/dev/loop10",
DMDev: "/dev/mapper/fc-rootfs-test", DMDev: "/dev/mapper/fc-rootfs-test",
}) })
@ -277,7 +277,7 @@ func TestCleanupDMSnapshotJoinsTeardownErrors(t *testing.T) {
} }
d := &Daemon{runner: runner} d := &Daemon{runner: runner}
err := d.cleanupDMSnapshot(context.Background(), dmSnapshotHandles{ err := d.hostNet().cleanupDMSnapshot(context.Background(), dmSnapshotHandles{
BaseLoop: "/dev/loop10", BaseLoop: "/dev/loop10",
COWLoop: "/dev/loop11", COWLoop: "/dev/loop11",
DMName: "fc-rootfs-test", DMName: "fc-rootfs-test",
@ -307,7 +307,7 @@ func TestRemoveDMSnapshotRetriesBusyDevice(t *testing.T) {
} }
d := &Daemon{runner: runner} d := &Daemon{runner: runner}
if err := d.removeDMSnapshot(context.Background(), "fc-rootfs-test"); err != nil { if err := d.hostNet().removeDMSnapshot(context.Background(), "fc-rootfs-test"); err != nil {
t.Fatalf("removeDMSnapshot returned error: %v", err) t.Fatalf("removeDMSnapshot returned error: %v", err)
} }
runner.assertExhausted() runner.assertExhausted()

View file

@ -18,98 +18,97 @@ type tapPool struct {
next int next int
} }
func (d *Daemon) initializeTapPool(ctx context.Context) error { // initializeTapPool seeds the monotonic pool index from the set of
if d.config.TapPoolSize <= 0 || d.store == nil { // tap names already in use by running/stopped VMs, so newly warmed
return nil // pool entries don't collide with existing ones. Callers (Daemon.Open)
} // enumerate used taps from the handle cache and pass them in.
vms, err := d.store.ListVMs(ctx) func (n *HostNetwork) initializeTapPool(usedTaps []string) {
if err != nil { if n.config.TapPoolSize <= 0 {
return err return
} }
next := 0 next := 0
for _, vm := range vms { for _, tapName := range usedTaps {
if index, ok := parseTapPoolIndex(d.vmHandles(vm.ID).TapDevice); ok && index >= next { if index, ok := parseTapPoolIndex(tapName); ok && index >= next {
next = index + 1 next = index + 1
} }
} }
d.tapPool.mu.Lock() n.tapPool.mu.Lock()
d.tapPool.next = next n.tapPool.next = next
d.tapPool.mu.Unlock() n.tapPool.mu.Unlock()
return nil
} }
func (d *Daemon) ensureTapPool(ctx context.Context) { func (n *HostNetwork) ensureTapPool(ctx context.Context) {
if d.config.TapPoolSize <= 0 { if n.config.TapPoolSize <= 0 {
return return
} }
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-d.closing: case <-n.closing:
return return
default: default:
} }
d.tapPool.mu.Lock() n.tapPool.mu.Lock()
if len(d.tapPool.entries) >= d.config.TapPoolSize { if len(n.tapPool.entries) >= n.config.TapPoolSize {
d.tapPool.mu.Unlock() n.tapPool.mu.Unlock()
return return
} }
tapName := fmt.Sprintf("%s%d", tapPoolPrefix, d.tapPool.next) tapName := fmt.Sprintf("%s%d", tapPoolPrefix, n.tapPool.next)
d.tapPool.next++ n.tapPool.next++
d.tapPool.mu.Unlock() n.tapPool.mu.Unlock()
if err := d.createTap(ctx, tapName); err != nil { if err := n.createTap(ctx, tapName); err != nil {
if d.logger != nil { if n.logger != nil {
d.logger.Warn("tap pool warmup failed", "tap_device", tapName, "error", err.Error()) n.logger.Warn("tap pool warmup failed", "tap_device", tapName, "error", err.Error())
} }
return return
} }
d.tapPool.mu.Lock() n.tapPool.mu.Lock()
d.tapPool.entries = append(d.tapPool.entries, tapName) n.tapPool.entries = append(n.tapPool.entries, tapName)
d.tapPool.mu.Unlock() n.tapPool.mu.Unlock()
if d.logger != nil { if n.logger != nil {
d.logger.Debug("tap added to idle pool", "tap_device", tapName) n.logger.Debug("tap added to idle pool", "tap_device", tapName)
} }
} }
} }
func (d *Daemon) acquireTap(ctx context.Context, fallbackName string) (string, error) { func (n *HostNetwork) acquireTap(ctx context.Context, fallbackName string) (string, error) {
d.tapPool.mu.Lock() n.tapPool.mu.Lock()
if n := len(d.tapPool.entries); n > 0 { if count := len(n.tapPool.entries); count > 0 {
tapName := d.tapPool.entries[n-1] tapName := n.tapPool.entries[count-1]
d.tapPool.entries = d.tapPool.entries[:n-1] n.tapPool.entries = n.tapPool.entries[:count-1]
d.tapPool.mu.Unlock() n.tapPool.mu.Unlock()
return tapName, nil return tapName, nil
} }
d.tapPool.mu.Unlock() n.tapPool.mu.Unlock()
if err := d.createTap(ctx, fallbackName); err != nil { if err := n.createTap(ctx, fallbackName); err != nil {
return "", err return "", err
} }
return fallbackName, nil return fallbackName, nil
} }
func (d *Daemon) releaseTap(ctx context.Context, tapName string) error { func (n *HostNetwork) releaseTap(ctx context.Context, tapName string) error {
tapName = strings.TrimSpace(tapName) tapName = strings.TrimSpace(tapName)
if tapName == "" { if tapName == "" {
return nil return nil
} }
if isTapPoolName(tapName) { if isTapPoolName(tapName) {
d.tapPool.mu.Lock() n.tapPool.mu.Lock()
if len(d.tapPool.entries) < d.config.TapPoolSize { if len(n.tapPool.entries) < n.config.TapPoolSize {
d.tapPool.entries = append(d.tapPool.entries, tapName) n.tapPool.entries = append(n.tapPool.entries, tapName)
d.tapPool.mu.Unlock() n.tapPool.mu.Unlock()
return nil return nil
} }
d.tapPool.mu.Unlock() n.tapPool.mu.Unlock()
} }
_, err := d.runner.RunSudo(ctx, "ip", "link", "del", tapName) _, err := n.runner.RunSudo(ctx, "ip", "link", "del", tapName)
if err == nil { if err == nil {
go d.ensureTapPool(context.Background()) go n.ensureTapPool(context.Background())
} }
return err return err
} }

View file

@ -4,23 +4,20 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net"
"os" "os"
"path/filepath"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"banger/internal/daemon/fcproc" "banger/internal/daemon/fcproc"
"banger/internal/firecracker"
"banger/internal/model" "banger/internal/model"
"banger/internal/namegen" "banger/internal/namegen"
"banger/internal/system" "banger/internal/system"
"banger/internal/vmdns"
"banger/internal/vsockagent"
) )
// Cross-service constants. Kept in vm.go because both lifecycle
// (VMService) and networking (HostNetwork) reference them; moving
// them to either owner would read as a layering violation.
var ( var (
errWaitForExitTimeout = fcproc.ErrWaitForExitTimeout errWaitForExitTimeout = fcproc.ErrWaitForExitTimeout
gracefulShutdownWait = 10 * time.Second gracefulShutdownWait = 10 * time.Second
@ -28,59 +25,43 @@ var (
vsockReadyPoll = 200 * time.Millisecond vsockReadyPoll = 200 * time.Millisecond
) )
// fc builds a fresh fcproc.Manager from the Daemon's current runner, config, // rebuildDNS enumerates live VMs and republishes the DNS record set.
// and layout. Manager is stateless beyond those handles, so constructing per // Lives on *Daemon (not HostNetwork) because "alive" is a VMService
// call keeps tests that build Daemon literals working without extra wiring. // concern that HostNetwork shouldn't need to reach into. Daemon
func (d *Daemon) fc() *fcproc.Manager { // orchestrates: VM list from the store, alive filter, hand the
return fcproc.New(d.runner, fcproc.Config{ // resulting map to HostNetwork.replaceDNS.
FirecrackerBin: d.config.FirecrackerBin, func (d *Daemon) rebuildDNS(ctx context.Context) error {
BridgeName: d.config.BridgeName, if d.net == nil {
BridgeIP: d.config.BridgeIP, return nil
CIDR: d.config.CIDR, }
RuntimeDir: d.layout.RuntimeDir, vms, err := d.store.ListVMs(ctx)
}, d.logger) if err != nil {
return err
}
records := make(map[string]string)
for _, vm := range vms {
if !d.vmAlive(vm) {
continue
}
if strings.TrimSpace(vm.Runtime.GuestIP) == "" {
continue
}
records[vmDNSRecordName(vm.Name)] = vm.Runtime.GuestIP
}
return d.hostNet().replaceDNS(records)
} }
func (d *Daemon) ensureBridge(ctx context.Context) error { // vmDNSRecordName is a small indirection so the dns-record-name
return d.fc().EnsureBridge(ctx) // helper is not directly pulled into every file that used to import
} // vmdns for this one call. Equivalent to vmdns.RecordName.
func vmDNSRecordName(name string) string {
func (d *Daemon) ensureSocketDir() error { return strings.ToLower(strings.TrimSpace(name)) + ".vm"
return d.fc().EnsureSocketDir()
}
func (d *Daemon) createTap(ctx context.Context, tap string) error {
return d.fc().CreateTap(ctx, tap)
}
func (d *Daemon) firecrackerBinary() (string, error) {
return d.fc().ResolveBinary()
}
func (d *Daemon) ensureSocketAccess(ctx context.Context, socketPath, label string) error {
return d.fc().EnsureSocketAccess(ctx, socketPath, label)
}
func (d *Daemon) findFirecrackerPID(ctx context.Context, apiSock string) (int, error) {
return d.fc().FindPID(ctx, apiSock)
}
func (d *Daemon) resolveFirecrackerPID(ctx context.Context, machine *firecracker.Machine, apiSock string) int {
return d.fc().ResolvePID(ctx, machine, apiSock)
}
func (d *Daemon) sendCtrlAltDel(ctx context.Context, vm model.VMRecord) error {
return d.fc().SendCtrlAltDel(ctx, vm.Runtime.APISockPath)
}
func (d *Daemon) waitForExit(ctx context.Context, pid int, apiSock string, timeout time.Duration) error {
return d.fc().WaitForExit(ctx, pid, apiSock, timeout)
}
func (d *Daemon) killVMProcess(ctx context.Context, pid int) error {
return d.fc().Kill(ctx, pid)
} }
// cleanupRuntime tears down the host-side state for a VM: firecracker
// process, DM snapshot, capabilities, tap, sockets. Stays on *Daemon
// for now because it reaches into handles (VMService-owned) and
// capabilities (still on Daemon). Phase 4 will move it to VMService.
func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserveDisks bool) error { func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserveDisks bool) error {
if d.logger != nil { if d.logger != nil {
d.logger.Debug("cleanup runtime", append(vmLogAttrs(vm), "preserve_disks", preserveDisks)...) d.logger.Debug("cleanup runtime", append(vmLogAttrs(vm), "preserve_disks", preserveDisks)...)
@ -88,17 +69,17 @@ func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserve
h := d.vmHandles(vm.ID) h := d.vmHandles(vm.ID)
cleanupPID := h.PID cleanupPID := h.PID
if vm.Runtime.APISockPath != "" { if vm.Runtime.APISockPath != "" {
if pid, err := d.findFirecrackerPID(ctx, vm.Runtime.APISockPath); err == nil && pid > 0 { if pid, err := d.hostNet().findFirecrackerPID(ctx, vm.Runtime.APISockPath); err == nil && pid > 0 {
cleanupPID = pid cleanupPID = pid
} }
} }
if cleanupPID > 0 && system.ProcessRunning(cleanupPID, vm.Runtime.APISockPath) { if cleanupPID > 0 && system.ProcessRunning(cleanupPID, vm.Runtime.APISockPath) {
_ = d.killVMProcess(ctx, cleanupPID) _ = d.hostNet().killVMProcess(ctx, cleanupPID)
if err := d.waitForExit(ctx, cleanupPID, vm.Runtime.APISockPath, 30*time.Second); err != nil { if err := d.hostNet().waitForExit(ctx, cleanupPID, vm.Runtime.APISockPath, 30*time.Second); err != nil {
return err return err
} }
} }
snapshotErr := d.cleanupDMSnapshot(ctx, dmSnapshotHandles{ snapshotErr := d.hostNet().cleanupDMSnapshot(ctx, dmSnapshotHandles{
BaseLoop: h.BaseLoop, BaseLoop: h.BaseLoop,
COWLoop: h.COWLoop, COWLoop: h.COWLoop,
DMName: h.DMName, DMName: h.DMName,
@ -107,7 +88,7 @@ func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserve
featureErr := d.cleanupCapabilityState(ctx, vm) featureErr := d.cleanupCapabilityState(ctx, vm)
var tapErr error var tapErr error
if h.TapDevice != "" { if h.TapDevice != "" {
tapErr = d.releaseTap(ctx, h.TapDevice) tapErr = d.hostNet().releaseTap(ctx, h.TapDevice)
} }
if vm.Runtime.APISockPath != "" { if vm.Runtime.APISockPath != "" {
_ = os.Remove(vm.Runtime.APISockPath) _ = os.Remove(vm.Runtime.APISockPath)
@ -125,92 +106,6 @@ func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserve
return errors.Join(snapshotErr, featureErr, tapErr) return errors.Join(snapshotErr, featureErr, tapErr)
} }
func defaultVSockPath(runtimeDir, vmID string) string {
return filepath.Join(runtimeDir, "fc-"+system.ShortID(vmID)+".vsock")
}
func defaultVSockCID(guestIP string) (uint32, error) {
ip := net.ParseIP(strings.TrimSpace(guestIP)).To4()
if ip == nil {
return 0, fmt.Errorf("guest IP is not IPv4: %q", guestIP)
}
return 10000 + uint32(ip[3]), nil
}
func waitForGuestVSockAgent(ctx context.Context, logger *slog.Logger, socketPath string, timeout time.Duration) error {
if strings.TrimSpace(socketPath) == "" {
return errors.New("vsock path is required")
}
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
ticker := time.NewTicker(vsockReadyPoll)
defer ticker.Stop()
var lastErr error
for {
pingCtx, pingCancel := context.WithTimeout(waitCtx, 3*time.Second)
err := vsockagent.Health(pingCtx, logger, socketPath)
pingCancel()
if err == nil {
return nil
}
lastErr = err
select {
case <-waitCtx.Done():
if lastErr != nil {
return fmt.Errorf("guest vsock agent not ready: %w", lastErr)
}
return errors.New("guest vsock agent not ready before timeout")
case <-ticker.C:
}
}
}
func (d *Daemon) setDNS(ctx context.Context, vmName, guestIP string) error {
if d.vmDNS == nil {
return nil
}
if err := d.vmDNS.Set(vmdns.RecordName(vmName), guestIP); err != nil {
return err
}
d.ensureVMDNSResolverRouting(ctx)
return nil
}
func (d *Daemon) removeDNS(ctx context.Context, dnsName string) error {
if dnsName == "" {
return nil
}
if d.vmDNS == nil {
return nil
}
return d.vmDNS.Remove(dnsName)
}
func (d *Daemon) rebuildDNS(ctx context.Context) error {
if d.vmDNS == nil {
return nil
}
vms, err := d.store.ListVMs(ctx)
if err != nil {
return err
}
records := make(map[string]string)
for _, vm := range vms {
if !d.vmAlive(vm) {
continue
}
if strings.TrimSpace(vm.Runtime.GuestIP) == "" {
continue
}
records[vmdns.RecordName(vm.Name)] = vm.Runtime.GuestIP
}
return d.vmDNS.Replace(records)
}
func (d *Daemon) generateName(ctx context.Context) (string, error) { func (d *Daemon) generateName(ctx context.Context) (string, error) {
_ = ctx _ = ctx
if name := strings.TrimSpace(namegen.Generate()); name != "" { if name := strings.TrimSpace(namegen.Generate()); name != "" {

View file

@ -200,7 +200,7 @@ func (d *Daemon) rediscoverHandles(ctx context.Context, vm model.VMRecord) (mode
if apiSock == "" { if apiSock == "" {
return saved, false, nil return saved, false, nil
} }
if pid, pidErr := d.findFirecrackerPID(ctx, apiSock); pidErr == nil && pid > 0 { if pid, pidErr := d.hostNet().findFirecrackerPID(ctx, apiSock); pidErr == nil && pid > 0 {
saved.PID = pid saved.PID = pid
return saved, true, nil return saved, true, nil
} }

View file

@ -56,11 +56,11 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
} }
d.clearVMHandles(vm) d.clearVMHandles(vm)
op.stage("bridge") op.stage("bridge")
if err := d.ensureBridge(ctx); err != nil { if err := d.hostNet().ensureBridge(ctx); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
op.stage("socket_dir") op.stage("socket_dir")
if err := d.ensureSocketDir(); err != nil { if err := d.hostNet().ensureSocketDir(); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
@ -92,7 +92,7 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
op.stage("dm_snapshot", "dm_name", dmName) op.stage("dm_snapshot", "dm_name", dmName)
vmCreateStage(ctx, "prepare_rootfs", "creating root filesystem snapshot") vmCreateStage(ctx, "prepare_rootfs", "creating root filesystem snapshot")
snapHandles, err := d.createDMSnapshot(ctx, image.RootfsPath, vm.Runtime.SystemOverlay, dmName) snapHandles, err := d.hostNet().createDMSnapshot(ctx, image.RootfsPath, vm.Runtime.SystemOverlay, dmName)
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
@ -138,7 +138,7 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
return cleanupOnErr(err) return cleanupOnErr(err)
} }
op.stage("tap") op.stage("tap")
tap, err := d.acquireTap(ctx, tapName) tap, err := d.hostNet().acquireTap(ctx, tapName)
if err != nil { if err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
@ -150,7 +150,7 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
} }
op.stage("firecracker_binary") op.stage("firecracker_binary")
fcPath, err := d.firecrackerBinary() fcPath, err := d.hostNet().firecrackerBinary()
if err != nil { if err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
@ -200,23 +200,23 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
// Use a fresh context: the request ctx may already be cancelled (client // Use a fresh context: the request ctx may already be cancelled (client
// disconnect), but we still need the PID so cleanupRuntime can kill the // disconnect), but we still need the PID so cleanupRuntime can kill the
// Firecracker process that was spawned before the failure. // Firecracker process that was spawned before the failure.
live.PID = d.resolveFirecrackerPID(context.Background(), machine, apiSock) live.PID = d.hostNet().resolveFirecrackerPID(context.Background(), machine, apiSock)
d.setVMHandles(vm, live) d.setVMHandles(vm, live)
return cleanupOnErr(err) return cleanupOnErr(err)
} }
live.PID = d.resolveFirecrackerPID(context.Background(), machine, apiSock) live.PID = d.hostNet().resolveFirecrackerPID(context.Background(), machine, apiSock)
d.setVMHandles(vm, live) d.setVMHandles(vm, live)
op.debugStage("firecracker_started", "pid", live.PID) op.debugStage("firecracker_started", "pid", live.PID)
op.stage("socket_access", "api_socket", apiSock) op.stage("socket_access", "api_socket", apiSock)
if err := d.ensureSocketAccess(ctx, apiSock, "firecracker api socket"); err != nil { if err := d.hostNet().ensureSocketAccess(ctx, apiSock, "firecracker api socket"); err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
op.stage("vsock_access", "vsock_path", vm.Runtime.VSockPath, "vsock_cid", vm.Runtime.VSockCID) op.stage("vsock_access", "vsock_path", vm.Runtime.VSockPath, "vsock_cid", vm.Runtime.VSockCID)
if err := d.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { if err := d.hostNet().ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
vmCreateStage(ctx, "wait_vsock_agent", "waiting for guest vsock agent") vmCreateStage(ctx, "wait_vsock_agent", "waiting for guest vsock agent")
if err := waitForGuestVSockAgent(ctx, d.logger, vm.Runtime.VSockPath, vsockReadyWait); err != nil { if err := d.hostNet().waitForGuestVSockAgent(ctx, vm.Runtime.VSockPath, vsockReadyWait); err != nil {
return cleanupOnErr(err) return cleanupOnErr(err)
} }
op.stage("post_start_features") op.stage("post_start_features")
@ -264,11 +264,11 @@ func (d *Daemon) stopVMLocked(ctx context.Context, current model.VMRecord) (vm m
} }
pid := d.vmHandles(vm.ID).PID pid := d.vmHandles(vm.ID).PID
op.stage("graceful_shutdown") op.stage("graceful_shutdown")
if err := d.sendCtrlAltDel(ctx, vm); err != nil { if err := d.hostNet().sendCtrlAltDel(ctx, vm.Runtime.APISockPath); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
op.stage("wait_for_exit", "pid", pid) op.stage("wait_for_exit", "pid", pid)
if err := d.waitForExit(ctx, pid, vm.Runtime.APISockPath, gracefulShutdownWait); err != nil { if err := d.hostNet().waitForExit(ctx, pid, vm.Runtime.APISockPath, gracefulShutdownWait); err != nil {
if !errors.Is(err, errWaitForExitTimeout) { if !errors.Is(err, errWaitForExitTimeout) {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
@ -328,7 +328,7 @@ func (d *Daemon) killVMLocked(ctx context.Context, current model.VMRecord, signa
return model.VMRecord{}, err return model.VMRecord{}, err
} }
op.stage("wait_for_exit", "pid", pid) op.stage("wait_for_exit", "pid", pid)
if err := d.waitForExit(ctx, pid, vm.Runtime.APISockPath, 30*time.Second); err != nil { if err := d.hostNet().waitForExit(ctx, pid, vm.Runtime.APISockPath, 30*time.Second); err != nil {
if !errors.Is(err, errWaitForExitTimeout) { if !errors.Is(err, errWaitForExitTimeout) {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
@ -395,7 +395,7 @@ func (d *Daemon) deleteVMLocked(ctx context.Context, current model.VMRecord) (vm
if d.vmAlive(vm) { if d.vmAlive(vm) {
pid := d.vmHandles(vm.ID).PID pid := d.vmHandles(vm.ID).PID
op.stage("kill_running_vm", "pid", pid) op.stage("kill_running_vm", "pid", pid)
_ = d.killVMProcess(ctx, pid) _ = d.hostNet().killVMProcess(ctx, pid)
} }
op.stage("cleanup_runtime") op.stage("cleanup_runtime")
if err := d.cleanupRuntime(ctx, vm, false); err != nil { if err := d.cleanupRuntime(ctx, vm, false); err != nil {

View file

@ -35,7 +35,7 @@ func (d *Daemon) HealthVM(ctx context.Context, idOrName string) (result api.VMHe
if vm.Runtime.VSockCID == 0 { if vm.Runtime.VSockCID == 0 {
return model.VMRecord{}, errors.New("vm has no vsock cid") return model.VMRecord{}, errors.New("vm has no vsock cid")
} }
if err := d.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { if err := d.hostNet().ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second) pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
@ -123,8 +123,8 @@ func (d *Daemon) stopStaleVMs(ctx context.Context) (err error) {
return nil return nil
} }
op.stage("stopping_vm", vmLogAttrs(vm)...) op.stage("stopping_vm", vmLogAttrs(vm)...)
_ = d.sendCtrlAltDel(ctx, vm) _ = d.hostNet().sendCtrlAltDel(ctx, vm.Runtime.APISockPath)
_ = d.waitForExit(ctx, d.vmHandles(vm.ID).PID, vm.Runtime.APISockPath, 10*time.Second) _ = d.hostNet().waitForExit(ctx, d.vmHandles(vm.ID).PID, vm.Runtime.APISockPath, 10*time.Second)
_ = d.cleanupRuntime(ctx, vm, true) _ = d.cleanupRuntime(ctx, vm, true)
vm.State = model.VMStateStopped vm.State = model.VMStateStopped
vm.Runtime.State = model.VMStateStopped vm.Runtime.State = model.VMStateStopped

View file

@ -212,7 +212,7 @@ func TestRebuildDNSIncludesOnlyLiveRunningVMs(t *testing.T) {
} }
}) })
d := &Daemon{store: db, vmDNS: server} d := &Daemon{store: db, net: &HostNetwork{vmDNS: server}}
// rebuildDNS reads the alive check from the handle cache. Seed // rebuildDNS reads the alive check from the handle cache. Seed
// the live VM with its real PID; leave the stale entry with a PID // the live VM with its real PID; leave the stale entry with a PID
// that definitely isn't running (999999 ≫ max PID on most hosts). // that definitely isn't running (999999 ≫ max PID on most hosts).
@ -512,7 +512,8 @@ func TestWaitForGuestVSockAgentRetriesUntilHealthy(t *testing.T) {
serverDone <- errors.New("health probe did not retry") serverDone <- errors.New("health probe did not retry")
}() }()
if err := waitForGuestVSockAgent(context.Background(), nil, socketPath, time.Second); err != nil { n := &HostNetwork{}
if err := n.waitForGuestVSockAgent(context.Background(), socketPath, time.Second); err != nil {
t.Fatalf("waitForGuestVSockAgent: %v", err) t.Fatalf("waitForGuestVSockAgent: %v", err)
} }
if err := <-serverDone; err != nil { if err := <-serverDone; err != nil {