banger/internal/guest/ssh_more_test.go
Thales Maciel ae14b9499d
ssh: trust-on-first-use host key pinning everywhere
Guest host-key verification was off in all three SSH paths:

  * Go SSH (internal/guest/ssh.go) used ssh.InsecureIgnoreHostKey
  * `banger vm ssh` passed StrictHostKeyChecking=no
    + UserKnownHostsFile=/dev/null
  * `~/.ssh/config` Host *.vm shipped the same posture into the
    user's global config

Now each path verifies against a banger-owned known_hosts file at
`~/.local/state/banger/ssh/known_hosts` with TOFU semantics:

  * First dial to a VM pins the key.
  * Subsequent dials require an exact match. A mismatch fails with
    an explicit "possible MITM" error.
  * `vm delete` removes the entries so a future VM reusing the IP
    or name re-pins cleanly.
  * The user's `~/.ssh/known_hosts` is untouched.

Changes:

  internal/guest/known_hosts.go (new) — OpenSSH-compatible parser,
    TOFUHostKeyCallback, RemoveKnownHosts. Process-wide mutex
    around the file.
  internal/guest/ssh.go — Dial and WaitForSSH grew a knownHostsPath
    parameter threaded through the callback. Empty path keeps the
    insecure callback (tests + throwaway tools only; documented).
  internal/daemon/{guest_sessions,session_attach,session_lifecycle,
    session_stream}.go — call sites pass d.layout.KnownHostsPath.
  internal/daemon/ssh_client_config.go — the ~/.ssh/config Host *.vm
    block now points at banger's known_hosts and uses
    StrictHostKeyChecking=accept-new. Missing path → fail closed.
  internal/daemon/vm_lifecycle.go — deleteVMLocked drops known_hosts
    entries for the VM's IP and DNS name via removeVMKnownHosts.
  internal/cli/banger.go — sshCommandArgs swaps StrictHostKeyChecking
    no + /dev/null for banger's file + accept-new. Path resolution
    failure falls through to StrictHostKeyChecking=yes.
  internal/paths/paths.go — Layout gains SSHDir + KnownHostsPath;
    Ensure creates SSHDir at 0700.

Tests (internal/guest/known_hosts_test.go): pin on first use, accept
matching key on second dial, reject mismatch, empty path skips
checking, RemoveKnownHosts drops the entry, re-pin works after
remove. Existing daemon + cli tests updated to assert the new
posture and regression-guard against the old flags.

Live verified: vm run writes the pin to banger's known_hosts at 0600
inside a 0700 dir; banger vm ssh + ssh root@<vm>.vm both succeed
using the pin; vm delete clears it.
2026-04-19 16:46:03 -03:00

293 lines
7.4 KiB
Go

