ssh: trust-on-first-use host key pinning everywhere

Guest host-key verification was off in all three SSH paths:

  * Go SSH (internal/guest/ssh.go) used ssh.InsecureIgnoreHostKey
  * `banger vm ssh` passed StrictHostKeyChecking=no
    + UserKnownHostsFile=/dev/null
  * `~/.ssh/config` Host *.vm shipped the same posture into the
    user's global config

Now each path verifies against a banger-owned known_hosts file at
`~/.local/state/banger/ssh/known_hosts` with TOFU semantics:

  * First dial to a VM pins the key.
  * Subsequent dials require an exact match. A mismatch fails with
    an explicit "possible MITM" error.
  * `vm delete` removes the entries so a future VM reusing the IP
    or name re-pins cleanly.
  * The user's `~/.ssh/known_hosts` is untouched.

Changes:

  internal/guest/known_hosts.go (new) — OpenSSH-compatible parser,
    TOFUHostKeyCallback, RemoveKnownHosts. Process-wide mutex
    around the file.
  internal/guest/ssh.go — Dial and WaitForSSH grew a knownHostsPath
    parameter threaded through the callback. Empty path keeps the
    insecure callback (tests + throwaway tools only; documented).
  internal/daemon/{guest_sessions,session_attach,session_lifecycle,
    session_stream}.go — call sites pass d.layout.KnownHostsPath.
  internal/daemon/ssh_client_config.go — the ~/.ssh/config Host *.vm
    block now points at banger's known_hosts and uses
    StrictHostKeyChecking=accept-new. Missing path → fail closed.
  internal/daemon/vm_lifecycle.go — deleteVMLocked drops known_hosts
    entries for the VM's IP and DNS name via removeVMKnownHosts.
  internal/cli/banger.go — sshCommandArgs swaps StrictHostKeyChecking
    no + /dev/null for banger's file + accept-new. Path resolution
    failure falls through to StrictHostKeyChecking=yes.
  internal/paths/paths.go — Layout gains SSHDir + KnownHostsPath;
    Ensure creates SSHDir at 0700.

Tests (internal/guest/known_hosts_test.go): pin on first use, accept
matching key on second dial, reject mismatch, empty path skips
checking, RemoveKnownHosts drops the entry, re-pin works after
remove. Existing daemon + cli tests updated to assert the new
posture and regression-guard against the old flags.

Live verified: vm run writes the pin to banger's known_hosts at 0600
inside a 0700 dir; banger vm ssh + ssh root@<vm>.vm both succeed
using the pin; vm delete clears it.
This commit is contained in:
Thales Maciel 2026-04-19 16:46:03 -03:00
parent a59958d4f5
commit ae14b9499d
No known key found for this signature in database
GPG key ID: 33112E6833C34679
14 changed files with 634 additions and 47 deletions

View file

