ssh: trust-on-first-use host key pinning everywhere
Guest host-key verification was off in all three SSH paths:
* Go SSH (internal/guest/ssh.go) used ssh.InsecureIgnoreHostKey
* `banger vm ssh` passed StrictHostKeyChecking=no
+ UserKnownHostsFile=/dev/null
* `~/.ssh/config` Host *.vm shipped the same posture into the
user's global config
Now each path verifies against a banger-owned known_hosts file at
`~/.local/state/banger/ssh/known_hosts` with TOFU semantics:
* First dial to a VM pins the key.
* Subsequent dials require an exact match. A mismatch fails with
an explicit "possible MITM" error.
* `vm delete` removes the entries so a future VM reusing the IP
or name re-pins cleanly.
* The user's `~/.ssh/known_hosts` is untouched.
Changes:
internal/guest/known_hosts.go (new) — OpenSSH-compatible parser,
TOFUHostKeyCallback, RemoveKnownHosts. Process-wide mutex
around the file.
internal/guest/ssh.go — Dial and WaitForSSH grew a knownHostsPath
parameter threaded through the callback. Empty path keeps the
insecure callback (tests + throwaway tools only; documented).
internal/daemon/{guest_sessions,session_attach,session_lifecycle,
session_stream}.go — call sites pass d.layout.KnownHostsPath.
internal/daemon/ssh_client_config.go — the ~/.ssh/config Host *.vm
block now points at banger's known_hosts and uses
StrictHostKeyChecking=accept-new. Missing path → fail closed.
internal/daemon/vm_lifecycle.go — deleteVMLocked drops known_hosts
entries for the VM's IP and DNS name via removeVMKnownHosts.
internal/cli/banger.go — sshCommandArgs swaps StrictHostKeyChecking
no + /dev/null for banger's file + accept-new. Path resolution
failure falls through to StrictHostKeyChecking=yes.
internal/paths/paths.go — Layout gains SSHDir + KnownHostsPath;
Ensure creates SSHDir at 0700.
Tests (internal/guest/known_hosts_test.go): pin on first use, accept
matching key on second dial, reject mismatch, empty path skips
checking, RemoveKnownHosts drops the entry, re-pin works after
remove. Existing daemon + cli tests updated to assert the new
posture and regression-guard against the old flags.
Live verified: vm run writes the pin to banger's known_hosts at 0600
inside a 0700 dir; banger vm ssh + ssh root@<vm>.vm both succeed
using the pin; vm delete clears it.
This commit is contained in:
parent
a59958d4f5
commit
ae14b9499d
14 changed files with 634 additions and 47 deletions
|
|
@ -131,10 +131,12 @@ var (
|
|||
return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params)
|
||||
}
|
||||
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
|
||||
return guest.WaitForSSH(ctx, address, privateKeyPath, interval)
|
||||
knownHosts, _ := bangerKnownHostsPath()
|
||||
return guest.WaitForSSH(ctx, address, privateKeyPath, knownHosts, interval)
|
||||
}
|
||||
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
|
||||
return guest.Dial(ctx, address, privateKeyPath)
|
||||
knownHosts, _ := bangerKnownHostsPath()
|
||||
return guest.Dial(ctx, address, privateKeyPath, knownHosts)
|
||||
}
|
||||
prepareVMRunRepoCopyFunc = prepareVMRunRepoCopy
|
||||
buildVMRunToolingPlanFunc = toolingplan.Build
|
||||
|
|
@ -2669,6 +2671,12 @@ func sshCommandArgs(cfg model.DaemonConfig, guestIP string, extra []string) ([]s
|
|||
if cfg.SSHKeyPath != "" {
|
||||
args = append(args, "-i", cfg.SSHKeyPath)
|
||||
}
|
||||
// Host-key verification uses a banger-owned known_hosts file
|
||||
// populated by the daemon's first successful Go-SSH dial to each
|
||||
// VM (trust-on-first-use). `accept-new` means: accept-and-pin on
|
||||
// first contact; strict-verify afterwards. The user's own
|
||||
// ~/.known_hosts is untouched.
|
||||
knownHosts, khErr := bangerKnownHostsPath()
|
||||
args = append(
|
||||
args,
|
||||
"-o", "IdentitiesOnly=yes",
|
||||
|
|
@ -2676,14 +2684,36 @@ func sshCommandArgs(cfg model.DaemonConfig, guestIP string, extra []string) ([]s
|
|||
"-o", "PreferredAuthentications=publickey",
|
||||
"-o", "PasswordAuthentication=no",
|
||||
"-o", "KbdInteractiveAuthentication=no",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"root@"+guestIP,
|
||||
)
|
||||
if khErr == nil {
|
||||
args = append(args,
|
||||
"-o", "UserKnownHostsFile="+knownHosts,
|
||||
"-o", "StrictHostKeyChecking=accept-new",
|
||||
)
|
||||
} else {
|
||||
// If we can't resolve the banger path (unusual — paths.Resolve
|
||||
// basically can't fail), fall through to a hard-fail posture
|
||||
// rather than silently disabling verification.
|
||||
args = append(args,
|
||||
"-o", "StrictHostKeyChecking=yes",
|
||||
)
|
||||
}
|
||||
args = append(args, "root@"+guestIP)
|
||||
args = append(args, extra...)
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// bangerKnownHostsPath resolves the TOFU file the daemon writes into
|
||||
// and the CLI reads back. Both sides must agree on the path or the
|
||||
// pin doesn't round-trip.
|
||||
func bangerKnownHostsPath() (string, error) {
|
||||
layout, err := paths.Resolve()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return layout.KnownHostsPath, nil
|
||||
}
|
||||
|
||||
func validateSSHPrereqs(cfg model.DaemonConfig) error {
|
||||
checks := system.NewPreflight()
|
||||
checks.RequireCommand("ssh", "install openssh-client")
|
||||
|
|
|
|||
|
|
@ -1049,25 +1049,57 @@ func TestExecuteVMActionBatchRunsConcurrentlyAndPreservesOrder(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSSHCommandArgs(t *testing.T) {
|
||||
// sshCommandArgs wires banger's own known_hosts into the shell
|
||||
// SSH invocation — never /dev/null. Assert the shape and the
|
||||
// posture rather than the exact path (which is host-XDG-derived).
|
||||
args, err := sshCommandArgs(model.DaemonConfig{SSHKeyPath: "/bundle/id_ed25519"}, "172.16.0.2", []string{"--", "uname", "-a"})
|
||||
if err != nil {
|
||||
t.Fatalf("sshCommandArgs: %v", err)
|
||||
}
|
||||
want := []string{
|
||||
|
||||
wantSubstrings := []string{
|
||||
"-F", "/dev/null",
|
||||
"-i", "/bundle/id_ed25519",
|
||||
"-o", "IdentitiesOnly=yes",
|
||||
"-o", "BatchMode=yes",
|
||||
"-o", "PreferredAuthentications=publickey",
|
||||
"-o", "PasswordAuthentication=no",
|
||||
"-o", "KbdInteractiveAuthentication=no",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"root@172.16.0.2",
|
||||
"--", "uname", "-a",
|
||||
}
|
||||
if !reflect.DeepEqual(args, want) {
|
||||
t.Fatalf("args = %v, want %v", args, want)
|
||||
for _, s := range wantSubstrings {
|
||||
found := false
|
||||
for _, a := range args {
|
||||
if a == s {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("args missing %q: %v", s, args)
|
||||
}
|
||||
}
|
||||
|
||||
// Host-key verification posture: accept-new + a real path into
|
||||
// banger state, not /dev/null.
|
||||
joined := strings.Join(args, " ")
|
||||
if !strings.Contains(joined, "StrictHostKeyChecking=accept-new") {
|
||||
t.Errorf("args missing accept-new posture: %v", args)
|
||||
}
|
||||
if strings.Contains(joined, "UserKnownHostsFile=/dev/null") {
|
||||
t.Errorf("args leaked UserKnownHostsFile=/dev/null: %v", args)
|
||||
}
|
||||
if strings.Contains(joined, "StrictHostKeyChecking=no") {
|
||||
t.Errorf("args leaked StrictHostKeyChecking=no: %v", args)
|
||||
}
|
||||
// Must reference a known_hosts file ending in "known_hosts".
|
||||
sawKnownHosts := false
|
||||
for _, a := range args {
|
||||
if strings.HasPrefix(a, "UserKnownHostsFile=") && strings.HasSuffix(a, "known_hosts") {
|
||||
sawKnownHosts = true
|
||||
}
|
||||
}
|
||||
if !sawKnownHosts {
|
||||
t.Errorf("args missing UserKnownHostsFile=<banger known_hosts>: %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -31,14 +31,14 @@ func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval t
|
|||
if d != nil && d.guestWaitForSSH != nil {
|
||||
return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
|
||||
}
|
||||
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
|
||||
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval)
|
||||
}
|
||||
|
||||
func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) {
|
||||
if d != nil && d.guestDial != nil {
|
||||
return d.guestDial(ctx, address, d.config.SSHKeyPath)
|
||||
}
|
||||
return guest.Dial(ctx, address, d.config.SSHKeyPath)
|
||||
return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
}
|
||||
|
||||
func (d *Daemon) waitForGuestSessionReadyHook(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
|
||||
|
|
@ -86,7 +86,7 @@ func (d *Daemon) refreshGuestSession(ctx context.Context, vm model.VMRecord, s m
|
|||
|
||||
func (d *Daemon) inspectGuestSessionState(ctx context.Context, vm model.VMRecord, s model.GuestSession) (session.StateSnapshot, error) {
|
||||
if d.vmAlive(vm) {
|
||||
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath)
|
||||
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
if err != nil {
|
||||
return session.StateSnapshot{}, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -189,7 +189,7 @@ func (d *Daemon) attachGuestSessionBridge(session model.GuestSession, controller
|
|||
}
|
||||
|
||||
func (d *Daemon) openGuestSessionAttachStream(address, command string) (*guest.StreamSession, error) {
|
||||
client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath)
|
||||
client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -195,7 +195,7 @@ func (d *Daemon) signalGuestSession(ctx context.Context, params api.GuestSession
|
|||
}
|
||||
return session, nil
|
||||
}
|
||||
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath)
|
||||
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ func (d *Daemon) SendToGuestSession(ctx context.Context, params api.GuestSession
|
|||
|
||||
func (d *Daemon) readGuestSessionLog(ctx context.Context, vm model.VMRecord, session model.GuestSession, stream string, tailLines int) (string, error) {
|
||||
if d.vmAlive(vm) {
|
||||
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath)
|
||||
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,13 +2,41 @@ package daemon
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"banger/internal/guest"
|
||||
"banger/internal/model"
|
||||
"banger/internal/paths"
|
||||
)
|
||||
|
||||
// removeVMKnownHosts drops every host-key pin for vm from the
|
||||
// banger-owned known_hosts. Best-effort — a failure here only
|
||||
// matters if the same IP/name is reused by a fresh VM before the
|
||||
// next daemon restart, and even then it just causes a
|
||||
// TOFU-mismatch error that the user can clear manually. Logged at
|
||||
// warn so it shows up if it ever actually breaks things.
|
||||
func removeVMKnownHosts(knownHostsPath string, vm model.VMRecord, logger *slog.Logger) {
|
||||
if strings.TrimSpace(knownHostsPath) == "" {
|
||||
return
|
||||
}
|
||||
var hosts []string
|
||||
if ip := strings.TrimSpace(vm.Runtime.GuestIP); ip != "" {
|
||||
hosts = append(hosts, ip)
|
||||
}
|
||||
if dns := strings.TrimSpace(vm.Runtime.DNSName); dns != "" {
|
||||
hosts = append(hosts, dns)
|
||||
}
|
||||
if len(hosts) == 0 {
|
||||
return
|
||||
}
|
||||
if err := guest.RemoveKnownHosts(knownHostsPath, hosts...); err != nil && logger != nil {
|
||||
logger.Warn("remove known_hosts entries", "vm_id", vm.ID, "error", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
vmSSHConfigIncludeBegin = "# BEGIN BANGER MANAGED VM SSH"
|
||||
vmSSHConfigIncludeEnd = "# END BANGER MANAGED VM SSH"
|
||||
|
|
@ -39,7 +67,7 @@ func syncVMSSHClientConfig(layout paths.Layout, keyPath string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updated, err := upsertManagedBlock(userConfig, vmSSHConfigIncludeBegin, vmSSHConfigIncludeEnd, renderManagedVMSSHBlock(keyPath))
|
||||
updated, err := upsertManagedBlock(userConfig, vmSSHConfigIncludeBegin, vmSSHConfigIncludeEnd, renderManagedVMSSHBlock(keyPath, layout.KnownHostsPath))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -54,11 +82,19 @@ func syncVMSSHClientConfig(layout paths.Layout, keyPath string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func renderManagedVMSSHBlock(keyPath string) string {
|
||||
// renderManagedVMSSHBlock produces the `Host *.vm` stanza banger
|
||||
// writes into the user's ~/.ssh/config. Host-key verification uses
|
||||
// the banger-owned known_hosts file at knownHostsPath — NOT the
|
||||
// user's ~/.ssh/known_hosts, and NOT /dev/null. `accept-new` means
|
||||
// first contact pins the key; any later mismatch fails the connect.
|
||||
func renderManagedVMSSHBlock(keyPath, knownHostsPath string) string {
|
||||
keyPath = strings.TrimSpace(keyPath)
|
||||
return strings.Join([]string{
|
||||
knownHostsPath = strings.TrimSpace(knownHostsPath)
|
||||
lines := []string{
|
||||
vmSSHConfigIncludeBegin,
|
||||
"# Generated by banger for direct SSH access to VM DNS names.",
|
||||
"# Host keys are pinned on first use into a banger-owned",
|
||||
"# known_hosts file (not ~/.ssh/known_hosts).",
|
||||
"Host *.vm",
|
||||
" User root",
|
||||
" IdentityFile " + keyPath,
|
||||
|
|
@ -67,12 +103,23 @@ func renderManagedVMSSHBlock(keyPath string) string {
|
|||
" PreferredAuthentications publickey",
|
||||
" PasswordAuthentication no",
|
||||
" KbdInteractiveAuthentication no",
|
||||
" StrictHostKeyChecking no",
|
||||
" UserKnownHostsFile /dev/null",
|
||||
}
|
||||
if knownHostsPath != "" {
|
||||
lines = append(lines,
|
||||
" UserKnownHostsFile "+knownHostsPath,
|
||||
" StrictHostKeyChecking accept-new",
|
||||
)
|
||||
} else {
|
||||
// Missing known_hosts path is a configuration anomaly — fail
|
||||
// closed rather than silently disable verification.
|
||||
lines = append(lines, " StrictHostKeyChecking yes")
|
||||
}
|
||||
lines = append(lines,
|
||||
" LogLevel ERROR",
|
||||
vmSSHConfigIncludeEnd,
|
||||
"",
|
||||
}, "\n")
|
||||
)
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func upsertManagedBlock(existing, beginMarker, endMarker, block string) (string, error) {
|
||||
|
|
|
|||
|
|
@ -13,8 +13,10 @@ func TestSyncVMSSHClientConfigCreatesManagedBlock(t *testing.T) {
|
|||
homeDir := t.TempDir()
|
||||
t.Setenv("HOME", homeDir)
|
||||
|
||||
knownHostsPath := filepath.Join(homeDir, ".local", "state", "banger", "ssh", "known_hosts")
|
||||
layout := paths.Layout{
|
||||
ConfigDir: filepath.Join(homeDir, ".config", "banger"),
|
||||
ConfigDir: filepath.Join(homeDir, ".config", "banger"),
|
||||
KnownHostsPath: knownHostsPath,
|
||||
}
|
||||
keyPath := filepath.Join(layout.ConfigDir, "ssh", "id_ed25519")
|
||||
|
||||
|
|
@ -38,12 +40,23 @@ func TestSyncVMSSHClientConfigCreatesManagedBlock(t *testing.T) {
|
|||
"IdentitiesOnly yes",
|
||||
"BatchMode yes",
|
||||
"PasswordAuthentication no",
|
||||
"UserKnownHostsFile /dev/null",
|
||||
"UserKnownHostsFile " + knownHostsPath,
|
||||
"StrictHostKeyChecking accept-new",
|
||||
} {
|
||||
if !strings.Contains(userContent, want) {
|
||||
t.Fatalf("user config = %q, want %q", userContent, want)
|
||||
}
|
||||
}
|
||||
// Regression: the legacy posture (StrictHostKeyChecking no +
|
||||
// UserKnownHostsFile /dev/null) must never reappear.
|
||||
for _, must := range []string{
|
||||
"StrictHostKeyChecking no",
|
||||
"UserKnownHostsFile /dev/null",
|
||||
} {
|
||||
if strings.Contains(userContent, must) {
|
||||
t.Fatalf("user config leaked legacy posture %q:\n%s", must, userContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncVMSSHClientConfigReplacesManagedIncludeBlock(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -411,5 +411,9 @@ func (d *Daemon) deleteVMLocked(ctx context.Context, current model.VMRecord) (vm
|
|||
return model.VMRecord{}, err
|
||||
}
|
||||
}
|
||||
// Drop any host-key pins. A future VM reusing this IP or name
|
||||
// would otherwise trip the TOFU mismatch branch in
|
||||
// TOFUHostKeyCallback and fail to connect.
|
||||
removeVMKnownHosts(d.layout.KnownHostsPath, vm, d.logger)
|
||||
return vm, nil
|
||||
}
|
||||
|
|
|
|||
256
internal/guest/known_hosts.go
Normal file
256
internal/guest/known_hosts.go
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
package guest
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// TOFUHostKeyCallback returns a HostKeyCallback that implements
|
||||
// trust-on-first-use against a banger-owned known_hosts file.
|
||||
//
|
||||
// Semantics:
|
||||
// - If the file has an entry for `host:port` → require an exact
|
||||
// key match; a mismatch returns an error (MITM protection).
|
||||
// - If no entry exists → append one and accept.
|
||||
//
|
||||
// The file format is compatible with OpenSSH so shell SSH clients can
|
||||
// use the same path via `UserKnownHostsFile`.
|
||||
//
|
||||
// Callers keep a process-wide mutex on the file so concurrent dials
|
||||
// to different VMs don't interleave writes.
|
||||
//
|
||||
// An empty path disables host-key checking entirely — only for test
|
||||
// harnesses and tools that dial ad-hoc infrastructure; production
|
||||
// paths must supply a real file.
|
||||
func TOFUHostKeyCallback(path string) ssh.HostKeyCallback {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return ssh.InsecureIgnoreHostKey()
|
||||
}
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
host := hostLookupKey(hostname, remote)
|
||||
knownHostsMu.Lock()
|
||||
defer knownHostsMu.Unlock()
|
||||
|
||||
entries, err := loadKnownHosts(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read known_hosts: %w", err)
|
||||
}
|
||||
stored, matched := entries.match(host, key.Type())
|
||||
if matched {
|
||||
if keysEqual(stored.key, key) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("banger: host key for %s does not match pinned entry — "+
|
||||
"possible MITM. If the VM was legitimately rebuilt, remove the old "+
|
||||
"entry from %s and retry.", host, path)
|
||||
}
|
||||
if err := appendKnownHost(path, host, key); err != nil {
|
||||
return fmt.Errorf("pin host key for %s: %w", host, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveKnownHosts strips every entry matching any host in `hosts`
|
||||
// from the known_hosts file. Called on VM delete so a future VM
|
||||
// reusing the same IP or name never trips the TOFU mismatch branch.
|
||||
// Missing file / missing hosts = no-op.
|
||||
func RemoveKnownHosts(path string, hosts ...string) error {
|
||||
if strings.TrimSpace(path) == "" || len(hosts) == 0 {
|
||||
return nil
|
||||
}
|
||||
knownHostsMu.Lock()
|
||||
defer knownHostsMu.Unlock()
|
||||
|
||||
entries, err := loadKnownHosts(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
drop := make(map[string]struct{}, len(hosts))
|
||||
for _, h := range hosts {
|
||||
h = strings.TrimSpace(h)
|
||||
if h == "" {
|
||||
continue
|
||||
}
|
||||
drop[h] = struct{}{}
|
||||
}
|
||||
if len(drop) == 0 {
|
||||
return nil
|
||||
}
|
||||
filtered := entries.filter(func(e knownHostEntry) bool {
|
||||
for _, h := range e.hosts {
|
||||
if _, skip := drop[h]; skip {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return filtered.write(path)
|
||||
}
|
||||
|
||||
var knownHostsMu sync.Mutex
|
||||
|
||||
// knownHostEntry is one line in known_hosts: a set of host patterns
|
||||
// (comma-separated in the file), a key type, and a key blob.
|
||||
type knownHostEntry struct {
|
||||
hosts []string
|
||||
keyType string
|
||||
key ssh.PublicKey
|
||||
raw string
|
||||
}
|
||||
|
||||
type knownHostList []knownHostEntry
|
||||
|
||||
func (l knownHostList) match(host, keyType string) (knownHostEntry, bool) {
|
||||
for _, e := range l {
|
||||
if e.keyType != keyType {
|
||||
continue
|
||||
}
|
||||
for _, h := range e.hosts {
|
||||
if h == host {
|
||||
return e, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return knownHostEntry{}, false
|
||||
}
|
||||
|
||||
func (l knownHostList) filter(keep func(knownHostEntry) bool) knownHostList {
|
||||
out := make(knownHostList, 0, len(l))
|
||||
for _, e := range l {
|
||||
if keep(e) {
|
||||
out = append(out, e)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (l knownHostList) write(path string) error {
|
||||
if len(l) == 0 {
|
||||
// If everything got filtered, truncate the file rather than
|
||||
// removing it — callers may want the file to keep existing
|
||||
// (with 0600 perms) for later appends.
|
||||
return os.WriteFile(path, nil, 0o600)
|
||||
}
|
||||
var buf strings.Builder
|
||||
for _, e := range l {
|
||||
buf.WriteString(e.raw)
|
||||
if !strings.HasSuffix(e.raw, "\n") {
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
return os.WriteFile(path, []byte(buf.String()), 0o600)
|
||||
}
|
||||
|
||||
func loadKnownHosts(path string) (knownHostList, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var out knownHostList
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(trimmed)
|
||||
if len(fields) < 3 {
|
||||
continue
|
||||
}
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(fields[2])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
key, err := ssh.ParsePublicKey(keyBytes)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, knownHostEntry{
|
||||
hosts: strings.Split(fields[0], ","),
|
||||
keyType: fields[1],
|
||||
key: key,
|
||||
raw: line,
|
||||
})
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func appendKnownHost(path, host string, key ssh.PublicKey) error {
|
||||
line := fmt.Sprintf("%s %s %s\n",
|
||||
host,
|
||||
key.Type(),
|
||||
base64.StdEncoding.EncodeToString(key.Marshal()),
|
||||
)
|
||||
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
_, err = f.WriteString(line)
|
||||
return err
|
||||
}
|
||||
|
||||
// hostLookupKey returns the canonical key under which we store host
|
||||
// entries. For a TCP dial the SSH library hands us hostname of the
|
||||
// form "host:port"; we normalise to "host" so pinning by IP also
|
||||
// works for a hostname-based lookup that resolves to the same IP.
|
||||
//
|
||||
// If hostname contains a port, strip it. If it's empty, fall back to
|
||||
// the remote address.
|
||||
func hostLookupKey(hostname string, remote net.Addr) string {
|
||||
if h, _, err := net.SplitHostPort(hostname); err == nil {
|
||||
hostname = h
|
||||
}
|
||||
if strings.TrimSpace(hostname) != "" {
|
||||
return hostname
|
||||
}
|
||||
if remote != nil {
|
||||
if h, _, err := net.SplitHostPort(remote.String()); err == nil {
|
||||
return h
|
||||
}
|
||||
return remote.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func keysEqual(a, b ssh.PublicKey) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == nil && b == nil
|
||||
}
|
||||
ba := a.Marshal()
|
||||
bb := b.Marshal()
|
||||
if len(ba) != len(bb) {
|
||||
return false
|
||||
}
|
||||
for i := range ba {
|
||||
if ba[i] != bb[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// errHostKeyMismatch sentinel is currently unused but reserved for
|
||||
// callers that want to distinguish MITM from other failures.
|
||||
var errHostKeyMismatch = errors.New("host key mismatch")
|
||||
|
||||
var _ = errHostKeyMismatch
|
||||
185
internal/guest/known_hosts_test.go
Normal file
185
internal/guest/known_hosts_test.go
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
package guest
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// makeTestHostKey generates a fresh ed25519 key and returns the
|
||||
// ssh.PublicKey the server would present during a handshake.
|
||||
func makeTestHostKey(t *testing.T) ssh.PublicKey {
|
||||
t.Helper()
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey: %v", err)
|
||||
}
|
||||
sshPub, err := ssh.NewPublicKey(pub)
|
||||
if err != nil {
|
||||
t.Fatalf("NewPublicKey: %v", err)
|
||||
}
|
||||
return sshPub
|
||||
}
|
||||
|
||||
func TestTOFUHostKeyCallbackPinsOnFirstUse(t *testing.T) {
|
||||
t.Parallel()
|
||||
path := filepath.Join(t.TempDir(), "known_hosts")
|
||||
cb := TOFUHostKeyCallback(path)
|
||||
|
||||
key := makeTestHostKey(t)
|
||||
addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.5"), Port: 22}
|
||||
|
||||
if err := cb("172.16.0.5:22", addr, key); err != nil {
|
||||
t.Fatalf("first-use callback: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile: %v", err)
|
||||
}
|
||||
content := string(data)
|
||||
if !strings.Contains(content, "172.16.0.5") {
|
||||
t.Errorf("known_hosts missing host:\n%s", content)
|
||||
}
|
||||
if !strings.Contains(content, key.Type()) {
|
||||
t.Errorf("known_hosts missing key type:\n%s", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOFUHostKeyCallbackAcceptsMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
path := filepath.Join(t.TempDir(), "known_hosts")
|
||||
cb := TOFUHostKeyCallback(path)
|
||||
key := makeTestHostKey(t)
|
||||
addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.6"), Port: 22}
|
||||
|
||||
if err := cb("172.16.0.6:22", addr, key); err != nil {
|
||||
t.Fatalf("first-use: %v", err)
|
||||
}
|
||||
// Same key, second dial: must succeed.
|
||||
if err := cb("172.16.0.6:22", addr, key); err != nil {
|
||||
t.Fatalf("second dial with matching key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOFUHostKeyCallbackRejectsMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
path := filepath.Join(t.TempDir(), "known_hosts")
|
||||
cb := TOFUHostKeyCallback(path)
|
||||
addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.7"), Port: 22}
|
||||
|
||||
original := makeTestHostKey(t)
|
||||
if err := cb("172.16.0.7:22", addr, original); err != nil {
|
||||
t.Fatalf("pin original: %v", err)
|
||||
}
|
||||
|
||||
impostor := makeTestHostKey(t)
|
||||
err := cb("172.16.0.7:22", addr, impostor)
|
||||
if err == nil {
|
||||
t.Fatal("expected mismatch error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "does not match") {
|
||||
t.Errorf("error = %v, want message about mismatch", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOFUEmptyPathDisablesVerification(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Empty path returns an Insecure callback — useful for tests /
|
||||
// throwaway tools. Document behaviour so the fallback doesn't
|
||||
// silently regress to "always verify but without a file".
|
||||
cb := TOFUHostKeyCallback("")
|
||||
addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}
|
||||
if err := cb("127.0.0.1:22", addr, makeTestHostKey(t)); err != nil {
|
||||
t.Fatalf("empty-path callback should accept: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveKnownHostsDropsEntry(t *testing.T) {
|
||||
t.Parallel()
|
||||
path := filepath.Join(t.TempDir(), "known_hosts")
|
||||
cb := TOFUHostKeyCallback(path)
|
||||
keep := makeTestHostKey(t)
|
||||
drop := makeTestHostKey(t)
|
||||
|
||||
if err := cb("172.16.0.10:22", &net.TCPAddr{IP: net.ParseIP("172.16.0.10"), Port: 22}, keep); err != nil {
|
||||
t.Fatalf("pin keep: %v", err)
|
||||
}
|
||||
if err := cb("172.16.0.11:22", &net.TCPAddr{IP: net.ParseIP("172.16.0.11"), Port: 22}, drop); err != nil {
|
||||
t.Fatalf("pin drop: %v", err)
|
||||
}
|
||||
|
||||
if err := RemoveKnownHosts(path, "172.16.0.11"); err != nil {
|
||||
t.Fatalf("RemoveKnownHosts: %v", err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(path)
|
||||
content := string(data)
|
||||
if !strings.Contains(content, "172.16.0.10") {
|
||||
t.Errorf("kept entry missing:\n%s", content)
|
||||
}
|
||||
if strings.Contains(content, "172.16.0.11") {
|
||||
t.Errorf("dropped entry still present:\n%s", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveKnownHostsMissingFileIsNoOp(t *testing.T) {
|
||||
t.Parallel()
|
||||
missing := filepath.Join(t.TempDir(), "absent")
|
||||
if err := RemoveKnownHosts(missing, "any"); err != nil {
|
||||
t.Fatalf("RemoveKnownHosts on missing: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveKnownHostsEmptyPathIsNoOp(t *testing.T) {
|
||||
t.Parallel()
|
||||
if err := RemoveKnownHosts("", "any"); err != nil {
|
||||
t.Fatalf("RemoveKnownHosts(empty): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTOFURewritesAllowsReuseAfterRemove: after a VM is deleted and
|
||||
// its pin is cleared, a future VM reusing the same IP (with a fresh
|
||||
// host key) should re-pin cleanly, not fail the mismatch branch.
|
||||
func TestTOFURewritesAllowsReuseAfterRemove(t *testing.T) {
|
||||
t.Parallel()
|
||||
path := filepath.Join(t.TempDir(), "known_hosts")
|
||||
cb := TOFUHostKeyCallback(path)
|
||||
addr := &net.TCPAddr{IP: net.ParseIP("172.16.0.15"), Port: 22}
|
||||
|
||||
original := makeTestHostKey(t)
|
||||
if err := cb("172.16.0.15:22", addr, original); err != nil {
|
||||
t.Fatalf("pin original: %v", err)
|
||||
}
|
||||
|
||||
// VM deleted → pin removed.
|
||||
if err := RemoveKnownHosts(path, "172.16.0.15"); err != nil {
|
||||
t.Fatalf("RemoveKnownHosts: %v", err)
|
||||
}
|
||||
|
||||
// New VM, same IP, new host key. Must re-pin without error.
|
||||
replacement := makeTestHostKey(t)
|
||||
if err := cb("172.16.0.15:22", addr, replacement); err != nil {
|
||||
t.Fatalf("re-pin after remove: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostLookupKeyStripsPort(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := hostLookupKey("10.0.0.1:22", nil); got != "10.0.0.1" {
|
||||
t.Errorf("got %q, want 10.0.0.1", got)
|
||||
}
|
||||
if got := hostLookupKey("host.vm", nil); got != "host.vm" {
|
||||
t.Errorf("got %q, want host.vm", got)
|
||||
}
|
||||
addr := &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 22}
|
||||
if got := hostLookupKey("", addr); got != "1.2.3.4" {
|
||||
t.Errorf("fallback: got %q, want 1.2.3.4", got)
|
||||
}
|
||||
}
|
||||
|
|
@ -35,12 +35,15 @@ type StreamSession struct {
|
|||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func WaitForSSH(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
|
||||
// WaitForSSH polls Dial until it succeeds or ctx cancels. The
|
||||
// knownHostsPath argument is the banger-owned TOFU file; empty
|
||||
// disables host-key verification (tests only).
|
||||
func WaitForSSH(ctx context.Context, address, privateKeyPath, knownHostsPath string, interval time.Duration) error {
|
||||
if interval <= 0 {
|
||||
interval = time.Second
|
||||
}
|
||||
for {
|
||||
client, err := Dial(ctx, address, privateKeyPath)
|
||||
client, err := Dial(ctx, address, privateKeyPath, knownHostsPath)
|
||||
if err == nil {
|
||||
_ = client.Close()
|
||||
return nil
|
||||
|
|
@ -53,7 +56,11 @@ func WaitForSSH(ctx context.Context, address, privateKeyPath string, interval ti
|
|||
}
|
||||
}
|
||||
|
||||
func Dial(ctx context.Context, address, privateKeyPath string) (*Client, error) {
|
||||
// Dial opens an SSH client to address, authenticating with the key
|
||||
// at privateKeyPath and verifying the remote host key against the
|
||||
// TOFU known_hosts file at knownHostsPath. An empty knownHostsPath
|
||||
// disables verification (tests / one-shot tools only).
|
||||
func Dial(ctx context.Context, address, privateKeyPath, knownHostsPath string) (*Client, error) {
|
||||
signer, err := privateKeySigner(privateKeyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -61,7 +68,7 @@ func Dial(ctx context.Context, address, privateKeyPath string) (*Client, error)
|
|||
config := &ssh.ClientConfig{
|
||||
User: "root",
|
||||
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
HostKeyCallback: TOFUHostKeyCallback(knownHostsPath),
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||
|
|
|
|||
|
|
@ -271,7 +271,7 @@ func TestWaitForSSHContextCancel(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
err := WaitForSSH(ctx, freeAddr(t), keyPath, 10*time.Millisecond)
|
||||
err := WaitForSSH(ctx, freeAddr(t), keyPath, "", 10*time.Millisecond)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("err = %v, want context.DeadlineExceeded", err)
|
||||
}
|
||||
|
|
@ -286,7 +286,7 @@ func TestDialReturnsErrorForBadKey(t *testing.T) {
|
|||
if err := os.WriteFile(keyPath, []byte("nope"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
_, err := Dial(context.Background(), freeAddr(t), keyPath)
|
||||
_, err := Dial(context.Background(), freeAddr(t), keyPath, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for bad key")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,21 +9,23 @@ import (
|
|||
)
|
||||
|
||||
type Layout struct {
|
||||
ConfigHome string
|
||||
StateHome string
|
||||
CacheHome string
|
||||
RuntimeHome string
|
||||
ConfigDir string
|
||||
StateDir string
|
||||
CacheDir string
|
||||
RuntimeDir string
|
||||
SocketPath string
|
||||
DBPath string
|
||||
DaemonLog string
|
||||
VMsDir string
|
||||
ImagesDir string
|
||||
KernelsDir string
|
||||
OCICacheDir string
|
||||
ConfigHome string
|
||||
StateHome string
|
||||
CacheHome string
|
||||
RuntimeHome string
|
||||
ConfigDir string
|
||||
StateDir string
|
||||
CacheDir string
|
||||
RuntimeDir string
|
||||
SocketPath string
|
||||
DBPath string
|
||||
DaemonLog string
|
||||
VMsDir string
|
||||
ImagesDir string
|
||||
KernelsDir string
|
||||
OCICacheDir string
|
||||
SSHDir string
|
||||
KnownHostsPath string
|
||||
}
|
||||
|
||||
func Resolve() (Layout, error) {
|
||||
|
|
@ -56,6 +58,8 @@ func Resolve() (Layout, error) {
|
|||
layout.ImagesDir = filepath.Join(layout.StateDir, "images")
|
||||
layout.KernelsDir = filepath.Join(layout.StateDir, "kernels")
|
||||
layout.OCICacheDir = filepath.Join(layout.CacheDir, "oci")
|
||||
layout.SSHDir = filepath.Join(layout.StateDir, "ssh")
|
||||
layout.KnownHostsPath = filepath.Join(layout.SSHDir, "known_hosts")
|
||||
return layout, nil
|
||||
}
|
||||
|
||||
|
|
@ -65,6 +69,15 @@ func Ensure(layout Layout) error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
// SSH material (private key, known_hosts) — 0700 like ~/.ssh so
|
||||
// strict SSH clients don't complain and no other host user can
|
||||
// read it. Empty SSHDir means the caller built a Layout by hand
|
||||
// (tests) and doesn't need the subdir; skip silently.
|
||||
if strings.TrimSpace(layout.SSHDir) != "" {
|
||||
if err := os.MkdirAll(layout.SSHDir, 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue