ssh: trust-on-first-use host key pinning everywhere

Guest host-key verification was off in all three SSH paths:

  * Go SSH (internal/guest/ssh.go) used ssh.InsecureIgnoreHostKey
  * `banger vm ssh` passed StrictHostKeyChecking=no
    + UserKnownHostsFile=/dev/null
  * `~/.ssh/config` Host *.vm shipped the same posture into the
    user's global config

Now each path verifies against a banger-owned known_hosts file at
`~/.local/state/banger/ssh/known_hosts` with TOFU semantics:

  * First dial to a VM pins the key.
  * Subsequent dials require an exact match. A mismatch fails with
    an explicit "possible MITM" error.
  * `vm delete` removes the entries so a future VM reusing the IP
    or name re-pins cleanly.
  * The user's `~/.ssh/known_hosts` is untouched.

Changes:

  internal/guest/known_hosts.go (new) — OpenSSH-compatible parser,
    TOFUHostKeyCallback, RemoveKnownHosts. Process-wide mutex
    around the file.
  internal/guest/ssh.go — Dial and WaitForSSH grew a knownHostsPath
    parameter threaded through the callback. Empty path keeps the
    insecure callback (tests + throwaway tools only; documented).
  internal/daemon/{guest_sessions,session_attach,session_lifecycle,
    session_stream}.go — call sites pass d.layout.KnownHostsPath.
  internal/daemon/ssh_client_config.go — the ~/.ssh/config Host *.vm
    block now points at banger's known_hosts and uses
    StrictHostKeyChecking=accept-new. Missing path → fail closed.
  internal/daemon/vm_lifecycle.go — deleteVMLocked drops known_hosts
    entries for the VM's IP and DNS name via removeVMKnownHosts.
  internal/cli/banger.go — sshCommandArgs swaps StrictHostKeyChecking
    no + /dev/null for banger's file + accept-new. Path resolution
    failure falls through to StrictHostKeyChecking=yes.
  internal/paths/paths.go — Layout gains SSHDir + KnownHostsPath;
    Ensure creates SSHDir at 0700.

Tests (internal/guest/known_hosts_test.go): pin on first use, accept
matching key on second dial, reject mismatch, empty path skips
checking, RemoveKnownHosts drops the entry, re-pin works after
remove. Existing daemon + cli tests updated to assert the new
posture and regression-guard against the old flags.

Live verified: vm run writes the pin to banger's known_hosts at 0600
inside a 0700 dir; banger vm ssh + ssh root@<vm>.vm both succeed
using the pin; vm delete clears it.
This commit is contained in:
Thales Maciel 2026-04-19 16:46:03 -03:00
parent a59958d4f5
commit ae14b9499d
No known key found for this signature in database
GPG key ID: 33112E6833C34679
14 changed files with 634 additions and 47 deletions

View file

