From ae14b9499d4f63dc9e22eb750ccffe7fefa67891 Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Sun, 19 Apr 2026 16:46:03 -0300 Subject: [PATCH] ssh: trust-on-first-use host key pinning everywhere MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Guest host-key verification was off in all three SSH paths: * Go SSH (internal/guest/ssh.go) used ssh.InsecureIgnoreHostKey * `banger vm ssh` passed StrictHostKeyChecking=no + UserKnownHostsFile=/dev/null * `~/.ssh/config` Host *.vm shipped the same posture into the user's global config Now each path verifies against a banger-owned known_hosts file at `~/.local/state/banger/ssh/known_hosts` with TOFU semantics: * First dial to a VM pins the key. * Subsequent dials require an exact match. A mismatch fails with an explicit "possible MITM" error. * `vm delete` removes the entries so a future VM reusing the IP or name re-pins cleanly. * The user's `~/.ssh/known_hosts` is untouched. Changes: internal/guest/known_hosts.go (new) — OpenSSH-compatible parser, TOFUHostKeyCallback, RemoveKnownHosts. Process-wide mutex around the file. internal/guest/ssh.go — Dial and WaitForSSH grew a knownHostsPath parameter threaded through the callback. Empty path keeps the insecure callback (tests + throwaway tools only; documented). internal/daemon/{guest_sessions,session_attach,session_lifecycle, session_stream}.go — call sites pass d.layout.KnownHostsPath. internal/daemon/ssh_client_config.go — the ~/.ssh/config Host *.vm block now points at banger's known_hosts and uses StrictHostKeyChecking=accept-new. Missing path → fail closed. internal/daemon/vm_lifecycle.go — deleteVMLocked drops known_hosts entries for the VM's IP and DNS name via removeVMKnownHosts. internal/cli/banger.go — sshCommandArgs swaps StrictHostKeyChecking no + /dev/null for banger's file + accept-new. Path resolution failure falls through to StrictHostKeyChecking=yes. internal/paths/paths.go — Layout gains SSHDir + KnownHostsPath; Ensure creates SSHDir at 0700. Tests (internal/guest/known_hosts_test.go): pin on first use, accept matching key on second dial, reject mismatch, empty path skips checking, RemoveKnownHosts drops the entry, re-pin works after remove. Existing daemon + cli tests updated to assert the new posture and regression-guard against the old flags. Live verified: vm run writes the pin to banger's known_hosts at 0600 inside a 0700 dir; banger vm ssh + ssh root@.vm both succeed using the pin; vm delete clears it. --- internal/cli/banger.go | 40 +++- internal/cli/cli_test.go | 46 +++- internal/daemon/guest_sessions.go | 6 +- internal/daemon/session_attach.go | 2 +- internal/daemon/session_lifecycle.go | 2 +- internal/daemon/session_stream.go | 2 +- internal/daemon/ssh_client_config.go | 59 ++++- internal/daemon/ssh_client_config_test.go | 17 +- internal/daemon/vm_lifecycle.go | 4 + internal/guest/known_hosts.go | 256 ++++++++++++++++++++++ internal/guest/known_hosts_test.go | 185 ++++++++++++++++ internal/guest/ssh.go | 15 +- internal/guest/ssh_more_test.go | 4 +- internal/paths/paths.go | 43 ++-- 14 files changed, 634 insertions(+), 47 deletions(-) create mode 100644 internal/guest/known_hosts.go create mode 100644 internal/guest/known_hosts_test.go diff --git a/internal/cli/banger.go b/internal/cli/banger.go index 1119d14..8cbeab1 100644 --- a/internal/cli/banger.go +++ b/internal/cli/banger.go @@ -131,10 +131,12 @@ var ( return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params) } guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { - return guest.WaitForSSH(ctx, address, privateKeyPath, interval) + knownHosts, _ := bangerKnownHostsPath() + return guest.WaitForSSH(ctx, address, privateKeyPath, knownHosts, interval) } guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { - return guest.Dial(ctx, address, privateKeyPath) + knownHosts, _ := bangerKnownHostsPath() + return guest.Dial(ctx, address, privateKeyPath, knownHosts) } prepareVMRunRepoCopyFunc = prepareVMRunRepoCopy buildVMRunToolingPlanFunc = toolingplan.Build @@ -2669,6 +2671,12 @@ func sshCommandArgs(cfg model.DaemonConfig, guestIP string, extra []string) ([]s if cfg.SSHKeyPath != "" { args = append(args, "-i", cfg.SSHKeyPath) } + // Host-key verification uses a banger-owned known_hosts file + // populated by the daemon's first successful Go-SSH dial to each + // VM (trust-on-first-use). `accept-new` means: accept-and-pin on + // first contact; strict-verify afterwards. The user's own + // ~/.known_hosts is untouched. + knownHosts, khErr := bangerKnownHostsPath() args = append( args, "-o", "IdentitiesOnly=yes", @@ -2676,14 +2684,36 @@ func sshCommandArgs(cfg model.DaemonConfig, guestIP string, extra []string) ([]s "-o", "PreferredAuthentications=publickey", "-o", "PasswordAuthentication=no", "-o", "KbdInteractiveAuthentication=no", - "-o", "StrictHostKeyChecking=no", - "-o", "UserKnownHostsFile=/dev/null", - "root@"+guestIP, ) + if khErr == nil { + args = append(args, + "-o", "UserKnownHostsFile="+knownHosts, + "-o", "StrictHostKeyChecking=accept-new", + ) + } else { + // If we can't resolve the banger path (unusual — paths.Resolve + // basically can't fail), fall through to a hard-fail posture + // rather than silently disabling verification. + args = append(args, + "-o", "StrictHostKeyChecking=yes", + ) + } + args = append(args, "root@"+guestIP) args = append(args, extra...) return args, nil } +// bangerKnownHostsPath resolves the TOFU file the daemon writes into +// and the CLI reads back. Both sides must agree on the path or the +// pin doesn't round-trip. +func bangerKnownHostsPath() (string, error) { + layout, err := paths.Resolve() + if err != nil { + return "", err + } + return layout.KnownHostsPath, nil +} + func validateSSHPrereqs(cfg model.DaemonConfig) error { checks := system.NewPreflight() checks.RequireCommand("ssh", "install openssh-client") diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 0c74a6e..3aeaf55 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -1049,25 +1049,57 @@ func TestExecuteVMActionBatchRunsConcurrentlyAndPreservesOrder(t *testing.T) { } func TestSSHCommandArgs(t *testing.T) { + // sshCommandArgs wires banger's own known_hosts into the shell + // SSH invocation — never /dev/null. Assert the shape and the + // posture rather than the exact path (which is host-XDG-derived). args, err := sshCommandArgs(model.DaemonConfig{SSHKeyPath: "/bundle/id_ed25519"}, "172.16.0.2", []string{"--", "uname", "-a"}) if err != nil { t.Fatalf("sshCommandArgs: %v", err) } - want := []string{ + + wantSubstrings := []string{ "-F", "/dev/null", "-i", "/bundle/id_ed25519", "-o", "IdentitiesOnly=yes", - "-o", "BatchMode=yes", - "-o", "PreferredAuthentications=publickey", "-o", "PasswordAuthentication=no", "-o", "KbdInteractiveAuthentication=no", - "-o", "StrictHostKeyChecking=no", - "-o", "UserKnownHostsFile=/dev/null", "root@172.16.0.2", "--", "uname", "-a", } - if !reflect.DeepEqual(args, want) { - t.Fatalf("args = %v, want %v", args, want) + for _, s := range wantSubstrings { + found := false + for _, a := range args { + if a == s { + found = true + break + } + } + if !found { + t.Errorf("args missing %q: %v", s, args) + } + } + + // Host-key verification posture: accept-new + a real path into + // banger state, not /dev/null. + joined := strings.Join(args, " ") + if !strings.Contains(joined, "StrictHostKeyChecking=accept-new") { + t.Errorf("args missing accept-new posture: %v", args) + } + if strings.Contains(joined, "UserKnownHostsFile=/dev/null") { + t.Errorf("args leaked UserKnownHostsFile=/dev/null: %v", args) + } + if strings.Contains(joined, "StrictHostKeyChecking=no") { + t.Errorf("args leaked StrictHostKeyChecking=no: %v", args) + } + // Must reference a known_hosts file ending in "known_hosts". + sawKnownHosts := false + for _, a := range args { + if strings.HasPrefix(a, "UserKnownHostsFile=") && strings.HasSuffix(a, "known_hosts") { + sawKnownHosts = true + } + } + if !sawKnownHosts { + t.Errorf("args missing UserKnownHostsFile=: %v", args) } } diff --git a/internal/daemon/guest_sessions.go b/internal/daemon/guest_sessions.go index 6a4cddb..bc59742 100644 --- a/internal/daemon/guest_sessions.go +++ b/internal/daemon/guest_sessions.go @@ -31,14 +31,14 @@ func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval t if d != nil && d.guestWaitForSSH != nil { return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval) } - return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, interval) + return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval) } func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) { if d != nil && d.guestDial != nil { return d.guestDial(ctx, address, d.config.SSHKeyPath) } - return guest.Dial(ctx, address, d.config.SSHKeyPath) + return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath) } func (d *Daemon) waitForGuestSessionReadyHook(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) { @@ -86,7 +86,7 @@ func (d *Daemon) refreshGuestSession(ctx context.Context, vm model.VMRecord, s m func (d *Daemon) inspectGuestSessionState(ctx context.Context, vm model.VMRecord, s model.GuestSession) (session.StateSnapshot, error) { if d.vmAlive(vm) { - client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath) + client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath) if err != nil { return session.StateSnapshot{}, err } diff --git a/internal/daemon/session_attach.go b/internal/daemon/session_attach.go index 9fef26b..6c83da4 100644 --- a/internal/daemon/session_attach.go +++ b/internal/daemon/session_attach.go @@ -189,7 +189,7 @@ func (d *Daemon) attachGuestSessionBridge(session model.GuestSession, controller } func (d *Daemon) openGuestSessionAttachStream(address, command string) (*guest.StreamSession, error) { - client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath) + client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath, d.layout.KnownHostsPath) if err != nil { return nil, err } diff --git a/internal/daemon/session_lifecycle.go b/internal/daemon/session_lifecycle.go index 18e4b02..beeaa07 100644 --- a/internal/daemon/session_lifecycle.go +++ b/internal/daemon/session_lifecycle.go @@ -195,7 +195,7 @@ func (d *Daemon) signalGuestSession(ctx context.Context, params api.GuestSession } return session, nil } - client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath) + client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath) if err != nil { return model.GuestSession{}, err } diff --git a/internal/daemon/session_stream.go b/internal/daemon/session_stream.go index fea9c54..46970d0 100644 --- a/internal/daemon/session_stream.go +++ b/internal/daemon/session_stream.go @@ -90,7 +90,7 @@ func (d *Daemon) SendToGuestSession(ctx context.Context, params api.GuestSession func (d *Daemon) readGuestSessionLog(ctx context.Context, vm model.VMRecord, session model.GuestSession, stream string, tailLines int) (string, error) { if d.vmAlive(vm) { - client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath) + client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath) if err != nil { return "", err } diff --git a/internal/daemon/ssh_client_config.go b/internal/daemon/ssh_client_config.go index 299b38e..1fffd7d 100644 --- a/internal/daemon/ssh_client_config.go +++ b/internal/daemon/ssh_client_config.go @@ -2,13 +2,41 @@ package daemon import ( "fmt" + "log/slog" "os" "path/filepath" "strings" + "banger/internal/guest" + "banger/internal/model" "banger/internal/paths" ) +// removeVMKnownHosts drops every host-key pin for vm from the +// banger-owned known_hosts. Best-effort — a failure here only +// matters if the same IP/name is reused by a fresh VM before the +// next daemon restart, and even then it just causes a +// TOFU-mismatch error that the user can clear manually. Logged at +// warn so it shows up if it ever actually breaks things. +func removeVMKnownHosts(knownHostsPath string, vm model.VMRecord, logger *slog.Logger) { + if strings.TrimSpace(knownHostsPath) == "" { + return + } + var hosts []string + if ip := strings.TrimSpace(vm.Runtime.GuestIP); ip != "" { + hosts = append(hosts, ip) + } + if dns := strings.TrimSpace(vm.Runtime.DNSName); dns != "" { + hosts = append(hosts, dns) + } + if len(hosts) == 0 { + return + } + if err := guest.RemoveKnownHosts(knownHostsPath, hosts...); err != nil && logger != nil { + logger.Warn("remove known_hosts entries", "vm_id", vm.ID, "error", err.Error()) + } +} + const ( vmSSHConfigIncludeBegin = "# BEGIN BANGER MANAGED VM SSH" vmSSHConfigIncludeEnd = "# END BANGER MANAGED VM SSH" @@ -39,7 +67,7 @@ func syncVMSSHClientConfig(layout paths.Layout, keyPath string) error { if err != nil { return err } - updated, err := upsertManagedBlock(userConfig, vmSSHConfigIncludeBegin, vmSSHConfigIncludeEnd, renderManagedVMSSHBlock(keyPath)) + updated, err := upsertManagedBlock(userConfig, vmSSHConfigIncludeBegin, vmSSHConfigIncludeEnd, renderManagedVMSSHBlock(keyPath, layout.KnownHostsPath)) if err != nil { return err } @@ -54,11 +82,19 @@ func syncVMSSHClientConfig(layout paths.Layout, keyPath string) error { return nil } -func renderManagedVMSSHBlock(keyPath string) string { +// renderManagedVMSSHBlock produces the `Host *.vm` stanza banger +// writes into the user's ~/.ssh/config. Host-key verification uses +// the banger-owned known_hosts file at knownHostsPath — NOT the +// user's ~/.ssh/known_hosts, and NOT /dev/null. `accept-new` means +// first contact pins the key; any later mismatch fails the connect. +func renderManagedVMSSHBlock(keyPath, knownHostsPath string) string { keyPath = strings.TrimSpace(keyPath) - return strings.Join([]string{ + knownHostsPath = strings.TrimSpace(knownHostsPath) + lines := []string{ vmSSHConfigIncludeBegin, "# Generated by banger for direct SSH access to VM DNS names.", + "# Host keys are pinned on first use into a banger-owned", + "# known_hosts file (not ~/.ssh/known_hosts).", "Host *.vm", " User root", " IdentityFile " + keyPath, @@ -67,12 +103,23 @@ func renderManagedVMSSHBlock(keyPath string) string { " PreferredAuthentications publickey", " PasswordAuthentication no", " KbdInteractiveAuthentication no", - " StrictHostKeyChecking no", - " UserKnownHostsFile /dev/null", + } + if knownHostsPath != "" { + lines = append(lines, + " UserKnownHostsFile "+knownHostsPath, + " StrictHostKeyChecking accept-new", + ) + } else { + // Missing known_hosts path is a configuration anomaly — fail + // closed rather than silently disable verification. + lines = append(lines, " StrictHostKeyChecking yes") + } + lines = append(lines, " LogLevel ERROR", vmSSHConfigIncludeEnd, "", - }, "\n") + ) + return strings.Join(lines, "\n") } func upsertManagedBlock(existing, beginMarker, endMarker, block string) (string, error) { diff --git a/internal/daemon/ssh_client_config_test.go b/internal/daemon/ssh_client_config_test.go index 80d8e95..6838eb2 100644 --- a/internal/daemon/ssh_client_config_test.go +++ b/internal/daemon/ssh_client_config_test.go @@ -13,8 +13,10 @@ func TestSyncVMSSHClientConfigCreatesManagedBlock(t *testing.T) { homeDir := t.TempDir() t.Setenv("HOME", homeDir) + knownHostsPath := filepath.Join(homeDir, ".local", "state", "banger", "ssh", "known_hosts") layout := paths.Layout{ - ConfigDir: filepath.Join(homeDir, ".config", "banger"), + ConfigDir: filepath.Join(homeDir, ".config", "banger"), + KnownHostsPath: knownHostsPath, } keyPath := filepath.Join(layout.ConfigDir, "ssh", "id_ed25519") @@ -38,12 +40,23 @@ func TestSyncVMSSHClientConfigCreatesManagedBlock(t *testing.T) { "IdentitiesOnly yes", "BatchMode yes", "PasswordAuthentication no", - "UserKnownHostsFile /dev/null", + "UserKnownHostsFile " + knownHostsPath, + "StrictHostKeyChecking accept-new", } { if !strings.Contains(userContent, want) { t.Fatalf("user config = %q, want %q", userContent, want) } } + // Regression: the legacy posture (StrictHostKeyChecking no + + // UserKnownHostsFile /dev/null) must never reappear. + for _, must := range []string{ + "StrictHostKeyChecking no", + "UserKnownHostsFile /dev/null", + } { + if strings.Contains(userContent, must) { + t.Fatalf("user config leaked legacy posture %q:\n%s", must, userContent) + } + } } func TestSyncVMSSHClientConfigReplacesManagedIncludeBlock(t *testing.T) { diff --git a/internal/daemon/vm_lifecycle.go b/internal/daemon/vm_lifecycle.go index ed1750e..2bb8eb7 100644 --- a/internal/daemon/vm_lifecycle.go +++ b/internal/daemon/vm_lifecycle.go @@ -411,5 +411,9 @@ func (d *Daemon) deleteVMLocked(ctx context.Context, current model.VMRecord) (vm return model.VMRecord{}, err } } + // Drop any host-key pins. A future VM reusing this IP or name + // would otherwise trip the TOFU mismatch branch in + // TOFUHostKeyCallback and fail to connect. + removeVMKnownHosts(d.layout.KnownHostsPath, vm, d.logger) return vm, nil } diff --git a/internal/guest/known_hosts.go b/internal/guest/known_hosts.go new file mode 100644 index 0000000..2dd3f90 --- /dev/null +++ b/internal/guest/known_hosts.go @@ -0,0 +1,256 @@ +package guest + +import ( + "bufio" + "encoding/base64" + "errors" + "fmt" + "net" + "os" + "strings" + "sync" + + "golang.org/x/crypto/ssh" +) + +// TOFUHostKeyCallback returns a HostKeyCallback that implements +// trust-on-first-use against a banger-owned known_hosts file. +// +// Semantics: +// - If the file has an entry for `host:port` → require an exact +// key match; a mismatch returns an error (MITM protection). +// - If no entry exists → append one and accept. +// +// The file format is compatible with OpenSSH so shell SSH clients can +// use the same path via `UserKnownHostsFile`. +// +// Callers keep a process-wide mutex on the file so concurrent dials +// to different VMs don't interleave writes. +// +// An empty path disables host-key checking entirely — only for test +// harnesses and tools that dial ad-hoc infrastructure; production +// paths must supply a real file. +func TOFUHostKeyCallback(path string) ssh.HostKeyCallback { + if strings.TrimSpace(path) == "" { + return ssh.InsecureIgnoreHostKey() + } + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + host := hostLookupKey(hostname, remote) + knownHostsMu.Lock() + defer knownHostsMu.Unlock() + + entries, err := loadKnownHosts(path) + if err != nil { + return fmt.Errorf("read known_hosts: %w", err) + } + stored, matched := entries.match(host, key.Type()) + if matched { + if keysEqual(stored.key, key) { + return nil + } + return fmt.Errorf("banger: host key for %s does not match pinned entry — "+ + "possible MITM. If the VM was legitimately rebuilt, remove the old "+ + "entry from %s and retry.", host, path) + } + if err := appendKnownHost(path, host, key); err != nil { + return fmt.Errorf("pin host key for %s: %w", host, err) + } + return nil + } +} + +// RemoveKnownHosts strips every entry matching any host in `hosts` +// from the known_hosts file. Called on VM delete so a future VM +// reusing the same IP or name never trips the TOFU mismatch branch. +// Missing file / missing hosts = no-op. +func RemoveKnownHosts(path string, hosts ...string) error { + if strings.TrimSpace(path) == "" || len(hosts) == 0 { + return nil + } + knownHostsMu.Lock() + defer knownHostsMu.Unlock() + + entries, err := loadKnownHosts(path) + if err != nil { + return err + } + drop := make(map[string]struct{}, len(hosts)) + for _, h := range hosts { + h = strings.TrimSpace(h) + if h == "" { + continue + } + drop[h] = struct{}{} + } + if len(drop) == 0 { + return nil + } + filtered := entries.filter(func(e knownHostEntry) bool { + for _, h := range e.hosts { + if _, skip := drop[h]; skip { + return false + } + } + return true + }) + return filtered.write(path) +} + +var knownHostsMu sync.Mutex + +// knownHostEntry is one line in known_hosts: a set of host patterns +// (comma-separated in the file), a key type, and a key blob. +type knownHostEntry struct { + hosts []string + keyType string + key ssh.PublicKey + raw string +} + +type knownHostList []knownHostEntry + +func (l knownHostList) match(host, keyType string) (knownHostEntry, bool) { + for _, e := range l { + if e.keyType != keyType { + continue + } + for _, h := range e.hosts { + if h == host { + return e, true + } + } + } + return knownHostEntry{}, false +} + +func (l knownHostList) filter(keep func(knownHostEntry) bool) knownHostList { + out := make(knownHostList, 0, len(l)) + for _, e := range l { + if keep(e) { + out = append(out, e) + } + } + return out +} + +func (l knownHostList) write(path string) error { + if len(l) == 0 { + // If everything got filtered, truncate the file rather than + // removing it — callers may want the file to keep existing + // (with 0600 perms) for later appends. + return os.WriteFile(path, nil, 0o600) + } + var buf strings.Builder + for _, e := range l { + buf.WriteString(e.raw) + if !strings.HasSuffix(e.raw, "\n") { + buf.WriteByte('\n') + } + } + return os.WriteFile(path, []byte(buf.String()), 0o600) +} + +func loadKnownHosts(path string) (knownHostList, error) { + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + defer f.Close() + + var out knownHostList + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for scanner.Scan() { + line := scanner.Text() + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + fields := strings.Fields(trimmed) + if len(fields) < 3 { + continue + } + keyBytes, err := base64.StdEncoding.DecodeString(fields[2]) + if err != nil { + continue + } + key, err := ssh.ParsePublicKey(keyBytes) + if err != nil { + continue + } + out = append(out, knownHostEntry{ + hosts: strings.Split(fields[0], ","), + keyType: fields[1], + key: key, + raw: line, + }) + } + if err := scanner.Err(); err != nil { + return nil, err + } + return out, nil +} + +func appendKnownHost(path, host string, key ssh.PublicKey) error { + line := fmt.Sprintf("%s %s %s\n", + host, + key.Type(), + base64.StdEncoding.EncodeToString(key.Marshal()), + ) + f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600) + if err != nil { + return err + } + defer f.Close() + _, err = f.WriteString(line) + return err +} + +// hostLookupKey returns the canonical key under which we store host +// entries. For a TCP dial the SSH library hands us hostname of the +// form "host:port"; we normalise to "host" so pinning by IP also +// works for a hostname-based lookup that resolves to the same IP. +// +// If hostname contains a port, strip it. If it's empty, fall back to +// the remote address. +func hostLookupKey(hostname string, remote net.Addr) string { + if h, _, err := net.SplitHostPort(hostname); err == nil { + hostname = h + } + if strings.TrimSpace(hostname) != "" { + return hostname + } + if remote != nil { + if h, _, err := net.SplitHostPort(remote.String()); err == nil { + return h + } + return remote.String() + } + return "" +} + +func keysEqual(a, b ssh.PublicKey) bool { + if a == nil || b == nil { + return a == nil && b == nil + } + ba := a.Marshal() + bb := b.Marshal() + if len(ba) != len(bb) { + return false + } + for i := range ba { + if ba[i] != bb[i] { + return false + } + } + return true +} + +// errHostKeyMismatch sentinel is currently unused but reserved for +// callers that want to distinguish MITM from other failures. +var errHostKeyMismatch = errors.New("host key mismatch") + +var _ = errHostKeyMismatch diff --git a/internal/guest/known_hosts_test.go b/internal/guest/known_hosts_test.go new file mode 100644 index 0000000..8c9e3b2 --- /dev/null +++ b/internal/guest/known_hosts_test.go @@ -0,0 +1,185 @@ +package guest + +import ( + "crypto/ed25519" + "crypto/rand" + "net" + "os" + "path/filepath" + "strings" + "testing" + + "golang.org/x/crypto/ssh" +) + +// makeTestHostKey generates a fresh ed25519 key and returns the +// ssh.PublicKey the server would present during a handshake. +func makeTestHostKey(t *testing.T) ssh.PublicKey { + t.Helper() + pub, _, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + sshPub, err := ssh.NewPublicKey(pub) + if err != nil { + t.Fatalf("NewPublicKey: %v", err) + } + return sshPub +} + +func TestTOFUHostKeyCallbackPinsOnFirstUse(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "known_hosts") + cb := TOFUHostKeyCallback(path) + + key := makeTestHostKey(t) + addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.5"), Port: 22} + + if err := cb("172.16.0.5:22", addr, key); err != nil { + t.Fatalf("first-use callback: %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + content := string(data) + if !strings.Contains(content, "172.16.0.5") { + t.Errorf("known_hosts missing host:\n%s", content) + } + if !strings.Contains(content, key.Type()) { + t.Errorf("known_hosts missing key type:\n%s", content) + } +} + +func TestTOFUHostKeyCallbackAcceptsMatch(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "known_hosts") + cb := TOFUHostKeyCallback(path) + key := makeTestHostKey(t) + addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.6"), Port: 22} + + if err := cb("172.16.0.6:22", addr, key); err != nil { + t.Fatalf("first-use: %v", err) + } + // Same key, second dial: must succeed. + if err := cb("172.16.0.6:22", addr, key); err != nil { + t.Fatalf("second dial with matching key: %v", err) + } +} + +func TestTOFUHostKeyCallbackRejectsMismatch(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "known_hosts") + cb := TOFUHostKeyCallback(path) + addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.7"), Port: 22} + + original := makeTestHostKey(t) + if err := cb("172.16.0.7:22", addr, original); err != nil { + t.Fatalf("pin original: %v", err) + } + + impostor := makeTestHostKey(t) + err := cb("172.16.0.7:22", addr, impostor) + if err == nil { + t.Fatal("expected mismatch error, got nil") + } + if !strings.Contains(err.Error(), "does not match") { + t.Errorf("error = %v, want message about mismatch", err) + } +} + +func TestTOFUEmptyPathDisablesVerification(t *testing.T) { + t.Parallel() + // Empty path returns an Insecure callback — useful for tests / + // throwaway tools. Document behaviour so the fallback doesn't + // silently regress to "always verify but without a file". + cb := TOFUHostKeyCallback("") + addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22} + if err := cb("127.0.0.1:22", addr, makeTestHostKey(t)); err != nil { + t.Fatalf("empty-path callback should accept: %v", err) + } +} + +func TestRemoveKnownHostsDropsEntry(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "known_hosts") + cb := TOFUHostKeyCallback(path) + keep := makeTestHostKey(t) + drop := makeTestHostKey(t) + + if err := cb("172.16.0.10:22", &net.TCPAddr{IP: net.ParseIP("172.16.0.10"), Port: 22}, keep); err != nil { + t.Fatalf("pin keep: %v", err) + } + if err := cb("172.16.0.11:22", &net.TCPAddr{IP: net.ParseIP("172.16.0.11"), Port: 22}, drop); err != nil { + t.Fatalf("pin drop: %v", err) + } + + if err := RemoveKnownHosts(path, "172.16.0.11"); err != nil { + t.Fatalf("RemoveKnownHosts: %v", err) + } + + data, _ := os.ReadFile(path) + content := string(data) + if !strings.Contains(content, "172.16.0.10") { + t.Errorf("kept entry missing:\n%s", content) + } + if strings.Contains(content, "172.16.0.11") { + t.Errorf("dropped entry still present:\n%s", content) + } +} + +func TestRemoveKnownHostsMissingFileIsNoOp(t *testing.T) { + t.Parallel() + missing := filepath.Join(t.TempDir(), "absent") + if err := RemoveKnownHosts(missing, "any"); err != nil { + t.Fatalf("RemoveKnownHosts on missing: %v", err) + } +} + +func TestRemoveKnownHostsEmptyPathIsNoOp(t *testing.T) { + t.Parallel() + if err := RemoveKnownHosts("", "any"); err != nil { + t.Fatalf("RemoveKnownHosts(empty): %v", err) + } +} + +// TestTOFURewritesAllowsReuseAfterRemove: after a VM is deleted and +// its pin is cleared, a future VM reusing the same IP (with a fresh +// host key) should re-pin cleanly, not fail the mismatch branch. +func TestTOFURewritesAllowsReuseAfterRemove(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "known_hosts") + cb := TOFUHostKeyCallback(path) + addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.15"), Port: 22} + + original := makeTestHostKey(t) + if err := cb("172.16.0.15:22", addr, original); err != nil { + t.Fatalf("pin original: %v", err) + } + + // VM deleted → pin removed. + if err := RemoveKnownHosts(path, "172.16.0.15"); err != nil { + t.Fatalf("RemoveKnownHosts: %v", err) + } + + // New VM, same IP, new host key. Must re-pin without error. + replacement := makeTestHostKey(t) + if err := cb("172.16.0.15:22", addr, replacement); err != nil { + t.Fatalf("re-pin after remove: %v", err) + } +} + +func TestHostLookupKeyStripsPort(t *testing.T) { + t.Parallel() + if got := hostLookupKey("10.0.0.1:22", nil); got != "10.0.0.1" { + t.Errorf("got %q, want 10.0.0.1", got) + } + if got := hostLookupKey("host.vm", nil); got != "host.vm" { + t.Errorf("got %q, want host.vm", got) + } + addr := &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 22} + if got := hostLookupKey("", addr); got != "1.2.3.4" { + t.Errorf("fallback: got %q, want 1.2.3.4", got) + } +} diff --git a/internal/guest/ssh.go b/internal/guest/ssh.go index 6723710..bbf2e4b 100644 --- a/internal/guest/ssh.go +++ b/internal/guest/ssh.go @@ -35,12 +35,15 @@ type StreamSession struct { closeOnce sync.Once } -func WaitForSSH(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { +// WaitForSSH polls Dial until it succeeds or ctx cancels. The +// knownHostsPath argument is the banger-owned TOFU file; empty +// disables host-key verification (tests only). +func WaitForSSH(ctx context.Context, address, privateKeyPath, knownHostsPath string, interval time.Duration) error { if interval <= 0 { interval = time.Second } for { - client, err := Dial(ctx, address, privateKeyPath) + client, err := Dial(ctx, address, privateKeyPath, knownHostsPath) if err == nil { _ = client.Close() return nil @@ -53,7 +56,11 @@ func WaitForSSH(ctx context.Context, address, privateKeyPath string, interval ti } } -func Dial(ctx context.Context, address, privateKeyPath string) (*Client, error) { +// Dial opens an SSH client to address, authenticating with the key +// at privateKeyPath and verifying the remote host key against the +// TOFU known_hosts file at knownHostsPath. An empty knownHostsPath +// disables verification (tests / one-shot tools only). +func Dial(ctx context.Context, address, privateKeyPath, knownHostsPath string) (*Client, error) { signer, err := privateKeySigner(privateKeyPath) if err != nil { return nil, err @@ -61,7 +68,7 @@ func Dial(ctx context.Context, address, privateKeyPath string) (*Client, error) config := &ssh.ClientConfig{ User: "root", Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + HostKeyCallback: TOFUHostKeyCallback(knownHostsPath), Timeout: 10 * time.Second, } dialer := &net.Dialer{Timeout: 10 * time.Second} diff --git a/internal/guest/ssh_more_test.go b/internal/guest/ssh_more_test.go index 271605e..4be594e 100644 --- a/internal/guest/ssh_more_test.go +++ b/internal/guest/ssh_more_test.go @@ -271,7 +271,7 @@ func TestWaitForSSHContextCancel(t *testing.T) { defer cancel() start := time.Now() - err := WaitForSSH(ctx, freeAddr(t), keyPath, 10*time.Millisecond) + err := WaitForSSH(ctx, freeAddr(t), keyPath, "", 10*time.Millisecond) if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("err = %v, want context.DeadlineExceeded", err) } @@ -286,7 +286,7 @@ func TestDialReturnsErrorForBadKey(t *testing.T) { if err := os.WriteFile(keyPath, []byte("nope"), 0o600); err != nil { t.Fatalf("WriteFile: %v", err) } - _, err := Dial(context.Background(), freeAddr(t), keyPath) + _, err := Dial(context.Background(), freeAddr(t), keyPath, "") if err == nil { t.Fatal("expected error for bad key") } diff --git a/internal/paths/paths.go b/internal/paths/paths.go index ce9ef96..518ea63 100644 --- a/internal/paths/paths.go +++ b/internal/paths/paths.go @@ -9,21 +9,23 @@ import ( ) type Layout struct { - ConfigHome string - StateHome string - CacheHome string - RuntimeHome string - ConfigDir string - StateDir string - CacheDir string - RuntimeDir string - SocketPath string - DBPath string - DaemonLog string - VMsDir string - ImagesDir string - KernelsDir string - OCICacheDir string + ConfigHome string + StateHome string + CacheHome string + RuntimeHome string + ConfigDir string + StateDir string + CacheDir string + RuntimeDir string + SocketPath string + DBPath string + DaemonLog string + VMsDir string + ImagesDir string + KernelsDir string + OCICacheDir string + SSHDir string + KnownHostsPath string } func Resolve() (Layout, error) { @@ -56,6 +58,8 @@ func Resolve() (Layout, error) { layout.ImagesDir = filepath.Join(layout.StateDir, "images") layout.KernelsDir = filepath.Join(layout.StateDir, "kernels") layout.OCICacheDir = filepath.Join(layout.CacheDir, "oci") + layout.SSHDir = filepath.Join(layout.StateDir, "ssh") + layout.KnownHostsPath = filepath.Join(layout.SSHDir, "known_hosts") return layout, nil } @@ -65,6 +69,15 @@ func Ensure(layout Layout) error { return err } } + // SSH material (private key, known_hosts) — 0700 like ~/.ssh so + // strict SSH clients don't complain and no other host user can + // read it. Empty SSHDir means the caller built a Layout by hand + // (tests) and doesn't need the subdir; skip silently. + if strings.TrimSpace(layout.SSHDir) != "" { + if err := os.MkdirAll(layout.SSHDir, 0o700); err != nil { + return err + } + } return nil }