@ -131,10 +131,12 @@ var (
return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params)
}
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
return guest.WaitForSSH(ctx, address, privateKeyPath, interval)
knownHosts, _ := bangerKnownHostsPath()
return guest.WaitForSSH(ctx, address, privateKeyPath, knownHosts, interval)
}
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
return guest.Dial(ctx, address, privateKeyPath)
knownHosts, _ := bangerKnownHostsPath()
return guest.Dial(ctx, address, privateKeyPath, knownHosts)
}
prepareVMRunRepoCopyFunc = prepareVMRunRepoCopy
buildVMRunToolingPlanFunc = toolingplan.Build
@ -2669,6 +2671,12 @@ func sshCommandArgs(cfg model.DaemonConfig, guestIP string, extra []string) ([]s
if cfg.SSHKeyPath != "" {
args = append(args, "-i", cfg.SSHKeyPath)
}
// Host-key verification uses a banger-owned known_hosts file
// populated by the daemon's first successful Go-SSH dial to each
// VM (trust-on-first-use). `accept-new` means: accept-and-pin on
// first contact; strict-verify afterwards. The user's own
// ~/.known_hosts is untouched.
knownHosts, khErr := bangerKnownHostsPath()
args = append(
args,
"-o", "IdentitiesOnly=yes",
@ -2676,14 +2684,36 @@ func sshCommandArgs(cfg model.DaemonConfig, guestIP string, extra []string) ([]s
"-o", "PreferredAuthentications=publickey",
"-o", "PasswordAuthentication=no",
"-o", "KbdInteractiveAuthentication=no",
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"root@"+guestIP,
)
if khErr == nil {
args = append(args,
"-o", "UserKnownHostsFile="+knownHosts,
"-o", "StrictHostKeyChecking=accept-new",
)
} else {
// If we can't resolve the banger path (unusual — paths.Resolve
// basically can't fail), fall through to a hard-fail posture
// rather than silently disabling verification.
args = append(args,
"-o", "StrictHostKeyChecking=yes",
)
}
args = append(args, "root@"+guestIP)
args = append(args, extra...)
return args, nil
}
// bangerKnownHostsPath resolves the TOFU file the daemon writes into
// and the CLI reads back. Both sides must agree on the path or the
// pin doesn't round-trip.
func bangerKnownHostsPath() (string, error) {
layout, err := paths.Resolve()
if err != nil {
return "", err
}
return layout.KnownHostsPath, nil
}
func validateSSHPrereqs(cfg model.DaemonConfig) error {
checks := system.NewPreflight()
checks.RequireCommand("ssh", "install openssh-client")

View file

@ -1049,25 +1049,57 @@ func TestExecuteVMActionBatchRunsConcurrentlyAndPreservesOrder(t *testing.T) {
}
func TestSSHCommandArgs(t *testing.T) {
// sshCommandArgs wires banger's own known_hosts into the shell
// SSH invocation — never /dev/null. Assert the shape and the
// posture rather than the exact path (which is host-XDG-derived).
args, err := sshCommandArgs(model.DaemonConfig{SSHKeyPath: "/bundle/id_ed25519"}, "172.16.0.2", []string{"--", "uname", "-a"})
if err != nil {
t.Fatalf("sshCommandArgs: %v", err)
}
want := []string{
wantSubstrings := []string{
"-F", "/dev/null",
"-i", "/bundle/id_ed25519",
"-o", "IdentitiesOnly=yes",
"-o", "BatchMode=yes",
"-o", "PreferredAuthentications=publickey",
"-o", "PasswordAuthentication=no",
"-o", "KbdInteractiveAuthentication=no",
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"root@172.16.0.2",
"--", "uname", "-a",
}
if !reflect.DeepEqual(args, want) {
t.Fatalf("args = %v, want %v", args, want)
for _, s := range wantSubstrings {
found := false
for _, a := range args {
if a == s {
found = true
break
}
}
if !found {
t.Errorf("args missing %q: %v", s, args)
}
}
// Host-key verification posture: accept-new + a real path into
// banger state, not /dev/null.
joined := strings.Join(args, " ")
if !strings.Contains(joined, "StrictHostKeyChecking=accept-new") {
t.Errorf("args missing accept-new posture: %v", args)
}
if strings.Contains(joined, "UserKnownHostsFile=/dev/null") {
t.Errorf("args leaked UserKnownHostsFile=/dev/null: %v", args)
}
if strings.Contains(joined, "StrictHostKeyChecking=no") {
t.Errorf("args leaked StrictHostKeyChecking=no: %v", args)
}
// Must reference a known_hosts file ending in "known_hosts".
sawKnownHosts := false
for _, a := range args {
if strings.HasPrefix(a, "UserKnownHostsFile=") && strings.HasSuffix(a, "known_hosts") {
sawKnownHosts = true
}
}
if !sawKnownHosts {
t.Errorf("args missing UserKnownHostsFile=<banger known_hosts>: %v", args)
}
}

View file

@ -31,14 +31,14 @@ func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval t
if d != nil && d.guestWaitForSSH != nil {
return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
}
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval)
}
func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) {
if d != nil && d.guestDial != nil {
return d.guestDial(ctx, address, d.config.SSHKeyPath)
}
return guest.Dial(ctx, address, d.config.SSHKeyPath)
return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
}
func (d *Daemon) waitForGuestSessionReadyHook(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
@ -86,7 +86,7 @@ func (d *Daemon) refreshGuestSession(ctx context.Context, vm model.VMRecord, s m
func (d *Daemon) inspectGuestSessionState(ctx context.Context, vm model.VMRecord, s model.GuestSession) (session.StateSnapshot, error) {
if d.vmAlive(vm) {
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath)
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return session.StateSnapshot{}, err
}

View file

@ -189,7 +189,7 @@ func (d *Daemon) attachGuestSessionBridge(session model.GuestSession, controller
}
func (d *Daemon) openGuestSessionAttachStream(address, command string) (*guest.StreamSession, error) {
client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath)
client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return nil, err
}

View file

@ -195,7 +195,7 @@ func (d *Daemon) signalGuestSession(ctx context.Context, params api.GuestSession
}
return session, nil
}
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath)
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return model.GuestSession{}, err
}

