package daemon import ( "os" "path/filepath" "strings" "testing" "banger/internal/paths" ) // Under the opt-in contract the daemon writes its own ssh_config file // and never touches ~/.ssh/config on its own. func TestSyncVMSSHClientConfigWritesBangerFileOnly(t *testing.T) { homeDir := t.TempDir() t.Setenv("HOME", homeDir) knownHostsPath := filepath.Join(homeDir, ".local", "state", "banger", "ssh", "known_hosts") layout := paths.Layout{ ConfigDir: filepath.Join(homeDir, ".config", "banger"), KnownHostsPath: knownHostsPath, } keyPath := filepath.Join(homeDir, ".config", "banger", "ssh", "id_ed25519") if err := syncVMSSHClientConfig(layout, keyPath); err != nil { t.Fatalf("syncVMSSHClientConfig: %v", err) } // Banger's own ssh_config file has the `Host *.vm` stanza. bangerConfig, err := os.ReadFile(BangerSSHConfigPath(layout)) if err != nil { t.Fatalf("ReadFile(banger ssh_config): %v", err) } for _, want := range []string{ "Host *.vm", "IdentityFile " + keyPath, "UserKnownHostsFile " + knownHostsPath, "StrictHostKeyChecking accept-new", } { if !strings.Contains(string(bangerConfig), want) { t.Fatalf("banger ssh_config missing %q:\n%s", want, bangerConfig) } } // ~/.ssh/config must NOT have been created or modified. if _, err := os.Stat(filepath.Join(homeDir, ".ssh", "config")); !os.IsNotExist(err) { t.Fatalf("~/.ssh/config should be untouched; stat err = %v", err) } // Regression: the legacy posture (strict no + /dev/null) must not // reappear in the banger file. for _, must := range []string{ "StrictHostKeyChecking no", "UserKnownHostsFile /dev/null", } { if strings.Contains(string(bangerConfig), must) { t.Fatalf("banger ssh_config leaked legacy posture %q:\n%s", must, bangerConfig) } } } func TestInstallUserSSHIncludeAddsIncludeBlock(t *testing.T) { homeDir := t.TempDir() t.Setenv("HOME", homeDir) layout := paths.Layout{ConfigDir: filepath.Join(homeDir, ".config", "banger")} if err := os.MkdirAll(layout.ConfigDir, 0o755); err != nil { t.Fatalf("MkdirAll: %v", err) } // Write a fake banger ssh_config so Install has something to include. if err := os.WriteFile(BangerSSHConfigPath(layout), []byte("Host *.vm\n"), 0o644); err != nil { t.Fatalf("WriteFile(banger ssh_config): %v", err) } if err := InstallUserSSHInclude(layout); err != nil { t.Fatalf("InstallUserSSHInclude: %v", err) } got, err := os.ReadFile(filepath.Join(homeDir, ".ssh", "config")) if err != nil { t.Fatalf("ReadFile(~/.ssh/config): %v", err) } want := "Include " + BangerSSHConfigPath(layout) if !strings.Contains(string(got), want) { t.Fatalf("user config missing %q:\n%s", want, got) } if !strings.Contains(string(got), bangerSSHIncludeBegin) { t.Fatalf("user config missing begin marker:\n%s", got) } } func TestInstallUserSSHIncludeIsIdempotent(t *testing.T) { homeDir := t.TempDir() t.Setenv("HOME", homeDir) layout := paths.Layout{ConfigDir: filepath.Join(homeDir, ".config", "banger")} if err := os.MkdirAll(layout.ConfigDir, 0o755); err != nil { t.Fatalf("MkdirAll: %v", err) } if err := os.WriteFile(BangerSSHConfigPath(layout), []byte("Host *.vm\n"), 0o644); err != nil { t.Fatalf("WriteFile: %v", err) } for i := 0; i < 3; i++ { if err := InstallUserSSHInclude(layout); err != nil { t.Fatalf("InstallUserSSHInclude (%d): %v", i, err) } } got, err := os.ReadFile(filepath.Join(homeDir, ".ssh", "config")) if err != nil { t.Fatalf("ReadFile: %v", err) } if n := strings.Count(string(got), bangerSSHIncludeBegin); n != 1 { t.Fatalf("begin markers = %d, want 1:\n%s", n, got) } } func TestInstallUserSSHIncludeMigratesLegacyInlineBlock(t *testing.T) { homeDir := t.TempDir() t.Setenv("HOME", homeDir) layout := paths.Layout{ConfigDir: filepath.Join(homeDir, ".config", "banger")} if err := os.MkdirAll(layout.ConfigDir, 0o755); err != nil { t.Fatalf("MkdirAll: %v", err) } if err := os.WriteFile(BangerSSHConfigPath(layout), []byte("Host *.vm\n"), 0o644); err != nil { t.Fatalf("WriteFile: %v", err) } sshDir := filepath.Join(homeDir, ".ssh") if err := os.MkdirAll(sshDir, 0o700); err != nil { t.Fatalf("MkdirAll(.ssh): %v", err) } legacy := strings.Join([]string{ "ServerAliveInterval 120", "", vmSSHConfigIncludeBegin, "Host *.vm", " User root", " IdentityFile /some/old/key", vmSSHConfigIncludeEnd, "", "Host other", " HostName 192.0.2.5", "", }, "\n") if err := os.WriteFile(filepath.Join(sshDir, "config"), []byte(legacy), 0o600); err != nil { t.Fatalf("seed legacy config: %v", err) } if err := InstallUserSSHInclude(layout); err != nil { t.Fatalf("InstallUserSSHInclude: %v", err) } got, err := os.ReadFile(filepath.Join(sshDir, "config")) if err != nil { t.Fatalf("ReadFile: %v", err) } gotStr := string(got) // Legacy inline block must be gone. if strings.Contains(gotStr, vmSSHConfigIncludeBegin) { t.Fatalf("legacy inline block survived:\n%s", gotStr) } // New Include block must be present. if !strings.Contains(gotStr, bangerSSHIncludeBegin) { t.Fatalf("new include block missing:\n%s", gotStr) } // Unrelated stanzas must be preserved. for _, want := range []string{"ServerAliveInterval 120", "Host other"} { if !strings.Contains(gotStr, want) { t.Fatalf("user config lost unrelated entry %q:\n%s", want, gotStr) } } } func TestUninstallUserSSHIncludeRemovesBothMarkerBlocks(t *testing.T) { homeDir := t.TempDir() t.Setenv("HOME", homeDir) sshDir := filepath.Join(homeDir, ".ssh") if err := os.MkdirAll(sshDir, 0o700); err != nil { t.Fatalf("MkdirAll: %v", err) } seed := strings.Join([]string{ "Host keep", " HostName 198.51.100.1", "", vmSSHConfigIncludeBegin, "Host *.vm", vmSSHConfigIncludeEnd, "", bangerSSHIncludeBegin, "Include /tmp/banger-ssh-config", bangerSSHIncludeEnd, "", }, "\n") if err := os.WriteFile(filepath.Join(sshDir, "config"), []byte(seed), 0o600); err != nil { t.Fatalf("seed: %v", err) } if err := UninstallUserSSHInclude(); err != nil { t.Fatalf("UninstallUserSSHInclude: %v", err) } got, err := os.ReadFile(filepath.Join(sshDir, "config")) if err != nil { t.Fatalf("ReadFile: %v", err) } gotStr := string(got) for _, banned := range []string{vmSSHConfigIncludeBegin, bangerSSHIncludeBegin} { if strings.Contains(gotStr, banned) { t.Fatalf("residue of %q:\n%s", banned, gotStr) } } if !strings.Contains(gotStr, "Host keep") { t.Fatalf("lost unrelated entry:\n%s", gotStr) } } func TestUninstallUserSSHIncludeIsNoOpWhenMissing(t *testing.T) { homeDir := t.TempDir() t.Setenv("HOME", homeDir) if err := UninstallUserSSHInclude(); err != nil { t.Fatalf("UninstallUserSSHInclude on missing file: %v", err) } // Still no ~/.ssh/config. if _, err := os.Stat(filepath.Join(homeDir, ".ssh", "config")); !os.IsNotExist(err) { t.Fatalf("~/.ssh/config unexpectedly created; stat err = %v", err) } } func TestUserSSHIncludeInstalledDetectsBothMarkers(t *testing.T) { for _, tc := range []struct { name string seed string wantIn bool }{ {"missing file", "", false}, {"unrelated only", "Host other\n HostName 1.2.3.4\n", false}, {"legacy marker", vmSSHConfigIncludeBegin + "\nHost *.vm\n" + vmSSHConfigIncludeEnd + "\n", true}, {"new marker", bangerSSHIncludeBegin + "\nInclude /tmp/banger\n" + bangerSSHIncludeEnd + "\n", true}, } { t.Run(tc.name, func(t *testing.T) { homeDir := t.TempDir() t.Setenv("HOME", homeDir) if tc.seed != "" { if err := os.MkdirAll(filepath.Join(homeDir, ".ssh"), 0o700); err != nil { t.Fatalf("MkdirAll: %v", err) } if err := os.WriteFile(filepath.Join(homeDir, ".ssh", "config"), []byte(tc.seed), 0o600); err != nil { t.Fatalf("WriteFile: %v", err) } } got, err := UserSSHIncludeInstalled() if err != nil { t.Fatalf("UserSSHIncludeInstalled: %v", err) } if got != tc.wantIn { t.Fatalf("got %v, want %v", got, tc.wantIn) } }) } }