From ea2db1e868fd7980d1a180a868518fd7549ee061 Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Sun, 22 Mar 2026 16:48:42 -0300 Subject: [PATCH] Configure direct SSH access for .vm hosts Make daemon startup sync a managed `Host *.vm` block into `~/.ssh/config` so plain `ssh root@.vm` uses banger's managed key and the same publickey-only options as `banger vm ssh`. Write the block directly instead of relying on a separate include file so it still applies when a user's SSH config ends inside another `Host` stanza, and remove the legacy managed include path. Add daemon tests that cover fresh config creation and managed-block replacement while preserving user entries. Validate with `go test ./...`, `make build`, `ssh -G alp.vm`, and `ssh alp.vm true`. --- internal/daemon/daemon.go | 1 + internal/daemon/ssh_client_config.go | 131 ++++++++++++++++++++++ internal/daemon/ssh_client_config_test.go | 95 ++++++++++++++++ 3 files changed, 227 insertions(+) create mode 100644 internal/daemon/ssh_client_config.go create mode 100644 internal/daemon/ssh_client_config_test.go diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 31f40ef..de1176e 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -84,6 +84,7 @@ func Open(ctx context.Context) (d *Daemon, err error) { closing: make(chan struct{}), pid: os.Getpid(), } + d.ensureVMSSHClientConfig() d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "log_level", cfg.LogLevel) if err = d.startVMDNS(vmdns.DefaultListenAddr); err != nil { d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error()) diff --git a/internal/daemon/ssh_client_config.go b/internal/daemon/ssh_client_config.go new file mode 100644 index 0000000..299b38e --- /dev/null +++ b/internal/daemon/ssh_client_config.go @@ -0,0 +1,131 @@ +package daemon + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "banger/internal/paths" +) + +const ( + vmSSHConfigIncludeBegin = "# BEGIN BANGER MANAGED VM SSH" + vmSSHConfigIncludeEnd = "# END BANGER MANAGED VM SSH" +) + +func (d *Daemon) ensureVMSSHClientConfig() { + if err := syncVMSSHClientConfig(d.layout, d.config.SSHKeyPath); err != nil && d.logger != nil { + d.logger.Warn("vm ssh client config sync failed", "error", err.Error()) + } +} + +func syncVMSSHClientConfig(layout paths.Layout, keyPath string) error { + keyPath = strings.TrimSpace(keyPath) + if keyPath == "" { + return nil + } + + home, err := os.UserHomeDir() + if err != nil { + return err + } + sshDir := filepath.Join(home, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + return err + } + userConfigPath := filepath.Join(sshDir, "config") + userConfig, err := readTextFileIfExists(userConfigPath) + if err != nil { + return err + } + updated, err := upsertManagedBlock(userConfig, vmSSHConfigIncludeBegin, vmSSHConfigIncludeEnd, renderManagedVMSSHBlock(keyPath)) + if err != nil { + return err + } + if err := writeTextFileIfChanged(userConfigPath, updated, 0o644); err != nil { + return err + } + + legacyManagedPath := filepath.Join(layout.ConfigDir, "ssh", "ssh_config") + if err := os.Remove(legacyManagedPath); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} + +func renderManagedVMSSHBlock(keyPath string) string { + keyPath = strings.TrimSpace(keyPath) + return strings.Join([]string{ + vmSSHConfigIncludeBegin, + "# Generated by banger for direct SSH access to VM DNS names.", + "Host *.vm", + " User root", + " IdentityFile " + keyPath, + " IdentitiesOnly yes", + " BatchMode yes", + " PreferredAuthentications publickey", + " PasswordAuthentication no", + " KbdInteractiveAuthentication no", + " StrictHostKeyChecking no", + " UserKnownHostsFile /dev/null", + " LogLevel ERROR", + vmSSHConfigIncludeEnd, + "", + }, "\n") +} + +func upsertManagedBlock(existing, beginMarker, endMarker, block string) (string, error) { + existing = normalizeConfigText(existing) + block = normalizeConfigText(block) + + start := strings.Index(existing, beginMarker) + if start >= 0 { + end := strings.Index(existing[start:], endMarker) + if end < 0 { + return "", fmt.Errorf("managed block %q is missing end marker %q", beginMarker, endMarker) + } + end += start + len(endMarker) + for end < len(existing) && existing[end] == '\n' { + end++ + } + existing = strings.TrimRight(existing[:start]+existing[end:], "\n") + } + + if strings.TrimSpace(existing) == "" { + return block, nil + } + return strings.TrimRight(existing, "\n") + "\n\n" + block, nil +} + +func normalizeConfigText(text string) string { + text = strings.ReplaceAll(text, "\r\n", "\n") + text = strings.TrimRight(text, "\n") + if text == "" { + return "" + } + return text + "\n" +} + +func readTextFileIfExists(path string) (string, error) { + data, err := os.ReadFile(path) + if err == nil { + return string(data), nil + } + if os.IsNotExist(err) { + return "", nil + } + return "", err +} + +func writeTextFileIfChanged(path, content string, mode os.FileMode) error { + content = normalizeConfigText(content) + existing, err := readTextFileIfExists(path) + if err != nil { + return err + } + if existing == content { + return nil + } + return os.WriteFile(path, []byte(content), mode) +} diff --git a/internal/daemon/ssh_client_config_test.go b/internal/daemon/ssh_client_config_test.go new file mode 100644 index 0000000..80d8e95 --- /dev/null +++ b/internal/daemon/ssh_client_config_test.go @@ -0,0 +1,95 @@ +package daemon + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "banger/internal/paths" +) + +func TestSyncVMSSHClientConfigCreatesManagedBlock(t *testing.T) { + homeDir := t.TempDir() + t.Setenv("HOME", homeDir) + + layout := paths.Layout{ + ConfigDir: filepath.Join(homeDir, ".config", "banger"), + } + keyPath := filepath.Join(layout.ConfigDir, "ssh", "id_ed25519") + + if err := syncVMSSHClientConfig(layout, keyPath); err != nil { + t.Fatalf("syncVMSSHClientConfig: %v", err) + } + + userConfigPath := filepath.Join(homeDir, ".ssh", "config") + userConfig, err := os.ReadFile(userConfigPath) + if err != nil { + t.Fatalf("ReadFile(user config): %v", err) + } + userContent := string(userConfig) + if !strings.Contains(userContent, vmSSHConfigIncludeBegin) { + t.Fatalf("user config = %q, want begin marker", userContent) + } + for _, want := range []string{ + "Host *.vm", + "User root", + "IdentityFile " + keyPath, + "IdentitiesOnly yes", + "BatchMode yes", + "PasswordAuthentication no", + "UserKnownHostsFile /dev/null", + } { + if !strings.Contains(userContent, want) { + t.Fatalf("user config = %q, want %q", userContent, want) + } + } +} + +func TestSyncVMSSHClientConfigReplacesManagedIncludeBlock(t *testing.T) { + homeDir := t.TempDir() + t.Setenv("HOME", homeDir) + + layout := paths.Layout{ + ConfigDir: filepath.Join(homeDir, ".config", "banger"), + } + keyPath := filepath.Join(layout.ConfigDir, "ssh", "id_ed25519") + + sshDir := filepath.Join(homeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + t.Fatalf("MkdirAll(.ssh): %v", err) + } + initial := strings.Join([]string{ + "ServerAliveInterval 120", + "", + vmSSHConfigIncludeBegin, + "Include /tmp/old-banger-config", + vmSSHConfigIncludeEnd, + "", + "Host other", + " HostName 192.0.2.5", + "", + }, "\n") + if err := os.WriteFile(filepath.Join(sshDir, "config"), []byte(initial), 0o644); err != nil { + t.Fatalf("WriteFile(user config): %v", err) + } + + if err := syncVMSSHClientConfig(layout, keyPath); err != nil { + t.Fatalf("syncVMSSHClientConfig: %v", err) + } + + userConfig, err := os.ReadFile(filepath.Join(sshDir, "config")) + if err != nil { + t.Fatalf("ReadFile(user config): %v", err) + } + userContent := string(userConfig) + if strings.Count(userContent, vmSSHConfigIncludeBegin) != 1 { + t.Fatalf("user config = %q, want one managed block", userContent) + } + if !strings.Contains(userContent, "ServerAliveInterval 120") || !strings.Contains(userContent, "Host other") { + t.Fatalf("user config = %q, want existing entries preserved", userContent) + } + if !strings.Contains(userContent, "Host *.vm") || !strings.Contains(userContent, "IdentityFile "+keyPath) { + t.Fatalf("user config = %q, want refreshed managed vm block", userContent) + } +}