@ -131,10 +131,12 @@ var (
return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params)
}
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
return guest.WaitForSSH(ctx, address, privateKeyPath, interval)
knownHosts, _ := bangerKnownHostsPath()
return guest.WaitForSSH(ctx, address, privateKeyPath, knownHosts, interval)
}
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
return guest.Dial(ctx, address, privateKeyPath)
knownHosts, _ := bangerKnownHostsPath()
return guest.Dial(ctx, address, privateKeyPath, knownHosts)
}
prepareVMRunRepoCopyFunc = prepareVMRunRepoCopy
buildVMRunToolingPlanFunc = toolingplan.Build
@ -2669,6 +2671,12 @@ func sshCommandArgs(cfg model.DaemonConfig, guestIP string, extra []string) ([]s
if cfg.SSHKeyPath != "" {
args = append(args, "-i", cfg.SSHKeyPath)
}
// Host-key verification uses a banger-owned known_hosts file
// populated by the daemon's first successful Go-SSH dial to each
// VM (trust-on-first-use). `accept-new` means: accept-and-pin on
// first contact; strict-verify afterwards. The user's own
// ~/.known_hosts is untouched.
knownHosts, khErr := bangerKnownHostsPath()
args = append(
args,
"-o", "IdentitiesOnly=yes",
@ -2676,14 +2684,36 @@ func sshCommandArgs(cfg model.DaemonConfig, guestIP string, extra []string) ([]s
"-o", "PreferredAuthentications=publickey",
"-o", "PasswordAuthentication=no",
"-o", "KbdInteractiveAuthentication=no",
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"root@"+guestIP,
)
if khErr == nil {
args = append(args,
"-o", "UserKnownHostsFile="+knownHosts,
"-o", "StrictHostKeyChecking=accept-new",
)
} else {
// If we can't resolve the banger path (unusual — paths.Resolve
// basically can't fail), fall through to a hard-fail posture
// rather than silently disabling verification.
args = append(args,
"-o", "StrictHostKeyChecking=yes",
)
}
args = append(args, "root@"+guestIP)
args = append(args, extra...)
return args, nil
}
// bangerKnownHostsPath resolves the TOFU file the daemon writes into
// and the CLI reads back. Both sides must agree on the path or the
// pin doesn't round-trip.
func bangerKnownHostsPath() (string, error) {
layout, err := paths.Resolve()
if err != nil {
return "", err
}
return layout.KnownHostsPath, nil
}
func validateSSHPrereqs(cfg model.DaemonConfig) error {
checks := system.NewPreflight()
checks.RequireCommand("ssh", "install openssh-client")

View file

@ -1049,25 +1049,57 @@ func TestExecuteVMActionBatchRunsConcurrentlyAndPreservesOrder(t *testing.T) {
}
func TestSSHCommandArgs(t *testing.T) {
// sshCommandArgs wires banger's own known_hosts into the shell
// SSH invocation — never /dev/null. Assert the shape and the
// posture rather than the exact path (which is host-XDG-derived).
args, err := sshCommandArgs(model.DaemonConfig{SSHKeyPath: "/bundle/id_ed25519"}, "172.16.0.2", []string{"--", "uname", "-a"})
if err != nil {
t.Fatalf("sshCommandArgs: %v", err)
}
want := []string{
wantSubstrings := []string{
"-F", "/dev/null",
"-i", "/bundle/id_ed25519",
"-o", "IdentitiesOnly=yes",
"-o", "BatchMode=yes",
"-o", "PreferredAuthentications=publickey",
"-o", "PasswordAuthentication=no",
"-o", "KbdInteractiveAuthentication=no",
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"root@172.16.0.2",
"--", "uname", "-a",
}
if !reflect.DeepEqual(args, want) {
t.Fatalf("args = %v, want %v", args, want)
for _, s := range wantSubstrings {
found := false
for _, a := range args {
if a == s {
found = true
break
}
}
if !found {
t.Errorf("args missing %q: %v", s, args)
}
}
// Host-key verification posture: accept-new + a real path into
// banger state, not /dev/null.
joined := strings.Join(args, " ")
if !strings.Contains(joined, "StrictHostKeyChecking=accept-new") {
t.Errorf("args missing accept-new posture: %v", args)
}
if strings.Contains(joined, "UserKnownHostsFile=/dev/null") {
t.Errorf("args leaked UserKnownHostsFile=/dev/null: %v", args)
}
if strings.Contains(joined, "StrictHostKeyChecking=no") {
t.Errorf("args leaked StrictHostKeyChecking=no: %v", args)
}
// Must reference a known_hosts file ending in "known_hosts".
sawKnownHosts := false
for _, a := range args {
if strings.HasPrefix(a, "UserKnownHostsFile=") && strings.HasSuffix(a, "known_hosts") {
sawKnownHosts = true
}
}
if !sawKnownHosts {
t.Errorf("args missing UserKnownHostsFile=<banger known_hosts>: %v", args)
}
}

View file

@ -31,14 +31,14 @@ func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval t
if d != nil && d.guestWaitForSSH != nil {
return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
}
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval)
}
func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) {
if d != nil && d.guestDial != nil {
return d.guestDial(ctx, address, d.config.SSHKeyPath)
}
return guest.Dial(ctx, address, d.config.SSHKeyPath)
return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
}
func (d *Daemon) waitForGuestSessionReadyHook(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
@ -86,7 +86,7 @@ func (d *Daemon) refreshGuestSession(ctx context.Context, vm model.VMRecord, s m
func (d *Daemon) inspectGuestSessionState(ctx context.Context, vm model.VMRecord, s model.GuestSession) (session.StateSnapshot, error) {
if d.vmAlive(vm) {
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath)
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return session.StateSnapshot{}, err
}

View file

@ -189,7 +189,7 @@ func (d *Daemon) attachGuestSessionBridge(session model.GuestSession, controller
}
func (d *Daemon) openGuestSessionAttachStream(address, command string) (*guest.StreamSession, error) {
client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath)
client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return nil, err
}

View file

@ -195,7 +195,7 @@ func (d *Daemon) signalGuestSession(ctx context.Context, params api.GuestSession
}
return session, nil
}
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath)
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return model.GuestSession{}, err
}

