package guest import ( "bufio" "encoding/base64" "errors" "fmt" "net" "os" "strings" "sync" "golang.org/x/crypto/ssh" ) // TOFUHostKeyCallback returns a HostKeyCallback that implements // trust-on-first-use against a banger-owned known_hosts file. // // Semantics: // - If the file has an entry for `host:port` → require an exact // key match; a mismatch returns an error (MITM protection). // - If no entry exists → append one and accept. // // The file format is compatible with OpenSSH so shell SSH clients can // use the same path via `UserKnownHostsFile`. // // Callers keep a process-wide mutex on the file so concurrent dials // to different VMs don't interleave writes. // // An empty path disables host-key checking entirely — only for test // harnesses and tools that dial ad-hoc infrastructure; production // paths must supply a real file. func TOFUHostKeyCallback(path string) ssh.HostKeyCallback { if strings.TrimSpace(path) == "" { return ssh.InsecureIgnoreHostKey() } return func(hostname string, remote net.Addr, key ssh.PublicKey) error { host := hostLookupKey(hostname, remote) knownHostsMu.Lock() defer knownHostsMu.Unlock() entries, err := loadKnownHosts(path) if err != nil { return fmt.Errorf("read known_hosts: %w", err) } stored, matched := entries.match(host, key.Type()) if matched { if keysEqual(stored.key, key) { return nil } return fmt.Errorf("banger: host key for %s does not match pinned entry — "+ "possible MITM. If the VM was legitimately rebuilt, remove the old "+ "entry from %s and retry.", host, path) } if err := appendKnownHost(path, host, key); err != nil { return fmt.Errorf("pin host key for %s: %w", host, err) } return nil } } // RemoveKnownHosts strips every entry matching any host in `hosts` // from the known_hosts file. Called on VM delete so a future VM // reusing the same IP or name never trips the TOFU mismatch branch. // Missing file / missing hosts = no-op. func RemoveKnownHosts(path string, hosts ...string) error { if strings.TrimSpace(path) == "" || len(hosts) == 0 { return nil } knownHostsMu.Lock() defer knownHostsMu.Unlock() entries, err := loadKnownHosts(path) if err != nil { return err } drop := make(map[string]struct{}, len(hosts)) for _, h := range hosts { h = strings.TrimSpace(h) if h == "" { continue } drop[h] = struct{}{} } if len(drop) == 0 { return nil } filtered := entries.filter(func(e knownHostEntry) bool { for _, h := range e.hosts { if _, skip := drop[h]; skip { return false } } return true }) return filtered.write(path) } var knownHostsMu sync.Mutex // knownHostEntry is one line in known_hosts: a set of host patterns // (comma-separated in the file), a key type, and a key blob. type knownHostEntry struct { hosts []string keyType string key ssh.PublicKey raw string } type knownHostList []knownHostEntry func (l knownHostList) match(host, keyType string) (knownHostEntry, bool) { for _, e := range l { if e.keyType != keyType { continue } for _, h := range e.hosts { if h == host { return e, true } } } return knownHostEntry{}, false } func (l knownHostList) filter(keep func(knownHostEntry) bool) knownHostList { out := make(knownHostList, 0, len(l)) for _, e := range l { if keep(e) { out = append(out, e) } } return out } func (l knownHostList) write(path string) error { if len(l) == 0 { // If everything got filtered, truncate the file rather than // removing it — callers may want the file to keep existing // (with 0600 perms) for later appends. return os.WriteFile(path, nil, 0o600) } var buf strings.Builder for _, e := range l { buf.WriteString(e.raw) if !strings.HasSuffix(e.raw, "\n") { buf.WriteByte('\n') } } return os.WriteFile(path, []byte(buf.String()), 0o600) } func loadKnownHosts(path string) (knownHostList, error) { f, err := os.Open(path) if err != nil { if os.IsNotExist(err) { return nil, nil } return nil, err } defer f.Close() var out knownHostList scanner := bufio.NewScanner(f) scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) for scanner.Scan() { line := scanner.Text() trimmed := strings.TrimSpace(line) if trimmed == "" || strings.HasPrefix(trimmed, "#") { continue } fields := strings.Fields(trimmed) if len(fields) < 3 { continue } keyBytes, err := base64.StdEncoding.DecodeString(fields[2]) if err != nil { continue } key, err := ssh.ParsePublicKey(keyBytes) if err != nil { continue } out = append(out, knownHostEntry{ hosts: strings.Split(fields[0], ","), keyType: fields[1], key: key, raw: line, }) } if err := scanner.Err(); err != nil { return nil, err } return out, nil } func appendKnownHost(path, host string, key ssh.PublicKey) error { line := fmt.Sprintf("%s %s %s\n", host, key.Type(), base64.StdEncoding.EncodeToString(key.Marshal()), ) f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600) if err != nil { return err } defer f.Close() _, err = f.WriteString(line) return err } // hostLookupKey returns the canonical key under which we store host // entries. For a TCP dial the SSH library hands us hostname of the // form "host:port"; we normalise to "host" so pinning by IP also // works for a hostname-based lookup that resolves to the same IP. // // If hostname contains a port, strip it. If it's empty, fall back to // the remote address. func hostLookupKey(hostname string, remote net.Addr) string { if h, _, err := net.SplitHostPort(hostname); err == nil { hostname = h } if strings.TrimSpace(hostname) != "" { return hostname } if remote != nil { if h, _, err := net.SplitHostPort(remote.String()); err == nil { return h } return remote.String() } return "" } func keysEqual(a, b ssh.PublicKey) bool { if a == nil || b == nil { return a == nil && b == nil } ba := a.Marshal() bb := b.Marshal() if len(ba) != len(bb) { return false } for i := range ba { if ba[i] != bb[i] { return false } } return true } // errHostKeyMismatch sentinel is currently unused but reserved for // callers that want to distinguish MITM from other failures. var errHostKeyMismatch = errors.New("host key mismatch") var _ = errHostKeyMismatch