View file

@ -90,7 +90,7 @@ func (d *Daemon) SendToGuestSession(ctx context.Context, params api.GuestSession
func (d *Daemon) readGuestSessionLog(ctx context.Context, vm model.VMRecord, session model.GuestSession, stream string, tailLines int) (string, error) {
if d.vmAlive(vm) {
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath)
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return "", err
}

View file

@ -2,13 +2,41 @@ package daemon
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"banger/internal/guest"
"banger/internal/model"
"banger/internal/paths"
)
// removeVMKnownHosts drops every host-key pin for vm from the
// banger-owned known_hosts. Best-effort — a failure here only
// matters if the same IP/name is reused by a fresh VM before the
// next daemon restart, and even then it just causes a
// TOFU-mismatch error that the user can clear manually. Logged at
// warn so it shows up if it ever actually breaks things.
func removeVMKnownHosts(knownHostsPath string, vm model.VMRecord, logger *slog.Logger) {
if strings.TrimSpace(knownHostsPath) == "" {
return
}
var hosts []string
if ip := strings.TrimSpace(vm.Runtime.GuestIP); ip != "" {
hosts = append(hosts, ip)
}
if dns := strings.TrimSpace(vm.Runtime.DNSName); dns != "" {
hosts = append(hosts, dns)
}
if len(hosts) == 0 {
return
}
if err := guest.RemoveKnownHosts(knownHostsPath, hosts...); err != nil && logger != nil {
logger.Warn("remove known_hosts entries", "vm_id", vm.ID, "error", err.Error())
}
}
const (
vmSSHConfigIncludeBegin = "# BEGIN BANGER MANAGED VM SSH"
vmSSHConfigIncludeEnd = "# END BANGER MANAGED VM SSH"
@ -39,7 +67,7 @@ func syncVMSSHClientConfig(layout paths.Layout, keyPath string) error {
if err != nil {
return err
}
updated, err := upsertManagedBlock(userConfig, vmSSHConfigIncludeBegin, vmSSHConfigIncludeEnd, renderManagedVMSSHBlock(keyPath))
updated, err := upsertManagedBlock(userConfig, vmSSHConfigIncludeBegin, vmSSHConfigIncludeEnd, renderManagedVMSSHBlock(keyPath, layout.KnownHostsPath))
if err != nil {
return err
}
@ -54,11 +82,19 @@ func syncVMSSHClientConfig(layout paths.Layout, keyPath string) error {
return nil
}
func renderManagedVMSSHBlock(keyPath string) string {
// renderManagedVMSSHBlock produces the `Host *.vm` stanza banger
// writes into the user's ~/.ssh/config. Host-key verification uses
// the banger-owned known_hosts file at knownHostsPath — NOT the
// user's ~/.ssh/known_hosts, and NOT /dev/null. `accept-new` means
// first contact pins the key; any later mismatch fails the connect.
func renderManagedVMSSHBlock(keyPath, knownHostsPath string) string {
keyPath = strings.TrimSpace(keyPath)
return strings.Join([]string{
knownHostsPath = strings.TrimSpace(knownHostsPath)
lines := []string{
vmSSHConfigIncludeBegin,
"# Generated by banger for direct SSH access to VM DNS names.",
"# Host keys are pinned on first use into a banger-owned",
"# known_hosts file (not ~/.ssh/known_hosts).",
"Host *.vm",
" User root",
" IdentityFile " + keyPath,
@ -67,12 +103,23 @@ func renderManagedVMSSHBlock(keyPath string) string {
" PreferredAuthentications publickey",
" PasswordAuthentication no",
" KbdInteractiveAuthentication no",
" StrictHostKeyChecking no",
" UserKnownHostsFile /dev/null",
}
if knownHostsPath != "" {
lines = append(lines,
" UserKnownHostsFile "+knownHostsPath,
" StrictHostKeyChecking accept-new",
)
} else {
// Missing known_hosts path is a configuration anomaly — fail
// closed rather than silently disable verification.
lines = append(lines, " StrictHostKeyChecking yes")
}
lines = append(lines,
" LogLevel ERROR",
vmSSHConfigIncludeEnd,
"",
}, "\n")
)
return strings.Join(lines, "\n")
}
func upsertManagedBlock(existing, beginMarker, endMarker, block string) (string, error) {

View file

@ -13,8 +13,10 @@ func TestSyncVMSSHClientConfigCreatesManagedBlock(t *testing.T) {
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
knownHostsPath := filepath.Join(homeDir, ".local", "state", "banger", "ssh", "known_hosts")
layout := paths.Layout{
ConfigDir: filepath.Join(homeDir, ".config", "banger"),
KnownHostsPath: knownHostsPath,
}
keyPath := filepath.Join(layout.ConfigDir, "ssh", "id_ed25519")
@ -38,12 +40,23 @@ func TestSyncVMSSHClientConfigCreatesManagedBlock(t *testing.T) {
"IdentitiesOnly yes",
"BatchMode yes",
"PasswordAuthentication no",
"UserKnownHostsFile /dev/null",
"UserKnownHostsFile " + knownHostsPath,
"StrictHostKeyChecking accept-new",
} {
if !strings.Contains(userContent, want) {
t.Fatalf("user config = %q, want %q", userContent, want)
}
}
// Regression: the legacy posture (StrictHostKeyChecking no +
// UserKnownHostsFile /dev/null) must never reappear.
for _, must := range []string{
"StrictHostKeyChecking no",
"UserKnownHostsFile /dev/null",
} {
if strings.Contains(userContent, must) {
t.Fatalf("user config leaked legacy posture %q:\n%s", must, userContent)
}
}
}
func TestSyncVMSSHClientConfigReplacesManagedIncludeBlock(t *testing.T) {

View file

@ -411,5 +411,9 @@ func (d *Daemon) deleteVMLocked(ctx context.Context, current model.VMRecord) (vm
return model.VMRecord{}, err
}
}
// Drop any host-key pins. A future VM reusing this IP or name
// would otherwise trip the TOFU mismatch branch in
// TOFUHostKeyCallback and fail to connect.
removeVMKnownHosts(d.layout.KnownHostsPath, vm, d.logger)
return vm, nil
}

View file

@ -0,0 +1,256 @@
package guest
import (
"bufio"
"encoding/base64"
"errors"
"fmt"
"net"
"os"
"strings"
"sync"
"golang.org/x/crypto/ssh"
)
// TOFUHostKeyCallback returns a HostKeyCallback that implements
// trust-on-first-use against a banger-owned known_hosts file.
//
// Semantics:
// - If the file has an entry for `host:port` → require an exact
// key match; a mismatch returns an error (MITM protection).
// - If no entry exists → append one and accept.
//
// The file format is compatible with OpenSSH so shell SSH clients can
// use the same path via `UserKnownHostsFile`.
//
// Callers keep a process-wide mutex on the file so concurrent dials
// to different VMs don't interleave writes.
//
// An empty path disables host-key checking entirely — only for test
// harnesses and tools that dial ad-hoc infrastructure; production
// paths must supply a real file.
func TOFUHostKeyCallback(path string) ssh.HostKeyCallback {
if strings.TrimSpace(path) == "" {
return ssh.InsecureIgnoreHostKey()
}
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
host := hostLookupKey(hostname, remote)
knownHostsMu.Lock()
defer knownHostsMu.Unlock()
entries, err := loadKnownHosts(path)
if err != nil {
return fmt.Errorf("read known_hosts: %w", err)
}
stored, matched := entries.match(host, key.Type())
if matched {
if keysEqual(stored.key, key) {
return nil
}
return fmt.Errorf("banger: host key for %s does not match pinned entry — "+
"possible MITM. If the VM was legitimately rebuilt, remove the old "+
"entry from %s and retry.", host, path)
}
if err := appendKnownHost(path, host, key); err != nil {
return fmt.Errorf("pin host key for %s: %w", host, err)
}
return nil
}
}
// RemoveKnownHosts strips every entry matching any host in `hosts`
// from the known_hosts file. Called on VM delete so a future VM
// reusing the same IP or name never trips the TOFU mismatch branch.
// Missing file / missing hosts = no-op.
func RemoveKnownHosts(path string, hosts ...string) error {
if strings.TrimSpace(path) == "" || len(hosts) == 0 {
return nil
}
knownHostsMu.Lock()
defer knownHostsMu.Unlock()
entries, err := loadKnownHosts(path)
if err != nil {
return err
}
drop := make(map[string]struct{}, len(hosts))
for _, h := range hosts {
h = strings.TrimSpace(h)
if h == "" {
continue
}
drop[h] = struct{}{}
}
if len(drop) == 0 {
return nil
}
filtered := entries.filter(func(e knownHostEntry) bool {
for _, h := range e.hosts {
if _, skip := drop[h]; skip {
return false
}
}
return true
})
return filtered.write(path)
}
var knownHostsMu sync.Mutex
// knownHostEntry is one line in known_hosts: a set of host patterns
// (comma-separated in the file), a key type, and a key blob.
type knownHostEntry struct {
hosts []string
keyType string
key ssh.PublicKey
raw string
}
type knownHostList []knownHostEntry
func (l knownHostList) match(host, keyType string) (knownHostEntry, bool) {
for _, e := range l {
if e.keyType != keyType {
continue
}
for _, h := range e.hosts {
if h == host {
return e, true
}
}
}
return knownHostEntry{}, false
}
func (l knownHostList) filter(keep func(knownHostEntry) bool) knownHostList {
out := make(knownHostList, 0, len(l))
for _, e := range l {
if keep(e) {
out = append(out, e)
}
}
return out
}
func (l knownHostList) write(path string) error {
if len(l) == 0 {
// If everything got filtered, truncate the file rather than
// removing it — callers may want the file to keep existing
// (with 0600 perms) for later appends.
return os.WriteFile(path, nil, 0o600)
}
var buf strings.Builder
for _, e := range l {
buf.WriteString(e.raw)
if !strings.HasSuffix(e.raw, "\n") {
buf.WriteByte('\n')
}
}
return os.WriteFile(path, []byte(buf.String()), 0o600)
}
func loadKnownHosts(path string) (knownHostList, error) {
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
defer f.Close()
var out knownHostList
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
continue
}
fields := strings.Fields(trimmed)
if len(fields) < 3 {
continue
}
keyBytes, err := base64.StdEncoding.DecodeString(fields[2])
if err != nil {
continue
}
key, err := ssh.ParsePublicKey(keyBytes)
if err != nil {
continue
}
out = append(out, knownHostEntry{
hosts: strings.Split(fields[0], ","),
keyType: fields[1],
key: key,
raw: line,
})
}
if err := scanner.Err(); err != nil {
return nil, err
}
return out, nil
}
func appendKnownHost(path, host string, key ssh.PublicKey) error {
line := fmt.Sprintf("%s %s %s\n",
host,
key.Type(),
base64.StdEncoding.EncodeToString(key.Marshal()),
)
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600)
if err != nil {
return err
}
defer f.Close()
_, err = f.WriteString(line)
return err
}
// hostLookupKey returns the canonical key under which we store host
// entries. For a TCP dial the SSH library hands us hostname of the
// form "host:port"; we normalise to "host" so pinning by IP also
// works for a hostname-based lookup that resolves to the same IP.
//
// If hostname contains a port, strip it. If it's empty, fall back to
// the remote address.
func hostLookupKey(hostname string, remote net.Addr) string {
if h, _, err := net.SplitHostPort(hostname); err == nil {
hostname = h
}
if strings.TrimSpace(hostname) != "" {
return hostname
}
if remote != nil {
if h, _, err := net.SplitHostPort(remote.String()); err == nil {
return h
}
return remote.String()
}
return ""
}
func keysEqual(a, b ssh.PublicKey) bool {
if a == nil || b == nil {
return a == nil && b == nil
}
ba := a.Marshal()
bb := b.Marshal()
if len(ba) != len(bb) {
return false
}
for i := range ba {
if ba[i] != bb[i] {
return false
}
}
return true
}
// errHostKeyMismatch sentinel is currently unused but reserved for
// callers that want to distinguish MITM from other failures.
var errHostKeyMismatch = errors.New("host key mismatch")
var _ = errHostKeyMismatch