View file

@ -90,7 +90,7 @@ func (d *Daemon) SendToGuestSession(ctx context.Context, params api.GuestSession
func (d *Daemon) readGuestSessionLog(ctx context.Context, vm model.VMRecord, session model.GuestSession, stream string, tailLines int) (string, error) {
if d.vmAlive(vm) {
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath)
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return "", err
}

View file

@ -2,13 +2,41 @@ package daemon
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"banger/internal/guest"
"banger/internal/model"
"banger/internal/paths"
)
// removeVMKnownHosts drops every host-key pin for vm from the
// banger-owned known_hosts. Best-effort — a failure here only
// matters if the same IP/name is reused by a fresh VM before the
// next daemon restart, and even then it just causes a
// TOFU-mismatch error that the user can clear manually. Logged at
// warn so it shows up if it ever actually breaks things.
func removeVMKnownHosts(knownHostsPath string, vm model.VMRecord, logger *slog.Logger) {
if strings.TrimSpace(knownHostsPath) == "" {
return
}
var hosts []string
if ip := strings.TrimSpace(vm.Runtime.GuestIP); ip != "" {
hosts = append(hosts, ip)
}
if dns := strings.TrimSpace(vm.Runtime.DNSName); dns != "" {
hosts = append(hosts, dns)
}
if len(hosts) == 0 {
return
}
if err := guest.RemoveKnownHosts(knownHostsPath, hosts...); err != nil && logger != nil {
logger.Warn("remove known_hosts entries", "vm_id", vm.ID, "error", err.Error())
}
}
const (
vmSSHConfigIncludeBegin = "# BEGIN BANGER MANAGED VM SSH"
vmSSHConfigIncludeEnd = "# END BANGER MANAGED VM SSH"
@ -39,7 +67,7 @@ func syncVMSSHClientConfig(layout paths.Layout, keyPath string) error {
if err != nil {
return err
}
updated, err := upsertManagedBlock(userConfig, vmSSHConfigIncludeBegin, vmSSHConfigIncludeEnd, renderManagedVMSSHBlock(keyPath))
updated, err := upsertManagedBlock(userConfig, vmSSHConfigIncludeBegin, vmSSHConfigIncludeEnd, renderManagedVMSSHBlock(keyPath, layout.KnownHostsPath))
if err != nil {
return err
}
@ -54,11 +82,19 @@ func syncVMSSHClientConfig(layout paths.Layout, keyPath string) error {
return nil
}
func renderManagedVMSSHBlock(keyPath string) string {
// renderManagedVMSSHBlock produces the `Host *.vm` stanza banger
// writes into the user's ~/.ssh/config. Host-key verification uses
// the banger-owned known_hosts file at knownHostsPath — NOT the
// user's ~/.ssh/known_hosts, and NOT /dev/null. `accept-new` means
// first contact pins the key; any later mismatch fails the connect.
func renderManagedVMSSHBlock(keyPath, knownHostsPath string) string {
keyPath = strings.TrimSpace(keyPath)
return strings.Join([]string{
knownHostsPath = strings.TrimSpace(knownHostsPath)
lines := []string{
vmSSHConfigIncludeBegin,
"# Generated by banger for direct SSH access to VM DNS names.",
"# Host keys are pinned on first use into a banger-owned",
"# known_hosts file (not ~/.ssh/known_hosts).",
"Host *.vm",
" User root",
" IdentityFile " + keyPath,
@ -67,12 +103,23 @@ func renderManagedVMSSHBlock(keyPath string) string {
" PreferredAuthentications publickey",
" PasswordAuthentication no",
" KbdInteractiveAuthentication no",
" StrictHostKeyChecking no",
" UserKnownHostsFile /dev/null",
}
if knownHostsPath != "" {
lines = append(lines,
" UserKnownHostsFile "+knownHostsPath,
" StrictHostKeyChecking accept-new",
)
} else {
// Missing known_hosts path is a configuration anomaly — fail
// closed rather than silently disable verification.
lines = append(lines, " StrictHostKeyChecking yes")
}
lines = append(lines,
" LogLevel ERROR",
vmSSHConfigIncludeEnd,
"",
}, "\n")
)
return strings.Join(lines, "\n")
}
func upsertManagedBlock(existing, beginMarker, endMarker, block string) (string, error) {

View file

@ -13,8 +13,10 @@ func TestSyncVMSSHClientConfigCreatesManagedBlock(t *testing.T) {
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
knownHostsPath := filepath.Join(homeDir, ".local", "state", "banger", "ssh", "known_hosts")
layout := paths.Layout{
ConfigDir: filepath.Join(homeDir, ".config", "banger"),
ConfigDir: filepath.Join(homeDir, ".config", "banger"),
KnownHostsPath: knownHostsPath,
}
keyPath := filepath.Join(layout.ConfigDir, "ssh", "id_ed25519")
@ -38,12 +40,23 @@ func TestSyncVMSSHClientConfigCreatesManagedBlock(t *testing.T) {
"IdentitiesOnly yes",
"BatchMode yes",
"PasswordAuthentication no",
"UserKnownHostsFile /dev/null",
"UserKnownHostsFile " + knownHostsPath,
"StrictHostKeyChecking accept-new",
} {
if !strings.Contains(userContent, want) {
t.Fatalf("user config = %q, want %q", userContent, want)
}
}
// Regression: the legacy posture (StrictHostKeyChecking no +
// UserKnownHostsFile /dev/null) must never reappear.
for _, must := range []string{
"StrictHostKeyChecking no",
"UserKnownHostsFile /dev/null",
} {
if strings.Contains(userContent, must) {
t.Fatalf("user config leaked legacy posture %q:\n%s", must, userContent)
}
}
}
func TestSyncVMSSHClientConfigReplacesManagedIncludeBlock(t *testing.T) {

View file

@ -411,5 +411,9 @@ func (d *Daemon) deleteVMLocked(ctx context.Context, current model.VMRecord) (vm
return model.VMRecord{}, err
}
}
// Drop any host-key pins. A future VM reusing this IP or name
// would otherwise trip the TOFU mismatch branch in
// TOFUHostKeyCallback and fail to connect.
removeVMKnownHosts(d.layout.KnownHostsPath, vm, d.logger)
return vm, nil
}

View file

@ -0,0 +1,256 @@
package guest
import (
"bufio"
"encoding/base64"
"errors"
"fmt"
"net"
"os"
"strings"
"sync"
"golang.org/x/crypto/ssh"
)
// TOFUHostKeyCallback returns a HostKeyCallback that implements
// trust-on-first-use against a banger-owned known_hosts file.
//
// Semantics:
// - If the file has an entry for `host:port` → require an exact
// key match; a mismatch returns an error (MITM protection).
// - If no entry exists → append one and accept.
//
// The file format is compatible with OpenSSH so shell SSH clients can
// use the same path via `UserKnownHostsFile`.
//
// Callers keep a process-wide mutex on the file so concurrent dials
// to different VMs don't interleave writes.
//
// An empty path disables host-key checking entirely — only for test
// harnesses and tools that dial ad-hoc infrastructure; production
// paths must supply a real file.
func TOFUHostKeyCallback(path string) ssh.HostKeyCallback {
if strings.TrimSpace(path) == "" {
return ssh.InsecureIgnoreHostKey()
}
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
host := hostLookupKey(hostname, remote)
knownHostsMu.Lock()
defer knownHostsMu.Unlock()
entries, err := loadKnownHosts(path)
if err != nil {
return fmt.Errorf("read known_hosts: %w", err)
}
stored, matched := entries.match(host, key.Type())
if matched {
if keysEqual(stored.key, key) {
return nil
}
return fmt.Errorf("banger: host key for %s does not match pinned entry — "+
"possible MITM. If the VM was legitimately rebuilt, remove the old "+
"entry from %s and retry.", host, path)
}
if err := appendKnownHost(path, host, key); err != nil {
return fmt.Errorf("pin host key for %s: %w", host, err)
}
return nil
}
}
// RemoveKnownHosts strips every entry matching any host in `hosts`
// from the known_hosts file. Called on VM delete so a future VM
// reusing the same IP or name never trips the TOFU mismatch branch.
// Missing file / missing hosts = no-op.
func RemoveKnownHosts(path string, hosts ...string) error {
if strings.TrimSpace(path) == "" || len(hosts) == 0 {
return nil
}
knownHostsMu.Lock()
defer knownHostsMu.Unlock()
entries, err := loadKnownHosts(path)
if err != nil {
return err
}
drop := make(map[string]struct{}, len(hosts))
for _, h := range hosts {
h = strings.TrimSpace(h)
if h == "" {
continue
}
drop[h] = struct{}{}
}
if len(drop) == 0 {
return nil
}
filtered := entries.filter(func(e knownHostEntry) bool {
for _, h := range e.hosts {
if _, skip := drop[h]; skip {
return false
}
}
return true
})
return filtered.write(path)
}
var knownHostsMu sync.Mutex
// knownHostEntry is one line in known_hosts: a set of host patterns
// (comma-separated in the file), a key type, and a key blob.
type knownHostEntry struct {
hosts []string
keyType string
key ssh.PublicKey
raw string
}
type knownHostList []knownHostEntry
func (l knownHostList) match(host, keyType string) (knownHostEntry, bool) {
for _, e := range l {
if e.keyType != keyType {
continue
}
for _, h := range e.hosts {
if h == host {
return e, true
}
}
}
return knownHostEntry{}, false
}
func (l knownHostList) filter(keep func(knownHostEntry) bool) knownHostList {
out := make(knownHostList, 0, len(l))
for _, e := range l {
if keep(e) {
out = append(out, e)
}
}
return out
}
func (l knownHostList) write(path string) error {
if len(l) == 0 {
// If everything got filtered, truncate the file rather than
// removing it — callers may want the file to keep existing
// (with 0600 perms) for later appends.
return os.WriteFile(path, nil, 0o600)
}
var buf strings.Builder
for _, e := range l {
buf.WriteString(e.raw)
if !strings.HasSuffix(e.raw, "\n") {
buf.WriteByte('\n')
}
}
return os.WriteFile(path, []byte(buf.String()), 0o600)
}
func loadKnownHosts(path string) (knownHostList, error) {
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
defer f.Close()
var out knownHostList
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
continue
}
fields := strings.Fields(trimmed)
if len(fields) < 3 {
continue
}
keyBytes, err := base64.StdEncoding.DecodeString(fields[2])
if err != nil {
continue
}
key, err := ssh.ParsePublicKey(keyBytes)
if err != nil {
continue
}
out = append(out, knownHostEntry{
hosts: strings.Split(fields[0], ","),
keyType: fields[1],
key: key,
raw: line,
})
}
if err := scanner.Err(); err != nil {
return nil, err
}
return out, nil
}
func appendKnownHost(path, host string, key ssh.PublicKey) error {
line := fmt.Sprintf("%s %s %s\n",
host,
key.Type(),
base64.StdEncoding.EncodeToString(key.Marshal()),
)
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600)
if err != nil {
return err
}
defer f.Close()
_, err = f.WriteString(line)
return err
}
// hostLookupKey returns the canonical key under which we store host
// entries. For a TCP dial the SSH library hands us hostname of the
// form "host:port"; we normalise to "host" so pinning by IP also
// works for a hostname-based lookup that resolves to the same IP.
//
// If hostname contains a port, strip it. If it's empty, fall back to
// the remote address.
func hostLookupKey(hostname string, remote net.Addr) string {
if h, _, err := net.SplitHostPort(hostname); err == nil {
hostname = h
}
if strings.TrimSpace(hostname) != "" {
return hostname
}
if remote != nil {
if h, _, err := net.SplitHostPort(remote.String()); err == nil {
return h
}
return remote.String()
}
return ""
}
func keysEqual(a, b ssh.PublicKey) bool {
if a == nil || b == nil {
return a == nil && b == nil
}
ba := a.Marshal()
bb := b.Marshal()
if len(ba) != len(bb) {
return false
}
for i := range ba {
if ba[i] != bb[i] {
return false
}
}
return true
}
// errHostKeyMismatch sentinel is currently unused but reserved for
// callers that want to distinguish MITM from other failures.
var errHostKeyMismatch = errors.New("host key mismatch")
var _ = errHostKeyMismatch

