package cli import ( "os" "path/filepath" "reflect" "strings" "testing" "banger/internal/api" "banger/internal/model" ) func TestNewBangerCommandHasExpectedSubcommands(t *testing.T) { cmd := NewBangerCommand() names := []string{} for _, sub := range cmd.Commands() { names = append(names, sub.Name()) } want := []string{"daemon", "image", "tui", "vm"} if !reflect.DeepEqual(names, want) { t.Fatalf("subcommands = %v, want %v", names, want) } } func TestVMCreateFlagsExist(t *testing.T) { root := NewBangerCommand() vm, _, err := root.Find([]string{"vm"}) if err != nil { t.Fatalf("find vm: %v", err) } create, _, err := vm.Find([]string{"create"}) if err != nil { t.Fatalf("find create: %v", err) } for _, flagName := range []string{"name", "image", "vcpu", "memory", "system-overlay-size", "disk-size", "nat", "no-start"} { if create.Flags().Lookup(flagName) == nil { t.Fatalf("missing flag %q", flagName) } } } func TestVMKillFlagsExist(t *testing.T) { root := NewBangerCommand() vm, _, err := root.Find([]string{"vm"}) if err != nil { t.Fatalf("find vm: %v", err) } kill, _, err := vm.Find([]string{"kill"}) if err != nil { t.Fatalf("find kill: %v", err) } if kill.Flags().Lookup("signal") == nil { t.Fatal("missing signal flag") } } func TestVMSetParamsFromFlags(t *testing.T) { params, err := vmSetParamsFromFlags("devbox", 4, 2048, "16G", true, false) if err != nil { t.Fatalf("vmSetParamsFromFlags: %v", err) } if params.IDOrName != "devbox" || params.VCPUCount == nil || *params.VCPUCount != 4 { t.Fatalf("unexpected params: %+v", params) } if params.MemoryMiB == nil || *params.MemoryMiB != 2048 { t.Fatalf("unexpected memory: %+v", params) } if params.WorkDiskSize != "16G" { t.Fatalf("unexpected disk size: %+v", params) } if params.NATEnabled == nil || !*params.NATEnabled { t.Fatalf("unexpected nat value: %+v", params) } } func TestVMSetParamsFromFlagsConflict(t *testing.T) { if _, err := vmSetParamsFromFlags("devbox", -1, -1, "", true, true); err == nil { t.Fatal("expected nat conflict error") } } func TestSSHCommandArgs(t *testing.T) { args, err := sshCommandArgs(model.DaemonConfig{SSHKeyPath: "/bundle/id_ed25519"}, "172.16.0.2", []string{"--", "uname", "-a"}) if err != nil { t.Fatalf("sshCommandArgs: %v", err) } want := []string{ "-i", "/bundle/id_ed25519", "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "root@172.16.0.2", "--", "uname", "-a", } if !reflect.DeepEqual(args, want) { t.Fatalf("args = %v, want %v", args, want) } } func TestValidateSSHPrereqs(t *testing.T) { dir := t.TempDir() keyPath := filepath.Join(dir, "id_ed25519") if err := os.WriteFile(keyPath, []byte("key"), 0o600); err != nil { t.Fatalf("write key: %v", err) } if err := validateSSHPrereqs(model.DaemonConfig{SSHKeyPath: keyPath}); err != nil { t.Fatalf("validateSSHPrereqs: %v", err) } } func TestValidateSSHPrereqsFailsForMissingKey(t *testing.T) { err := validateSSHPrereqs(model.DaemonConfig{SSHKeyPath: "/does/not/exist"}) if err == nil || !strings.Contains(err.Error(), "ssh private key") { t.Fatalf("validateSSHPrereqs() error = %v, want missing key", err) } } func TestNewBangerdCommandRejectsArgs(t *testing.T) { cmd := NewBangerdCommand() cmd.SetArgs([]string{"extra"}) if err := cmd.Execute(); err == nil { t.Fatal("expected extra args to be rejected") } } func TestDaemonOutdated(t *testing.T) { dir := t.TempDir() current := filepath.Join(dir, "bangerd-current") same := filepath.Join(dir, "bangerd-same") stale := filepath.Join(dir, "bangerd-stale") if err := os.WriteFile(current, []byte("current"), 0o755); err != nil { t.Fatalf("write current: %v", err) } if err := os.Link(current, same); err != nil { t.Fatalf("hard link: %v", err) } if err := os.WriteFile(stale, []byte("stale"), 0o755); err != nil { t.Fatalf("write stale: %v", err) } origBangerdPath := bangerdPathFunc origDaemonExePath := daemonExePath t.Cleanup(func() { bangerdPathFunc = origBangerdPath daemonExePath = origDaemonExePath }) bangerdPathFunc = func() (string, error) { return current, nil } daemonExePath = func(pid int) string { if pid == 1 { return same } return stale } if daemonOutdated(1) { t.Fatal("expected matching daemon executable to be current") } if !daemonOutdated(2) { t.Fatal("expected replaced daemon executable to be outdated") } } func TestAbsolutizeImageBuildPaths(t *testing.T) { dir := t.TempDir() prev, err := os.Getwd() if err != nil { t.Fatalf("getwd: %v", err) } if err := os.Chdir(dir); err != nil { t.Fatalf("chdir: %v", err) } t.Cleanup(func() { _ = os.Chdir(prev) }) params := api.ImageBuildParams{ BaseRootfs: "images/base.ext4", KernelPath: "/kernel", InitrdPath: "boot/initrd.img", ModulesDir: "modules", } if err := absolutizeImageBuildPaths(¶ms); err != nil { t.Fatalf("absolutizeImageBuildPaths: %v", err) } want := api.ImageBuildParams{ BaseRootfs: filepath.Join(dir, "images/base.ext4"), KernelPath: "/kernel", InitrdPath: filepath.Join(dir, "boot/initrd.img"), ModulesDir: filepath.Join(dir, "modules"), } if !reflect.DeepEqual(params, want) { t.Fatalf("params = %+v, want %+v", params, want) } }