package cli import ( "bytes" "context" "errors" "fmt" "io" "os" "os/exec" "path/filepath" "reflect" "strings" "testing" "time" "banger/internal/api" "banger/internal/model" "banger/internal/system" ) func TestNewBangerCommandHasExpectedSubcommands(t *testing.T) { cmd := NewBangerCommand() names := []string{} for _, sub := range cmd.Commands() { names = append(names, sub.Name()) } want := []string{"daemon", "doctor", "image", "internal", "vm"} if !reflect.DeepEqual(names, want) { t.Fatalf("subcommands = %v, want %v", names, want) } } func TestLegacyRemovedCommandIsRejected(t *testing.T) { cmd := NewBangerCommand() cmd.SetArgs([]string{"tui"}) err := cmd.Execute() if err == nil || !strings.Contains(err.Error(), "unknown command \"tui\"") { t.Fatalf("Execute() error = %v, want unknown legacy command", err) } } func TestDoctorCommandPrintsReportAndFailsOnHardFailures(t *testing.T) { original := doctorFunc t.Cleanup(func() { doctorFunc = original }) doctorFunc = func(context.Context) (system.Report, error) { return system.Report{ Checks: []system.CheckResult{ {Name: "runtime bundle", Status: system.CheckStatusPass, Details: []string{"runtime dir /tmp/runtime"}}, {Name: "feature nat", Status: system.CheckStatusFail, Details: []string{"missing iptables"}}, }, }, nil } cmd := NewBangerCommand() var stdout bytes.Buffer cmd.SetOut(&stdout) cmd.SetErr(&stdout) cmd.SetArgs([]string{"doctor"}) err := cmd.Execute() if err == nil || !strings.Contains(err.Error(), "doctor found failing checks") { t.Fatalf("Execute() error = %v, want doctor failure", err) } output := stdout.String() if !strings.Contains(output, "PASS\truntime bundle") { t.Fatalf("output = %q, want runtime bundle pass", output) } if !strings.Contains(output, "FAIL\tfeature nat") { t.Fatalf("output = %q, want feature nat fail", output) } } func TestDoctorCommandReturnsUnderlyingError(t *testing.T) { original := doctorFunc t.Cleanup(func() { doctorFunc = original }) doctorFunc = func(context.Context) (system.Report, error) { return system.Report{}, errors.New("load failed") } cmd := NewBangerCommand() cmd.SetArgs([]string{"doctor"}) err := cmd.Execute() if err == nil || !strings.Contains(err.Error(), "load failed") { t.Fatalf("Execute() error = %v, want load failed", err) } } func TestInternalNATFlagsExist(t *testing.T) { root := NewBangerCommand() internal, _, err := root.Find([]string{"internal"}) if err != nil { t.Fatalf("find internal: %v", err) } nat, _, err := internal.Find([]string{"nat"}) if err != nil { t.Fatalf("find nat: %v", err) } up, _, err := nat.Find([]string{"up"}) if err != nil { t.Fatalf("find nat up: %v", err) } for _, flagName := range []string{"guest-ip", "tap"} { if up.Flags().Lookup(flagName) == nil { t.Fatalf("missing flag %q", flagName) } } } func TestInternalPackagesCommandSupportsAlpine(t *testing.T) { cmd := NewBangerCommand() var stdout bytes.Buffer cmd.SetOut(&stdout) cmd.SetArgs([]string{"internal", "packages", "alpine"}) if err := cmd.Execute(); err != nil { t.Fatalf("Execute(): %v", err) } output := stdout.String() for _, want := range []string{"alpine-base", "docker", "libgcc", "libstdc++", "mkinitfs", "openssh"} { if !strings.Contains(output, want+"\n") { t.Fatalf("output = %q, want package %q", output, want) } } } func TestVMCreateFlagsExist(t *testing.T) { root := NewBangerCommand() vm, _, err := root.Find([]string{"vm"}) if err != nil { t.Fatalf("find vm: %v", err) } create, _, err := vm.Find([]string{"create"}) if err != nil { t.Fatalf("find create: %v", err) } for _, flagName := range []string{"name", "image", "vcpu", "memory", "system-overlay-size", "disk-size", "nat", "no-start"} { if create.Flags().Lookup(flagName) == nil { t.Fatalf("missing flag %q", flagName) } } } func TestVMCreateFlagsShowStaticDefaults(t *testing.T) { root := NewBangerCommand() vm, _, err := root.Find([]string{"vm"}) if err != nil { t.Fatalf("find vm: %v", err) } create, _, err := vm.Find([]string{"create"}) if err != nil { t.Fatalf("find create: %v", err) } if got := create.Flags().Lookup("vcpu").DefValue; got != fmt.Sprintf("%d", model.DefaultVCPUCount) { t.Fatalf("vcpu default = %q, want %d", got, model.DefaultVCPUCount) } if got := create.Flags().Lookup("memory").DefValue; got != fmt.Sprintf("%d", model.DefaultMemoryMiB) { t.Fatalf("memory default = %q, want %d", got, model.DefaultMemoryMiB) } if got := create.Flags().Lookup("system-overlay-size").DefValue; got != model.FormatSizeBytes(model.DefaultSystemOverlaySize) { t.Fatalf("system-overlay-size default = %q, want %q", got, model.FormatSizeBytes(model.DefaultSystemOverlaySize)) } if got := create.Flags().Lookup("disk-size").DefValue; got != model.FormatSizeBytes(model.DefaultWorkDiskSize) { t.Fatalf("disk-size default = %q, want %q", got, model.FormatSizeBytes(model.DefaultWorkDiskSize)) } } func TestImageRegisterFlagsExist(t *testing.T) { root := NewBangerCommand() image, _, err := root.Find([]string{"image"}) if err != nil { t.Fatalf("find image: %v", err) } register, _, err := image.Find([]string{"register"}) if err != nil { t.Fatalf("find register: %v", err) } for _, flagName := range []string{"name", "rootfs", "work-seed", "kernel", "initrd", "modules", "docker"} { if register.Flags().Lookup(flagName) == nil { t.Fatalf("missing flag %q", flagName) } } } func TestImagePromoteCommandExists(t *testing.T) { root := NewBangerCommand() image, _, err := root.Find([]string{"image"}) if err != nil { t.Fatalf("find image: %v", err) } if _, _, err := image.Find([]string{"promote"}); err != nil { t.Fatalf("find promote: %v", err) } } func TestVMKillFlagsExist(t *testing.T) { root := NewBangerCommand() vm, _, err := root.Find([]string{"vm"}) if err != nil { t.Fatalf("find vm: %v", err) } kill, _, err := vm.Find([]string{"kill"}) if err != nil { t.Fatalf("find kill: %v", err) } if kill.Flags().Lookup("signal") == nil { t.Fatal("missing signal flag") } } func TestVMPortsCommandExists(t *testing.T) { root := NewBangerCommand() vm, _, err := root.Find([]string{"vm"}) if err != nil { t.Fatalf("find vm: %v", err) } if _, _, err := vm.Find([]string{"ports"}); err != nil { t.Fatalf("find ports: %v", err) } } func TestVMPortsCommandRejectsMultipleRefs(t *testing.T) { cmd := NewBangerCommand() cmd.SetArgs([]string{"vm", "ports", "alpha", "beta"}) err := cmd.Execute() if err == nil || !strings.Contains(err.Error(), "usage: banger vm ports ") { t.Fatalf("Execute() error = %v, want single-vm usage error", err) } } func TestVMSetParamsFromFlags(t *testing.T) { params, err := vmSetParamsFromFlags("devbox", 4, 2048, "16G", true, false) if err != nil { t.Fatalf("vmSetParamsFromFlags: %v", err) } if params.IDOrName != "devbox" || params.VCPUCount == nil || *params.VCPUCount != 4 { t.Fatalf("unexpected params: %+v", params) } if params.MemoryMiB == nil || *params.MemoryMiB != 2048 { t.Fatalf("unexpected memory: %+v", params) } if params.WorkDiskSize != "16G" { t.Fatalf("unexpected disk size: %+v", params) } if params.NATEnabled == nil || !*params.NATEnabled { t.Fatalf("unexpected nat value: %+v", params) } } func TestVMCreateParamsFromFlagsOmitsStaticDefaultsWhenFlagsAreUnchanged(t *testing.T) { cmd := NewBangerCommand() vm, _, err := cmd.Find([]string{"vm"}) if err != nil { t.Fatalf("find vm: %v", err) } create, _, err := vm.Find([]string{"create"}) if err != nil { t.Fatalf("find create: %v", err) } params, err := vmCreateParamsFromFlags( create, "devbox", "default", model.DefaultVCPUCount, model.DefaultMemoryMiB, model.FormatSizeBytes(model.DefaultSystemOverlaySize), model.FormatSizeBytes(model.DefaultWorkDiskSize), false, false, ) if err != nil { t.Fatalf("vmCreateParamsFromFlags: %v", err) } if params.VCPUCount != nil || params.MemoryMiB != nil || params.SystemOverlaySize != "" || params.WorkDiskSize != "" { t.Fatalf("expected unchanged defaults to stay omitted: %+v", params) } } func TestVMCreateParamsFromFlagsIncludesChangedDiskFlags(t *testing.T) { cmd := NewBangerCommand() vm, _, err := cmd.Find([]string{"vm"}) if err != nil { t.Fatalf("find vm: %v", err) } create, _, err := vm.Find([]string{"create"}) if err != nil { t.Fatalf("find create: %v", err) } if err := create.Flags().Set("system-overlay-size", "16G"); err != nil { t.Fatalf("set system-overlay-size flag: %v", err) } if err := create.Flags().Set("disk-size", "32G"); err != nil { t.Fatalf("set disk-size flag: %v", err) } params, err := vmCreateParamsFromFlags(create, "devbox", "default", model.DefaultVCPUCount, model.DefaultMemoryMiB, "16G", "32G", false, false) if err != nil { t.Fatalf("vmCreateParamsFromFlags: %v", err) } if params.SystemOverlaySize != "16G" || params.WorkDiskSize != "32G" { t.Fatalf("expected changed disk flags to be included: %+v", params) } } func TestVMCreateParamsFromFlagsRejectsNonPositiveCPUAndMemory(t *testing.T) { cmd := NewBangerCommand() vm, _, err := cmd.Find([]string{"vm"}) if err != nil { t.Fatalf("find vm: %v", err) } create, _, err := vm.Find([]string{"create"}) if err != nil { t.Fatalf("find create: %v", err) } if err := create.Flags().Set("vcpu", "0"); err != nil { t.Fatalf("set vcpu flag: %v", err) } if _, err := vmCreateParamsFromFlags(create, "devbox", "default", 0, 0, "", "", false, false); err == nil || !strings.Contains(err.Error(), "vcpu must be a positive integer") { t.Fatalf("vmCreateParamsFromFlags(vcpu=0) error = %v", err) } if err := create.Flags().Set("memory", "-1"); err != nil { t.Fatalf("set memory flag: %v", err) } if _, err := vmCreateParamsFromFlags(create, "devbox", "default", 1, -1, "", "", false, false); err == nil || !strings.Contains(err.Error(), "memory must be a positive integer") { t.Fatalf("vmCreateParamsFromFlags(memory=-1) error = %v", err) } } func TestRunVMCreatePollsUntilDone(t *testing.T) { origBegin := vmCreateBeginFunc origStatus := vmCreateStatusFunc origCancel := vmCreateCancelFunc t.Cleanup(func() { vmCreateBeginFunc = origBegin vmCreateStatusFunc = origStatus vmCreateCancelFunc = origCancel }) vm := model.VMRecord{ ID: "vm-id", Name: "devbox", Spec: model.VMSpec{WorkDiskSizeBytes: model.DefaultWorkDiskSize}, Runtime: model.VMRuntime{ State: model.VMStateRunning, GuestIP: "172.16.0.2", DNSName: "devbox.vm", }, } vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { return api.VMCreateBeginResult{ Operation: api.VMCreateOperation{ ID: "op-1", Stage: "prepare_work_disk", Detail: "cloning work seed", }, }, nil } statusCalls := 0 vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) { statusCalls++ if statusCalls == 1 { return api.VMCreateStatusResult{ Operation: api.VMCreateOperation{ ID: "op-1", Stage: "wait_opencode", Detail: "waiting for opencode on guest port 4096", }, }, nil } return api.VMCreateStatusResult{ Operation: api.VMCreateOperation{ ID: "op-1", Stage: "ready", Detail: "vm is ready", Done: true, Success: true, VM: &vm, }, }, nil } vmCreateCancelFunc = func(context.Context, string, string) error { t.Fatal("cancel should not be called") return nil } got, err := runVMCreate(context.Background(), "/tmp/bangerd.sock", &bytes.Buffer{}, api.VMCreateParams{Name: "devbox"}) if err != nil { t.Fatalf("runVMCreate: %v", err) } if got.Name != vm.Name || got.Runtime.GuestIP != vm.Runtime.GuestIP { t.Fatalf("vm = %+v, want %+v", got, vm) } if statusCalls != 2 { t.Fatalf("statusCalls = %d, want 2", statusCalls) } } func TestVMCreateProgressRendererSuppressesDuplicateLines(t *testing.T) { var stderr bytes.Buffer renderer := &vmCreateProgressRenderer{out: &stderr, enabled: true} renderer.render(api.VMCreateOperation{Stage: "prepare_work_disk", Detail: "cloning work seed"}) renderer.render(api.VMCreateOperation{Stage: "prepare_work_disk", Detail: "cloning work seed"}) renderer.render(api.VMCreateOperation{Stage: "wait_opencode", Detail: "waiting for opencode on guest port 4096"}) lines := strings.Split(strings.TrimSpace(stderr.String()), "\n") if len(lines) != 2 { t.Fatalf("rendered lines = %q, want 2 lines", stderr.String()) } if lines[0] != "[vm create] preparing work disk: cloning work seed" { t.Fatalf("first line = %q", lines[0]) } if lines[1] != "[vm create] waiting for opencode: waiting for opencode on guest port 4096" { t.Fatalf("second line = %q", lines[1]) } } func TestVMSetParamsFromFlagsConflict(t *testing.T) { if _, err := vmSetParamsFromFlags("devbox", -1, -1, "", true, true); err == nil { t.Fatal("expected nat conflict error") } } func TestVMSetParamsFromFlagsRejectsNonPositiveCPUAndMemory(t *testing.T) { if _, err := vmSetParamsFromFlags("devbox", 0, -1, "", false, false); err == nil || !strings.Contains(err.Error(), "vcpu must be a positive integer") { t.Fatalf("vmSetParamsFromFlags(vcpu=0) error = %v", err) } if _, err := vmSetParamsFromFlags("devbox", -1, 0, "", false, false); err == nil || !strings.Contains(err.Error(), "memory must be a positive integer") { t.Fatalf("vmSetParamsFromFlags(memory=0) error = %v", err) } } func TestAbsolutizeImageRegisterPaths(t *testing.T) { tmp := t.TempDir() params := api.ImageRegisterParams{ RootfsPath: filepath.Join(".", "runtime", "rootfs-void.ext4"), WorkSeedPath: filepath.Join(".", "runtime", "rootfs-void.work-seed.ext4"), KernelPath: filepath.Join(".", "runtime", "vmlinux"), InitrdPath: filepath.Join(".", "runtime", "initrd.img"), ModulesDir: filepath.Join(".", "runtime", "modules"), } wd, err := os.Getwd() if err != nil { t.Fatalf("Getwd: %v", err) } if err := os.Chdir(tmp); err != nil { t.Fatalf("Chdir(%s): %v", tmp, err) } t.Cleanup(func() { _ = os.Chdir(wd) }) if err := absolutizeImageRegisterPaths(¶ms); err != nil { t.Fatalf("absolutizeImageRegisterPaths: %v", err) } for _, value := range []string{ params.RootfsPath, params.WorkSeedPath, params.KernelPath, params.InitrdPath, params.ModulesDir, } { if !filepath.IsAbs(value) { t.Fatalf("path %q is not absolute", value) } } } func TestPrintImageListTableShowsRootfsSizes(t *testing.T) { rootfs := filepath.Join(t.TempDir(), "rootfs.ext4") if err := os.WriteFile(rootfs, nil, 0o644); err != nil { t.Fatalf("WriteFile(%s): %v", rootfs, err) } if err := os.Truncate(rootfs, 8*1024); err != nil { t.Fatalf("Truncate(%s): %v", rootfs, err) } var out bytes.Buffer err := printImageListTable(&out, []model.Image{ { ID: "0123456789abcdef", Name: "alpine", Managed: true, RootfsPath: rootfs, CreatedAt: time.Now().Add(-1 * time.Hour), }, { ID: "fedcba9876543210", Name: "missing", Managed: false, RootfsPath: filepath.Join(t.TempDir(), "missing.ext4"), CreatedAt: time.Now().Add(-2 * time.Hour), }, }) if err != nil { t.Fatalf("printImageListTable() error = %v", err) } output := out.String() if !strings.Contains(output, "ROOTFS SIZE") { t.Fatalf("output = %q, want rootfs size header", output) } if !strings.Contains(output, "alpine") || !strings.Contains(output, "8K") { t.Fatalf("output = %q, want alpine row with 8K size", output) } if strings.Contains(output, rootfs) { t.Fatalf("output = %q, should not include rootfs path", output) } if !strings.Contains(output, "missing") || !strings.Contains(output, "-") { t.Fatalf("output = %q, want fallback size for missing image", output) } } func TestPrintVMPortsTableSortsAndRendersURLEndpoints(t *testing.T) { result := api.VMPortsResult{ Name: "alpha", Ports: []api.VMPort{ { Proto: "https", Port: 443, Endpoint: "https://alpha.vm:443/", Process: "caddy", Command: "caddy run", }, { Proto: "udp", Port: 53, Endpoint: "alpha.vm:53", Process: "dnsd", Command: "dnsd --foreground", }, }, } var out bytes.Buffer if err := printVMPortsTable(&out, result); err != nil { t.Fatalf("printVMPortsTable: %v", err) } lines := strings.Split(strings.TrimSpace(out.String()), "\n") if len(lines) != 3 { t.Fatalf("lines = %q, want header + 2 rows", lines) } if !strings.Contains(lines[0], "PROTO") || !strings.Contains(lines[0], "ENDPOINT") || strings.Contains(lines[0], "VM") || strings.Contains(lines[0], "WEB") { t.Fatalf("header = %q, want PROTO/ENDPOINT without VM/WEB", lines[0]) } if !strings.Contains(lines[1], "https") || !strings.Contains(lines[1], "https://alpha.vm:443/") { t.Fatalf("first row = %q, want https endpoint row", lines[1]) } if !strings.Contains(lines[2], "udp") || !strings.Contains(lines[2], "alpha.vm:53") { t.Fatalf("second row = %q, want udp endpoint row", lines[2]) } } func TestRunSSHSessionPrintsReminderWhenHealthCheckPasses(t *testing.T) { origSSHExec := sshExecFunc origHealth := vmHealthFunc t.Cleanup(func() { sshExecFunc = origSSHExec vmHealthFunc = origHealth }) sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { return nil } vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { return api.VMHealthResult{Name: "devbox", Healthy: true}, nil } var stderr bytes.Buffer if err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}); err != nil { t.Fatalf("runSSHSession: %v", err) } if !strings.Contains(stderr.String(), "devbox is still running") { t.Fatalf("stderr = %q, want reminder", stderr.String()) } } func TestRunSSHSessionPreservesSSHExitStatusOnHealthWarning(t *testing.T) { origSSHExec := sshExecFunc origHealth := vmHealthFunc t.Cleanup(func() { sshExecFunc = origSSHExec vmHealthFunc = origHealth }) sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { return exitErrorWithCode(t, 1) } vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { return api.VMHealthResult{}, errors.New("dial failed") } var stderr bytes.Buffer err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}) var exitErr *exec.ExitError if !errors.As(err, &exitErr) { t.Fatalf("runSSHSession error = %v, want exit error", err) } if !strings.Contains(stderr.String(), "failed to check whether devbox is still running") { t.Fatalf("stderr = %q, want warning", stderr.String()) } } func TestRunSSHSessionSkipsReminderOnSSHAuthFailure(t *testing.T) { origSSHExec := sshExecFunc origHealth := vmHealthFunc t.Cleanup(func() { sshExecFunc = origSSHExec vmHealthFunc = origHealth }) healthCalled := false sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { return exitErrorWithCode(t, 255) } vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { healthCalled = true return api.VMHealthResult{Name: "devbox", Healthy: true}, nil } var stderr bytes.Buffer err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}) var exitErr *exec.ExitError if !errors.As(err, &exitErr) || exitErr.ExitCode() != 255 { t.Fatalf("runSSHSession error = %v, want exit 255", err) } if healthCalled { t.Fatal("vm health should not run after ssh auth failure") } if strings.Contains(stderr.String(), "still running") { t.Fatalf("stderr = %q, should not contain reminder", stderr.String()) } } func TestResolveVMTargetsDeduplicatesAndReportsErrors(t *testing.T) { vms := []model.VMRecord{ testCLIResolvedVM("alpha-id", "alpha"), testCLIResolvedVM("alpine-id", "alpine"), testCLIResolvedVM("bravo-id", "bravo"), } targets, errs := resolveVMTargets(vms, []string{"alpha", "alpha-id", "al", "missing", "br"}) if len(targets) != 2 { t.Fatalf("len(targets) = %d, want 2", len(targets)) } if targets[0].VM.ID != "alpha-id" || targets[0].Ref != "alpha" { t.Fatalf("targets[0] = %+v, want alpha target", targets[0]) } if targets[1].VM.ID != "bravo-id" || targets[1].Ref != "br" { t.Fatalf("targets[1] = %+v, want bravo target", targets[1]) } if len(errs) != 2 { t.Fatalf("len(errs) = %d, want 2", len(errs)) } if errs[0].Ref != "al" || !strings.Contains(errs[0].Err.Error(), "multiple VMs match") { t.Fatalf("errs[0] = %+v, want ambiguous prefix", errs[0]) } if errs[1].Ref != "missing" || !strings.Contains(errs[1].Err.Error(), `vm "missing" not found`) { t.Fatalf("errs[1] = %+v, want missing vm", errs[1]) } } func TestResolveVMRefPrefersExactMatchBeforePrefix(t *testing.T) { vms := []model.VMRecord{ testCLIResolvedVM("1111111111111111111111111111111111111111111111111111111111111111", "alpha"), testCLIResolvedVM("alpha222222222222222222222222222222222222222222222222222222222222", "bravo"), } vm, err := resolveVMRef(vms, "alpha") if err != nil { t.Fatalf("resolveVMRef(alpha): %v", err) } if vm.Name != "alpha" { t.Fatalf("resolveVMRef(alpha) = %+v, want exact-name vm", vm) } } func TestExecuteVMActionBatchRunsConcurrentlyAndPreservesOrder(t *testing.T) { targets := []resolvedVMTarget{ {Ref: "alpha", VM: testCLIResolvedVM("alpha-id", "alpha")}, {Ref: "bravo", VM: testCLIResolvedVM("bravo-id", "bravo")}, } started := make(chan string, len(targets)) release := make(chan struct{}) done := make(chan []vmBatchActionResult, 1) go func() { done <- executeVMActionBatch(context.Background(), targets, func(ctx context.Context, id string) (model.VMRecord, error) { started <- id <-release return model.VMRecord{ID: id, Name: id}, nil }) }() for range targets { select { case <-started: case <-time.After(500 * time.Millisecond): t.Fatal("batch actions did not overlap") } } close(release) var results []vmBatchActionResult select { case results = <-done: case <-time.After(500 * time.Millisecond): t.Fatal("executeVMActionBatch did not finish") } if len(results) != len(targets) { t.Fatalf("len(results) = %d, want %d", len(results), len(targets)) } for index, result := range results { if result.Target.Ref != targets[index].Ref { t.Fatalf("results[%d].Target.Ref = %q, want %q", index, result.Target.Ref, targets[index].Ref) } if result.VM.ID != targets[index].VM.ID { t.Fatalf("results[%d].VM.ID = %q, want %q", index, result.VM.ID, targets[index].VM.ID) } } } func TestSSHCommandArgs(t *testing.T) { args, err := sshCommandArgs(model.DaemonConfig{SSHKeyPath: "/bundle/id_ed25519"}, "172.16.0.2", []string{"--", "uname", "-a"}) if err != nil { t.Fatalf("sshCommandArgs: %v", err) } want := []string{ "-F", "/dev/null", "-i", "/bundle/id_ed25519", "-o", "IdentitiesOnly=yes", "-o", "BatchMode=yes", "-o", "PreferredAuthentications=publickey", "-o", "PasswordAuthentication=no", "-o", "KbdInteractiveAuthentication=no", "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "root@172.16.0.2", "--", "uname", "-a", } if !reflect.DeepEqual(args, want) { t.Fatalf("args = %v, want %v", args, want) } } func TestValidateSSHPrereqs(t *testing.T) { dir := t.TempDir() keyPath := filepath.Join(dir, "id_ed25519") if err := os.WriteFile(keyPath, []byte("key"), 0o600); err != nil { t.Fatalf("write key: %v", err) } if err := validateSSHPrereqs(model.DaemonConfig{SSHKeyPath: keyPath}); err != nil { t.Fatalf("validateSSHPrereqs: %v", err) } } func exitErrorWithCode(t *testing.T, code int) *exec.ExitError { t.Helper() cmd := exec.Command("bash", "-lc", fmt.Sprintf("exit %d", code)) err := cmd.Run() var exitErr *exec.ExitError if !errors.As(err, &exitErr) { t.Fatalf("exitErrorWithCode(%d) error = %v, want exit error", code, err) } return exitErr } func TestValidateSSHPrereqsFailsForMissingKey(t *testing.T) { err := validateSSHPrereqs(model.DaemonConfig{SSHKeyPath: "/does/not/exist"}) if err == nil || !strings.Contains(err.Error(), "ssh private key") { t.Fatalf("validateSSHPrereqs() error = %v, want missing key", err) } } func TestNewBangerdCommandRejectsArgs(t *testing.T) { cmd := NewBangerdCommand() cmd.SetArgs([]string{"extra"}) if err := cmd.Execute(); err == nil { t.Fatal("expected extra args to be rejected") } } func TestDaemonOutdated(t *testing.T) { dir := t.TempDir() current := filepath.Join(dir, "bangerd-current") same := filepath.Join(dir, "bangerd-same") stale := filepath.Join(dir, "bangerd-stale") if err := os.WriteFile(current, []byte("current"), 0o755); err != nil { t.Fatalf("write current: %v", err) } if err := os.Link(current, same); err != nil { t.Fatalf("hard link: %v", err) } if err := os.WriteFile(stale, []byte("stale"), 0o755); err != nil { t.Fatalf("write stale: %v", err) } origBangerdPath := bangerdPathFunc origDaemonExePath := daemonExePath t.Cleanup(func() { bangerdPathFunc = origBangerdPath daemonExePath = origDaemonExePath }) bangerdPathFunc = func() (string, error) { return current, nil } daemonExePath = func(pid int) string { if pid == 1 { return same } return stale } if daemonOutdated(1) { t.Fatal("expected matching daemon executable to be current") } if !daemonOutdated(2) { t.Fatal("expected replaced daemon executable to be outdated") } } func TestDaemonStatusIncludesLogPathWhenStopped(t *testing.T) { configHome := filepath.Join(t.TempDir(), "config") stateHome := filepath.Join(t.TempDir(), "state") runtimeHome := filepath.Join(t.TempDir(), "runtime") t.Setenv("XDG_CONFIG_HOME", configHome) t.Setenv("XDG_STATE_HOME", stateHome) t.Setenv("XDG_RUNTIME_DIR", runtimeHome) cmd := NewBangerCommand() var stdout bytes.Buffer cmd.SetOut(&stdout) cmd.SetErr(&stdout) cmd.SetArgs([]string{"daemon", "status"}) if err := cmd.Execute(); err != nil { t.Fatalf("Execute: %v", err) } output := stdout.String() if !strings.Contains(output, "stopped\n") { t.Fatalf("output = %q, want stopped status", output) } if !strings.Contains(output, "log: "+filepath.Join(stateHome, "banger", "bangerd.log")) { t.Fatalf("output = %q, want daemon log path", output) } if !strings.Contains(output, "dns: 127.0.0.1:42069") { t.Fatalf("output = %q, want dns listener", output) } if !strings.Contains(output, "web: http://127.0.0.1:7777") { t.Fatalf("output = %q, want default web listener", output) } } func TestBuildDaemonCommandIsDetachedFromCallerContext(t *testing.T) { cmd := buildDaemonCommand("/tmp/bangerd") if cmd.Path != "/tmp/bangerd" { t.Fatalf("command path = %q", cmd.Path) } if cmd.Cancel != nil { t.Fatal("daemon process should not be tied to a CLI request context") } } func TestAbsolutizeImageBuildPaths(t *testing.T) { dir := t.TempDir() prev, err := os.Getwd() if err != nil { t.Fatalf("getwd: %v", err) } if err := os.Chdir(dir); err != nil { t.Fatalf("chdir: %v", err) } t.Cleanup(func() { _ = os.Chdir(prev) }) params := api.ImageBuildParams{ FromImage: "base-image", KernelPath: "/kernel", InitrdPath: "boot/initrd.img", ModulesDir: "modules", } if err := absolutizeImageBuildPaths(¶ms); err != nil { t.Fatalf("absolutizeImageBuildPaths: %v", err) } want := api.ImageBuildParams{ FromImage: "base-image", KernelPath: "/kernel", InitrdPath: filepath.Join(dir, "boot/initrd.img"), ModulesDir: filepath.Join(dir, "modules"), } if !reflect.DeepEqual(params, want) { t.Fatalf("params = %+v, want %+v", params, want) } } func testCLIResolvedVM(id, name string) model.VMRecord { return model.VMRecord{ID: id, Name: name} }