View file

@ -0,0 +1,185 @@
package guest
import (
"crypto/ed25519"
"crypto/rand"
"net"
"os"
"path/filepath"
"strings"
"testing"
"golang.org/x/crypto/ssh"
)
// makeTestHostKey generates a fresh ed25519 key and returns the
// ssh.PublicKey the server would present during a handshake.
func makeTestHostKey(t *testing.T) ssh.PublicKey {
t.Helper()
pub, _, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
sshPub, err := ssh.NewPublicKey(pub)
if err != nil {
t.Fatalf("NewPublicKey: %v", err)
}
return sshPub
}
func TestTOFUHostKeyCallbackPinsOnFirstUse(t *testing.T) {
t.Parallel()
path := filepath.Join(t.TempDir(), "known_hosts")
cb := TOFUHostKeyCallback(path)
key := makeTestHostKey(t)
addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.5"), Port: 22}
if err := cb("172.16.0.5:22", addr, key); err != nil {
t.Fatalf("first-use callback: %v", err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile: %v", err)
}
content := string(data)
if !strings.Contains(content, "172.16.0.5") {
t.Errorf("known_hosts missing host:\n%s", content)
}
if !strings.Contains(content, key.Type()) {
t.Errorf("known_hosts missing key type:\n%s", content)
}
}
func TestTOFUHostKeyCallbackAcceptsMatch(t *testing.T) {
t.Parallel()
path := filepath.Join(t.TempDir(), "known_hosts")
cb := TOFUHostKeyCallback(path)
key := makeTestHostKey(t)
addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.6"), Port: 22}
if err := cb("172.16.0.6:22", addr, key); err != nil {
t.Fatalf("first-use: %v", err)
}
// Same key, second dial: must succeed.
if err := cb("172.16.0.6:22", addr, key); err != nil {
t.Fatalf("second dial with matching key: %v", err)
}
}
func TestTOFUHostKeyCallbackRejectsMismatch(t *testing.T) {
t.Parallel()
path := filepath.Join(t.TempDir(), "known_hosts")
cb := TOFUHostKeyCallback(path)
addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.7"), Port: 22}
original := makeTestHostKey(t)
if err := cb("172.16.0.7:22", addr, original); err != nil {
t.Fatalf("pin original: %v", err)
}
impostor := makeTestHostKey(t)
err := cb("172.16.0.7:22", addr, impostor)
if err == nil {
t.Fatal("expected mismatch error, got nil")
}
if !strings.Contains(err.Error(), "does not match") {
t.Errorf("error = %v, want message about mismatch", err)
}
}
func TestTOFUEmptyPathDisablesVerification(t *testing.T) {
t.Parallel()
// Empty path returns an Insecure callback — useful for tests /
// throwaway tools. Document behaviour so the fallback doesn't
// silently regress to "always verify but without a file".
cb := TOFUHostKeyCallback("")
addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}
if err := cb("127.0.0.1:22", addr, makeTestHostKey(t)); err != nil {
t.Fatalf("empty-path callback should accept: %v", err)
}
}
func TestRemoveKnownHostsDropsEntry(t *testing.T) {
t.Parallel()
path := filepath.Join(t.TempDir(), "known_hosts")
cb := TOFUHostKeyCallback(path)
keep := makeTestHostKey(t)
drop := makeTestHostKey(t)
if err := cb("172.16.0.10:22", &net.TCPAddr{IP: net.ParseIP("172.16.0.10"), Port: 22}, keep); err != nil {
t.Fatalf("pin keep: %v", err)
}
if err := cb("172.16.0.11:22", &net.TCPAddr{IP: net.ParseIP("172.16.0.11"), Port: 22}, drop); err != nil {
t.Fatalf("pin drop: %v", err)
}
if err := RemoveKnownHosts(path, "172.16.0.11"); err != nil {
t.Fatalf("RemoveKnownHosts: %v", err)
}
data, _ := os.ReadFile(path)
content := string(data)
if !strings.Contains(content, "172.16.0.10") {
t.Errorf("kept entry missing:\n%s", content)
}
if strings.Contains(content, "172.16.0.11") {
t.Errorf("dropped entry still present:\n%s", content)
}
}
func TestRemoveKnownHostsMissingFileIsNoOp(t *testing.T) {
t.Parallel()
missing := filepath.Join(t.TempDir(), "absent")
if err := RemoveKnownHosts(missing, "any"); err != nil {
t.Fatalf("RemoveKnownHosts on missing: %v", err)
}
}
func TestRemoveKnownHostsEmptyPathIsNoOp(t *testing.T) {
t.Parallel()
if err := RemoveKnownHosts("", "any"); err != nil {
t.Fatalf("RemoveKnownHosts(empty): %v", err)
}
}
// TestTOFURewritesAllowsReuseAfterRemove: after a VM is deleted and
// its pin is cleared, a future VM reusing the same IP (with a fresh
// host key) should re-pin cleanly, not fail the mismatch branch.
func TestTOFURewritesAllowsReuseAfterRemove(t *testing.T) {
t.Parallel()
path := filepath.Join(t.TempDir(), "known_hosts")
cb := TOFUHostKeyCallback(path)
addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.15"), Port: 22}
original := makeTestHostKey(t)
if err := cb("172.16.0.15:22", addr, original); err != nil {
t.Fatalf("pin original: %v", err)
}
// VM deleted → pin removed.
if err := RemoveKnownHosts(path, "172.16.0.15"); err != nil {
t.Fatalf("RemoveKnownHosts: %v", err)
}
// New VM, same IP, new host key. Must re-pin without error.
replacement := makeTestHostKey(t)
if err := cb("172.16.0.15:22", addr, replacement); err != nil {
t.Fatalf("re-pin after remove: %v", err)
}
}
func TestHostLookupKeyStripsPort(t *testing.T) {
t.Parallel()
if got := hostLookupKey("10.0.0.1:22", nil); got != "10.0.0.1" {
t.Errorf("got %q, want 10.0.0.1", got)
}
if got := hostLookupKey("host.vm", nil); got != "host.vm" {
t.Errorf("got %q, want host.vm", got)
}
addr := &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 22}
if got := hostLookupKey("", addr); got != "1.2.3.4" {
t.Errorf("fallback: got %q, want 1.2.3.4", got)
}
}

