package vmdns import ( "errors" "fmt" "log/slog" "net" "net/netip" "strings" "sync" "time" "github.com/miekg/dns" ) const ( DefaultListenAddr = "127.0.0.1:42069" recordTTLSeconds = 5 vmZoneSuffix = ".vm." ) type Server struct { logger *slog.Logger mu sync.RWMutex records map[string]netip.Addr addr string server *dns.Server conn net.PacketConn done chan error } func New(addr string, logger *slog.Logger) (*Server, error) { packetConn, err := net.ListenPacket("udp", addr) if err != nil { return nil, err } s := &Server{ logger: logger, records: make(map[string]netip.Addr), addr: packetConn.LocalAddr().String(), conn: packetConn, done: make(chan error, 1), } s.server = &dns.Server{ PacketConn: packetConn, Handler: dns.HandlerFunc(s.handleDNS), } go func() { s.done <- s.server.ActivateAndServe() close(s.done) }() return s, nil } func (s *Server) Addr() string { if s == nil { return "" } return s.addr } func (s *Server) Close() error { if s == nil || s.server == nil { return nil } connErr := error(nil) if s.conn != nil { connErr = s.conn.Close() s.conn = nil } shutdownErr := s.server.Shutdown() if isIgnorableCloseErr(shutdownErr) { shutdownErr = nil } var serveErr error select { case serveErr = <-s.done: case <-time.After(2 * time.Second): serveErr = errors.New("timed out waiting for vm dns server shutdown") } if isClosedServeErr(serveErr) { serveErr = nil } s.server = nil s.done = nil return errors.Join(connErr, shutdownErr, serveErr) } func (s *Server) Set(name, guestIP string) error { if s == nil { return nil } addr, err := netip.ParseAddr(strings.TrimSpace(guestIP)) if err != nil { return fmt.Errorf("parse guest IP %q: %w", guestIP, err) } if !addr.Is4() { return fmt.Errorf("guest IP must be IPv4: %q", guestIP) } fqdn, err := normalizeVMName(name) if err != nil { return err } s.mu.Lock() s.records[fqdn] = addr s.mu.Unlock() if s.logger != nil { s.logger.Debug("vm dns record set", "dns_name", displayName(fqdn), "guest_ip", addr.String()) } return nil } func (s *Server) Remove(name string) error { if s == nil { return nil } fqdn, err := normalizeVMName(name) if err != nil { return nil } s.mu.Lock() delete(s.records, fqdn) s.mu.Unlock() if s.logger != nil { s.logger.Debug("vm dns record removed", "dns_name", displayName(fqdn)) } return nil } func (s *Server) Replace(records map[string]string) error { if s == nil { return nil } next := make(map[string]netip.Addr, len(records)) for name, guestIP := range records { fqdn, err := normalizeVMName(name) if err != nil { return err } addr, err := netip.ParseAddr(strings.TrimSpace(guestIP)) if err != nil { return fmt.Errorf("parse guest IP for %s: %w", name, err) } if !addr.Is4() { return fmt.Errorf("guest IP for %s must be IPv4: %q", name, guestIP) } next[fqdn] = addr } s.mu.Lock() s.records = next s.mu.Unlock() return nil } func (s *Server) Lookup(name string) (netip.Addr, bool) { if s == nil { return netip.Addr{}, false } fqdn, err := normalizeVMName(name) if err != nil { return netip.Addr{}, false } s.mu.RLock() defer s.mu.RUnlock() addr, ok := s.records[fqdn] return addr, ok } func RecordName(vmName string) string { name := strings.TrimSpace(strings.ToLower(vmName)) name = strings.TrimSuffix(name, ".") if strings.HasSuffix(name, ".vm") { return name } return name + ".vm" } func normalizeVMName(name string) (string, error) { name = strings.TrimSpace(name) if name == "" { return "", errors.New("dns name is required") } fqdn := strings.ToLower(dns.Fqdn(name)) if !strings.HasSuffix(fqdn, vmZoneSuffix) { return "", fmt.Errorf("dns name must end with .vm: %q", name) } return fqdn, nil } func displayName(fqdn string) string { return strings.TrimSuffix(fqdn, ".") } func isVMQueryName(name string) bool { return strings.HasSuffix(strings.ToLower(dns.Fqdn(name)), vmZoneSuffix) } func (s *Server) handleDNS(w dns.ResponseWriter, req *dns.Msg) { resp := new(dns.Msg) resp.SetReply(req) resp.Authoritative = true if len(req.Question) == 0 { resp.Rcode = dns.RcodeFormatError _ = w.WriteMsg(resp) return } question := req.Question[0] if !isVMQueryName(question.Name) { resp.Rcode = dns.RcodeRefused _ = w.WriteMsg(resp) return } addr, ok := s.Lookup(question.Name) if !ok { resp.Rcode = dns.RcodeNameError _ = w.WriteMsg(resp) return } if question.Qtype == dns.TypeA { resp.Answer = []dns.RR{ &dns.A{ Hdr: dns.RR_Header{ Name: strings.ToLower(dns.Fqdn(question.Name)), Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: recordTTLSeconds, }, A: net.IP(addr.AsSlice()), }, } } _ = w.WriteMsg(resp) } func isClosedServeErr(err error) bool { if err == nil { return true } return errors.Is(err, net.ErrClosed) || strings.Contains(strings.ToLower(err.Error()), "closed") } func isIgnorableCloseErr(err error) bool { if err == nil { return true } return strings.Contains(strings.ToLower(err.Error()), "server not started") }