Guest host-key verification was off in all three SSH paths:
* Go SSH (internal/guest/ssh.go) used ssh.InsecureIgnoreHostKey
* `banger vm ssh` passed StrictHostKeyChecking=no
+ UserKnownHostsFile=/dev/null
* `~/.ssh/config` Host *.vm shipped the same posture into the
user's global config
Now each path verifies against a banger-owned known_hosts file at
`~/.local/state/banger/ssh/known_hosts` with TOFU semantics:
* First dial to a VM pins the key.
* Subsequent dials require an exact match. A mismatch fails with
an explicit "possible MITM" error.
* `vm delete` removes the entries so a future VM reusing the IP
or name re-pins cleanly.
* The user's `~/.ssh/known_hosts` is untouched.
Changes:
internal/guest/known_hosts.go (new) — OpenSSH-compatible parser,
TOFUHostKeyCallback, RemoveKnownHosts. Process-wide mutex
around the file.
internal/guest/ssh.go — Dial and WaitForSSH grew a knownHostsPath
parameter threaded through the callback. Empty path keeps the
insecure callback (tests + throwaway tools only; documented).
internal/daemon/{guest_sessions,session_attach,session_lifecycle,
session_stream}.go — call sites pass d.layout.KnownHostsPath.
internal/daemon/ssh_client_config.go — the ~/.ssh/config Host *.vm
block now points at banger's known_hosts and uses
StrictHostKeyChecking=accept-new. Missing path → fail closed.
internal/daemon/vm_lifecycle.go — deleteVMLocked drops known_hosts
entries for the VM's IP and DNS name via removeVMKnownHosts.
internal/cli/banger.go — sshCommandArgs swaps StrictHostKeyChecking
no + /dev/null for banger's file + accept-new. Path resolution
failure falls through to StrictHostKeyChecking=yes.
internal/paths/paths.go — Layout gains SSHDir + KnownHostsPath;
Ensure creates SSHDir at 0700.
Tests (internal/guest/known_hosts_test.go): pin on first use, accept
matching key on second dial, reject mismatch, empty path skips
checking, RemoveKnownHosts drops the entry, re-pin works after
remove. Existing daemon + cli tests updated to assert the new
posture and regression-guard against the old flags.
Live verified: vm run writes the pin to banger's known_hosts at 0600
inside a 0700 dir; banger vm ssh + ssh root@<vm>.vm both succeed
using the pin; vm delete clears it.
436 lines
9.8 KiB
Go
436 lines
9.8 KiB
Go
package guest
|
|
|
|
import (
|
|
"archive/tar"
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type Client struct {
|
|
client *ssh.Client
|
|
}
|
|
|
|
type StreamSession struct {
|
|
client *Client
|
|
session *ssh.Session
|
|
stdin io.WriteCloser
|
|
stdout io.Reader
|
|
stderr io.Reader
|
|
waitCh chan error
|
|
closeOnce sync.Once
|
|
}
|
|
|
|
// WaitForSSH polls Dial until it succeeds or ctx cancels. The
|
|
// knownHostsPath argument is the banger-owned TOFU file; empty
|
|
// disables host-key verification (tests only).
|
|
func WaitForSSH(ctx context.Context, address, privateKeyPath, knownHostsPath string, interval time.Duration) error {
|
|
if interval <= 0 {
|
|
interval = time.Second
|
|
}
|
|
for {
|
|
client, err := Dial(ctx, address, privateKeyPath, knownHostsPath)
|
|
if err == nil {
|
|
_ = client.Close()
|
|
return nil
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(interval):
|
|
}
|
|
}
|
|
}
|
|
|
|
// Dial opens an SSH client to address, authenticating with the key
|
|
// at privateKeyPath and verifying the remote host key against the
|
|
// TOFU known_hosts file at knownHostsPath. An empty knownHostsPath
|
|
// disables verification (tests / one-shot tools only).
|
|
func Dial(ctx context.Context, address, privateKeyPath, knownHostsPath string) (*Client, error) {
|
|
signer, err := privateKeySigner(privateKeyPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
config := &ssh.ClientConfig{
|
|
User: "root",
|
|
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
|
|
HostKeyCallback: TOFUHostKeyCallback(knownHostsPath),
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
|
conn, err := dialer.DialContext(ctx, "tcp", address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sshConn, chans, reqs, err := ssh.NewClientConn(conn, address, config)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return nil, err
|
|
}
|
|
client := ssh.NewClient(sshConn, chans, reqs)
|
|
return &Client{client: client}, nil
|
|
}
|
|
|
|
func (c *Client) Close() error {
|
|
if c == nil || c.client == nil {
|
|
return nil
|
|
}
|
|
return c.client.Close()
|
|
}
|
|
|
|
func (c *Client) RunScript(ctx context.Context, script string, logWriter io.Writer) error {
|
|
return c.runSession(ctx, "bash -se", strings.NewReader(script), logWriter)
|
|
}
|
|
|
|
// RunScriptOutput runs script on the guest and returns its stdout.
|
|
// Stderr is discarded. Use for capturing structured output (patches, JSON,
|
|
// file content) where mixing stderr into stdout would corrupt the result.
|
|
func (c *Client) RunScriptOutput(ctx context.Context, script string) ([]byte, error) {
|
|
if c == nil || c.client == nil {
|
|
return nil, fmt.Errorf("ssh client is not connected")
|
|
}
|
|
session, err := c.client.NewSession()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer session.Close()
|
|
session.Stdin = strings.NewReader(script)
|
|
var stdout bytes.Buffer
|
|
session.Stdout = &stdout
|
|
// session.Stderr left nil: stderr is intentionally discarded.
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = c.client.Close()
|
|
case <-done:
|
|
}
|
|
}()
|
|
err = session.Run("bash -se")
|
|
done <- nil
|
|
return stdout.Bytes(), err
|
|
}
|
|
|
|
func (c *Client) UploadFile(ctx context.Context, remotePath string, mode os.FileMode, data []byte, logWriter io.Writer) error {
|
|
command := fmt.Sprintf("install -D -m %04o /dev/stdin %s", mode.Perm(), shellQuote(remotePath))
|
|
return c.runSession(ctx, command, bytes.NewReader(data), logWriter)
|
|
}
|
|
|
|
func (c *Client) StreamTar(ctx context.Context, sourceDir, remoteCommand string, logWriter io.Writer) error {
|
|
reader, writer := io.Pipe()
|
|
writeErr := make(chan error, 1)
|
|
go func() {
|
|
writeErr <- writeTarArchive(writer, sourceDir)
|
|
_ = writer.Close()
|
|
}()
|
|
|
|
runErr := c.runSession(ctx, remoteCommand, reader, logWriter)
|
|
tarErr := <-writeErr
|
|
return errors.Join(runErr, tarErr)
|
|
}
|
|
|
|
func (c *Client) StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error {
|
|
reader, writer := io.Pipe()
|
|
writeErr := make(chan error, 1)
|
|
go func() {
|
|
writeErr <- writeTarEntriesArchive(writer, sourceDir, entries)
|
|
_ = writer.Close()
|
|
}()
|
|
|
|
runErr := c.runSession(ctx, remoteCommand, reader, logWriter)
|
|
tarErr := <-writeErr
|
|
return errors.Join(runErr, tarErr)
|
|
}
|
|
|
|
func (c *Client) StartCommand(ctx context.Context, command string) (*StreamSession, error) {
|
|
if c == nil || c.client == nil {
|
|
return nil, fmt.Errorf("ssh client is not connected")
|
|
}
|
|
session, err := c.client.NewSession()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
stdin, err := session.StdinPipe()
|
|
if err != nil {
|
|
_ = session.Close()
|
|
return nil, err
|
|
}
|
|
stdout, err := session.StdoutPipe()
|
|
if err != nil {
|
|
_ = session.Close()
|
|
return nil, err
|
|
}
|
|
stderr, err := session.StderrPipe()
|
|
if err != nil {
|
|
_ = session.Close()
|
|
return nil, err
|
|
}
|
|
done := make(chan struct{})
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = session.Close()
|
|
_ = c.client.Close()
|
|
case <-done:
|
|
}
|
|
}()
|
|
if err := session.Start(command); err != nil {
|
|
close(done)
|
|
_ = session.Close()
|
|
return nil, err
|
|
}
|
|
stream := &StreamSession{
|
|
client: c,
|
|
session: session,
|
|
stdin: stdin,
|
|
stdout: stdout,
|
|
stderr: stderr,
|
|
waitCh: make(chan error, 1),
|
|
}
|
|
go func() {
|
|
err := session.Wait()
|
|
close(done)
|
|
stream.waitCh <- err
|
|
close(stream.waitCh)
|
|
}()
|
|
return stream, nil
|
|
}
|
|
|
|
func (s *StreamSession) Stdin() io.WriteCloser {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
return s.stdin
|
|
}
|
|
|
|
func (s *StreamSession) Stdout() io.Reader {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
return s.stdout
|
|
}
|
|
|
|
func (s *StreamSession) Stderr() io.Reader {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
return s.stderr
|
|
}
|
|
|
|
func (s *StreamSession) Wait() error {
|
|
if s == nil || s.waitCh == nil {
|
|
return nil
|
|
}
|
|
err, ok := <-s.waitCh
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *StreamSession) Close() error {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
var err error
|
|
s.closeOnce.Do(func() {
|
|
err = errors.Join(
|
|
func() error {
|
|
if s.session != nil {
|
|
return s.session.Close()
|
|
}
|
|
return nil
|
|
}(),
|
|
func() error {
|
|
if s.client != nil {
|
|
return s.client.Close()
|
|
}
|
|
return nil
|
|
}(),
|
|
)
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (c *Client) runSession(ctx context.Context, command string, stdin io.Reader, logWriter io.Writer) error {
|
|
if c == nil || c.client == nil {
|
|
return fmt.Errorf("ssh client is not connected")
|
|
}
|
|
session, err := c.client.NewSession()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer session.Close()
|
|
session.Stdin = stdin
|
|
if logWriter != nil {
|
|
session.Stdout = logWriter
|
|
session.Stderr = logWriter
|
|
}
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = c.client.Close()
|
|
case <-done:
|
|
}
|
|
}()
|
|
|
|
err = session.Run(command)
|
|
done <- nil
|
|
return err
|
|
}
|
|
|
|
func privateKeySigner(path string) (ssh.Signer, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ssh.ParsePrivateKey(data)
|
|
}
|
|
|
|
func AuthorizedPublicKey(path string) ([]byte, error) {
|
|
signer, err := privateKeySigner(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ssh.MarshalAuthorizedKey(signer.PublicKey()), nil
|
|
}
|
|
|
|
func AuthorizedPublicKeyFingerprint(path string) (string, error) {
|
|
key, err := AuthorizedPublicKey(path)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
sum := sha256.Sum256([]byte(strings.TrimSpace(string(key))))
|
|
return hex.EncodeToString(sum[:]), nil
|
|
}
|
|
|
|
func shellQuote(value string) string {
|
|
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
|
|
}
|
|
|
|
func writeTarArchive(dst io.Writer, sourceDir string) error {
|
|
tw := tar.NewWriter(dst)
|
|
defer tw.Close()
|
|
|
|
sourceDir = filepath.Clean(sourceDir)
|
|
rootName := filepath.Base(sourceDir)
|
|
return filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
name := rootName
|
|
if path != sourceDir {
|
|
relPath, err := filepath.Rel(sourceDir, path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
name = filepath.Join(rootName, relPath)
|
|
}
|
|
linkTarget := ""
|
|
if info.Mode()&os.ModeSymlink != 0 {
|
|
linkTarget, err = os.Readlink(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
header, err := tar.FileInfoHeader(info, linkTarget)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
header.Name = name
|
|
if err := tw.WriteHeader(header); err != nil {
|
|
return err
|
|
}
|
|
if !info.Mode().IsRegular() {
|
|
return nil
|
|
}
|
|
file, err := os.Open(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer file.Close()
|
|
_, err = io.Copy(tw, file)
|
|
return err
|
|
})
|
|
}
|
|
|
|
func writeTarEntriesArchive(dst io.Writer, sourceDir string, entries []string) error {
|
|
tw := tar.NewWriter(dst)
|
|
defer tw.Close()
|
|
|
|
sourceDir = filepath.Clean(sourceDir)
|
|
rootName := filepath.Base(sourceDir)
|
|
|
|
uniqueEntries := make([]string, 0, len(entries))
|
|
seen := make(map[string]struct{}, len(entries))
|
|
for _, entry := range entries {
|
|
entry = strings.TrimSpace(entry)
|
|
if entry == "" {
|
|
continue
|
|
}
|
|
entry = filepath.Clean(entry)
|
|
if entry == "." || entry == ".." || strings.HasPrefix(entry, ".."+string(filepath.Separator)) {
|
|
return fmt.Errorf("tar entry %q escapes source dir", entry)
|
|
}
|
|
if _, ok := seen[entry]; ok {
|
|
continue
|
|
}
|
|
seen[entry] = struct{}{}
|
|
uniqueEntries = append(uniqueEntries, entry)
|
|
}
|
|
sort.Strings(uniqueEntries)
|
|
|
|
for _, entry := range uniqueEntries {
|
|
fullPath := filepath.Join(sourceDir, entry)
|
|
info, err := os.Lstat(fullPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
linkTarget := ""
|
|
if info.Mode()&os.ModeSymlink != 0 {
|
|
linkTarget, err = os.Readlink(fullPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
header, err := tar.FileInfoHeader(info, linkTarget)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
header.Name = path.Join(rootName, filepath.ToSlash(entry))
|
|
if err := tw.WriteHeader(header); err != nil {
|
|
return err
|
|
}
|
|
if !info.Mode().IsRegular() {
|
|
continue
|
|
}
|
|
file, err := os.Open(fullPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := io.Copy(tw, file); err != nil {
|
|
_ = file.Close()
|
|
return err
|
|
}
|
|
if err := file.Close(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|