Replace mapdns with daemon DNS

Serve daemon-managed .vm names directly from bangerd on 127.0.0.1:42069 instead of shelling out to mapdns. This keeps DNS state tied to VM lifecycle and lets the daemon rebuild records from running VMs after startup or reconcile.

Add a small in-process authoritative DNS server, register and remove records from the VM start/stop/delete paths, and show the listener in daemon status. Remove the mapdns config and preflight surface, stop helper-flow DNS publishing in customize.sh and interactive.sh, drop dns.sh from the runtime bundle, and update docs/tests for the new local-resolver integration model.

Validated with GOCACHE=/tmp/banger-gocache go test ./..., GOCACHE=/tmp/banger-gocache make build, and bash -n customize.sh interactive.sh.
This commit is contained in:
Thales Maciel 2026-03-17 15:49:35 -03:00
parent 430f66d5dd
commit 0a0b0b617b
No known key found for this signature in database
GPG key ID: 33112E6833C34679
24 changed files with 576 additions and 278 deletions

View file

@ -20,6 +20,7 @@ import (
"banger/internal/paths"
"banger/internal/rpc"
"banger/internal/system"
"banger/internal/vmdns"
"github.com/spf13/cobra"
)
@ -112,10 +113,10 @@ func newDaemonCommand() *cobra.Command {
}
ping, pingErr := rpc.Call[api.PingResult](cmd.Context(), layout.SocketPath, "ping", api.Empty{})
if pingErr != nil {
_, err = fmt.Fprintf(cmd.OutOrStdout(), "stopped\nsocket: %s\nlog: %s\n", layout.SocketPath, layout.DaemonLog)
_, err = fmt.Fprintf(cmd.OutOrStdout(), "stopped\nsocket: %s\nlog: %s\ndns: %s\n", layout.SocketPath, layout.DaemonLog, vmdns.DefaultListenAddr)
return err
}
_, err = fmt.Fprintf(cmd.OutOrStdout(), "running\npid: %d\nsocket: %s\nlog: %s\n", ping.PID, layout.SocketPath, layout.DaemonLog)
_, err = fmt.Fprintf(cmd.OutOrStdout(), "running\npid: %d\nsocket: %s\nlog: %s\ndns: %s\n", ping.PID, layout.SocketPath, layout.DaemonLog, vmdns.DefaultListenAddr)
return err
},
},

View file

