Configure direct SSH access for .vm hosts
Make daemon startup sync a managed `Host *.vm` block into `~/.ssh/config` so plain `ssh root@<vm>.vm` uses banger's managed key and the same publickey-only options as `banger vm ssh`. Write the block directly instead of relying on a separate include file so it still applies when a user's SSH config ends inside another `Host` stanza, and remove the legacy managed include path. Add daemon tests that cover fresh config creation and managed-block replacement while preserving user entries. Validate with `go test ./...`, `make build`, `ssh -G alp.vm`, and `ssh alp.vm true`.
This commit is contained in:
parent
b7f6d1fe1b
commit
ea2db1e868
3 changed files with 227 additions and 0 deletions
|
|
@ -84,6 +84,7 @@ func Open(ctx context.Context) (d *Daemon, err error) {
|
||||||
closing: make(chan struct{}),
|
closing: make(chan struct{}),
|
||||||
pid: os.Getpid(),
|
pid: os.Getpid(),
|
||||||
}
|
}
|
||||||
|
d.ensureVMSSHClientConfig()
|
||||||
d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "log_level", cfg.LogLevel)
|
d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "log_level", cfg.LogLevel)
|
||||||
if err = d.startVMDNS(vmdns.DefaultListenAddr); err != nil {
|
if err = d.startVMDNS(vmdns.DefaultListenAddr); err != nil {
|
||||||
d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error())
|
d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error())
|
||||||
|
|
|
||||||
131
internal/daemon/ssh_client_config.go
Normal file
131
internal/daemon/ssh_client_config.go
Normal file
|
|
@ -0,0 +1,131 @@
|
||||||
|
package daemon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"banger/internal/paths"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
vmSSHConfigIncludeBegin = "# BEGIN BANGER MANAGED VM SSH"
|
||||||
|
vmSSHConfigIncludeEnd = "# END BANGER MANAGED VM SSH"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (d *Daemon) ensureVMSSHClientConfig() {
|
||||||
|
if err := syncVMSSHClientConfig(d.layout, d.config.SSHKeyPath); err != nil && d.logger != nil {
|
||||||
|
d.logger.Warn("vm ssh client config sync failed", "error", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func syncVMSSHClientConfig(layout paths.Layout, keyPath string) error {
|
||||||
|
keyPath = strings.TrimSpace(keyPath)
|
||||||
|
if keyPath == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sshDir := filepath.Join(home, ".ssh")
|
||||||
|
if err := os.MkdirAll(sshDir, 0o700); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
userConfigPath := filepath.Join(sshDir, "config")
|
||||||
|
userConfig, err := readTextFileIfExists(userConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
updated, err := upsertManagedBlock(userConfig, vmSSHConfigIncludeBegin, vmSSHConfigIncludeEnd, renderManagedVMSSHBlock(keyPath))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := writeTextFileIfChanged(userConfigPath, updated, 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
legacyManagedPath := filepath.Join(layout.ConfigDir, "ssh", "ssh_config")
|
||||||
|
if err := os.Remove(legacyManagedPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderManagedVMSSHBlock(keyPath string) string {
|
||||||
|
keyPath = strings.TrimSpace(keyPath)
|
||||||
|
return strings.Join([]string{
|
||||||
|
vmSSHConfigIncludeBegin,
|
||||||
|
"# Generated by banger for direct SSH access to VM DNS names.",
|
||||||
|
"Host *.vm",
|
||||||
|
" User root",
|
||||||
|
" IdentityFile " + keyPath,
|
||||||
|
" IdentitiesOnly yes",
|
||||||
|
" BatchMode yes",
|
||||||
|
" PreferredAuthentications publickey",
|
||||||
|
" PasswordAuthentication no",
|
||||||
|
" KbdInteractiveAuthentication no",
|
||||||
|
" StrictHostKeyChecking no",
|
||||||
|
" UserKnownHostsFile /dev/null",
|
||||||
|
" LogLevel ERROR",
|
||||||
|
vmSSHConfigIncludeEnd,
|
||||||
|
"",
|
||||||
|
}, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func upsertManagedBlock(existing, beginMarker, endMarker, block string) (string, error) {
|
||||||
|
existing = normalizeConfigText(existing)
|
||||||
|
block = normalizeConfigText(block)
|
||||||
|
|
||||||
|
start := strings.Index(existing, beginMarker)
|
||||||
|
if start >= 0 {
|
||||||
|
end := strings.Index(existing[start:], endMarker)
|
||||||
|
if end < 0 {
|
||||||
|
return "", fmt.Errorf("managed block %q is missing end marker %q", beginMarker, endMarker)
|
||||||
|
}
|
||||||
|
end += start + len(endMarker)
|
||||||
|
for end < len(existing) && existing[end] == '\n' {
|
||||||
|
end++
|
||||||
|
}
|
||||||
|
existing = strings.TrimRight(existing[:start]+existing[end:], "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(existing) == "" {
|
||||||
|
return block, nil
|
||||||
|
}
|
||||||
|
return strings.TrimRight(existing, "\n") + "\n\n" + block, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeConfigText(text string) string {
|
||||||
|
text = strings.ReplaceAll(text, "\r\n", "\n")
|
||||||
|
text = strings.TrimRight(text, "\n")
|
||||||
|
if text == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return text + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
func readTextFileIfExists(path string) (string, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err == nil {
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeTextFileIfChanged(path, content string, mode os.FileMode) error {
|
||||||
|
content = normalizeConfigText(content)
|
||||||
|
existing, err := readTextFileIfExists(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if existing == content {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return os.WriteFile(path, []byte(content), mode)
|
||||||
|
}
|
||||||
95
internal/daemon/ssh_client_config_test.go
Normal file
95
internal/daemon/ssh_client_config_test.go
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
package daemon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"banger/internal/paths"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSyncVMSSHClientConfigCreatesManagedBlock(t *testing.T) {
|
||||||
|
homeDir := t.TempDir()
|
||||||
|
t.Setenv("HOME", homeDir)
|
||||||
|
|
||||||
|
layout := paths.Layout{
|
||||||
|
ConfigDir: filepath.Join(homeDir, ".config", "banger"),
|
||||||
|
}
|
||||||
|
keyPath := filepath.Join(layout.ConfigDir, "ssh", "id_ed25519")
|
||||||
|
|
||||||
|
if err := syncVMSSHClientConfig(layout, keyPath); err != nil {
|
||||||
|
t.Fatalf("syncVMSSHClientConfig: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userConfigPath := filepath.Join(homeDir, ".ssh", "config")
|
||||||
|
userConfig, err := os.ReadFile(userConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFile(user config): %v", err)
|
||||||
|
}
|
||||||
|
userContent := string(userConfig)
|
||||||
|
if !strings.Contains(userContent, vmSSHConfigIncludeBegin) {
|
||||||
|
t.Fatalf("user config = %q, want begin marker", userContent)
|
||||||
|
}
|
||||||
|
for _, want := range []string{
|
||||||
|
"Host *.vm",
|
||||||
|
"User root",
|
||||||
|
"IdentityFile " + keyPath,
|
||||||
|
"IdentitiesOnly yes",
|
||||||
|
"BatchMode yes",
|
||||||
|
"PasswordAuthentication no",
|
||||||
|
"UserKnownHostsFile /dev/null",
|
||||||
|
} {
|
||||||
|
if !strings.Contains(userContent, want) {
|
||||||
|
t.Fatalf("user config = %q, want %q", userContent, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncVMSSHClientConfigReplacesManagedIncludeBlock(t *testing.T) {
|
||||||
|
homeDir := t.TempDir()
|
||||||
|
t.Setenv("HOME", homeDir)
|
||||||
|
|
||||||
|
layout := paths.Layout{
|
||||||
|
ConfigDir: filepath.Join(homeDir, ".config", "banger"),
|
||||||
|
}
|
||||||
|
keyPath := filepath.Join(layout.ConfigDir, "ssh", "id_ed25519")
|
||||||
|
|
||||||
|
sshDir := filepath.Join(homeDir, ".ssh")
|
||||||
|
if err := os.MkdirAll(sshDir, 0o700); err != nil {
|
||||||
|
t.Fatalf("MkdirAll(.ssh): %v", err)
|
||||||
|
}
|
||||||
|
initial := strings.Join([]string{
|
||||||
|
"ServerAliveInterval 120",
|
||||||
|
"",
|
||||||
|
vmSSHConfigIncludeBegin,
|
||||||
|
"Include /tmp/old-banger-config",
|
||||||
|
vmSSHConfigIncludeEnd,
|
||||||
|
"",
|
||||||
|
"Host other",
|
||||||
|
" HostName 192.0.2.5",
|
||||||
|
"",
|
||||||
|
}, "\n")
|
||||||
|
if err := os.WriteFile(filepath.Join(sshDir, "config"), []byte(initial), 0o644); err != nil {
|
||||||
|
t.Fatalf("WriteFile(user config): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := syncVMSSHClientConfig(layout, keyPath); err != nil {
|
||||||
|
t.Fatalf("syncVMSSHClientConfig: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userConfig, err := os.ReadFile(filepath.Join(sshDir, "config"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFile(user config): %v", err)
|
||||||
|
}
|
||||||
|
userContent := string(userConfig)
|
||||||
|
if strings.Count(userContent, vmSSHConfigIncludeBegin) != 1 {
|
||||||
|
t.Fatalf("user config = %q, want one managed block", userContent)
|
||||||
|
}
|
||||||
|
if !strings.Contains(userContent, "ServerAliveInterval 120") || !strings.Contains(userContent, "Host other") {
|
||||||
|
t.Fatalf("user config = %q, want existing entries preserved", userContent)
|
||||||
|
}
|
||||||
|
if !strings.Contains(userContent, "Host *.vm") || !strings.Contains(userContent, "IdentityFile "+keyPath) {
|
||||||
|
t.Fatalf("user config = %q, want refreshed managed vm block", userContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue