package daemon import ( "fmt" "os" "path/filepath" "strings" "banger/internal/paths" ) const ( vmSSHConfigIncludeBegin = "# BEGIN BANGER MANAGED VM SSH" vmSSHConfigIncludeEnd = "# END BANGER MANAGED VM SSH" ) func (d *Daemon) ensureVMSSHClientConfig() { if err := syncVMSSHClientConfig(d.layout, d.config.SSHKeyPath); err != nil && d.logger != nil { d.logger.Warn("vm ssh client config sync failed", "error", err.Error()) } } func syncVMSSHClientConfig(layout paths.Layout, keyPath string) error { keyPath = strings.TrimSpace(keyPath) if keyPath == "" { return nil } home, err := os.UserHomeDir() if err != nil { return err } sshDir := filepath.Join(home, ".ssh") if err := os.MkdirAll(sshDir, 0o700); err != nil { return err } userConfigPath := filepath.Join(sshDir, "config") userConfig, err := readTextFileIfExists(userConfigPath) if err != nil { return err } updated, err := upsertManagedBlock(userConfig, vmSSHConfigIncludeBegin, vmSSHConfigIncludeEnd, renderManagedVMSSHBlock(keyPath)) if err != nil { return err } if err := writeTextFileIfChanged(userConfigPath, updated, 0o644); err != nil { return err } legacyManagedPath := filepath.Join(layout.ConfigDir, "ssh", "ssh_config") if err := os.Remove(legacyManagedPath); err != nil && !os.IsNotExist(err) { return err } return nil } func renderManagedVMSSHBlock(keyPath string) string { keyPath = strings.TrimSpace(keyPath) return strings.Join([]string{ vmSSHConfigIncludeBegin, "# Generated by banger for direct SSH access to VM DNS names.", "Host *.vm", " User root", " IdentityFile " + keyPath, " IdentitiesOnly yes", " BatchMode yes", " PreferredAuthentications publickey", " PasswordAuthentication no", " KbdInteractiveAuthentication no", " StrictHostKeyChecking no", " UserKnownHostsFile /dev/null", " LogLevel ERROR", vmSSHConfigIncludeEnd, "", }, "\n") } func upsertManagedBlock(existing, beginMarker, endMarker, block string) (string, error) { existing = normalizeConfigText(existing) block = normalizeConfigText(block) start := strings.Index(existing, beginMarker) if start >= 0 { end := strings.Index(existing[start:], endMarker) if end < 0 { return "", fmt.Errorf("managed block %q is missing end marker %q", beginMarker, endMarker) } end += start + len(endMarker) for end < len(existing) && existing[end] == '\n' { end++ } existing = strings.TrimRight(existing[:start]+existing[end:], "\n") } if strings.TrimSpace(existing) == "" { return block, nil } return strings.TrimRight(existing, "\n") + "\n\n" + block, nil } func normalizeConfigText(text string) string { text = strings.ReplaceAll(text, "\r\n", "\n") text = strings.TrimRight(text, "\n") if text == "" { return "" } return text + "\n" } func readTextFileIfExists(path string) (string, error) { data, err := os.ReadFile(path) if err == nil { return string(data), nil } if os.IsNotExist(err) { return "", nil } return "", err } func writeTextFileIfChanged(path, content string, mode os.FileMode) error { content = normalizeConfigText(content) existing, err := readTextFileIfExists(path) if err != nil { return err } if existing == content { return nil } return os.WriteFile(path, []byte(content), mode) }