package daemon import ( "context" "errors" "fmt" "log/slog" "net" "os" "path/filepath" "strconv" "strings" "time" "banger/internal/firecracker" "banger/internal/model" "banger/internal/namegen" "banger/internal/system" "banger/internal/vmdns" "banger/internal/vsockagent" ) var ( errWaitForExitTimeout = errors.New("timed out waiting for VM to exit") gracefulShutdownWait = 10 * time.Second vsockReadyWait = 30 * time.Second vsockReadyPoll = 200 * time.Millisecond ) func (d *Daemon) ensureBridge(ctx context.Context) error { if _, err := d.runner.Run(ctx, "ip", "link", "show", d.config.BridgeName); err == nil { _, err = d.runner.RunSudo(ctx, "ip", "link", "set", d.config.BridgeName, "up") return err } if _, err := d.runner.RunSudo(ctx, "ip", "link", "add", "name", d.config.BridgeName, "type", "bridge"); err != nil { return err } if _, err := d.runner.RunSudo(ctx, "ip", "addr", "add", fmt.Sprintf("%s/%s", d.config.BridgeIP, d.config.CIDR), "dev", d.config.BridgeName); err != nil { return err } _, err := d.runner.RunSudo(ctx, "ip", "link", "set", d.config.BridgeName, "up") return err } func (d *Daemon) ensureSocketDir() error { return os.MkdirAll(d.layout.RuntimeDir, 0o755) } func (d *Daemon) createTap(ctx context.Context, tap string) error { if _, err := d.runner.Run(ctx, "ip", "link", "show", tap); err == nil { _, _ = d.runner.RunSudo(ctx, "ip", "link", "del", tap) } if _, err := d.runner.RunSudo(ctx, "ip", "tuntap", "add", "dev", tap, "mode", "tap", "user", strconv.Itoa(os.Getuid()), "group", strconv.Itoa(os.Getgid())); err != nil { return err } if _, err := d.runner.RunSudo(ctx, "ip", "link", "set", tap, "master", d.config.BridgeName); err != nil { return err } if _, err := d.runner.RunSudo(ctx, "ip", "link", "set", tap, "up"); err != nil { return err } _, err := d.runner.RunSudo(ctx, "ip", "link", "set", d.config.BridgeName, "up") return err } func (d *Daemon) firecrackerBinary() (string, error) { if d.config.FirecrackerBin == "" { return "", fmt.Errorf("firecracker binary not configured; install firecracker or set firecracker_bin") } path := d.config.FirecrackerBin if strings.ContainsRune(path, os.PathSeparator) { if !exists(path) { return "", fmt.Errorf("firecracker binary not found at %s; install firecracker or set firecracker_bin", path) } return path, nil } resolved, err := system.LookupExecutable(path) if err != nil { return "", fmt.Errorf("firecracker binary %q not found in PATH; install firecracker or set firecracker_bin", path) } return resolved, nil } func (d *Daemon) ensureSocketAccess(ctx context.Context, socketPath, label string) error { if err := waitForPath(ctx, socketPath, 5*time.Second, label); err != nil { return err } if _, err := d.runner.RunSudo(ctx, "chown", fmt.Sprintf("%d:%d", os.Getuid(), os.Getgid()), socketPath); err != nil { return err } _, err := d.runner.RunSudo(ctx, "chmod", "600", socketPath) return err } func (d *Daemon) findFirecrackerPID(ctx context.Context, apiSock string) (int, error) { out, err := d.runner.Run(ctx, "pgrep", "-n", "-f", apiSock) if err != nil { return 0, err } return strconv.Atoi(strings.TrimSpace(string(out))) } func (d *Daemon) resolveFirecrackerPID(ctx context.Context, machine *firecracker.Machine, apiSock string) int { if pid, err := d.findFirecrackerPID(ctx, apiSock); err == nil && pid > 0 { return pid } if machine != nil { if pid, err := machine.PID(); err == nil && pid > 0 { return pid } } return 0 } func (d *Daemon) sendCtrlAltDel(ctx context.Context, vm model.VMRecord) error { if err := d.ensureSocketAccess(ctx, vm.Runtime.APISockPath, "firecracker api socket"); err != nil { return err } client := firecracker.New(vm.Runtime.APISockPath, d.logger) return client.SendCtrlAltDel(ctx) } func (d *Daemon) waitForExit(ctx context.Context, pid int, apiSock string, timeout time.Duration) error { deadline := time.Now().Add(timeout) for { if !system.ProcessRunning(pid, apiSock) { return nil } if time.Now().After(deadline) { return errWaitForExitTimeout } select { case <-ctx.Done(): return ctx.Err() case <-time.After(100 * time.Millisecond): } } } func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserveDisks bool) error { if d.logger != nil { d.logger.Debug("cleanup runtime", append(vmLogAttrs(vm), "preserve_disks", preserveDisks)...) } cleanupPID := vm.Runtime.PID if vm.Runtime.APISockPath != "" { if pid, err := d.findFirecrackerPID(ctx, vm.Runtime.APISockPath); err == nil && pid > 0 { cleanupPID = pid } } if cleanupPID > 0 && system.ProcessRunning(cleanupPID, vm.Runtime.APISockPath) { _ = d.killVMProcess(ctx, cleanupPID) if err := d.waitForExit(ctx, cleanupPID, vm.Runtime.APISockPath, 30*time.Second); err != nil { return err } } snapshotErr := d.cleanupDMSnapshot(ctx, dmSnapshotHandles{ BaseLoop: vm.Runtime.BaseLoop, COWLoop: vm.Runtime.COWLoop, DMName: vm.Runtime.DMName, DMDev: vm.Runtime.DMDev, }) featureErr := d.cleanupCapabilityState(ctx, vm) var tapErr error if vm.Runtime.TapDevice != "" { tapErr = d.releaseTap(ctx, vm.Runtime.TapDevice) } if vm.Runtime.APISockPath != "" { _ = os.Remove(vm.Runtime.APISockPath) } if vm.Runtime.VSockPath != "" { _ = os.Remove(vm.Runtime.VSockPath) } if !preserveDisks && vm.Runtime.VMDir != "" { return errors.Join(snapshotErr, featureErr, tapErr, os.RemoveAll(vm.Runtime.VMDir)) } return errors.Join(snapshotErr, featureErr, tapErr) } func clearRuntimeHandles(vm *model.VMRecord) { vm.Runtime.PID = 0 vm.Runtime.APISockPath = "" vm.Runtime.TapDevice = "" vm.Runtime.BaseLoop = "" vm.Runtime.COWLoop = "" vm.Runtime.DMName = "" vm.Runtime.DMDev = "" } func defaultVSockPath(runtimeDir, vmID string) string { return filepath.Join(runtimeDir, "fc-"+system.ShortID(vmID)+".vsock") } func defaultVSockCID(guestIP string) (uint32, error) { ip := net.ParseIP(strings.TrimSpace(guestIP)).To4() if ip == nil { return 0, fmt.Errorf("guest IP is not IPv4: %q", guestIP) } return 10000 + uint32(ip[3]), nil } func waitForPath(ctx context.Context, path string, timeout time.Duration, label string) error { deadline := time.Now().Add(timeout) for { if _, err := os.Stat(path); err == nil { return nil } else if err != nil && !os.IsNotExist(err) { return err } if time.Now().After(deadline) { return fmt.Errorf("%s not ready: %s: %w", label, path, context.DeadlineExceeded) } select { case <-ctx.Done(): return ctx.Err() case <-time.After(100 * time.Millisecond): } } } func waitForGuestVSockAgent(ctx context.Context, logger *slog.Logger, socketPath string, timeout time.Duration) error { if strings.TrimSpace(socketPath) == "" { return errors.New("vsock path is required") } waitCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() ticker := time.NewTicker(vsockReadyPoll) defer ticker.Stop() var lastErr error for { pingCtx, pingCancel := context.WithTimeout(waitCtx, 3*time.Second) err := vsockagent.Health(pingCtx, logger, socketPath) pingCancel() if err == nil { return nil } lastErr = err select { case <-waitCtx.Done(): if lastErr != nil { return fmt.Errorf("guest vsock agent not ready: %w", lastErr) } return errors.New("guest vsock agent not ready before timeout") case <-ticker.C: } } } func (d *Daemon) setDNS(ctx context.Context, vmName, guestIP string) error { if d.vmDNS == nil { return nil } if err := d.vmDNS.Set(vmdns.RecordName(vmName), guestIP); err != nil { return err } d.ensureVMDNSResolverRouting(ctx) return nil } func (d *Daemon) removeDNS(ctx context.Context, dnsName string) error { if dnsName == "" { return nil } if d.vmDNS == nil { return nil } return d.vmDNS.Remove(dnsName) } func (d *Daemon) rebuildDNS(ctx context.Context) error { if d.vmDNS == nil { return nil } vms, err := d.store.ListVMs(ctx) if err != nil { return err } records := make(map[string]string) for _, vm := range vms { if vm.State != model.VMStateRunning { continue } if !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { continue } if strings.TrimSpace(vm.Runtime.GuestIP) == "" { continue } records[vmdns.RecordName(vm.Name)] = vm.Runtime.GuestIP } return d.vmDNS.Replace(records) } func (d *Daemon) killVMProcess(ctx context.Context, pid int) error { _, err := d.runner.RunSudo(ctx, "kill", "-KILL", strconv.Itoa(pid)) return err } func (d *Daemon) generateName(ctx context.Context) (string, error) { _ = ctx if name := strings.TrimSpace(namegen.Generate()); name != "" { return name, nil } return "vm-" + strconv.FormatInt(time.Now().Unix(), 10), nil } func bridgePrefix(bridgeIP string) string { parts := strings.Split(bridgeIP, ".") if len(parts) < 3 { return bridgeIP } return strings.Join(parts[:3], ".") } func optionalIntOrDefault(value *int, fallback int) int { if value != nil { return *value } return fallback } func validateOptionalPositiveSetting(label string, value *int) error { if value == nil { return nil } if *value <= 0 { return fmt.Errorf("%s must be a positive integer", label) } return nil }