@ -263,6 +263,9 @@ func TestDaemonStatusIncludesLogPathWhenStopped(t *testing.T) {
if !strings.Contains(output, "log: "+filepath.Join(stateHome, "banger", "bangerd.log")) {
t.Fatalf("output = %q, want daemon log path", output)
}
if !strings.Contains(output, "dns: 127.0.0.1:42069") {
t.Fatalf("output = %q, want dns listener", output)
}
}
func TestBuildDaemonCommandIsDetachedFromCallerContext(t *testing.T) {

View file

@ -18,8 +18,6 @@ type fileConfig struct {
RepoRoot string `toml:"repo_root"`
LogLevel string `toml:"log_level"`
FirecrackerBin string `toml:"firecracker_bin"`
MapDNSBin string `toml:"mapdns_bin"`
MapDNSDataFile string `toml:"mapdns_data_file"`
SSHKeyPath string `toml:"ssh_key_path"`
NamegenPath string `toml:"namegen_path"`
CustomizeScript string `toml:"customize_script"`
@ -80,12 +78,6 @@ func Load(layout paths.Layout) (model.DaemonConfig, error) {
if file.LogLevel != "" {
cfg.LogLevel = file.LogLevel
}
if file.MapDNSBin != "" {
cfg.MapDNSBin = file.MapDNSBin
}
if file.MapDNSDataFile != "" {
cfg.MapDNSDataFile = file.MapDNSDataFile
}
if file.SSHKeyPath != "" {
cfg.SSHKeyPath = file.SSHKeyPath
}
@ -149,18 +141,9 @@ func Load(layout paths.Layout) (model.DaemonConfig, error) {
}
cfg.MetricsPollInterval = duration
}
if value := os.Getenv("BANGER_MAPDNS_BIN"); value != "" {
cfg.MapDNSBin = value
}
if value := os.Getenv("BANGER_MAPDNS_DATA_FILE"); value != "" {
cfg.MapDNSDataFile = value
}
if value := os.Getenv("BANGER_LOG_LEVEL"); value != "" {
cfg.LogLevel = value
}
if cfg.MapDNSBin == "" {
cfg.MapDNSBin = "mapdns"
}
return cfg, nil
}

View file

@ -127,9 +127,7 @@ func TestLoadFallsBackToLegacyRuntimeLayoutWithoutBundleMetadata(t *testing.T) {
}
}
func TestLoadAppliesMapDNSEnvOverrides(t *testing.T) {
t.Setenv("BANGER_MAPDNS_BIN", "/opt/bin/mapdns")
t.Setenv("BANGER_MAPDNS_DATA_FILE", "/tmp/mapdns-records.json")
func TestLoadAppliesLogLevelEnvOverride(t *testing.T) {
t.Setenv("BANGER_LOG_LEVEL", "debug")
cfg, err := Load(paths.Layout{ConfigDir: t.TempDir()})
@ -137,12 +135,6 @@ func TestLoadAppliesMapDNSEnvOverrides(t *testing.T) {
t.Fatalf("Load: %v", err)
}
if cfg.MapDNSBin != "/opt/bin/mapdns" {
t.Fatalf("MapDNSBin = %q", cfg.MapDNSBin)
}
if cfg.MapDNSDataFile != "/tmp/mapdns-records.json" {
t.Fatalf("MapDNSDataFile = %q", cfg.MapDNSDataFile)
}
if cfg.LogLevel != "debug" {
t.Fatalf("LogLevel = %q", cfg.LogLevel)
}

View file

@ -22,6 +22,7 @@ import (
"banger/internal/rpc"
"banger/internal/store"
"banger/internal/system"
"banger/internal/vmdns"
)
type Daemon struct {
@ -35,10 +36,11 @@ type Daemon struct {
once sync.Once
pid int
listener net.Listener
vmDNS *vmdns.Server
requestHandler func(context.Context, rpc.Request) rpc.Response
}
func Open(ctx context.Context) (*Daemon, error) {
func Open(ctx context.Context) (d *Daemon, err error) {
layout, err := paths.Resolve()
if err != nil {
return nil, err
@ -59,7 +61,7 @@ func Open(ctx context.Context) (*Daemon, error) {
if err != nil {
return nil, err
}
d := &Daemon{
d = &Daemon{
layout: layout,
config: cfg,
store: db,
@ -69,11 +71,20 @@ func Open(ctx context.Context) (*Daemon, error) {
pid: os.Getpid(),
}
d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "runtime_dir", cfg.RuntimeDir, "log_level", cfg.LogLevel)
if err := d.ensureDefaultImage(ctx); err != nil {
if err = d.startVMDNS(vmdns.DefaultListenAddr); err != nil {
d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error())
return nil, err
}
defer func() {
if err != nil {
_ = d.stopVMDNS()
}
}()
if err = d.ensureDefaultImage(ctx); err != nil {
d.logger.Error("daemon open failed", "stage", "ensure_default_image", "error", err.Error())
return nil, err
}
if err := d.reconcile(ctx); err != nil {
if err = d.reconcile(ctx); err != nil {
d.logger.Error("daemon open failed", "stage", "reconcile", "error", err.Error())
return nil, err
}
@ -90,7 +101,7 @@ func (d *Daemon) Close() error {
if d.listener != nil {
_ = d.listener.Close()
}
err = d.store.Close()
err = errors.Join(d.stopVMDNS(), d.store.Close())
})
return err
}
@ -358,6 +369,27 @@ func (d *Daemon) backgroundLoop() {
}
}
func (d *Daemon) startVMDNS(addr string) error {
server, err := vmdns.New(addr, d.logger)
if err != nil {
return err
}
d.vmDNS = server
if d.logger != nil {
d.logger.Info("vm dns serving", "dns_addr", server.Addr())
}
return nil
}
func (d *Daemon) stopVMDNS() error {
if d.vmDNS == nil {
return nil
}
err := d.vmDNS.Close()
d.vmDNS = nil
return err
}
func (d *Daemon) ensureDefaultImage(ctx context.Context) error {
if d.config.DefaultImageName == "" {
return nil
@ -477,6 +509,9 @@ func (d *Daemon) reconcile(ctx context.Context) error {
return op.fail(err, vmLogAttrs(vm)...)
}
}
if err := d.rebuildDNS(ctx); err != nil {
return op.fail(err)
}
op.done()
return nil
}

View file

@ -283,58 +283,36 @@ func writeDefaultImageArtifacts(t *testing.T, dir string) (rootfs, kernel, initr
return rootfs, kernel, initrd, modulesDir, packages
}
func TestSetDNSUsesConfiguredMapDNSDataFile(t *testing.T) {
func TestStartVMDNSFailsWhenAddressBusy(t *testing.T) {
t.Parallel()
dataFile := filepath.Join(t.TempDir(), "mapdns", "records.json")
runner := &scriptedRunner{
t: t,
steps: []runnerStep{
{
call: runnerCall{
name: "custom-mapdns",
args: []string{"set", "--data-file", dataFile, "devbox.vm", "172.16.0.8"},
},
},
},
}
d := &Daemon{
runner: runner,
config: model.DaemonConfig{
MapDNSBin: "custom-mapdns",
MapDNSDataFile: dataFile,
},
packetConn, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ListenPacket: %v", err)
}
defer packetConn.Close()
if err := d.setDNS(context.Background(), "devbox", "172.16.0.8"); err != nil {
t.Fatalf("setDNS: %v", err)
d := &Daemon{}
if err := d.startVMDNS(packetConn.LocalAddr().String()); err == nil {
t.Fatal("startVMDNS() succeeded on occupied address, want failure")
}
runner.assertExhausted()
}
func TestSetDNSUsesMapDNSDefaultsWhenDataFileUnset(t *testing.T) {
func TestSetDNSPublishesIntoDaemonServer(t *testing.T) {
t.Parallel()
runner := &scriptedRunner{
t: t,
steps: []runnerStep{
{
call: runnerCall{
name: "mapdns",
args: []string{"set", "devbox.vm", "172.16.0.8"},
},
},
},
}
d := &Daemon{
runner: runner,
config: model.DaemonConfig{},
d := &Daemon{}
if err := d.startVMDNS("127.0.0.1:0"); err != nil {
t.Fatalf("startVMDNS: %v", err)
}
defer d.stopVMDNS()
if err := d.setDNS(context.Background(), "devbox", "172.16.0.8"); err != nil {
t.Fatalf("setDNS: %v", err)
}
runner.assertExhausted()
if _, ok := d.vmDNS.Lookup("devbox.vm"); !ok {
t.Fatal("devbox.vm missing after setDNS")
}
}
func TestDispatchUsesPassedContext(t *testing.T) {

View file

@ -109,12 +109,6 @@ func (d *Daemon) BuildImage(ctx context.Context, params api.ImageBuildParams) (i
"BANGER_RUNTIME_DIR="+d.config.RuntimeDir,
"BANGER_STATE_DIR="+filepath.Join(d.layout.StateDir, "image-build"),
)
if d.config.MapDNSBin != "" {
cmd.Env = append(cmd.Env, "BANGER_MAPDNS_BIN="+d.config.MapDNSBin)
}
if d.config.MapDNSDataFile != "" {
cmd.Env = append(cmd.Env, "BANGER_MAPDNS_DATA_FILE="+d.config.MapDNSDataFile)
}
if err := cmd.Run(); err != nil {
_ = os.RemoveAll(artifactDir)
return model.Image{}, err

View file

@ -48,7 +48,7 @@ func TestStartVMLockedLogsBridgeFailure(t *testing.T) {
for _, name := range []string{
"sudo", "ip", "dmsetup", "losetup", "blockdev", "truncate", "pgrep", "ps",
"chown", "chmod", "kill", "e2cp", "e2rm", "debugfs", "mkfs.ext4", "mount",
"umount", "cp", "mapdns",
"umount", "cp",
} {
writeFakeExecutable(t, filepath.Join(binDir, name))
}
@ -98,7 +98,6 @@ func TestStartVMLockedLogsBridgeFailure(t *testing.T) {
BridgeIP: model.DefaultBridgeIP,
DefaultDNS: model.DefaultDNS,
FirecrackerBin: firecrackerBin,
MapDNSBin: "mapdns",
StatsPollInterval: model.DefaultStatsPollInterval,
},
runner: runner,
@ -130,7 +129,7 @@ func TestBuildImagePreservesBuildLogOnFailure(t *testing.T) {
}
binDir := t.TempDir()
for _, name := range []string{"sudo", "ip", "curl", "ssh", "jq", "sha256sum", "e2fsck", "resize2fs", "mapdns"} {
for _, name := range []string{"sudo", "ip", "curl", "ssh", "jq", "sha256sum", "e2fsck", "resize2fs"} {
writeFakeExecutable(t, filepath.Join(binDir, name))
}
bashPath, err := exec.LookPath("bash")
@ -169,7 +168,6 @@ func TestBuildImagePreservesBuildLogOnFailure(t *testing.T) {
config: model.DaemonConfig{
RuntimeDir: t.TempDir(),
CustomizeScript: script,
MapDNSBin: "mapdns",
DefaultImageName: "default",
},
store: store,

View file

@ -2,8 +2,6 @@ package daemon
import (
"context"
"os"
"path/filepath"
"strings"
"banger/internal/model"
@ -19,7 +17,6 @@ func (d *Daemon) validateStartPrereqs(ctx context.Context, vm model.VMRecord, im
checks.RequireCommand(command, toolHint(command))
}
checks.RequireExecutable(d.config.FirecrackerBin, "firecracker binary", hint)
checks.RequireExecutable(d.config.MapDNSBin, "mapdns binary", `install mapdns or set "mapdns_bin" / BANGER_MAPDNS_BIN`)
checks.RequireFile(image.RootfsPath, "rootfs image", "select a valid image or rebuild the runtime bundle")
checks.RequireFile(image.KernelPath, "kernel image", `set "default_kernel" or refresh the runtime bundle`)
if strings.TrimSpace(image.InitrdPath) != "" {
@ -33,14 +30,6 @@ func (d *Daemon) validateStartPrereqs(ctx context.Context, vm model.VMRecord, im
if vm.Spec.NATEnabled {
d.addNATPrereqs(ctx, checks)
}
if dataFile := strings.TrimSpace(d.config.MapDNSDataFile); dataFile != "" {
parent := filepath.Dir(dataFile)
if parent != "." && parent != "" {
if _, err := os.Stat(parent); err != nil && !os.IsNotExist(err) {
checks.Addf("mapdns data directory %s is not accessible (%v)", parent, err)
}
}
}
return checks.Err("vm start preflight failed")
}
@ -52,7 +41,6 @@ func (d *Daemon) validateImageBuildPrereqs(ctx context.Context, baseRootfs, kern
checks.RequireCommand(command, toolHint(command))
}
checks.RequireExecutable(d.config.CustomizeScript, "customize.sh helper", hint)
checks.RequireExecutable(d.config.MapDNSBin, "mapdns binary", `install mapdns or set "mapdns_bin" / BANGER_MAPDNS_BIN`)
checks.RequireFile(baseRootfs, "base rootfs image", `pass --base-rootfs or set "default_base_rootfs"`)
checks.RequireFile(kernelPath, "kernel image", `pass --kernel or set "default_kernel"`)
if strings.TrimSpace(initrdPath) != "" {
@ -109,8 +97,6 @@ func toolHint(command string) string {
return "install jq"
case "sha256sum":
return "install coreutils"
case "mapdns":
return `install mapdns or set "mapdns_bin" / BANGER_MAPDNS_BIN`
case "ssh":
return "install openssh-client"
case "bash":

View file

@ -15,6 +15,7 @@ import (
"banger/internal/model"
"banger/internal/paths"
"banger/internal/system"
"banger/internal/vmdns"
)
func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm model.VMRecord, err error) {
@ -100,7 +101,7 @@ func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm mo
Runtime: model.VMRuntime{
State: model.VMStateCreated,
GuestIP: guestIP,
DNSName: name + ".vm",
DNSName: vmdns.RecordName(name),
VMDir: vmDir,
SystemOverlay: filepath.Join(vmDir, "system.cow"),
WorkDiskPath: filepath.Join(vmDir, "root.ext4"),
@ -859,30 +860,44 @@ func clearRuntimeHandles(vm *model.VMRecord) {
}
func (d *Daemon) setDNS(ctx context.Context, vmName, guestIP string) error {
if dataFile := strings.TrimSpace(d.config.MapDNSDataFile); dataFile != "" {
if err := os.MkdirAll(filepath.Dir(dataFile), 0o755); err != nil {
return err
}
if d.vmDNS == nil {
return nil
}
_, err := d.runner.Run(ctx, d.mapdnsBinary(), d.mapdnsArgs("set", vmName+".vm", guestIP)...)
if err == nil && d.logger != nil {
d.logger.Debug("dns record set", "dns_name", vmName+".vm", "guest_ip", guestIP)
}
return err
return d.vmDNS.Set(vmdns.RecordName(vmName), guestIP)
}
func (d *Daemon) removeDNS(ctx context.Context, dnsName string) error {
if dnsName == "" {
return nil
}
_, err := d.runner.Run(ctx, d.mapdnsBinary(), d.mapdnsArgs("rm", dnsName)...)
if err != nil && strings.Contains(err.Error(), "not found") {
if d.vmDNS == nil {
return nil
}
if err == nil && d.logger != nil {
d.logger.Debug("dns record removed", "dns_name", dnsName)
return d.vmDNS.Remove(dnsName)
}
func (d *Daemon) rebuildDNS(ctx context.Context) error {
if d.vmDNS == nil {
return nil
}
return err
vms, err := d.store.ListVMs(ctx)
if err != nil {
return err
}
records := make(map[string]string)
for _, vm := range vms {
if vm.State != model.VMStateRunning {
continue
}
if !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
continue
}
if strings.TrimSpace(vm.Runtime.GuestIP) == "" {
continue
}
records[vmdns.RecordName(vm.Name)] = vm.Runtime.GuestIP
}
return d.vmDNS.Replace(records)
}
func (d *Daemon) killVMProcess(ctx context.Context, pid int) error {
@ -927,19 +942,3 @@ func validateOptionalPositiveSetting(label string, value *int) error {
}
return nil
}
func (d *Daemon) mapdnsBinary() string {
if value := strings.TrimSpace(d.config.MapDNSBin); value != "" {
return value
}
return "mapdns"
}
func (d *Daemon) mapdnsArgs(subcommand string, args ...string) []string {
out := []string{subcommand}
if value := strings.TrimSpace(d.config.MapDNSDataFile); value != "" {
out = append(out, "--data-file", value)
}
out = append(out, args...)
return out
}

View file

@ -15,6 +15,7 @@ import (
"banger/internal/model"
"banger/internal/paths"
"banger/internal/store"
"banger/internal/vmdns"
)
func TestFindVMPrefixResolution(t *testing.T) {
@ -143,6 +144,65 @@ func TestReconcileStopsStaleRunningVMAndClearsRuntimeHandles(t *testing.T) {
}
}
func TestRebuildDNSIncludesOnlyLiveRunningVMs(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := openDaemonStore(t)
liveSock := filepath.Join(t.TempDir(), "live.sock")
liveCmd := startFakeFirecrackerProcess(t, liveSock)
t.Cleanup(func() {
_ = liveCmd.Process.Kill()
_ = liveCmd.Wait()
})
live := testVM("live", "image-live", "172.16.0.21")
live.State = model.VMStateRunning
live.Runtime.State = model.VMStateRunning
live.Runtime.PID = liveCmd.Process.Pid
live.Runtime.APISockPath = liveSock
stale := testVM("stale", "image-stale", "172.16.0.22")
stale.State = model.VMStateRunning
stale.Runtime.State = model.VMStateRunning
stale.Runtime.PID = 999999
stale.Runtime.APISockPath = filepath.Join(t.TempDir(), "stale.sock")
stopped := testVM("stopped", "image-stopped", "172.16.0.23")
for _, vm := range []model.VMRecord{live, stale, stopped} {
if err := db.UpsertVM(ctx, vm); err != nil {
t.Fatalf("UpsertVM(%s): %v", vm.Name, err)
}
}
server, err := vmdns.New("127.0.0.1:0", nil)
if err != nil {
t.Fatalf("vmdns.New: %v", err)
}
t.Cleanup(func() {
if err := server.Close(); err != nil {
t.Fatalf("server.Close: %v", err)
}
})
d := &Daemon{store: db, vmDNS: server}
if err := d.rebuildDNS(ctx); err != nil {
t.Fatalf("rebuildDNS: %v", err)
}
if _, ok := server.Lookup("live.vm"); !ok {
t.Fatal("live.vm missing after rebuildDNS")
}
if _, ok := server.Lookup("stale.vm"); ok {
t.Fatal("stale.vm should not be published")
}
if _, ok := server.Lookup("stopped.vm"); ok {
t.Fatal("stopped.vm should not be published")
}
}
func TestSetVMRejectsStoppedOnlyChangesForRunningVM(t *testing.T) {
t.Parallel()
@ -316,7 +376,7 @@ func TestValidateStartPrereqsReportsNATUplinkFailure(t *testing.T) {
for _, name := range []string{
"sudo", "ip", "dmsetup", "losetup", "blockdev", "truncate", "pgrep", "ps",
"chown", "chmod", "kill", "e2cp", "e2rm", "debugfs", "mkfs.ext4", "mount",
"umount", "cp", "iptables", "sysctl", "mapdns",
"umount", "cp", "iptables", "sysctl",
} {
writeFakeExecutable(t, filepath.Join(binDir, name))
}
@ -339,7 +399,6 @@ func TestValidateStartPrereqsReportsNATUplinkFailure(t *testing.T) {
runner: runner,
config: model.DaemonConfig{
FirecrackerBin: firecrackerBin,
MapDNSBin: "mapdns",
},
}
vm := testVM("nat", "image-nat", "172.16.0.12")

View file

@ -38,8 +38,6 @@ type DaemonConfig struct {
RuntimeDir string
LogLevel string
FirecrackerBin string
MapDNSBin string
MapDNSDataFile string
SSHKeyPath string
NamegenPath string
CustomizeScript string

View file

@ -22,7 +22,6 @@ func TestBootstrapExtractsBundleAndValidatesChecksum(t *testing.T) {
"runtime/namegen": "namegen",
"runtime/customize.sh": "#!/bin/bash\n",
"runtime/packages.sh": "#!/bin/bash\n",
"runtime/dns.sh": "#!/bin/bash\n",
"runtime/packages.apt": "vim\n",
"runtime/rootfs-docker.ext4": "rootfs",
"runtime/wtf/root/boot/vmlinux-6.8.0-94-generic": "kernel",

257
internal/vmdns/server.go Normal file
View file

@ -0,0 +1,257 @@
package vmdns
import (
"errors"
"fmt"
"log/slog"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/miekg/dns"
)
const (
DefaultListenAddr = "127.0.0.1:42069"
recordTTLSeconds = 5
vmZoneSuffix = ".vm."
)
type Server struct {
logger *slog.Logger
mu sync.RWMutex
records map[string]netip.Addr
addr string
server *dns.Server
conn net.PacketConn
done chan error
}
func New(addr string, logger *slog.Logger) (*Server, error) {
packetConn, err := net.ListenPacket("udp", addr)
if err != nil {
return nil, err
}
s := &Server{
logger: logger,
records: make(map[string]netip.Addr),
addr: packetConn.LocalAddr().String(),
conn: packetConn,
done: make(chan error, 1),
}
s.server = &dns.Server{
PacketConn: packetConn,
Handler: dns.HandlerFunc(s.handleDNS),
}
go func() {
s.done <- s.server.ActivateAndServe()
close(s.done)
}()
return s, nil
}
func (s *Server) Addr() string {
if s == nil {
return ""
}
return s.addr
}
func (s *Server) Close() error {
if s == nil || s.server == nil {
return nil
}
connErr := error(nil)
if s.conn != nil {
connErr = s.conn.Close()
s.conn = nil
}
shutdownErr := s.server.Shutdown()
if isIgnorableCloseErr(shutdownErr) {
shutdownErr = nil
}
var serveErr error
select {
case serveErr = <-s.done:
case <-time.After(2 * time.Second):
serveErr = errors.New("timed out waiting for vm dns server shutdown")
}
if isClosedServeErr(serveErr) {
serveErr = nil
}
s.server = nil
s.done = nil
return errors.Join(connErr, shutdownErr, serveErr)
}
func (s *Server) Set(name, guestIP string) error {
if s == nil {
return nil
}
addr, err := netip.ParseAddr(strings.TrimSpace(guestIP))
if err != nil {
return fmt.Errorf("parse guest IP %q: %w", guestIP, err)
}
if !addr.Is4() {
return fmt.Errorf("guest IP must be IPv4: %q", guestIP)
}
fqdn, err := normalizeVMName(name)
if err != nil {
return err
}
s.mu.Lock()
s.records[fqdn] = addr
s.mu.Unlock()
if s.logger != nil {
s.logger.Debug("vm dns record set", "dns_name", displayName(fqdn), "guest_ip", addr.String())
}
return nil
}
func (s *Server) Remove(name string) error {
if s == nil {
return nil
}
fqdn, err := normalizeVMName(name)
if err != nil {
return nil
}
s.mu.Lock()
delete(s.records, fqdn)
s.mu.Unlock()
if s.logger != nil {
s.logger.Debug("vm dns record removed", "dns_name", displayName(fqdn))
}
return nil
}
func (s *Server) Replace(records map[string]string) error {
if s == nil {
return nil
}
next := make(map[string]netip.Addr, len(records))
for name, guestIP := range records {
fqdn, err := normalizeVMName(name)
if err != nil {
return err
}
addr, err := netip.ParseAddr(strings.TrimSpace(guestIP))
if err != nil {
return fmt.Errorf("parse guest IP for %s: %w", name, err)
}
if !addr.Is4() {
return fmt.Errorf("guest IP for %s must be IPv4: %q", name, guestIP)
}
next[fqdn] = addr
}
s.mu.Lock()
s.records = next
s.mu.Unlock()
return nil
}
func (s *Server) Lookup(name string) (netip.Addr, bool) {
if s == nil {
return netip.Addr{}, false
}
fqdn, err := normalizeVMName(name)
if err != nil {
return netip.Addr{}, false
}
s.mu.RLock()
defer s.mu.RUnlock()
addr, ok := s.records[fqdn]
return addr, ok
}
func RecordName(vmName string) string {
name := strings.TrimSpace(strings.ToLower(vmName))
name = strings.TrimSuffix(name, ".")
if strings.HasSuffix(name, ".vm") {
return name
}
return name + ".vm"
}
func normalizeVMName(name string) (string, error) {
name = strings.TrimSpace(name)
if name == "" {
return "", errors.New("dns name is required")
}
fqdn := strings.ToLower(dns.Fqdn(name))
if !strings.HasSuffix(fqdn, vmZoneSuffix) {
return "", fmt.Errorf("dns name must end with .vm: %q", name)
}
return fqdn, nil
}
func displayName(fqdn string) string {
return strings.TrimSuffix(fqdn, ".")
}
func isVMQueryName(name string) bool {
return strings.HasSuffix(strings.ToLower(dns.Fqdn(name)), vmZoneSuffix)
}
func (s *Server) handleDNS(w dns.ResponseWriter, req *dns.Msg) {
resp := new(dns.Msg)
resp.SetReply(req)
resp.Authoritative = true
if len(req.Question) == 0 {
resp.Rcode = dns.RcodeFormatError
_ = w.WriteMsg(resp)
return
}
question := req.Question[0]
if !isVMQueryName(question.Name) {
resp.Rcode = dns.RcodeRefused
_ = w.WriteMsg(resp)
return
}
addr, ok := s.Lookup(question.Name)
if !ok {
resp.Rcode = dns.RcodeNameError
_ = w.WriteMsg(resp)
return
}
if question.Qtype == dns.TypeA {
resp.Answer = []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: strings.ToLower(dns.Fqdn(question.Name)),
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: recordTTLSeconds,
},
A: net.IP(addr.AsSlice()),
},
}
}
_ = w.WriteMsg(resp)
}
func isClosedServeErr(err error) bool {
if err == nil {
return true
}
return errors.Is(err, net.ErrClosed) || strings.Contains(strings.ToLower(err.Error()), "closed")
}
func isIgnorableCloseErr(err error) bool {
if err == nil {
return true
}
return strings.Contains(strings.ToLower(err.Error()), "server not started")
}

View file

@ -0,0 +1,126 @@
package vmdns
import (
"net"
"testing"
"github.com/miekg/dns"
)
func TestRecordName(t *testing.T) {
if got := RecordName("DevBox"); got != "devbox.vm" {
t.Fatalf("RecordName = %q, want devbox.vm", got)
}
if got := RecordName("already.vm"); got != "already.vm" {
t.Fatalf("RecordName = %q, want already.vm", got)
}
}
func TestServerAnswersVMQueries(t *testing.T) {
server := startTestServer(t)
if err := server.Set("devbox.vm", "172.16.0.8"); err != nil {
t.Fatalf("Set: %v", err)
}
t.Run("A record", func(t *testing.T) {
resp := exchangeQuery(t, server.Addr(), "devbox.vm.", dns.TypeA)
if resp.Rcode != dns.RcodeSuccess {
t.Fatalf("rcode = %d, want success", resp.Rcode)
}
if len(resp.Answer) != 1 {
t.Fatalf("answer count = %d, want 1", len(resp.Answer))
}
a, ok := resp.Answer[0].(*dns.A)
if !ok {
t.Fatalf("answer type = %T, want *dns.A", resp.Answer[0])
}
if got := a.A.String(); got != "172.16.0.8" {
t.Fatalf("A = %q, want 172.16.0.8", got)
}
})
t.Run("known AAAA returns NODATA", func(t *testing.T) {
resp := exchangeQuery(t, server.Addr(), "devbox.vm.", dns.TypeAAAA)
if resp.Rcode != dns.RcodeSuccess {
t.Fatalf("rcode = %d, want success", resp.Rcode)
}
if len(resp.Answer) != 0 {
t.Fatalf("answer count = %d, want 0", len(resp.Answer))
}
})
t.Run("unknown name returns NXDOMAIN", func(t *testing.T) {
resp := exchangeQuery(t, server.Addr(), "missing.vm.", dns.TypeA)
if resp.Rcode != dns.RcodeNameError {
t.Fatalf("rcode = %d, want NXDOMAIN", resp.Rcode)
}
})
t.Run("outside zone returns REFUSED", func(t *testing.T) {
resp := exchangeQuery(t, server.Addr(), "example.com.", dns.TypeA)
if resp.Rcode != dns.RcodeRefused {
t.Fatalf("rcode = %d, want REFUSED", resp.Rcode)
}
})
}
func TestServerReplaceSwapsRecordSet(t *testing.T) {
server := startTestServer(t)
if err := server.Replace(map[string]string{
"alpha.vm": "172.16.0.2",
"beta.vm": "172.16.0.3",
}); err != nil {
t.Fatalf("Replace: %v", err)
}
if _, ok := server.Lookup("alpha.vm"); !ok {
t.Fatal("alpha.vm missing after replace")
}
if err := server.Replace(map[string]string{"beta.vm": "172.16.0.4"}); err != nil {
t.Fatalf("Replace second set: %v", err)
}
if _, ok := server.Lookup("alpha.vm"); ok {
t.Fatal("alpha.vm should have been removed by replace")
}
addr, ok := server.Lookup("beta.vm")
if !ok || addr.String() != "172.16.0.4" {
t.Fatalf("beta.vm = %v, %v, want 172.16.0.4", addr, ok)
}
}
func TestServerFailsWhenAddressAlreadyInUse(t *testing.T) {
packetConn, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ListenPacket: %v", err)
}
defer packetConn.Close()
if _, err := New(packetConn.LocalAddr().String(), nil); err == nil {
t.Fatal("New() succeeded on occupied UDP address, want failure")
}
}
func startTestServer(t *testing.T) *Server {
t.Helper()
server, err := New("127.0.0.1:0", nil)
if err != nil {
t.Fatalf("New: %v", err)
}
t.Cleanup(func() {
if err := server.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
})
return server
}
func exchangeQuery(t *testing.T, addr, name string, qtype uint16) *dns.Msg {
t.Helper()
client := &dns.Client{Net: "udp"}
req := new(dns.Msg)
req.SetQuestion(name, qtype)
resp, _, err := client.Exchange(req, addr)
if err != nil {
t.Fatalf("Exchange(%s, %d): %v", name, qtype, err)
}
return resp
}