package daemon import ( "fmt" "log/slog" "os" "path/filepath" "strings" "banger/internal/guest" "banger/internal/model" "banger/internal/paths" ) // removeVMKnownHosts drops every host-key pin for vm from the // banger-owned known_hosts. Best-effort — a failure here only // matters if the same IP/name is reused by a fresh VM before the // next daemon restart, and even then it just causes a // TOFU-mismatch error that the user can clear manually. Logged at // warn so it shows up if it ever actually breaks things. func removeVMKnownHosts(knownHostsPath string, vm model.VMRecord, logger *slog.Logger) { if strings.TrimSpace(knownHostsPath) == "" { return } var hosts []string if ip := strings.TrimSpace(vm.Runtime.GuestIP); ip != "" { hosts = append(hosts, ip) } if dns := strings.TrimSpace(vm.Runtime.DNSName); dns != "" { hosts = append(hosts, dns) } if len(hosts) == 0 { return } if err := guest.RemoveKnownHosts(knownHostsPath, hosts...); err != nil && logger != nil { logger.Warn("remove known_hosts entries", "vm_id", vm.ID, "error", err.Error()) } } const ( vmSSHConfigIncludeBegin = "# BEGIN BANGER MANAGED VM SSH" vmSSHConfigIncludeEnd = "# END BANGER MANAGED VM SSH" ) func (d *Daemon) ensureVMSSHClientConfig() { if err := syncVMSSHClientConfig(d.layout, d.config.SSHKeyPath); err != nil && d.logger != nil { d.logger.Warn("vm ssh client config sync failed", "error", err.Error()) } } func syncVMSSHClientConfig(layout paths.Layout, keyPath string) error { keyPath = strings.TrimSpace(keyPath) if keyPath == "" { return nil } home, err := os.UserHomeDir() if err != nil { return err } sshDir := filepath.Join(home, ".ssh") if err := os.MkdirAll(sshDir, 0o700); err != nil { return err } userConfigPath := filepath.Join(sshDir, "config") userConfig, err := readTextFileIfExists(userConfigPath) if err != nil { return err } updated, err := upsertManagedBlock(userConfig, vmSSHConfigIncludeBegin, vmSSHConfigIncludeEnd, renderManagedVMSSHBlock(keyPath, layout.KnownHostsPath)) if err != nil { return err } if err := writeTextFileIfChanged(userConfigPath, updated, 0o644); err != nil { return err } legacyManagedPath := filepath.Join(layout.ConfigDir, "ssh", "ssh_config") if err := os.Remove(legacyManagedPath); err != nil && !os.IsNotExist(err) { return err } return nil } // renderManagedVMSSHBlock produces the `Host *.vm` stanza banger // writes into the user's ~/.ssh/config. Host-key verification uses // the banger-owned known_hosts file at knownHostsPath — NOT the // user's ~/.ssh/known_hosts, and NOT /dev/null. `accept-new` means // first contact pins the key; any later mismatch fails the connect. func renderManagedVMSSHBlock(keyPath, knownHostsPath string) string { keyPath = strings.TrimSpace(keyPath) knownHostsPath = strings.TrimSpace(knownHostsPath) lines := []string{ vmSSHConfigIncludeBegin, "# Generated by banger for direct SSH access to VM DNS names.", "# Host keys are pinned on first use into a banger-owned", "# known_hosts file (not ~/.ssh/known_hosts).", "Host *.vm", " User root", " IdentityFile " + keyPath, " IdentitiesOnly yes", " BatchMode yes", " PreferredAuthentications publickey", " PasswordAuthentication no", " KbdInteractiveAuthentication no", } if knownHostsPath != "" { lines = append(lines, " UserKnownHostsFile "+knownHostsPath, " StrictHostKeyChecking accept-new", ) } else { // Missing known_hosts path is a configuration anomaly — fail // closed rather than silently disable verification. lines = append(lines, " StrictHostKeyChecking yes") } lines = append(lines, " LogLevel ERROR", vmSSHConfigIncludeEnd, "", ) return strings.Join(lines, "\n") } func upsertManagedBlock(existing, beginMarker, endMarker, block string) (string, error) { existing = normalizeConfigText(existing) block = normalizeConfigText(block) start := strings.Index(existing, beginMarker) if start >= 0 { end := strings.Index(existing[start:], endMarker) if end < 0 { return "", fmt.Errorf("managed block %q is missing end marker %q", beginMarker, endMarker) } end += start + len(endMarker) for end < len(existing) && existing[end] == '\n' { end++ } existing = strings.TrimRight(existing[:start]+existing[end:], "\n") } if strings.TrimSpace(existing) == "" { return block, nil } return strings.TrimRight(existing, "\n") + "\n\n" + block, nil } func normalizeConfigText(text string) string { text = strings.ReplaceAll(text, "\r\n", "\n") text = strings.TrimRight(text, "\n") if text == "" { return "" } return text + "\n" } func readTextFileIfExists(path string) (string, error) { data, err := os.ReadFile(path) if err == nil { return string(data), nil } if os.IsNotExist(err) { return "", nil } return "", err } func writeTextFileIfChanged(path, content string, mode os.FileMode) error { content = normalizeConfigText(content) existing, err := readTextFileIfExists(path) if err != nil { return err } if existing == content { return nil } return os.WriteFile(path, []byte(content), mode) }