View file

@ -0,0 +1,185 @@
package guest
import (
"crypto/ed25519"
"crypto/rand"
"net"
"os"
"path/filepath"
"strings"
"testing"
"golang.org/x/crypto/ssh"
)
// makeTestHostKey generates a fresh ed25519 key and returns the
// ssh.PublicKey the server would present during a handshake.
func makeTestHostKey(t *testing.T) ssh.PublicKey {
t.Helper()
pub, _, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
sshPub, err := ssh.NewPublicKey(pub)
if err != nil {
t.Fatalf("NewPublicKey: %v", err)
}
return sshPub
}
func TestTOFUHostKeyCallbackPinsOnFirstUse(t *testing.T) {
t.Parallel()
path := filepath.Join(t.TempDir(), "known_hosts")
cb := TOFUHostKeyCallback(path)
key := makeTestHostKey(t)
addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.5"), Port: 22}
if err := cb("172.16.0.5:22", addr, key); err != nil {
t.Fatalf("first-use callback: %v", err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile: %v", err)
}
content := string(data)
if !strings.Contains(content, "172.16.0.5") {
t.Errorf("known_hosts missing host:\n%s", content)
}
if !strings.Contains(content, key.Type()) {
t.Errorf("known_hosts missing key type:\n%s", content)
}
}
func TestTOFUHostKeyCallbackAcceptsMatch(t *testing.T) {
t.Parallel()
path := filepath.Join(t.TempDir(), "known_hosts")
cb := TOFUHostKeyCallback(path)
key := makeTestHostKey(t)
addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.6"), Port: 22}
if err := cb("172.16.0.6:22", addr, key); err != nil {
t.Fatalf("first-use: %v", err)
}
// Same key, second dial: must succeed.
if err := cb("172.16.0.6:22", addr, key); err != nil {
t.Fatalf("second dial with matching key: %v", err)
}
}
func TestTOFUHostKeyCallbackRejectsMismatch(t *testing.T) {
t.Parallel()
path := filepath.Join(t.TempDir(), "known_hosts")
cb := TOFUHostKeyCallback(path)
addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.7"), Port: 22}
original := makeTestHostKey(t)
if err := cb("172.16.0.7:22", addr, original); err != nil {
t.Fatalf("pin original: %v", err)
}
impostor := makeTestHostKey(t)
err := cb("172.16.0.7:22", addr, impostor)
if err == nil {
t.Fatal("expected mismatch error, got nil")
}
if !strings.Contains(err.Error(), "does not match") {
t.Errorf("error = %v, want message about mismatch", err)
}
}
func TestTOFUEmptyPathDisablesVerification(t *testing.T) {
t.Parallel()
// Empty path returns an Insecure callback — useful for tests /
// throwaway tools. Document behaviour so the fallback doesn't
// silently regress to "always verify but without a file".
cb := TOFUHostKeyCallback("")
addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}
if err := cb("127.0.0.1:22", addr, makeTestHostKey(t)); err != nil {
t.Fatalf("empty-path callback should accept: %v", err)
}
}
func TestRemoveKnownHostsDropsEntry(t *testing.T) {
t.Parallel()
path := filepath.Join(t.TempDir(), "known_hosts")
cb := TOFUHostKeyCallback(path)
keep := makeTestHostKey(t)
drop := makeTestHostKey(t)
if err := cb("172.16.0.10:22", &net.TCPAddr{IP: net.ParseIP("172.16.0.10"), Port: 22}, keep); err != nil {
t.Fatalf("pin keep: %v", err)
}
if err := cb("172.16.0.11:22", &net.TCPAddr{IP: net.ParseIP("172.16.0.11"), Port: 22}, drop); err != nil {
t.Fatalf("pin drop: %v", err)
}
if err := RemoveKnownHosts(path, "172.16.0.11"); err != nil {
t.Fatalf("RemoveKnownHosts: %v", err)
}
data, _ := os.ReadFile(path)
content := string(data)
if !strings.Contains(content, "172.16.0.10") {
t.Errorf("kept entry missing:\n%s", content)
}
if strings.Contains(content, "172.16.0.11") {
t.Errorf("dropped entry still present:\n%s", content)
}
}
func TestRemoveKnownHostsMissingFileIsNoOp(t *testing.T) {
t.Parallel()
missing := filepath.Join(t.TempDir(), "absent")
if err := RemoveKnownHosts(missing, "any"); err != nil {
t.Fatalf("RemoveKnownHosts on missing: %v", err)
}
}
func TestRemoveKnownHostsEmptyPathIsNoOp(t *testing.T) {
t.Parallel()
if err := RemoveKnownHosts("", "any"); err != nil {
t.Fatalf("RemoveKnownHosts(empty): %v", err)
}
}
// TestTOFURewritesAllowsReuseAfterRemove: after a VM is deleted and
// its pin is cleared, a future VM reusing the same IP (with a fresh
// host key) should re-pin cleanly, not fail the mismatch branch.
func TestTOFURewritesAllowsReuseAfterRemove(t *testing.T) {
t.Parallel()
path := filepath.Join(t.TempDir(), "known_hosts")
cb := TOFUHostKeyCallback(path)
addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.15"), Port: 22}
original := makeTestHostKey(t)
if err := cb("172.16.0.15:22", addr, original); err != nil {
t.Fatalf("pin original: %v", err)
}
// VM deleted → pin removed.
if err := RemoveKnownHosts(path, "172.16.0.15"); err != nil {
t.Fatalf("RemoveKnownHosts: %v", err)
}
// New VM, same IP, new host key. Must re-pin without error.
replacement := makeTestHostKey(t)
if err := cb("172.16.0.15:22", addr, replacement); err != nil {
t.Fatalf("re-pin after remove: %v", err)
}
}
func TestHostLookupKeyStripsPort(t *testing.T) {
t.Parallel()
if got := hostLookupKey("10.0.0.1:22", nil); got != "10.0.0.1" {
t.Errorf("got %q, want 10.0.0.1", got)
}
if got := hostLookupKey("host.vm", nil); got != "host.vm" {
t.Errorf("got %q, want host.vm", got)
}
addr := &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 22}
if got := hostLookupKey("", addr); got != "1.2.3.4" {
t.Errorf("fallback: got %q, want 1.2.3.4", got)
}
}

