Guest host-key verification was off in all three SSH paths:
* Go SSH (internal/guest/ssh.go) used ssh.InsecureIgnoreHostKey
* `banger vm ssh` passed StrictHostKeyChecking=no
+ UserKnownHostsFile=/dev/null
* `~/.ssh/config` Host *.vm shipped the same posture into the
user's global config
Now each path verifies against a banger-owned known_hosts file at
`~/.local/state/banger/ssh/known_hosts` with TOFU semantics:
* First dial to a VM pins the key.
* Subsequent dials require an exact match. A mismatch fails with
an explicit "possible MITM" error.
* `vm delete` removes the entries so a future VM reusing the IP
or name re-pins cleanly.
* The user's `~/.ssh/known_hosts` is untouched.
Changes:
internal/guest/known_hosts.go (new) — OpenSSH-compatible parser,
TOFUHostKeyCallback, RemoveKnownHosts. Process-wide mutex
around the file.
internal/guest/ssh.go — Dial and WaitForSSH grew a knownHostsPath
parameter threaded through the callback. Empty path keeps the
insecure callback (tests + throwaway tools only; documented).
internal/daemon/{guest_sessions,session_attach,session_lifecycle,
session_stream}.go — call sites pass d.layout.KnownHostsPath.
internal/daemon/ssh_client_config.go — the ~/.ssh/config Host *.vm
block now points at banger's known_hosts and uses
StrictHostKeyChecking=accept-new. Missing path → fail closed.
internal/daemon/vm_lifecycle.go — deleteVMLocked drops known_hosts
entries for the VM's IP and DNS name via removeVMKnownHosts.
internal/cli/banger.go — sshCommandArgs swaps StrictHostKeyChecking
no + /dev/null for banger's file + accept-new. Path resolution
failure falls through to StrictHostKeyChecking=yes.
internal/paths/paths.go — Layout gains SSHDir + KnownHostsPath;
Ensure creates SSHDir at 0700.
Tests (internal/guest/known_hosts_test.go): pin on first use, accept
matching key on second dial, reject mismatch, empty path skips
checking, RemoveKnownHosts drops the entry, re-pin works after
remove. Existing daemon + cli tests updated to assert the new
posture and regression-guard against the old flags.
Live verified: vm run writes the pin to banger's known_hosts at 0600
inside a 0700 dir; banger vm ssh + ssh root@<vm>.vm both succeed
using the pin; vm delete clears it.
293 lines
7.4 KiB
Go
293 lines
7.4 KiB
Go
package guest
|
|
|
|
import (
|
|
"archive/tar"
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func writeTestKey(t *testing.T) string {
|
|
t.Helper()
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
if err != nil {
|
|
t.Fatalf("GenerateKey: %v", err)
|
|
}
|
|
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
|
})
|
|
keyPath := filepath.Join(t.TempDir(), "id_rsa")
|
|
if err := os.WriteFile(keyPath, privateKeyPEM, 0o600); err != nil {
|
|
t.Fatalf("WriteFile: %v", err)
|
|
}
|
|
return keyPath
|
|
}
|
|
|
|
func TestAuthorizedPublicKeyFingerprint(t *testing.T) {
|
|
t.Parallel()
|
|
keyPath := writeTestKey(t)
|
|
|
|
fp, err := AuthorizedPublicKeyFingerprint(keyPath)
|
|
if err != nil {
|
|
t.Fatalf("AuthorizedPublicKeyFingerprint: %v", err)
|
|
}
|
|
if !regexp.MustCompile(`^[0-9a-f]{64}$`).MatchString(fp) {
|
|
t.Fatalf("fingerprint = %q, want 64 hex chars", fp)
|
|
}
|
|
|
|
fp2, err := AuthorizedPublicKeyFingerprint(keyPath)
|
|
if err != nil {
|
|
t.Fatalf("AuthorizedPublicKeyFingerprint (second): %v", err)
|
|
}
|
|
if fp != fp2 {
|
|
t.Fatalf("fingerprint not deterministic: %q vs %q", fp, fp2)
|
|
}
|
|
}
|
|
|
|
func TestAuthorizedPublicKeyFingerprintMissingFile(t *testing.T) {
|
|
t.Parallel()
|
|
_, err := AuthorizedPublicKeyFingerprint(filepath.Join(t.TempDir(), "nope"))
|
|
if err == nil {
|
|
t.Fatal("expected error for missing key file")
|
|
}
|
|
}
|
|
|
|
func TestAuthorizedPublicKeyBadPEM(t *testing.T) {
|
|
t.Parallel()
|
|
keyPath := filepath.Join(t.TempDir(), "bad")
|
|
if err := os.WriteFile(keyPath, []byte("not a private key"), 0o600); err != nil {
|
|
t.Fatalf("WriteFile: %v", err)
|
|
}
|
|
if _, err := AuthorizedPublicKey(keyPath); err == nil {
|
|
t.Fatal("expected ParsePrivateKey error")
|
|
}
|
|
}
|
|
|
|
func TestShellQuote(t *testing.T) {
|
|
t.Parallel()
|
|
cases := []struct {
|
|
in, want string
|
|
}{
|
|
{"", "''"},
|
|
{"simple", "'simple'"},
|
|
{"with space", "'with space'"},
|
|
{"it's", `'it'"'"'s'`},
|
|
{"a'b'c", `'a'"'"'b'"'"'c'`},
|
|
{"/path/to/file", "'/path/to/file'"},
|
|
}
|
|
for _, tc := range cases {
|
|
got := shellQuote(tc.in)
|
|
if got != tc.want {
|
|
t.Errorf("shellQuote(%q) = %q, want %q", tc.in, got, tc.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWriteTarEntriesArchiveRejectsEscape(t *testing.T) {
|
|
t.Parallel()
|
|
dir := t.TempDir()
|
|
var buf bytes.Buffer
|
|
err := writeTarEntriesArchive(&buf, dir, []string{"../escape"})
|
|
if err == nil {
|
|
t.Fatal("expected error for escaping entry")
|
|
}
|
|
if !strings.Contains(err.Error(), "escapes source dir") {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestWriteTarEntriesArchiveRejectsDot(t *testing.T) {
|
|
t.Parallel()
|
|
dir := t.TempDir()
|
|
var buf bytes.Buffer
|
|
for _, bad := range []string{".", ".."} {
|
|
if err := writeTarEntriesArchive(&buf, dir, []string{bad}); err == nil {
|
|
t.Errorf("expected error for entry %q", bad)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWriteTarEntriesArchiveDedupsAndSkipsBlank(t *testing.T) {
|
|
t.Parallel()
|
|
sourceDir := filepath.Join(t.TempDir(), "repo")
|
|
if err := os.MkdirAll(sourceDir, 0o755); err != nil {
|
|
t.Fatalf("MkdirAll: %v", err)
|
|
}
|
|
if err := os.WriteFile(filepath.Join(sourceDir, "a.txt"), []byte("A"), 0o644); err != nil {
|
|
t.Fatalf("WriteFile: %v", err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := writeTarEntriesArchive(&buf, sourceDir, []string{"a.txt", "a.txt", "", " "}); err != nil {
|
|
t.Fatalf("writeTarEntriesArchive: %v", err)
|
|
}
|
|
|
|
tr := tar.NewReader(&buf)
|
|
var names []string
|
|
for {
|
|
h, err := tr.Next()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("tar.Next: %v", err)
|
|
}
|
|
names = append(names, h.Name)
|
|
}
|
|
if len(names) != 1 || names[0] != "repo/a.txt" {
|
|
t.Fatalf("names = %v, want [repo/a.txt]", names)
|
|
}
|
|
}
|
|
|
|
func TestWriteTarEntriesArchiveSymlink(t *testing.T) {
|
|
t.Parallel()
|
|
sourceDir := filepath.Join(t.TempDir(), "repo")
|
|
if err := os.MkdirAll(sourceDir, 0o755); err != nil {
|
|
t.Fatalf("MkdirAll: %v", err)
|
|
}
|
|
if err := os.WriteFile(filepath.Join(sourceDir, "target.txt"), []byte("T"), 0o644); err != nil {
|
|
t.Fatalf("WriteFile: %v", err)
|
|
}
|
|
linkPath := filepath.Join(sourceDir, "link")
|
|
if err := os.Symlink("target.txt", linkPath); err != nil {
|
|
t.Skipf("symlink unsupported: %v", err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := writeTarEntriesArchive(&buf, sourceDir, []string{"link"}); err != nil {
|
|
t.Fatalf("writeTarEntriesArchive: %v", err)
|
|
}
|
|
|
|
tr := tar.NewReader(&buf)
|
|
h, err := tr.Next()
|
|
if err != nil {
|
|
t.Fatalf("tar.Next: %v", err)
|
|
}
|
|
if h.Typeflag != tar.TypeSymlink {
|
|
t.Fatalf("typeflag = %v, want TypeSymlink", h.Typeflag)
|
|
}
|
|
if h.Linkname != "target.txt" {
|
|
t.Fatalf("linkname = %q, want target.txt", h.Linkname)
|
|
}
|
|
}
|
|
|
|
func TestWriteTarEntriesArchiveMissingPath(t *testing.T) {
|
|
t.Parallel()
|
|
sourceDir := t.TempDir()
|
|
var buf bytes.Buffer
|
|
err := writeTarEntriesArchive(&buf, sourceDir, []string{"missing.txt"})
|
|
if err == nil {
|
|
t.Fatal("expected error for missing entry")
|
|
}
|
|
}
|
|
|
|
func TestStreamSessionNilSafe(t *testing.T) {
|
|
t.Parallel()
|
|
var s *StreamSession
|
|
if s.Stdin() != nil || s.Stdout() != nil || s.Stderr() != nil {
|
|
t.Fatal("nil StreamSession getters should return nil")
|
|
}
|
|
if err := s.Wait(); err != nil {
|
|
t.Fatalf("nil Wait error: %v", err)
|
|
}
|
|
if err := s.Close(); err != nil {
|
|
t.Fatalf("nil Close error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestClientNilClose(t *testing.T) {
|
|
t.Parallel()
|
|
var c *Client
|
|
if err := c.Close(); err != nil {
|
|
t.Fatalf("nil Close error: %v", err)
|
|
}
|
|
c2 := &Client{}
|
|
if err := c2.Close(); err != nil {
|
|
t.Fatalf("empty Close error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestClientRunScriptOutputNotConnected(t *testing.T) {
|
|
t.Parallel()
|
|
var c *Client
|
|
if _, err := c.RunScriptOutput(context.Background(), "true"); err == nil {
|
|
t.Fatal("expected not-connected error")
|
|
}
|
|
c2 := &Client{}
|
|
if _, err := c2.RunScriptOutput(context.Background(), "true"); err == nil {
|
|
t.Fatal("expected not-connected error")
|
|
}
|
|
}
|
|
|
|
func TestClientStartCommandNotConnected(t *testing.T) {
|
|
t.Parallel()
|
|
var c *Client
|
|
if _, err := c.StartCommand(context.Background(), "true"); err == nil {
|
|
t.Fatal("expected not-connected error")
|
|
}
|
|
}
|
|
|
|
func TestClientRunScriptNotConnected(t *testing.T) {
|
|
t.Parallel()
|
|
var c *Client
|
|
if err := c.RunScript(context.Background(), "true", io.Discard); err == nil {
|
|
t.Fatal("expected not-connected error")
|
|
}
|
|
}
|
|
|
|
// freeAddr grabs a loopback port by listening briefly, then closing. Next
|
|
// Dial to it deterministically fails with "connection refused" — no real
|
|
// server on the far end, no flakiness from random ports being taken.
|
|
func freeAddr(t *testing.T) string {
|
|
t.Helper()
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("net.Listen: %v", err)
|
|
}
|
|
addr := ln.Addr().String()
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("Close listener: %v", err)
|
|
}
|
|
return addr
|
|
}
|
|
|
|
func TestWaitForSSHContextCancel(t *testing.T) {
|
|
t.Parallel()
|
|
keyPath := writeTestKey(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
start := time.Now()
|
|
err := WaitForSSH(ctx, freeAddr(t), keyPath, "", 10*time.Millisecond)
|
|
if !errors.Is(err, context.DeadlineExceeded) {
|
|
t.Fatalf("err = %v, want context.DeadlineExceeded", err)
|
|
}
|
|
if elapsed := time.Since(start); elapsed > 2*time.Second {
|
|
t.Fatalf("took too long: %v", elapsed)
|
|
}
|
|
}
|
|
|
|
func TestDialReturnsErrorForBadKey(t *testing.T) {
|
|
t.Parallel()
|
|
keyPath := filepath.Join(t.TempDir(), "bogus")
|
|
if err := os.WriteFile(keyPath, []byte("nope"), 0o600); err != nil {
|
|
t.Fatalf("WriteFile: %v", err)
|
|
}
|
|
_, err := Dial(context.Background(), freeAddr(t), keyPath, "")
|
|
if err == nil {
|
|
t.Fatal("expected error for bad key")
|
|
}
|
|
}
|