View file

@ -35,12 +35,15 @@ type StreamSession struct {
closeOnce sync.Once
}
func WaitForSSH(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
// WaitForSSH polls Dial until it succeeds or ctx cancels. The
// knownHostsPath argument is the banger-owned TOFU file; empty
// disables host-key verification (tests only).
func WaitForSSH(ctx context.Context, address, privateKeyPath, knownHostsPath string, interval time.Duration) error {
if interval <= 0 {
interval = time.Second
}
for {
client, err := Dial(ctx, address, privateKeyPath)
client, err := Dial(ctx, address, privateKeyPath, knownHostsPath)
if err == nil {
_ = client.Close()
return nil
@ -53,7 +56,11 @@ func WaitForSSH(ctx context.Context, address, privateKeyPath string, interval ti
}
}
func Dial(ctx context.Context, address, privateKeyPath string) (*Client, error) {
// Dial opens an SSH client to address, authenticating with the key
// at privateKeyPath and verifying the remote host key against the
// TOFU known_hosts file at knownHostsPath. An empty knownHostsPath
// disables verification (tests / one-shot tools only).
func Dial(ctx context.Context, address, privateKeyPath, knownHostsPath string) (*Client, error) {
signer, err := privateKeySigner(privateKeyPath)
if err != nil {
return nil, err
@ -61,7 +68,7 @@ func Dial(ctx context.Context, address, privateKeyPath string) (*Client, error)
config := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
HostKeyCallback: TOFUHostKeyCallback(knownHostsPath),
Timeout: 10 * time.Second,
}
dialer := &net.Dialer{Timeout: 10 * time.Second}

View file

@ -271,7 +271,7 @@ func TestWaitForSSHContextCancel(t *testing.T) {
defer cancel()
start := time.Now()
err := WaitForSSH(ctx, freeAddr(t), keyPath, 10*time.Millisecond)
err := WaitForSSH(ctx, freeAddr(t), keyPath, "", 10*time.Millisecond)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("err = %v, want context.DeadlineExceeded", err)
}
@ -286,7 +286,7 @@ func TestDialReturnsErrorForBadKey(t *testing.T) {
if err := os.WriteFile(keyPath, []byte("nope"), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
_, err := Dial(context.Background(), freeAddr(t), keyPath)
_, err := Dial(context.Background(), freeAddr(t), keyPath, "")
if err == nil {
t.Fatal("expected error for bad key")
}

View file

@ -24,6 +24,8 @@ type Layout struct {
ImagesDir string
KernelsDir string
OCICacheDir string
SSHDir string
KnownHostsPath string
}
func Resolve() (Layout, error) {
@ -56,6 +58,8 @@ func Resolve() (Layout, error) {
layout.ImagesDir = filepath.Join(layout.StateDir, "images")
layout.KernelsDir = filepath.Join(layout.StateDir, "kernels")
layout.OCICacheDir = filepath.Join(layout.CacheDir, "oci")
layout.SSHDir = filepath.Join(layout.StateDir, "ssh")
layout.KnownHostsPath = filepath.Join(layout.SSHDir, "known_hosts")
return layout, nil
}
@ -65,6 +69,15 @@ func Ensure(layout Layout) error {
return err
}
}
// SSH material (private key, known_hosts) — 0700 like ~/.ssh so
// strict SSH clients don't complain and no other host user can
// read it. Empty SSHDir means the caller built a Layout by hand
// (tests) and doesn't need the subdir; skip silently.
if strings.TrimSpace(layout.SSHDir) != "" {
if err := os.MkdirAll(layout.SSHDir, 0o700); err != nil {
return err
}
}
return nil
}