package config import ( "os" "path/filepath" "strings" "testing" "time" "banger/internal/paths" ) func TestLoadDefaultsResolveFirecrackerAndGenerateSSHKey(t *testing.T) { configDir := t.TempDir() binDir := t.TempDir() firecrackerPath := filepath.Join(binDir, "firecracker") if err := os.WriteFile(firecrackerPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { t.Fatalf("write firecracker: %v", err) } t.Setenv("PATH", binDir) cfg, err := Load(paths.Layout{ConfigDir: configDir}) if err != nil { t.Fatalf("Load: %v", err) } if cfg.FirecrackerBin != firecrackerPath { t.Fatalf("FirecrackerBin = %q, want %q", cfg.FirecrackerBin, firecrackerPath) } wantKey := filepath.Join(configDir, "ssh", "id_ed25519") if cfg.SSHKeyPath != wantKey { t.Fatalf("SSHKeyPath = %q, want %q", cfg.SSHKeyPath, wantKey) } for _, path := range []string{wantKey, wantKey + ".pub"} { if _, err := os.Stat(path); err != nil { t.Fatalf("stat %s: %v", path, err) } } if cfg.DefaultImageName != "debian-bookworm" { t.Fatalf("DefaultImageName = %q, want debian-bookworm", cfg.DefaultImageName) } if cfg.WebListenAddr != "127.0.0.1:7777" { t.Fatalf("WebListenAddr = %q", cfg.WebListenAddr) } } func TestLoadAppliesConfigOverrides(t *testing.T) { configDir := t.TempDir() data := []byte(` log_level = "debug" web_listen_addr = "" firecracker_bin = "/opt/firecracker" ssh_key_path = "/tmp/custom-key" default_image_name = "void" auto_stop_stale_after = "1h" stats_poll_interval = "15s" metrics_poll_interval = "30s" bridge_name = "br-test" bridge_ip = "10.0.0.1" cidr = "25" tap_pool_size = 8 default_dns = "9.9.9.9" `) if err := os.WriteFile(filepath.Join(configDir, "config.toml"), data, 0o644); err != nil { t.Fatalf("write config.toml: %v", err) } cfg, err := Load(paths.Layout{ConfigDir: configDir}) if err != nil { t.Fatalf("Load: %v", err) } if cfg.LogLevel != "debug" { t.Fatalf("LogLevel = %q", cfg.LogLevel) } if cfg.WebListenAddr != "" { t.Fatalf("WebListenAddr = %q, want empty", cfg.WebListenAddr) } if cfg.FirecrackerBin != "/opt/firecracker" { t.Fatalf("FirecrackerBin = %q", cfg.FirecrackerBin) } if cfg.SSHKeyPath != "/tmp/custom-key" { t.Fatalf("SSHKeyPath = %q", cfg.SSHKeyPath) } if cfg.DefaultImageName != "void" { t.Fatalf("DefaultImageName = %q", cfg.DefaultImageName) } if cfg.AutoStopStaleAfter != time.Hour { t.Fatalf("AutoStopStaleAfter = %s", cfg.AutoStopStaleAfter) } if cfg.StatsPollInterval != 15*time.Second { t.Fatalf("StatsPollInterval = %s", cfg.StatsPollInterval) } if cfg.MetricsPollInterval != 30*time.Second { t.Fatalf("MetricsPollInterval = %s", cfg.MetricsPollInterval) } if cfg.BridgeName != "br-test" || cfg.BridgeIP != "10.0.0.1" || cfg.CIDR != "25" { t.Fatalf("bridge config = %+v", cfg) } if cfg.TapPoolSize != 8 { t.Fatalf("TapPoolSize = %d", cfg.TapPoolSize) } if cfg.DefaultDNS != "9.9.9.9" { t.Fatalf("DefaultDNS = %q", cfg.DefaultDNS) } } func TestLoadAppliesLogLevelEnvOverride(t *testing.T) { t.Setenv("BANGER_LOG_LEVEL", "warn") cfg, err := Load(paths.Layout{ConfigDir: t.TempDir()}) if err != nil { t.Fatalf("Load: %v", err) } if cfg.LogLevel != "warn" { t.Fatalf("LogLevel = %q, want warn", cfg.LogLevel) } } func TestLoadAcceptsFileSyncEntries(t *testing.T) { configDir := t.TempDir() data := []byte(` [[file_sync]] host = "~/.aws" guest = "~/.aws" [[file_sync]] host = "/etc/resolv.conf" guest = "/root/.config/resolv.conf" mode = "0644" `) if err := os.WriteFile(filepath.Join(configDir, "config.toml"), data, 0o644); err != nil { t.Fatal(err) } cfg, err := Load(paths.Layout{ConfigDir: configDir}) if err != nil { t.Fatalf("Load: %v", err) } if len(cfg.FileSync) != 2 { t.Fatalf("FileSync = %+v", cfg.FileSync) } if cfg.FileSync[0].Host != "~/.aws" || cfg.FileSync[0].Guest != "~/.aws" { t.Fatalf("entry[0] = %+v", cfg.FileSync[0]) } if cfg.FileSync[1].Mode != "0644" { t.Fatalf("entry[1] mode = %q", cfg.FileSync[1].Mode) } } func TestLoadRejectsInvalidFileSyncEntries(t *testing.T) { cases := []struct { name string toml string want string }{ { "empty host", `[[file_sync]]` + "\n" + `host = ""` + "\n" + `guest = "~/foo"`, "host path is required", }, { "empty guest", `[[file_sync]]` + "\n" + `host = "~/foo"` + "\n" + `guest = ""`, "guest path is required", }, { "relative host", `[[file_sync]]` + "\n" + `host = "foo/bar"` + "\n" + `guest = "~/foo"`, "must be absolute", }, { "guest outside /root", `[[file_sync]]` + "\n" + `host = "~/x"` + "\n" + `guest = "/etc/resolv.conf"`, "must be under /root or ~/", }, { "path traversal", `[[file_sync]]` + "\n" + `host = "~/../secrets"` + "\n" + `guest = "~/secrets"`, "'..' segments", }, { "tilde user", `[[file_sync]]` + "\n" + `host = "~other/foo"` + "\n" + `guest = "~/foo"`, "only '~/' is expanded", }, { "invalid mode", `[[file_sync]]` + "\n" + `host = "~/x"` + "\n" + `guest = "~/x"` + "\n" + `mode = "rwx"`, "must be octal", }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { configDir := t.TempDir() if err := os.WriteFile(filepath.Join(configDir, "config.toml"), []byte(tc.toml+"\n"), 0o644); err != nil { t.Fatal(err) } _, err := Load(paths.Layout{ConfigDir: configDir}) if err == nil { t.Fatalf("Load: want error containing %q", tc.want) } if !strings.Contains(err.Error(), tc.want) { t.Fatalf("Load error = %v, want contains %q", err, tc.want) } }) } }