View file

@ -35,12 +35,15 @@ type StreamSession struct {
closeOnce sync.Once
}
func WaitForSSH(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
// WaitForSSH polls Dial until it succeeds or ctx cancels. The
// knownHostsPath argument is the banger-owned TOFU file; empty
// disables host-key verification (tests only).
func WaitForSSH(ctx context.Context, address, privateKeyPath, knownHostsPath string, interval time.Duration) error {
if interval <= 0 {
interval = time.Second
}
for {
client, err := Dial(ctx, address, privateKeyPath)
client, err := Dial(ctx, address, privateKeyPath, knownHostsPath)
if err == nil {
_ = client.Close()
return nil
@ -53,7 +56,11 @@ func WaitForSSH(ctx context.Context, address, privateKeyPath string, interval ti
}
}
func Dial(ctx context.Context, address, privateKeyPath string) (*Client, error) {
// Dial opens an SSH client to address, authenticating with the key
// at privateKeyPath and verifying the remote host key against the
// TOFU known_hosts file at knownHostsPath. An empty knownHostsPath
// disables verification (tests / one-shot tools only).
func Dial(ctx context.Context, address, privateKeyPath, knownHostsPath string) (*Client, error) {
signer, err := privateKeySigner(privateKeyPath)
if err != nil {
return nil, err
@ -61,7 +68,7 @@ func Dial(ctx context.Context, address, privateKeyPath string) (*Client, error)
config := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
HostKeyCallback: TOFUHostKeyCallback(knownHostsPath),
Timeout: 10 * time.Second,
}
dialer := &net.Dialer{Timeout: 10 * time.Second}

View file

@ -271,7 +271,7 @@ func TestWaitForSSHContextCancel(t *testing.T) {
defer cancel()
start := time.Now()
err := WaitForSSH(ctx, freeAddr(t), keyPath, 10*time.Millisecond)
err := WaitForSSH(ctx, freeAddr(t), keyPath, "", 10*time.Millisecond)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("err = %v, want context.DeadlineExceeded", err)
}
@ -286,7 +286,7 @@ func TestDialReturnsErrorForBadKey(t *testing.T) {
if err := os.WriteFile(keyPath, []byte("nope"), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
_, err := Dial(context.Background(), freeAddr(t), keyPath)
_, err := Dial(context.Background(), freeAddr(t), keyPath, "")
if err == nil {
t.Fatal("expected error for bad key")
}

View file

@ -9,21 +9,23 @@ import (
)
type Layout struct {
ConfigHome string
StateHome string
CacheHome string
RuntimeHome string
ConfigDir string
StateDir string
CacheDir string
RuntimeDir string
SocketPath string
DBPath string
DaemonLog string
VMsDir string
ImagesDir string
KernelsDir string
OCICacheDir string
ConfigHome string
StateHome string
CacheHome string
RuntimeHome string
ConfigDir string
StateDir string
CacheDir string
RuntimeDir string
SocketPath string
DBPath string
DaemonLog string
VMsDir string
ImagesDir string
KernelsDir string
OCICacheDir string
SSHDir string
KnownHostsPath string
}
func Resolve() (Layout, error) {
@ -56,6 +58,8 @@ func Resolve() (Layout, error) {
layout.ImagesDir = filepath.Join(layout.StateDir, "images")
layout.KernelsDir = filepath.Join(layout.StateDir, "kernels")
layout.OCICacheDir = filepath.Join(layout.CacheDir, "oci")
layout.SSHDir = filepath.Join(layout.StateDir, "ssh")
layout.KnownHostsPath = filepath.Join(layout.SSHDir, "known_hosts")
return layout, nil
}
@ -65,6 +69,15 @@ func Ensure(layout Layout) error {
return err
}
}
// SSH material (private key, known_hosts) — 0700 like ~/.ssh so
// strict SSH clients don't complain and no other host user can
// read it. Empty SSHDir means the caller built a Layout by hand
// (tests) and doesn't need the subdir; skip silently.
if strings.TrimSpace(layout.SSHDir) != "" {
if err := os.MkdirAll(layout.SSHDir, 0o700); err != nil {
return err
}
}
return nil
}