package guest
import (
"archive/tar"
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"io"
"net"
"os"
"path/filepath"
"regexp"
"strings"
"testing"
"time"
)
func writeTestKey(t *testing.T) string {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
keyPath := filepath.Join(t.TempDir(), "id_rsa")
if err := os.WriteFile(keyPath, privateKeyPEM, 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
return keyPath
}
func TestAuthorizedPublicKeyFingerprint(t *testing.T) {
t.Parallel()
keyPath := writeTestKey(t)
fp, err := AuthorizedPublicKeyFingerprint(keyPath)
if err != nil {
t.Fatalf("AuthorizedPublicKeyFingerprint: %v", err)
}
if !regexp.MustCompile(`^[0-9a-f]{64}$`).MatchString(fp) {
t.Fatalf("fingerprint = %q, want 64 hex chars", fp)
}
fp2, err := AuthorizedPublicKeyFingerprint(keyPath)
if err != nil {
t.Fatalf("AuthorizedPublicKeyFingerprint (second): %v", err)
}
if fp != fp2 {
t.Fatalf("fingerprint not deterministic: %q vs %q", fp, fp2)
}
}
func TestAuthorizedPublicKeyFingerprintMissingFile(t *testing.T) {
t.Parallel()
_, err := AuthorizedPublicKeyFingerprint(filepath.Join(t.TempDir(), "nope"))
if err == nil {
t.Fatal("expected error for missing key file")
}
}
func TestAuthorizedPublicKeyBadPEM(t *testing.T) {
t.Parallel()
keyPath := filepath.Join(t.TempDir(), "bad")
if err := os.WriteFile(keyPath, []byte("not a private key"), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
if _, err := AuthorizedPublicKey(keyPath); err == nil {
t.Fatal("expected ParsePrivateKey error")
}
}
func TestShellQuote(t *testing.T) {
t.Parallel()
cases := []struct {
in, want string
}{
{"", "''"},
{"simple", "'simple'"},
{"with space", "'with space'"},
{"it's", `'it'"'"'s'`},
{"a'b'c", `'a'"'"'b'"'"'c'`},
{"/path/to/file", "'/path/to/file'"},
}
for _, tc := range cases {
got := shellQuote(tc.in)
if got != tc.want {
t.Errorf("shellQuote(%q) = %q, want %q", tc.in, got, tc.want)
}
}
}
func TestWriteTarEntriesArchiveRejectsEscape(t *testing.T) {
t.Parallel()
dir := t.TempDir()
var buf bytes.Buffer
err := writeTarEntriesArchive(&buf, dir, []string{"../escape"})
if err == nil {
t.Fatal("expected error for escaping entry")
}
if !strings.Contains(err.Error(), "escapes source dir") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestWriteTarEntriesArchiveRejectsDot(t *testing.T) {
t.Parallel()
dir := t.TempDir()
var buf bytes.Buffer
for _, bad := range []string{".", ".."} {
if err := writeTarEntriesArchive(&buf, dir, []string{bad}); err == nil {
t.Errorf("expected error for entry %q", bad)
}
}
}
func TestWriteTarEntriesArchiveDedupsAndSkipsBlank(t *testing.T) {
t.Parallel()
sourceDir := filepath.Join(t.TempDir(), "repo")
if err := os.MkdirAll(sourceDir, 0o755); err != nil {
t.Fatalf("MkdirAll: %v", err)
}
if err := os.WriteFile(filepath.Join(sourceDir, "a.txt"), []byte("A"), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
var buf bytes.Buffer
if err := writeTarEntriesArchive(&buf, sourceDir, []string{"a.txt", "a.txt", "", " "}); err != nil {
t.Fatalf("writeTarEntriesArchive: %v", err)
}
tr := tar.NewReader(&buf)
var names []string
for {
h, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("tar.Next: %v", err)
}
names = append(names, h.Name)
}
if len(names) != 1 || names[0] != "repo/a.txt" {
t.Fatalf("names = %v, want [repo/a.txt]", names)
}
}
func TestWriteTarEntriesArchiveSymlink(t *testing.T) {
t.Parallel()
sourceDir := filepath.Join(t.TempDir(), "repo")
if err := os.MkdirAll(sourceDir, 0o755); err != nil {
t.Fatalf("MkdirAll: %v", err)
}
if err := os.WriteFile(filepath.Join(sourceDir, "target.txt"), []byte("T"), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
linkPath := filepath.Join(sourceDir, "link")
if err := os.Symlink("target.txt", linkPath); err != nil {
t.Skipf("symlink unsupported: %v", err)
}
var buf bytes.Buffer
if err := writeTarEntriesArchive(&buf, sourceDir, []string{"link"}); err != nil {
t.Fatalf("writeTarEntriesArchive: %v", err)
}
tr := tar.NewReader(&buf)
h, err := tr.Next()
if err != nil {
t.Fatalf("tar.Next: %v", err)
}
if h.Typeflag != tar.TypeSymlink {
t.Fatalf("typeflag = %v, want TypeSymlink", h.Typeflag)
}
if h.Linkname != "target.txt" {
t.Fatalf("linkname = %q, want target.txt", h.Linkname)
}
}
func TestWriteTarEntriesArchiveMissingPath(t *testing.T) {
t.Parallel()
sourceDir := t.TempDir()
var buf bytes.Buffer
err := writeTarEntriesArchive(&buf, sourceDir, []string{"missing.txt"})
if err == nil {
t.Fatal("expected error for missing entry")
}
}
func TestStreamSessionNilSafe(t *testing.T) {
t.Parallel()
var s *StreamSession
if s.Stdin() != nil || s.Stdout() != nil || s.Stderr() != nil {
t.Fatal("nil StreamSession getters should return nil")
}
if err := s.Wait(); err != nil {
t.Fatalf("nil Wait error: %v", err)
}
if err := s.Close(); err != nil {
t.Fatalf("nil Close error: %v", err)
}
}
func TestClientNilClose(t *testing.T) {
t.Parallel()
var c *Client
if err := c.Close(); err != nil {
t.Fatalf("nil Close error: %v", err)
}
c2 := &Client{}
if err := c2.Close(); err != nil {
t.Fatalf("empty Close error: %v", err)
}
}
func TestClientRunScriptOutputNotConnected(t *testing.T) {
t.Parallel()
var c *Client
if _, err := c.RunScriptOutput(context.Background(), "true"); err == nil {
t.Fatal("expected not-connected error")
}
c2 := &Client{}
if _, err := c2.RunScriptOutput(context.Background(), "true"); err == nil {
t.Fatal("expected not-connected error")
}
}
func TestClientStartCommandNotConnected(t *testing.T) {
t.Parallel()
var c *Client
if _, err := c.StartCommand(context.Background(), "true"); err == nil {
t.Fatal("expected not-connected error")
}
}
func TestClientRunScriptNotConnected(t *testing.T) {
t.Parallel()
var c *Client
if err := c.RunScript(context.Background(), "true", io.Discard); err == nil {
t.Fatal("expected not-connected error")
}
}
// freeAddr grabs a loopback port by listening briefly, then closing. Next
// Dial to it deterministically fails with "connection refused" — no real
// server on the far end, no flakiness from random ports being taken.
func freeAddr(t *testing.T) string {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen: %v", err)
}
addr := ln.Addr().String()
if err := ln.Close(); err != nil {
t.Fatalf("Close listener: %v", err)
}
return addr
}
func TestWaitForSSHContextCancel(t *testing.T) {
t.Parallel()
keyPath := writeTestKey(t)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
err := WaitForSSH(ctx, freeAddr(t), keyPath, "", 10*time.Millisecond)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("err = %v, want context.DeadlineExceeded", err)
}
if elapsed := time.Since(start); elapsed > 2*time.Second {
t.Fatalf("took too long: %v", elapsed)
}
}
func TestDialReturnsErrorForBadKey(t *testing.T) {
t.Parallel()
keyPath := filepath.Join(t.TempDir(), "bogus")
if err := os.WriteFile(keyPath, []byte("nope"), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
_, err := Dial(context.Background(), freeAddr(t), keyPath, "")
if err == nil {
t.Fatal("expected error for bad key")
}
}