package daemon import ( "context" "errors" "fmt" "log/slog" "net" "os" "path/filepath" "strconv" "strings" "time" "banger/internal/daemon/fcproc" "banger/internal/firecracker" "banger/internal/model" "banger/internal/namegen" "banger/internal/system" "banger/internal/vmdns" "banger/internal/vsockagent" ) var ( errWaitForExitTimeout = fcproc.ErrWaitForExitTimeout gracefulShutdownWait = 10 * time.Second vsockReadyWait = 30 * time.Second vsockReadyPoll = 200 * time.Millisecond ) // fc builds a fresh fcproc.Manager from the Daemon's current runner, config, // and layout. Manager is stateless beyond those handles, so constructing per // call keeps tests that build Daemon literals working without extra wiring. func (d *Daemon) fc() *fcproc.Manager { return fcproc.New(d.runner, fcproc.Config{ FirecrackerBin: d.config.FirecrackerBin, BridgeName: d.config.BridgeName, BridgeIP: d.config.BridgeIP, CIDR: d.config.CIDR, RuntimeDir: d.layout.RuntimeDir, }, d.logger) } func (d *Daemon) ensureBridge(ctx context.Context) error { return d.fc().EnsureBridge(ctx) } func (d *Daemon) ensureSocketDir() error { return d.fc().EnsureSocketDir() } func (d *Daemon) createTap(ctx context.Context, tap string) error { return d.fc().CreateTap(ctx, tap) } func (d *Daemon) firecrackerBinary() (string, error) { return d.fc().ResolveBinary() } func (d *Daemon) ensureSocketAccess(ctx context.Context, socketPath, label string) error { return d.fc().EnsureSocketAccess(ctx, socketPath, label) } func (d *Daemon) findFirecrackerPID(ctx context.Context, apiSock string) (int, error) { return d.fc().FindPID(ctx, apiSock) } func (d *Daemon) resolveFirecrackerPID(ctx context.Context, machine *firecracker.Machine, apiSock string) int { return d.fc().ResolvePID(ctx, machine, apiSock) } func (d *Daemon) sendCtrlAltDel(ctx context.Context, vm model.VMRecord) error { return d.fc().SendCtrlAltDel(ctx, vm.Runtime.APISockPath) } func (d *Daemon) waitForExit(ctx context.Context, pid int, apiSock string, timeout time.Duration) error { return d.fc().WaitForExit(ctx, pid, apiSock, timeout) } func (d *Daemon) killVMProcess(ctx context.Context, pid int) error { return d.fc().Kill(ctx, pid) } func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserveDisks bool) error { if d.logger != nil { d.logger.Debug("cleanup runtime", append(vmLogAttrs(vm), "preserve_disks", preserveDisks)...) } cleanupPID := vm.Runtime.PID if vm.Runtime.APISockPath != "" { if pid, err := d.findFirecrackerPID(ctx, vm.Runtime.APISockPath); err == nil && pid > 0 { cleanupPID = pid } } if cleanupPID > 0 && system.ProcessRunning(cleanupPID, vm.Runtime.APISockPath) { _ = d.killVMProcess(ctx, cleanupPID) if err := d.waitForExit(ctx, cleanupPID, vm.Runtime.APISockPath, 30*time.Second); err != nil { return err } } snapshotErr := d.cleanupDMSnapshot(ctx, dmSnapshotHandles{ BaseLoop: vm.Runtime.BaseLoop, COWLoop: vm.Runtime.COWLoop, DMName: vm.Runtime.DMName, DMDev: vm.Runtime.DMDev, }) featureErr := d.cleanupCapabilityState(ctx, vm) var tapErr error if vm.Runtime.TapDevice != "" { tapErr = d.releaseTap(ctx, vm.Runtime.TapDevice) } if vm.Runtime.APISockPath != "" { _ = os.Remove(vm.Runtime.APISockPath) } if vm.Runtime.VSockPath != "" { _ = os.Remove(vm.Runtime.VSockPath) } if !preserveDisks && vm.Runtime.VMDir != "" { return errors.Join(snapshotErr, featureErr, tapErr, os.RemoveAll(vm.Runtime.VMDir)) } return errors.Join(snapshotErr, featureErr, tapErr) } func clearRuntimeHandles(vm *model.VMRecord) { vm.Runtime.PID = 0 vm.Runtime.APISockPath = "" vm.Runtime.TapDevice = "" vm.Runtime.BaseLoop = "" vm.Runtime.COWLoop = "" vm.Runtime.DMName = "" vm.Runtime.DMDev = "" } func defaultVSockPath(runtimeDir, vmID string) string { return filepath.Join(runtimeDir, "fc-"+system.ShortID(vmID)+".vsock") } func defaultVSockCID(guestIP string) (uint32, error) { ip := net.ParseIP(strings.TrimSpace(guestIP)).To4() if ip == nil { return 0, fmt.Errorf("guest IP is not IPv4: %q", guestIP) } return 10000 + uint32(ip[3]), nil } func waitForGuestVSockAgent(ctx context.Context, logger *slog.Logger, socketPath string, timeout time.Duration) error { if strings.TrimSpace(socketPath) == "" { return errors.New("vsock path is required") } waitCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() ticker := time.NewTicker(vsockReadyPoll) defer ticker.Stop() var lastErr error for { pingCtx, pingCancel := context.WithTimeout(waitCtx, 3*time.Second) err := vsockagent.Health(pingCtx, logger, socketPath) pingCancel() if err == nil { return nil } lastErr = err select { case <-waitCtx.Done(): if lastErr != nil { return fmt.Errorf("guest vsock agent not ready: %w", lastErr) } return errors.New("guest vsock agent not ready before timeout") case <-ticker.C: } } } func (d *Daemon) setDNS(ctx context.Context, vmName, guestIP string) error { if d.vmDNS == nil { return nil } if err := d.vmDNS.Set(vmdns.RecordName(vmName), guestIP); err != nil { return err } d.ensureVMDNSResolverRouting(ctx) return nil } func (d *Daemon) removeDNS(ctx context.Context, dnsName string) error { if dnsName == "" { return nil } if d.vmDNS == nil { return nil } return d.vmDNS.Remove(dnsName) } func (d *Daemon) rebuildDNS(ctx context.Context) error { if d.vmDNS == nil { return nil } vms, err := d.store.ListVMs(ctx) if err != nil { return err } records := make(map[string]string) for _, vm := range vms { if vm.State != model.VMStateRunning { continue } if !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { continue } if strings.TrimSpace(vm.Runtime.GuestIP) == "" { continue } records[vmdns.RecordName(vm.Name)] = vm.Runtime.GuestIP } return d.vmDNS.Replace(records) } func (d *Daemon) generateName(ctx context.Context) (string, error) { _ = ctx if name := strings.TrimSpace(namegen.Generate()); name != "" { return name, nil } return "vm-" + strconv.FormatInt(time.Now().Unix(), 10), nil } func bridgePrefix(bridgeIP string) string { parts := strings.Split(bridgeIP, ".") if len(parts) < 3 { return bridgeIP } return strings.Join(parts[:3], ".") } func optionalIntOrDefault(value *int, fallback int) int { if value != nil { return *value } return fallback } func validateOptionalPositiveSetting(label string, value *int) error { if value == nil { return nil } if *value <= 0 { return fmt.Errorf("%s must be a positive integer", label) } return nil }