Configure direct SSH access for .vm hosts

Make daemon startup sync a managed `Host *.vm` block into `~/.ssh/config` so plain `ssh root@<vm>.vm` uses banger's managed key and the same publickey-only options as `banger vm ssh`.

Write the block directly instead of relying on a separate include file so it still applies when a user's SSH config ends inside another `Host` stanza, and remove the legacy managed include path. Add daemon tests that cover fresh config creation and managed-block replacement while preserving user entries.

Validate with `go test ./...`, `make build`, `ssh -G alp.vm`, and `ssh alp.vm true`.
This commit is contained in:
Thales Maciel 2026-03-22 16:48:42 -03:00
parent b7f6d1fe1b
commit ea2db1e868
No known key found for this signature in database
GPG key ID: 33112E6833C34679
3 changed files with 227 additions and 0 deletions

View file

@ -84,6 +84,7 @@ func Open(ctx context.Context) (d *Daemon, err error) {
closing: make(chan struct{}),
pid: os.Getpid(),
}
d.ensureVMSSHClientConfig()
d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "log_level", cfg.LogLevel)
if err = d.startVMDNS(vmdns.DefaultListenAddr); err != nil {
d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error())

View file

@ -0,0 +1,131 @@
package daemon
import (
"fmt"
"os"
"path/filepath"
"strings"
"banger/internal/paths"
)
const (
vmSSHConfigIncludeBegin = "# BEGIN BANGER MANAGED VM SSH"
vmSSHConfigIncludeEnd = "# END BANGER MANAGED VM SSH"
)
func (d *Daemon) ensureVMSSHClientConfig() {
if err := syncVMSSHClientConfig(d.layout, d.config.SSHKeyPath); err != nil && d.logger != nil {
d.logger.Warn("vm ssh client config sync failed", "error", err.Error())
}
}
func syncVMSSHClientConfig(layout paths.Layout, keyPath string) error {
keyPath = strings.TrimSpace(keyPath)
if keyPath == "" {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
sshDir := filepath.Join(home, ".ssh")
if err := os.MkdirAll(sshDir, 0o700); err != nil {
return err
}
userConfigPath := filepath.Join(sshDir, "config")
userConfig, err := readTextFileIfExists(userConfigPath)
if err != nil {
return err
}
updated, err := upsertManagedBlock(userConfig, vmSSHConfigIncludeBegin, vmSSHConfigIncludeEnd, renderManagedVMSSHBlock(keyPath))
if err != nil {
return err
}
if err := writeTextFileIfChanged(userConfigPath, updated, 0o644); err != nil {
return err
}
legacyManagedPath := filepath.Join(layout.ConfigDir, "ssh", "ssh_config")
if err := os.Remove(legacyManagedPath); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
func renderManagedVMSSHBlock(keyPath string) string {
keyPath = strings.TrimSpace(keyPath)
return strings.Join([]string{
vmSSHConfigIncludeBegin,
"# Generated by banger for direct SSH access to VM DNS names.",
"Host *.vm",
" User root",
" IdentityFile " + keyPath,
" IdentitiesOnly yes",
" BatchMode yes",
" PreferredAuthentications publickey",
" PasswordAuthentication no",
" KbdInteractiveAuthentication no",
" StrictHostKeyChecking no",
" UserKnownHostsFile /dev/null",
" LogLevel ERROR",
vmSSHConfigIncludeEnd,
"",
}, "\n")
}
func upsertManagedBlock(existing, beginMarker, endMarker, block string) (string, error) {
existing = normalizeConfigText(existing)
block = normalizeConfigText(block)
start := strings.Index(existing, beginMarker)
if start >= 0 {
end := strings.Index(existing[start:], endMarker)
if end < 0 {
return "", fmt.Errorf("managed block %q is missing end marker %q", beginMarker, endMarker)
}
end += start + len(endMarker)
for end < len(existing) && existing[end] == '\n' {
end++
}
existing = strings.TrimRight(existing[:start]+existing[end:], "\n")
}
if strings.TrimSpace(existing) == "" {
return block, nil
}
return strings.TrimRight(existing, "\n") + "\n\n" + block, nil
}
func normalizeConfigText(text string) string {
text = strings.ReplaceAll(text, "\r\n", "\n")
text = strings.TrimRight(text, "\n")
if text == "" {
return ""
}
return text + "\n"
}
func readTextFileIfExists(path string) (string, error) {
data, err := os.ReadFile(path)
if err == nil {
return string(data), nil
}
if os.IsNotExist(err) {
return "", nil
}
return "", err
}
func writeTextFileIfChanged(path, content string, mode os.FileMode) error {
content = normalizeConfigText(content)
existing, err := readTextFileIfExists(path)
if err != nil {
return err
}
if existing == content {
return nil
}
return os.WriteFile(path, []byte(content), mode)
}

View file

@ -0,0 +1,95 @@
package daemon
import (
"os"
"path/filepath"
"strings"
"testing"
"banger/internal/paths"
)
func TestSyncVMSSHClientConfigCreatesManagedBlock(t *testing.T) {
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
layout := paths.Layout{
ConfigDir: filepath.Join(homeDir, ".config", "banger"),
}
keyPath := filepath.Join(layout.ConfigDir, "ssh", "id_ed25519")
if err := syncVMSSHClientConfig(layout, keyPath); err != nil {
t.Fatalf("syncVMSSHClientConfig: %v", err)
}
userConfigPath := filepath.Join(homeDir, ".ssh", "config")
userConfig, err := os.ReadFile(userConfigPath)
if err != nil {
t.Fatalf("ReadFile(user config): %v", err)
}
userContent := string(userConfig)
if !strings.Contains(userContent, vmSSHConfigIncludeBegin) {
t.Fatalf("user config = %q, want begin marker", userContent)
}
for _, want := range []string{
"Host *.vm",
"User root",
"IdentityFile " + keyPath,
"IdentitiesOnly yes",
"BatchMode yes",
"PasswordAuthentication no",
"UserKnownHostsFile /dev/null",
} {
if !strings.Contains(userContent, want) {
t.Fatalf("user config = %q, want %q", userContent, want)
}
}
}
func TestSyncVMSSHClientConfigReplacesManagedIncludeBlock(t *testing.T) {
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
layout := paths.Layout{
ConfigDir: filepath.Join(homeDir, ".config", "banger"),
}
keyPath := filepath.Join(layout.ConfigDir, "ssh", "id_ed25519")
sshDir := filepath.Join(homeDir, ".ssh")
if err := os.MkdirAll(sshDir, 0o700); err != nil {
t.Fatalf("MkdirAll(.ssh): %v", err)
}
initial := strings.Join([]string{
"ServerAliveInterval 120",
"",
vmSSHConfigIncludeBegin,
"Include /tmp/old-banger-config",
vmSSHConfigIncludeEnd,
"",
"Host other",
" HostName 192.0.2.5",
"",
}, "\n")
if err := os.WriteFile(filepath.Join(sshDir, "config"), []byte(initial), 0o644); err != nil {
t.Fatalf("WriteFile(user config): %v", err)
}
if err := syncVMSSHClientConfig(layout, keyPath); err != nil {
t.Fatalf("syncVMSSHClientConfig: %v", err)
}
userConfig, err := os.ReadFile(filepath.Join(sshDir, "config"))
if err != nil {
t.Fatalf("ReadFile(user config): %v", err)
}
userContent := string(userConfig)
if strings.Count(userContent, vmSSHConfigIncludeBegin) != 1 {
t.Fatalf("user config = %q, want one managed block", userContent)
}
if !strings.Contains(userContent, "ServerAliveInterval 120") || !strings.Contains(userContent, "Host other") {
t.Fatalf("user config = %q, want existing entries preserved", userContent)
}
if !strings.Contains(userContent, "Host *.vm") || !strings.Contains(userContent, "IdentityFile "+keyPath) {
t.Fatalf("user config = %q, want refreshed managed vm block", userContent)
}
}