banger/internal/cli/cli_test.go
Thales Maciel 108f7a0600
ssh-config: make the ssh <name>.vm shortcut opt-in
Before this change, every daemon.Open() wrote a Host *.vm stanza into
~/.ssh/config in a marker-fenced block. That's a real footgun for users
who manage their SSH config declaratively (chezmoi, dotfiles, NixOS):
banger was mutating host state outside its own directory on every
daemon start, easy to miss and hard to audit.

New contract: the daemon only ever writes its own ssh_config file at
~/.config/banger/ssh_config. ~/.ssh/config is untouched unless the user
opts in. `banger vm ssh <name>` still works out of the box — the
shortcut only matters for plain `ssh sandbox.vm` from any terminal.

The opt-in surface is `banger ssh-config`:

  banger ssh-config              # prints path + include-line +
                                 # install/uninstall hints
  banger ssh-config --install    # adds `Include <bangerConfig>` to
                                 # ~/.ssh/config inside a marker-fenced
                                 # block; idempotent; migrates any
                                 # legacy inline Host *.vm block from
                                 # pre-opt-in builds
  banger ssh-config --uninstall  # removes the new Include block AND
                                 # any legacy inline block

Doctor gains a gentle warn-level note when banger's ssh_config exists
but the user hasn't wired it in — not a fail, since the shortcut is
convenience and `banger vm ssh` covers the essential case.

Tests cover: daemon writes banger file and does NOT touch ~/.ssh/config,
Install adds the block, Install is idempotent, Install migrates the
legacy inline block cleanly (removing it, preserving unrelated
entries, adding the new Include block), Uninstall removes both marker
variants, Uninstall is a no-op when ~/.ssh/config is absent, and
UserSSHIncludeInstalled detects both marker shapes.

README reframes the feature as optional convenience.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-20 13:57:26 -03:00

2097 lines
65 KiB
Go

package cli
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"reflect"
"strings"
"testing"
"time"
"banger/internal/api"
"banger/internal/buildinfo"
"banger/internal/daemon/workspace"
"banger/internal/model"
"banger/internal/system"
"banger/internal/toolingplan"
"github.com/spf13/cobra"
)
func TestNewBangerCommandHasExpectedSubcommands(t *testing.T) {
cmd := NewBangerCommand()
names := []string{}
for _, sub := range cmd.Commands() {
names = append(names, sub.Name())
}
want := []string{"daemon", "doctor", "image", "internal", "kernel", "ps", "ssh-config", "version", "vm"}
if !reflect.DeepEqual(names, want) {
t.Fatalf("subcommands = %v, want %v", names, want)
}
}
func TestVersionCommandPrintsBuildInfo(t *testing.T) {
cmd := NewBangerCommand()
var stdout bytes.Buffer
cmd.SetOut(&stdout)
cmd.SetErr(&stdout)
cmd.SetArgs([]string{"version"})
if err := cmd.Execute(); err != nil {
t.Fatalf("Execute: %v", err)
}
info := buildinfo.Current()
output := stdout.String()
for _, want := range []string{
"version: " + info.Version,
"commit: " + info.Commit,
"built_at: " + info.BuiltAt,
} {
if !strings.Contains(output, want) {
t.Fatalf("output = %q, want %q", output, want)
}
}
}
func TestImageCommandIncludesPull(t *testing.T) {
cmd := NewBangerCommand()
var image *cobra.Command
for _, sub := range cmd.Commands() {
if sub.Name() == "image" {
image = sub
break
}
}
if image == nil {
t.Fatalf("image command missing from root")
}
hasPull := false
for _, sub := range image.Commands() {
if sub.Name() == "pull" {
hasPull = true
if flag := sub.Flags().Lookup("kernel-ref"); flag == nil {
t.Errorf("image pull missing --kernel-ref flag")
}
if flag := sub.Flags().Lookup("size"); flag == nil {
t.Errorf("image pull missing --size flag")
}
}
}
if !hasPull {
t.Fatalf("image pull subcommand missing")
}
}
func TestKernelCommandExposesSubcommands(t *testing.T) {
cmd := NewBangerCommand()
var kernel *cobra.Command
for _, sub := range cmd.Commands() {
if sub.Name() == "kernel" {
kernel = sub
break
}
}
if kernel == nil {
t.Fatalf("kernel command missing from root")
}
names := []string{}
for _, sub := range kernel.Commands() {
names = append(names, sub.Name())
}
want := []string{"import", "list", "pull", "rm", "show"}
if !reflect.DeepEqual(names, want) {
t.Fatalf("kernel subcommands = %v, want %v", names, want)
}
}
func TestLegacyRemovedCommandIsRejected(t *testing.T) {
cmd := NewBangerCommand()
cmd.SetArgs([]string{"tui"})
err := cmd.Execute()
if err == nil || !strings.Contains(err.Error(), "unknown command \"tui\"") {
t.Fatalf("Execute() error = %v, want unknown legacy command", err)
}
}
func TestDoctorCommandPrintsReportAndFailsOnHardFailures(t *testing.T) {
d := defaultDeps()
d.doctor = func(context.Context) (system.Report, error) {
return system.Report{
Checks: []system.CheckResult{
{Name: "runtime bundle", Status: system.CheckStatusPass, Details: []string{"runtime dir /tmp/runtime"}},
{Name: "feature nat", Status: system.CheckStatusFail, Details: []string{"missing iptables"}},
},
}, nil
}
cmd := d.newRootCommand()
var stdout bytes.Buffer
cmd.SetOut(&stdout)
cmd.SetErr(&stdout)
cmd.SetArgs([]string{"doctor"})
err := cmd.Execute()
if err == nil || !strings.Contains(err.Error(), "doctor found failing checks") {
t.Fatalf("Execute() error = %v, want doctor failure", err)
}
output := stdout.String()
if !strings.Contains(output, "PASS\truntime bundle") {
t.Fatalf("output = %q, want runtime bundle pass", output)
}
if !strings.Contains(output, "FAIL\tfeature nat") {
t.Fatalf("output = %q, want feature nat fail", output)
}
}
func TestDoctorCommandReturnsUnderlyingError(t *testing.T) {
d := defaultDeps()
d.doctor = func(context.Context) (system.Report, error) {
return system.Report{}, errors.New("load failed")
}
cmd := d.newRootCommand()
cmd.SetArgs([]string{"doctor"})
err := cmd.Execute()
if err == nil || !strings.Contains(err.Error(), "load failed") {
t.Fatalf("Execute() error = %v, want load failed", err)
}
}
func TestInternalNATFlagsExist(t *testing.T) {
root := NewBangerCommand()
internal, _, err := root.Find([]string{"internal"})
if err != nil {
t.Fatalf("find internal: %v", err)
}
nat, _, err := internal.Find([]string{"nat"})
if err != nil {
t.Fatalf("find nat: %v", err)
}
up, _, err := nat.Find([]string{"up"})
if err != nil {
t.Fatalf("find nat up: %v", err)
}
for _, flagName := range []string{"guest-ip", "tap"} {
if up.Flags().Lookup(flagName) == nil {
t.Fatalf("missing flag %q", flagName)
}
}
}
func TestPSAndVMListAliasesAndFlagsExist(t *testing.T) {
root := NewBangerCommand()
ps, _, err := root.Find([]string{"ps"})
if err != nil {
t.Fatalf("find ps: %v", err)
}
for _, flagName := range []string{"all", "latest", "quiet"} {
if ps.Flags().Lookup(flagName) == nil {
t.Fatalf("missing ps flag %q", flagName)
}
}
vm, _, err := root.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
list, _, err := vm.Find([]string{"list"})
if err != nil {
t.Fatalf("find list: %v", err)
}
if _, _, err := vm.Find([]string{"ls"}); err != nil {
t.Fatalf("find ls alias: %v", err)
}
if _, _, err := vm.Find([]string{"ps"}); err != nil {
t.Fatalf("find ps alias: %v", err)
}
for _, flagName := range []string{"all", "latest", "quiet"} {
if list.Flags().Lookup(flagName) == nil {
t.Fatalf("missing vm list flag %q", flagName)
}
}
}
func TestPSCommandRejectsArgs(t *testing.T) {
cmd := NewBangerCommand()
cmd.SetArgs([]string{"ps", "extra"})
err := cmd.Execute()
if err == nil || !strings.Contains(err.Error(), "usage: banger ps") {
t.Fatalf("Execute() error = %v, want ps usage error", err)
}
}
func TestVMCreateFlagsExist(t *testing.T) {
root := NewBangerCommand()
vm, _, err := root.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
create, _, err := vm.Find([]string{"create"})
if err != nil {
t.Fatalf("find create: %v", err)
}
for _, flagName := range []string{"name", "image", "vcpu", "memory", "system-overlay-size", "disk-size", "nat", "no-start"} {
if create.Flags().Lookup(flagName) == nil {
t.Fatalf("missing flag %q", flagName)
}
}
}
func TestVMRunFlagsExist(t *testing.T) {
root := NewBangerCommand()
vm, _, err := root.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
run, _, err := vm.Find([]string{"run"})
if err != nil {
t.Fatalf("find run: %v", err)
}
for _, flagName := range []string{"name", "image", "vcpu", "memory", "system-overlay-size", "disk-size", "nat", "branch", "from"} {
if run.Flags().Lookup(flagName) == nil {
t.Fatalf("missing flag %q", flagName)
}
}
if run.Flags().Lookup("no-start") != nil {
t.Fatal("vm run should not expose --no-start")
}
}
func TestVMCreateFlagsShowResolvedDefaults(t *testing.T) {
// Defaults are resolved at command-build time from config + host
// heuristics. Guarantee only that the values are sensible-positive
// and match the resolver's output — the exact numbers depend on
// the host the tests run on.
root := NewBangerCommand()
vm, _, err := root.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
create, _, err := vm.Find([]string{"create"})
if err != nil {
t.Fatalf("find create: %v", err)
}
for _, flagName := range []string{"vcpu", "memory"} {
flag := create.Flags().Lookup(flagName)
if flag == nil {
t.Fatalf("flag %q missing", flagName)
}
if flag.DefValue == "" || flag.DefValue == "0" {
t.Errorf("flag %q default = %q, want a positive integer", flagName, flag.DefValue)
}
}
for _, flagName := range []string{"system-overlay-size", "disk-size"} {
flag := create.Flags().Lookup(flagName)
if flag == nil {
t.Fatalf("flag %q missing", flagName)
}
if !strings.ContainsAny(flag.DefValue, "GMK") {
t.Errorf("flag %q default = %q, want a formatted size like '8G'", flagName, flag.DefValue)
}
}
}
func TestVMRunRejectsFromWithoutBranch(t *testing.T) {
cmd := NewBangerCommand()
cmd.SetArgs([]string{"vm", "run", "--from", "HEAD"})
err := cmd.Execute()
if err == nil || !strings.Contains(err.Error(), "--from requires --branch") {
t.Fatalf("Execute() error = %v, want --from requires --branch", err)
}
}
func TestImageRegisterFlagsExist(t *testing.T) {
root := NewBangerCommand()
image, _, err := root.Find([]string{"image"})
if err != nil {
t.Fatalf("find image: %v", err)
}
register, _, err := image.Find([]string{"register"})
if err != nil {
t.Fatalf("find register: %v", err)
}
for _, flagName := range []string{"name", "rootfs", "work-seed", "kernel", "initrd", "modules", "docker"} {
if register.Flags().Lookup(flagName) == nil {
t.Fatalf("missing flag %q", flagName)
}
}
}
func TestImagePromoteCommandExists(t *testing.T) {
root := NewBangerCommand()
image, _, err := root.Find([]string{"image"})
if err != nil {
t.Fatalf("find image: %v", err)
}
if _, _, err := image.Find([]string{"promote"}); err != nil {
t.Fatalf("find promote: %v", err)
}
}
func TestVMKillFlagsExist(t *testing.T) {
root := NewBangerCommand()
vm, _, err := root.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
kill, _, err := vm.Find([]string{"kill"})
if err != nil {
t.Fatalf("find kill: %v", err)
}
if kill.Flags().Lookup("signal") == nil {
t.Fatal("missing signal flag")
}
}
func TestVMPortsCommandExists(t *testing.T) {
root := NewBangerCommand()
vm, _, err := root.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
if _, _, err := vm.Find([]string{"ports"}); err != nil {
t.Fatalf("find ports: %v", err)
}
}
func TestVMPortsCommandRejectsMultipleRefs(t *testing.T) {
cmd := NewBangerCommand()
cmd.SetArgs([]string{"vm", "ports", "alpha", "beta"})
err := cmd.Execute()
if err == nil || !strings.Contains(err.Error(), "usage: banger vm ports <id-or-name>") {
t.Fatalf("Execute() error = %v, want single-vm usage error", err)
}
}
func TestVMSetParamsFromFlags(t *testing.T) {
params, err := vmSetParamsFromFlags("devbox", 4, 2048, "16G", true, false)
if err != nil {
t.Fatalf("vmSetParamsFromFlags: %v", err)
}
if params.IDOrName != "devbox" || params.VCPUCount == nil || *params.VCPUCount != 4 {
t.Fatalf("unexpected params: %+v", params)
}
if params.MemoryMiB == nil || *params.MemoryMiB != 2048 {
t.Fatalf("unexpected memory: %+v", params)
}
if params.WorkDiskSize != "16G" {
t.Fatalf("unexpected disk size: %+v", params)
}
if params.NATEnabled == nil || !*params.NATEnabled {
t.Fatalf("unexpected nat value: %+v", params)
}
}
func TestVMCreateParamsFromFlagsAlwaysPopulatesResolvedValues(t *testing.T) {
// Post-resolver behavior: the CLI is the single source of truth for
// effective defaults. Whether or not the user changed a flag, the
// daemon receives the explicit value so the spec printed to the
// user matches the VM that gets created.
cmd := NewBangerCommand()
vm, _, err := cmd.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
create, _, err := vm.Find([]string{"create"})
if err != nil {
t.Fatalf("find create: %v", err)
}
params, err := vmCreateParamsFromFlags(
create,
"devbox",
"default",
3,
4096,
"10G",
"20G",
false,
false,
)
if err != nil {
t.Fatalf("vmCreateParamsFromFlags: %v", err)
}
if params.VCPUCount == nil || *params.VCPUCount != 3 {
t.Errorf("VCPUCount = %v, want 3", params.VCPUCount)
}
if params.MemoryMiB == nil || *params.MemoryMiB != 4096 {
t.Errorf("MemoryMiB = %v, want 4096", params.MemoryMiB)
}
if params.SystemOverlaySize != "10G" {
t.Errorf("SystemOverlaySize = %q, want 10G", params.SystemOverlaySize)
}
if params.WorkDiskSize != "20G" {
t.Errorf("WorkDiskSize = %q, want 20G", params.WorkDiskSize)
}
}
func TestVMCreateParamsFromFlagsRejectsNonPositive(t *testing.T) {
cmd := NewBangerCommand()
vm, _, err := cmd.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
create, _, err := vm.Find([]string{"create"})
if err != nil {
t.Fatalf("find create: %v", err)
}
if _, err := vmCreateParamsFromFlags(create, "x", "", 0, 1024, "8G", "8G", false, false); err == nil {
t.Error("expected error for vcpu=0")
}
if _, err := vmCreateParamsFromFlags(create, "x", "", 2, 0, "8G", "8G", false, false); err == nil {
t.Error("expected error for memory=0")
}
}
func TestVMCreateParamsFromFlagsIncludesChangedDiskFlags(t *testing.T) {
cmd := NewBangerCommand()
vm, _, err := cmd.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
create, _, err := vm.Find([]string{"create"})
if err != nil {
t.Fatalf("find create: %v", err)
}
if err := create.Flags().Set("system-overlay-size", "16G"); err != nil {
t.Fatalf("set system-overlay-size flag: %v", err)
}
if err := create.Flags().Set("disk-size", "32G"); err != nil {
t.Fatalf("set disk-size flag: %v", err)
}
params, err := vmCreateParamsFromFlags(create, "devbox", "default", model.DefaultVCPUCount, model.DefaultMemoryMiB, "16G", "32G", false, false)
if err != nil {
t.Fatalf("vmCreateParamsFromFlags: %v", err)
}
if params.SystemOverlaySize != "16G" || params.WorkDiskSize != "32G" {
t.Fatalf("expected changed disk flags to be included: %+v", params)
}
}
func TestVMCreateParamsFromFlagsRejectsNonPositiveCPUAndMemory(t *testing.T) {
cmd := NewBangerCommand()
vm, _, err := cmd.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
create, _, err := vm.Find([]string{"create"})
if err != nil {
t.Fatalf("find create: %v", err)
}
if err := create.Flags().Set("vcpu", "0"); err != nil {
t.Fatalf("set vcpu flag: %v", err)
}
if _, err := vmCreateParamsFromFlags(create, "devbox", "default", 0, 0, "", "", false, false); err == nil || !strings.Contains(err.Error(), "vcpu must be a positive integer") {
t.Fatalf("vmCreateParamsFromFlags(vcpu=0) error = %v", err)
}
if err := create.Flags().Set("memory", "-1"); err != nil {
t.Fatalf("set memory flag: %v", err)
}
if _, err := vmCreateParamsFromFlags(create, "devbox", "default", 1, -1, "", "", false, false); err == nil || !strings.Contains(err.Error(), "memory must be a positive integer") {
t.Fatalf("vmCreateParamsFromFlags(memory=-1) error = %v", err)
}
}
func TestRunVMCreatePollsUntilDone(t *testing.T) {
d := defaultDeps()
vm := model.VMRecord{
ID: "vm-id",
Name: "devbox",
Spec: model.VMSpec{WorkDiskSizeBytes: model.DefaultWorkDiskSize},
Runtime: model.VMRuntime{
State: model.VMStateRunning,
GuestIP: "172.16.0.2",
DNSName: "devbox.vm",
},
}
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{
Operation: api.VMCreateOperation{
ID: "op-1",
Stage: "prepare_work_disk",
Detail: "cloning work seed",
},
}, nil
}
statusCalls := 0
d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
statusCalls++
if statusCalls == 1 {
return api.VMCreateStatusResult{
Operation: api.VMCreateOperation{
ID: "op-1",
Stage: "wait_vsock_agent",
Detail: "waiting for guest vsock agent",
},
}, nil
}
return api.VMCreateStatusResult{
Operation: api.VMCreateOperation{
ID: "op-1",
Stage: "ready",
Detail: "vm is ready",
Done: true,
Success: true,
VM: &vm,
},
}, nil
}
d.vmCreateCancel = func(context.Context, string, string) error {
t.Fatal("cancel should not be called")
return nil
}
got, err := d.runVMCreate(context.Background(), "/tmp/bangerd.sock", &bytes.Buffer{}, api.VMCreateParams{Name: "devbox"})
if err != nil {
t.Fatalf("d.runVMCreate: %v", err)
}
if got.Name != vm.Name || got.Runtime.GuestIP != vm.Runtime.GuestIP {
t.Fatalf("vm = %+v, want %+v", got, vm)
}
if statusCalls != 2 {
t.Fatalf("statusCalls = %d, want 2", statusCalls)
}
}
func TestVMCreateProgressRendererSuppressesDuplicateLines(t *testing.T) {
var stderr bytes.Buffer
renderer := &vmCreateProgressRenderer{out: &stderr, enabled: true}
renderer.render(api.VMCreateOperation{Stage: "prepare_work_disk", Detail: "cloning work seed"})
renderer.render(api.VMCreateOperation{Stage: "prepare_work_disk", Detail: "cloning work seed"})
renderer.render(api.VMCreateOperation{Stage: "wait_vsock_agent", Detail: "waiting for guest vsock agent"})
lines := strings.Split(strings.TrimSpace(stderr.String()), "\n")
if len(lines) != 2 {
t.Fatalf("rendered lines = %q, want 2 lines", stderr.String())
}
if lines[0] != "[vm create] preparing work disk: cloning work seed" {
t.Fatalf("first line = %q", lines[0])
}
if lines[1] != "[vm create] waiting for vsock agent: waiting for guest vsock agent" {
t.Fatalf("second line = %q", lines[1])
}
}
func TestVMRunProgressRendererSuppressesDuplicateLines(t *testing.T) {
var stderr bytes.Buffer
renderer := newVMRunProgressRenderer(&stderr)
renderer.render("waiting for guest ssh")
renderer.render("waiting for guest ssh")
renderer.render("overlaying host working tree")
lines := strings.Split(strings.TrimSpace(stderr.String()), "\n")
if len(lines) != 2 {
t.Fatalf("rendered lines = %q, want 2 lines", stderr.String())
}
if lines[0] != "[vm run] waiting for guest ssh" {
t.Fatalf("first line = %q", lines[0])
}
if lines[1] != "[vm run] overlaying host working tree" {
t.Fatalf("second line = %q", lines[1])
}
}
func TestWithHeartbeatNoOpForNonTTY(t *testing.T) {
var buf bytes.Buffer
called := false
err := withHeartbeat(&buf, "image pull", func() error {
called = true
return nil
})
if err != nil {
t.Fatalf("withHeartbeat: %v", err)
}
if !called {
t.Fatal("fn should have been called")
}
if buf.Len() != 0 {
t.Fatalf("stderr = %q, want empty for non-TTY", buf.String())
}
}
func TestWithHeartbeatPropagatesError(t *testing.T) {
sentinel := errors.New("boom")
var buf bytes.Buffer
err := withHeartbeat(&buf, "image pull", func() error { return sentinel })
if !errors.Is(err, sentinel) {
t.Fatalf("withHeartbeat error = %v, want %v", err, sentinel)
}
}
func TestVMSetParamsFromFlagsConflict(t *testing.T) {
if _, err := vmSetParamsFromFlags("devbox", -1, -1, "", true, true); err == nil {
t.Fatal("expected nat conflict error")
}
}
func TestVMSetParamsFromFlagsRejectsNonPositiveCPUAndMemory(t *testing.T) {
if _, err := vmSetParamsFromFlags("devbox", 0, -1, "", false, false); err == nil || !strings.Contains(err.Error(), "vcpu must be a positive integer") {
t.Fatalf("vmSetParamsFromFlags(vcpu=0) error = %v", err)
}
if _, err := vmSetParamsFromFlags("devbox", -1, 0, "", false, false); err == nil || !strings.Contains(err.Error(), "memory must be a positive integer") {
t.Fatalf("vmSetParamsFromFlags(memory=0) error = %v", err)
}
}
func TestAbsolutizeImageRegisterPaths(t *testing.T) {
tmp := t.TempDir()
params := api.ImageRegisterParams{
RootfsPath: filepath.Join(".", "runtime", "rootfs-void.ext4"),
WorkSeedPath: filepath.Join(".", "runtime", "rootfs-void.work-seed.ext4"),
KernelPath: filepath.Join(".", "runtime", "vmlinux"),
InitrdPath: filepath.Join(".", "runtime", "initrd.img"),
ModulesDir: filepath.Join(".", "runtime", "modules"),
}
wd, err := os.Getwd()
if err != nil {
t.Fatalf("Getwd: %v", err)
}
if err := os.Chdir(tmp); err != nil {
t.Fatalf("Chdir(%s): %v", tmp, err)
}
t.Cleanup(func() {
_ = os.Chdir(wd)
})
if err := absolutizeImageRegisterPaths(&params); err != nil {
t.Fatalf("absolutizeImageRegisterPaths: %v", err)
}
for _, value := range []string{
params.RootfsPath,
params.WorkSeedPath,
params.KernelPath,
params.InitrdPath,
params.ModulesDir,
} {
if !filepath.IsAbs(value) {
t.Fatalf("path %q is not absolute", value)
}
}
}
func TestPrintImageListTableShowsRootfsSizes(t *testing.T) {
rootfs := filepath.Join(t.TempDir(), "rootfs.ext4")
if err := os.WriteFile(rootfs, nil, 0o644); err != nil {
t.Fatalf("WriteFile(%s): %v", rootfs, err)
}
if err := os.Truncate(rootfs, 8*1024); err != nil {
t.Fatalf("Truncate(%s): %v", rootfs, err)
}
var out bytes.Buffer
err := printImageListTable(&out, []model.Image{
{
ID: "0123456789abcdef",
Name: "alpine",
Managed: true,
RootfsPath: rootfs,
CreatedAt: time.Now().Add(-1 * time.Hour),
},
{
ID: "fedcba9876543210",
Name: "missing",
Managed: false,
RootfsPath: filepath.Join(t.TempDir(), "missing.ext4"),
CreatedAt: time.Now().Add(-2 * time.Hour),
},
})
if err != nil {
t.Fatalf("printImageListTable() error = %v", err)
}
output := out.String()
if !strings.Contains(output, "ROOTFS SIZE") {
t.Fatalf("output = %q, want rootfs size header", output)
}
if !strings.Contains(output, "alpine") || !strings.Contains(output, "8K") {
t.Fatalf("output = %q, want alpine row with 8K size", output)
}
if strings.Contains(output, rootfs) {
t.Fatalf("output = %q, should not include rootfs path", output)
}
if !strings.Contains(output, "missing") || !strings.Contains(output, "-") {
t.Fatalf("output = %q, want fallback size for missing image", output)
}
}
func TestSelectVMListVMsDefaultsToRunning(t *testing.T) {
now := time.Now()
vms := []model.VMRecord{
{ID: "running-1", State: model.VMStateRunning, CreatedAt: now.Add(-3 * time.Hour)},
{ID: "stopped-1", State: model.VMStateStopped, CreatedAt: now.Add(-2 * time.Hour)},
{ID: "running-2", State: model.VMStateRunning, CreatedAt: now.Add(-1 * time.Hour)},
}
got := selectVMListVMs(vms, false, false)
if len(got) != 2 || got[0].ID != "running-1" || got[1].ID != "running-2" {
t.Fatalf("selectVMListVMs() = %#v, want only running VMs in original order", got)
}
}
func TestSelectVMListVMsLatestUsesFilteredSet(t *testing.T) {
now := time.Now()
vms := []model.VMRecord{
{ID: "running-old", State: model.VMStateRunning, CreatedAt: now.Add(-3 * time.Hour)},
{ID: "stopped-new", State: model.VMStateStopped, CreatedAt: now.Add(-30 * time.Minute)},
{ID: "running-new", State: model.VMStateRunning, CreatedAt: now.Add(-1 * time.Hour)},
}
got := selectVMListVMs(vms, false, true)
if len(got) != 1 || got[0].ID != "running-new" {
t.Fatalf("selectVMListVMs(default latest) = %#v, want latest running VM", got)
}
got = selectVMListVMs(vms, true, true)
if len(got) != 1 || got[0].ID != "stopped-new" {
t.Fatalf("selectVMListVMs(all latest) = %#v, want latest VM across all states", got)
}
}
func TestPrintVMIDListShowsFullIDs(t *testing.T) {
var out bytes.Buffer
err := printVMIDList(&out, []model.VMRecord{{ID: "0123456789abcdef0123456789abcdef"}, {ID: "fedcba9876543210fedcba9876543210"}})
if err != nil {
t.Fatalf("printVMIDList() error = %v", err)
}
lines := strings.Split(strings.TrimSpace(out.String()), "\n")
want := []string{"0123456789abcdef0123456789abcdef", "fedcba9876543210fedcba9876543210"}
if !reflect.DeepEqual(lines, want) {
t.Fatalf("lines = %v, want %v", lines, want)
}
}
func TestPrintVMListTableShowsImageNames(t *testing.T) {
var out bytes.Buffer
err := printVMListTable(&out, []model.VMRecord{
{
ID: "0123456789abcdef",
Name: "alp-fast",
ImageID: "image-alpine-123456",
State: model.VMStateRunning,
CreatedAt: time.Now().Add(-1 * time.Hour),
Spec: model.VMSpec{
VCPUCount: 2,
MemoryMiB: model.DefaultMemoryMiB,
WorkDiskSizeBytes: model.DefaultWorkDiskSize,
},
Runtime: model.VMRuntime{GuestIP: "172.16.0.4"},
},
{
ID: "fedcba9876543210",
Name: "mystery",
ImageID: "abcdef1234567890",
State: model.VMStateStopped,
CreatedAt: time.Now().Add(-2 * time.Hour),
Spec: model.VMSpec{
VCPUCount: 1,
MemoryMiB: 512,
WorkDiskSizeBytes: 4 * 1024 * 1024 * 1024,
},
},
}, map[string]string{
"image-alpine-123456": "alpine",
})
if err != nil {
t.Fatalf("printVMListTable() error = %v", err)
}
output := out.String()
if !strings.Contains(output, "IMAGE") || !strings.Contains(output, "MEM") {
t.Fatalf("output = %q, want vm list headers", output)
}
if !strings.Contains(output, "alp-fast") || !strings.Contains(output, "alpine") {
t.Fatalf("output = %q, want resolved image name", output)
}
if strings.Contains(output, "image-alpine-123456") {
t.Fatalf("output = %q, should not include full image id when name is known", output)
}
if !strings.Contains(output, shortID("abcdef1234567890")) {
t.Fatalf("output = %q, want short image id fallback", output)
}
if !strings.Contains(output, fmt.Sprintf("%d MiB", model.DefaultMemoryMiB)) {
t.Fatalf("output = %q, want updated default memory display", output)
}
}
func TestPrintVMPortsTableSortsAndRendersURLEndpoints(t *testing.T) {
result := api.VMPortsResult{
Name: "alpha",
Ports: []api.VMPort{
{
Proto: "https",
Port: 443,
Endpoint: "https://alpha.vm:443/",
Process: "caddy",
Command: "caddy run",
},
{
Proto: "udp",
Port: 53,
Endpoint: "alpha.vm:53",
Process: "dnsd",
Command: "dnsd --foreground",
},
},
}
var out bytes.Buffer
if err := printVMPortsTable(&out, result); err != nil {
t.Fatalf("printVMPortsTable: %v", err)
}
lines := strings.Split(strings.TrimSpace(out.String()), "\n")
if len(lines) != 3 {
t.Fatalf("lines = %q, want header + 2 rows", lines)
}
if !strings.Contains(lines[0], "PROTO") || !strings.Contains(lines[0], "ENDPOINT") || strings.Contains(lines[0], "VM") || strings.Contains(lines[0], "WEB") {
t.Fatalf("header = %q, want PROTO/ENDPOINT without VM/WEB", lines[0])
}
if !strings.Contains(lines[1], "https") || !strings.Contains(lines[1], "https://alpha.vm:443/") {
t.Fatalf("first row = %q, want https endpoint row", lines[1])
}
if !strings.Contains(lines[2], "udp") || !strings.Contains(lines[2], "alpha.vm:53") {
t.Fatalf("second row = %q, want udp endpoint row", lines[2])
}
}
func TestRunSSHSessionPrintsReminderWhenHealthCheckPasses(t *testing.T) {
d := defaultDeps()
d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
return nil
}
d.vmHealth = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
return api.VMHealthResult{Name: "devbox", Healthy: true}, nil
}
var stderr bytes.Buffer
if err := d.runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false); err != nil {
t.Fatalf("d.runSSHSession: %v", err)
}
if !strings.Contains(stderr.String(), "devbox is still running") {
t.Fatalf("stderr = %q, want reminder", stderr.String())
}
}
func TestRunSSHSessionPreservesSSHExitStatusOnHealthWarning(t *testing.T) {
d := defaultDeps()
d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
return exitErrorWithCode(t, 1)
}
d.vmHealth = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
return api.VMHealthResult{}, errors.New("dial failed")
}
var stderr bytes.Buffer
err := d.runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false)
var exitErr *exec.ExitError
if !errors.As(err, &exitErr) {
t.Fatalf("d.runSSHSession error = %v, want exit error", err)
}
if !strings.Contains(stderr.String(), "failed to check whether devbox is still running") {
t.Fatalf("stderr = %q, want warning", stderr.String())
}
}
func TestRunSSHSessionSkipsReminderOnSSHAuthFailure(t *testing.T) {
d := defaultDeps()
healthCalled := false
d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
return exitErrorWithCode(t, 255)
}
d.vmHealth = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
healthCalled = true
return api.VMHealthResult{Name: "devbox", Healthy: true}, nil
}
var stderr bytes.Buffer
err := d.runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false)
var exitErr *exec.ExitError
if !errors.As(err, &exitErr) || exitErr.ExitCode() != 255 {
t.Fatalf("d.runSSHSession error = %v, want exit 255", err)
}
if healthCalled {
t.Fatal("vm health should not run after ssh auth failure")
}
if strings.Contains(stderr.String(), "still running") {
t.Fatalf("stderr = %q, should not contain reminder", stderr.String())
}
}
func TestResolveVMTargetsDeduplicatesAndReportsErrors(t *testing.T) {
vms := []model.VMRecord{
testCLIResolvedVM("alpha-id", "alpha"),
testCLIResolvedVM("alpine-id", "alpine"),
testCLIResolvedVM("bravo-id", "bravo"),
}
targets, errs := resolveVMTargets(vms, []string{"alpha", "alpha-id", "al", "missing", "br"})
if len(targets) != 2 {
t.Fatalf("len(targets) = %d, want 2", len(targets))
}
if targets[0].VM.ID != "alpha-id" || targets[0].Ref != "alpha" {
t.Fatalf("targets[0] = %+v, want alpha target", targets[0])
}
if targets[1].VM.ID != "bravo-id" || targets[1].Ref != "br" {
t.Fatalf("targets[1] = %+v, want bravo target", targets[1])
}
if len(errs) != 2 {
t.Fatalf("len(errs) = %d, want 2", len(errs))
}
if errs[0].Ref != "al" || !strings.Contains(errs[0].Err.Error(), "multiple VMs match") {
t.Fatalf("errs[0] = %+v, want ambiguous prefix", errs[0])
}
if errs[1].Ref != "missing" || !strings.Contains(errs[1].Err.Error(), `vm "missing" not found`) {
t.Fatalf("errs[1] = %+v, want missing vm", errs[1])
}
}
func TestResolveVMRefPrefersExactMatchBeforePrefix(t *testing.T) {
vms := []model.VMRecord{
testCLIResolvedVM("1111111111111111111111111111111111111111111111111111111111111111", "alpha"),
testCLIResolvedVM("alpha222222222222222222222222222222222222222222222222222222222222", "bravo"),
}
vm, err := resolveVMRef(vms, "alpha")
if err != nil {
t.Fatalf("resolveVMRef(alpha): %v", err)
}
if vm.Name != "alpha" {
t.Fatalf("resolveVMRef(alpha) = %+v, want exact-name vm", vm)
}
}
func TestExecuteVMActionBatchRunsConcurrentlyAndPreservesOrder(t *testing.T) {
targets := []resolvedVMTarget{
{Ref: "alpha", VM: testCLIResolvedVM("alpha-id", "alpha")},
{Ref: "bravo", VM: testCLIResolvedVM("bravo-id", "bravo")},
}
started := make(chan string, len(targets))
release := make(chan struct{})
done := make(chan []vmBatchActionResult, 1)
go func() {
done <- executeVMActionBatch(context.Background(), targets, func(ctx context.Context, id string) (model.VMRecord, error) {
started <- id
<-release
return model.VMRecord{ID: id, Name: id}, nil
})
}()
for range targets {
select {
case <-started:
case <-time.After(500 * time.Millisecond):
t.Fatal("batch actions did not overlap")
}
}
close(release)
var results []vmBatchActionResult
select {
case results = <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("executeVMActionBatch did not finish")
}
if len(results) != len(targets) {
t.Fatalf("len(results) = %d, want %d", len(results), len(targets))
}
for index, result := range results {
if result.Target.Ref != targets[index].Ref {
t.Fatalf("results[%d].Target.Ref = %q, want %q", index, result.Target.Ref, targets[index].Ref)
}
if result.VM.ID != targets[index].VM.ID {
t.Fatalf("results[%d].VM.ID = %q, want %q", index, result.VM.ID, targets[index].VM.ID)
}
}
}
func TestSSHCommandArgs(t *testing.T) {
// sshCommandArgs wires banger's own known_hosts into the shell
// SSH invocation — never /dev/null. Assert the shape and the
// posture rather than the exact path (which is host-XDG-derived).
args, err := sshCommandArgs(model.DaemonConfig{SSHKeyPath: "/bundle/id_ed25519"}, "172.16.0.2", []string{"--", "uname", "-a"})
if err != nil {
t.Fatalf("sshCommandArgs: %v", err)
}
wantSubstrings := []string{
"-F", "/dev/null",
"-i", "/bundle/id_ed25519",
"-o", "IdentitiesOnly=yes",
"-o", "PasswordAuthentication=no",
"-o", "KbdInteractiveAuthentication=no",
"root@172.16.0.2",
"--", "uname", "-a",
}
for _, s := range wantSubstrings {
found := false
for _, a := range args {
if a == s {
found = true
break
}
}
if !found {
t.Errorf("args missing %q: %v", s, args)
}
}
// Host-key verification posture: accept-new + a real path into
// banger state, not /dev/null.
joined := strings.Join(args, " ")
if !strings.Contains(joined, "StrictHostKeyChecking=accept-new") {
t.Errorf("args missing accept-new posture: %v", args)
}
if strings.Contains(joined, "UserKnownHostsFile=/dev/null") {
t.Errorf("args leaked UserKnownHostsFile=/dev/null: %v", args)
}
if strings.Contains(joined, "StrictHostKeyChecking=no") {
t.Errorf("args leaked StrictHostKeyChecking=no: %v", args)
}
// Must reference a known_hosts file ending in "known_hosts".
sawKnownHosts := false
for _, a := range args {
if strings.HasPrefix(a, "UserKnownHostsFile=") && strings.HasSuffix(a, "known_hosts") {
sawKnownHosts = true
}
}
if !sawKnownHosts {
t.Errorf("args missing UserKnownHostsFile=<banger known_hosts>: %v", args)
}
}
func TestValidateSSHPrereqs(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "id_ed25519")
if err := os.WriteFile(keyPath, []byte("key"), 0o600); err != nil {
t.Fatalf("write key: %v", err)
}
if err := validateSSHPrereqs(model.DaemonConfig{SSHKeyPath: keyPath}); err != nil {
t.Fatalf("validateSSHPrereqs: %v", err)
}
}
func exitErrorWithCode(t *testing.T, code int) *exec.ExitError {
t.Helper()
cmd := exec.Command("bash", "-lc", fmt.Sprintf("exit %d", code))
err := cmd.Run()
var exitErr *exec.ExitError
if !errors.As(err, &exitErr) {
t.Fatalf("exitErrorWithCode(%d) error = %v, want exit error", code, err)
}
return exitErr
}
func TestValidateSSHPrereqsFailsForMissingKey(t *testing.T) {
err := validateSSHPrereqs(model.DaemonConfig{SSHKeyPath: "/does/not/exist"})
if err == nil || !strings.Contains(err.Error(), "ssh private key") {
t.Fatalf("validateSSHPrereqs() error = %v, want missing key", err)
}
}
// CLI-side git inspection moved to internal/daemon/workspace; the
// CLI now runs only a minimal preflight. Those tests live in the
// workspace package. What we still guard here is the preflight
// policy: reject submodules before the VM is created so the user
// gets a fast error instead of an orphaned VM.
func TestVMRunPreflightRejectsSubmodules(t *testing.T) {
d := defaultDeps()
repoRoot := t.TempDir()
origHostCommandOutput := workspace.HostCommandOutputFunc
t.Cleanup(func() {
workspace.HostCommandOutputFunc = origHostCommandOutput
})
workspace.HostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) {
t.Helper()
if name != "git" {
t.Fatalf("command = %q, want git", name)
}
switch {
case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--show-toplevel"}):
return []byte(repoRoot + "\n"), nil
case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--is-bare-repository"}):
return []byte("false\n"), nil
case reflect.DeepEqual(args, []string{"-C", repoRoot, "ls-files", "--stage", "-z"}):
return []byte("160000 deadbeef 0\tvendor/submodule\x00"), nil
default:
t.Fatalf("unexpected git args: %v", args)
return nil, nil
}
}
_, err := d.vmRunPreflightRepo(context.Background(), repoRoot)
if err == nil || !strings.Contains(err.Error(), "submodules") {
t.Fatalf("d.vmRunPreflightRepo() error = %v, want submodule rejection", err)
}
}
func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) {
d := defaultDeps()
repoRoot := t.TempDir()
vm := model.VMRecord{
ID: "vm-id",
Name: "devbox",
Runtime: model.VMRuntime{
State: model.VMStateRunning,
GuestIP: "172.16.0.2",
DNSName: "devbox.vm",
},
}
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{
Operation: api.VMCreateOperation{
ID: "op-1", Stage: "ready", Detail: "vm is ready",
Done: true, Success: true, VM: &vm,
},
}, nil
}
d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
t.Fatal("d.vmCreateStatus should not be called")
return api.VMCreateStatusResult{}, nil
}
d.vmCreateCancel = func(context.Context, string, string) error {
t.Fatal("d.vmCreateCancel should not be called")
return nil
}
fakeClient := &testVMRunGuestClient{}
d.guestWaitForSSH = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
return nil
}
d.guestDial = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
return fakeClient, nil
}
var workspaceParams api.VMWorkspacePrepareParams
d.vmWorkspacePrepare = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
workspaceParams = params
return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil
}
d.buildVMRunToolingPlan = func(context.Context, string) toolingplan.Plan {
return toolingplan.Plan{
RepoManagedTools: []string{"go"},
Steps: []toolingplan.InstallStep{{Tool: "go", Version: "1.25.0", Source: "go.mod"}},
}
}
var sshArgsSeen []string
d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
sshArgsSeen = args
return nil
}
d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Name: "devbox", Healthy: false}, nil
}
repo := vmRunRepo{sourcePath: repoRoot}
var stdout, stderr bytes.Buffer
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
strings.NewReader(""),
&stdout, &stderr,
api.VMCreateParams{Name: "devbox"},
&repo,
nil,
false,
)
if err != nil {
t.Fatalf("d.runVMRun: %v", err)
}
if workspaceParams.IDOrName != "devbox" || workspaceParams.SourcePath != repoRoot {
t.Fatalf("workspaceParams = %+v", workspaceParams)
}
if len(fakeClient.uploads) != 1 {
t.Fatalf("uploads = %d, want tooling harness upload", len(fakeClient.uploads))
}
if !fakeClient.closed {
t.Fatal("guest client should be closed after tooling bootstrap")
}
if len(sshArgsSeen) == 0 || sshArgsSeen[len(sshArgsSeen)-1] != "root@172.16.0.2" {
t.Fatalf("sshArgsSeen = %v, want interactive ssh to 172.16.0.2 (no trailing command)", sshArgsSeen)
}
if got := stdout.String(); strings.Contains(got, "VM ready.") {
t.Fatalf("stdout = %q, want no next-steps block", got)
}
}
func TestVMRunPrintsPostCreateProgress(t *testing.T) {
d := defaultDeps()
vm := model.VMRecord{
ID: "vm-id",
Name: "devbox",
Runtime: model.VMRuntime{
State: model.VMStateRunning,
GuestIP: "172.16.0.2",
},
}
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{
Operation: api.VMCreateOperation{
ID: "op-1", Stage: "ready", Detail: "vm is ready",
Done: true, Success: true, VM: &vm,
},
}, nil
}
d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
t.Fatal("d.vmCreateStatus should not be called")
return api.VMCreateStatusResult{}, nil
}
d.vmCreateCancel = func(context.Context, string, string) error {
t.Fatal("d.vmCreateCancel should not be called")
return nil
}
d.guestWaitForSSH = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
return nil
}
d.guestDial = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
return &testVMRunGuestClient{}, nil
}
d.vmWorkspacePrepare = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil
}
d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error {
return nil
}
d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Name: "devbox", Healthy: false}, nil
}
repo := vmRunRepo{sourcePath: t.TempDir()}
var stdout, stderr bytes.Buffer
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
strings.NewReader(""),
&stdout, &stderr,
api.VMCreateParams{Name: "devbox"},
&repo,
nil,
false,
)
if err != nil {
t.Fatalf("d.runVMRun: %v", err)
}
output := stderr.String()
for _, want := range []string{
"[vm run] waiting for guest ssh",
"[vm run] preparing guest workspace",
"[vm run] starting guest tooling bootstrap",
"[vm run] guest tooling log: /root/.cache/banger/vm-run-tooling-repo.log",
"[vm run] attaching to guest",
} {
if !strings.Contains(output, want) {
t.Fatalf("stderr = %q, want %q", output, want)
}
}
if strings.Contains(output, "[vm run] printing next steps") {
t.Fatalf("stderr = %q, should not print next-steps progress", output)
}
}
func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) {
d := defaultDeps()
vm := model.VMRecord{
ID: "vm-id",
Name: "devbox",
Runtime: model.VMRuntime{
State: model.VMStateRunning,
GuestIP: "172.16.0.2",
},
}
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Detail: "vm is ready", Done: true, Success: true, VM: &vm}}, nil
}
d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
t.Fatal("d.vmCreateStatus should not be called")
return api.VMCreateStatusResult{}, nil
}
d.vmCreateCancel = func(context.Context, string, string) error {
t.Fatal("d.vmCreateCancel should not be called")
return nil
}
d.guestWaitForSSH = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
return nil
}
fakeClient := &testVMRunGuestClient{launchErr: errors.New("launch failed")}
d.guestDial = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
return fakeClient, nil
}
d.vmWorkspacePrepare = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil
}
sshExecCalls := 0
d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error {
sshExecCalls++
return nil
}
d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Healthy: false}, nil
}
repo := vmRunRepo{sourcePath: t.TempDir()}
var stdout, stderr bytes.Buffer
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
strings.NewReader(""),
&stdout, &stderr,
api.VMCreateParams{Name: "devbox"},
&repo,
nil,
false,
)
if err != nil {
t.Fatalf("d.runVMRun: %v", err)
}
if !strings.Contains(stderr.String(), "[vm run] warning: guest tooling bootstrap start failed: launch guest tooling bootstrap") {
t.Fatalf("stderr = %q, want tooling bootstrap warning", stderr.String())
}
if sshExecCalls != 1 {
t.Fatalf("sshExec calls = %d, want 1 (interactive attach still runs)", sshExecCalls)
}
}
func TestRunVMRunBareModeSkipsWorkspaceAndTooling(t *testing.T) {
d := defaultDeps()
vm := model.VMRecord{
ID: "vm-id", Name: "bare",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
}
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
}
d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
d.guestDial = func(context.Context, string, string) (vmRunGuestClient, error) {
t.Fatal("d.guestDial should not be called in bare mode")
return nil, nil
}
d.vmWorkspacePrepare = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
t.Fatal("d.vmWorkspacePrepare should not be called in bare mode")
return api.VMWorkspacePrepareResult{}, nil
}
sshExecCalls := 0
d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error {
sshExecCalls++
return nil
}
d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Healthy: false}, nil
}
var stdout, stderr bytes.Buffer
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
strings.NewReader(""),
&stdout, &stderr,
api.VMCreateParams{Name: "bare"},
nil,
nil,
false,
)
if err != nil {
t.Fatalf("d.runVMRun: %v", err)
}
if sshExecCalls != 1 {
t.Fatalf("sshExec calls = %d, want 1", sshExecCalls)
}
if !strings.Contains(stderr.String(), "[vm run] attaching to guest") {
t.Fatalf("stderr = %q, want attach progress", stderr.String())
}
}
func TestRunVMRunRMDeletesAfterSessionExits(t *testing.T) {
d := defaultDeps()
vm := model.VMRecord{
ID: "vm-id", Name: "tmpbox",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
}
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
}
d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { return nil }
d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Healthy: false}, nil
}
deletedRef := ""
d.vmDelete = func(_ context.Context, _, idOrName string) error {
deletedRef = idOrName
return nil
}
var stdout, stderr bytes.Buffer
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
strings.NewReader(""),
&stdout, &stderr,
api.VMCreateParams{Name: "tmpbox"},
nil,
nil,
true, // --rm
)
if err != nil {
t.Fatalf("d.runVMRun: %v", err)
}
if deletedRef != "tmpbox" {
t.Fatalf("deletedRef = %q, want tmpbox", deletedRef)
}
// The "VM is still running" reminder would be misleading when
// the VM is about to be deleted; it must be suppressed.
if strings.Contains(stderr.String(), "is still running") {
t.Fatalf("stderr = %q, should not print still-running reminder under --rm", stderr.String())
}
}
func TestRunVMRunRMSkipsDeleteOnSSHWaitTimeout(t *testing.T) {
d := defaultDeps()
origTimeout := vmRunSSHTimeout
vmRunSSHTimeout = 50 * time.Millisecond
t.Cleanup(func() {
vmRunSSHTimeout = origTimeout
})
vm := model.VMRecord{
ID: "vm-id", Name: "slowvm",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
}
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
}
d.guestWaitForSSH = func(ctx context.Context, _, _ string, _ time.Duration) error {
<-ctx.Done()
return ctx.Err()
}
deleteCalled := false
d.vmDelete = func(context.Context, string, string) error {
deleteCalled = true
return nil
}
var stdout, stderr bytes.Buffer
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
strings.NewReader(""),
&stdout, &stderr,
api.VMCreateParams{Name: "slowvm"},
nil,
nil,
true, // --rm
)
if err == nil {
t.Fatal("want timeout error")
}
if deleteCalled {
t.Fatal("VM should NOT be deleted on ssh-wait timeout even with --rm (keep for debugging)")
}
}
func TestRunVMRunSSHTimeoutReturnsActionableError(t *testing.T) {
d := defaultDeps()
origTimeout := vmRunSSHTimeout
vmRunSSHTimeout = 50 * time.Millisecond
t.Cleanup(func() {
vmRunSSHTimeout = origTimeout
})
vm := model.VMRecord{
ID: "vm-id", Name: "slowvm",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
}
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
}
// Simulate the guest never bringing sshd up — the wait-for-ssh
// child context fires its deadline, returning a DeadlineExceeded.
d.guestWaitForSSH = func(ctx context.Context, _, _ string, _ time.Duration) error {
<-ctx.Done()
return ctx.Err()
}
var stdout, stderr bytes.Buffer
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
strings.NewReader(""),
&stdout, &stderr,
api.VMCreateParams{Name: "slowvm"},
nil,
nil,
false,
)
if err == nil {
t.Fatal("want timeout error")
}
msg := err.Error()
for _, want := range []string{
"slowvm",
"did not come up",
"banger vm logs slowvm",
"banger vm delete slowvm",
} {
if !strings.Contains(msg, want) {
t.Fatalf("err = %q, want contains %q", msg, want)
}
}
}
func TestRunVMRunCommandModePropagatesExitCode(t *testing.T) {
d := defaultDeps()
vm := model.VMRecord{
ID: "vm-id", Name: "cmdbox",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
}
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
}
d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
d.vmWorkspacePrepare = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
t.Fatal("workspace prepare should not run without spec")
return api.VMWorkspacePrepareResult{}, nil
}
var sshArgsSeen []string
d.sshExec = func(_ context.Context, _ io.Reader, _, _ io.Writer, args []string) error {
sshArgsSeen = args
return exitErrorWithCode(t, 7)
}
var stdout, stderr bytes.Buffer
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
strings.NewReader(""),
&stdout, &stderr,
api.VMCreateParams{Name: "cmdbox"},
nil,
[]string{"false"},
false,
)
var exitErr ExitCodeError
if !errors.As(err, &exitErr) || exitErr.Code != 7 {
t.Fatalf("d.runVMRun error = %v, want ExitCodeError{7}", err)
}
if len(sshArgsSeen) == 0 || sshArgsSeen[len(sshArgsSeen)-1] != "false" {
t.Fatalf("sshArgsSeen = %v, want trailing command 'false'", sshArgsSeen)
}
if !strings.Contains(stderr.String(), "[vm run] running command in guest") {
t.Fatalf("stderr = %q, want command progress", stderr.String())
}
}
func TestVMRunCommandRejectsBranchWithoutPath(t *testing.T) {
cmd := NewBangerCommand()
cmd.SetArgs([]string{"vm", "run", "--branch", "feat"})
cmd.SetOut(&bytes.Buffer{})
cmd.SetErr(&bytes.Buffer{})
err := cmd.Execute()
if err == nil || !strings.Contains(err.Error(), "--branch requires a path") {
t.Fatalf("Execute() error = %v, want --branch requires a path", err)
}
}
func TestSplitVMRunArgsPartitionsOnDash(t *testing.T) {
cases := []struct {
name string
argv []string
wantPath []string
wantCmd []string
}{
{"empty", []string{}, []string{}, nil},
{"path only", []string{"./repo"}, []string{"./repo"}, nil},
{"cmd only", []string{"--", "make", "test"}, []string{}, []string{"make", "test"}},
{"path and cmd", []string{"./repo", "--", "ls"}, []string{"./repo"}, []string{"ls"}},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// Parse through cobra so ArgsLenAtDash is populated.
var seenPath, seenCmd []string
root := &cobra.Command{Use: "root"}
run := &cobra.Command{
Use: "run",
Args: cobra.ArbitraryArgs,
RunE: func(cmd *cobra.Command, args []string) error {
seenPath, seenCmd = splitVMRunArgs(cmd, args)
return nil
},
}
root.AddCommand(run)
root.SetArgs(append([]string{"run"}, tc.argv...))
root.SetOut(&bytes.Buffer{})
root.SetErr(&bytes.Buffer{})
if err := root.Execute(); err != nil {
t.Fatalf("execute: %v", err)
}
if len(seenPath) != len(tc.wantPath) {
t.Fatalf("path = %v, want %v", seenPath, tc.wantPath)
}
for i := range seenPath {
if seenPath[i] != tc.wantPath[i] {
t.Fatalf("path = %v, want %v", seenPath, tc.wantPath)
}
}
if len(seenCmd) != len(tc.wantCmd) {
t.Fatalf("cmd = %v, want %v", seenCmd, tc.wantCmd)
}
for i := range seenCmd {
if seenCmd[i] != tc.wantCmd[i] {
t.Fatalf("cmd = %v, want %v", seenCmd, tc.wantCmd)
}
}
})
}
}
func TestVMRunToolingHarnessScriptUsesMiseOnly(t *testing.T) {
script := vmRunToolingHarnessScript(toolingplan.Plan{
RepoManagedTools: []string{"node"},
Steps: []toolingplan.InstallStep{{Tool: "go", Version: "1.25.0", Source: "go.mod"}},
Skips: []toolingplan.SkipNote{{Target: "python", Reason: "no .python-version"}},
})
for _, want := range []string{
`repo-managed mise tools: node`,
`run_best_effort "$MISE_BIN" install`,
`run_bounded_best_effort "$INSTALL_TIMEOUT_SECS" "$MISE_BIN" use -g --pin 'go@1.25.0'`,
`deterministic skip: python (no .python-version)`,
`run_best_effort "$MISE_BIN" reshim`,
} {
if !strings.Contains(script, want) {
t.Fatalf("script = %q, want %q", script, want)
}
}
for _, unwanted := range []string{`opencode run`, `PROMPT_FILE=`, `--format json`, `mimo-v2-pro-free`} {
if strings.Contains(script, unwanted) {
t.Fatalf("script = %q, want no %q", script, unwanted)
}
}
}
// The shallow-repo-copy + checkout-script paths used to live in the
// CLI. They now live in internal/daemon/workspace and are exercised
// by that package's tests; no need to duplicate here.
func TestVMRunGuestDirIsFixed(t *testing.T) {
if got := vmRunGuestDir(); got != "/root/repo" {
t.Fatalf("vmRunGuestDir() = %q, want /root/repo", got)
}
}
func TestNewBangerdCommandRejectsArgs(t *testing.T) {
cmd := NewBangerdCommand()
cmd.SetArgs([]string{"extra"})
if err := cmd.Execute(); err == nil {
t.Fatal("expected extra args to be rejected")
}
}
func TestDaemonOutdated(t *testing.T) {
d := defaultDeps()
dir := t.TempDir()
current := filepath.Join(dir, "bangerd-current")
same := filepath.Join(dir, "bangerd-same")
stale := filepath.Join(dir, "bangerd-stale")
if err := os.WriteFile(current, []byte("current"), 0o755); err != nil {
t.Fatalf("write current: %v", err)
}
if err := os.Link(current, same); err != nil {
t.Fatalf("hard link: %v", err)
}
if err := os.WriteFile(stale, []byte("stale"), 0o755); err != nil {
t.Fatalf("write stale: %v", err)
}
d.bangerdPath = func() (string, error) {
return current, nil
}
d.daemonExePath = func(pid int) string {
if pid == 1 {
return same
}
return stale
}
if d.daemonOutdated(1) {
t.Fatal("expected matching daemon executable to be current")
}
if !d.daemonOutdated(2) {
t.Fatal("expected replaced daemon executable to be outdated")
}
}
func TestDaemonStatusIncludesLogPathWhenStopped(t *testing.T) {
configHome := filepath.Join(t.TempDir(), "config")
stateHome := filepath.Join(t.TempDir(), "state")
runtimeHome := filepath.Join(t.TempDir(), "runtime")
t.Setenv("XDG_CONFIG_HOME", configHome)
t.Setenv("XDG_STATE_HOME", stateHome)
t.Setenv("XDG_RUNTIME_DIR", runtimeHome)
cmd := NewBangerCommand()
var stdout bytes.Buffer
cmd.SetOut(&stdout)
cmd.SetErr(&stdout)
cmd.SetArgs([]string{"daemon", "status"})
if err := cmd.Execute(); err != nil {
t.Fatalf("Execute: %v", err)
}
output := stdout.String()
if !strings.Contains(output, "stopped\n") {
t.Fatalf("output = %q, want stopped status", output)
}
if !strings.Contains(output, "log: "+filepath.Join(stateHome, "banger", "bangerd.log")) {
t.Fatalf("output = %q, want daemon log path", output)
}
if !strings.Contains(output, "dns: 127.0.0.1:42069") {
t.Fatalf("output = %q, want dns listener", output)
}
}
func TestDaemonStatusIncludesDaemonBuildInfoWhenRunning(t *testing.T) {
d := defaultDeps()
configHome := filepath.Join(t.TempDir(), "config")
stateHome := filepath.Join(t.TempDir(), "state")
runtimeHome := filepath.Join(t.TempDir(), "runtime")
t.Setenv("XDG_CONFIG_HOME", configHome)
t.Setenv("XDG_STATE_HOME", stateHome)
t.Setenv("XDG_RUNTIME_DIR", runtimeHome)
d.daemonPing = func(context.Context, string) (api.PingResult, error) {
return api.PingResult{
Status: "ok",
PID: 42,
Version: "v1.2.3",
Commit: "abc123",
BuiltAt: "2026-03-22T12:00:00Z",
}, nil
}
cmd := d.newRootCommand()
var stdout bytes.Buffer
cmd.SetOut(&stdout)
cmd.SetErr(&stdout)
cmd.SetArgs([]string{"daemon", "status"})
if err := cmd.Execute(); err != nil {
t.Fatalf("Execute: %v", err)
}
output := stdout.String()
for _, want := range []string{
"running\n",
"pid: 42",
"version: v1.2.3",
"commit: abc123",
"built_at: 2026-03-22T12:00:00Z",
"log: " + filepath.Join(stateHome, "banger", "bangerd.log"),
} {
if !strings.Contains(output, want) {
t.Fatalf("output = %q, want %q", output, want)
}
}
}
func TestBuildDaemonCommandIsDetachedFromCallerContext(t *testing.T) {
cmd := buildDaemonCommand("/tmp/bangerd")
if cmd.Path != "/tmp/bangerd" {
t.Fatalf("command path = %q", cmd.Path)
}
if cmd.Cancel != nil {
t.Fatal("daemon process should not be tied to a CLI request context")
}
}
func testCLIResolvedVM(id, name string) model.VMRecord {
return model.VMRecord{ID: id, Name: name}
}
func testRunGit(t *testing.T, dir string, args ...string) string {
t.Helper()
cmd := exec.Command("git", append([]string{"-c", "commit.gpgsign=false", "-C", dir}, args...)...)
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("git %v: %v\n%s", args, err, string(output))
}
return string(output)
}
type testVMRunUpload struct {
path string
mode os.FileMode
data []byte
}
type testVMRunGuestClient struct {
closed bool
uploads []testVMRunUpload
uploadPath string
uploadMode os.FileMode
uploadData []byte
uploadErr error
checkoutErr error
launchErr error
script string
launchScript string
runScriptCalls int
tarSourceDir string
tarCommand string
streamSourceDir string
streamEntries []string
streamCommand string
}
func (c *testVMRunGuestClient) Close() error {
c.closed = true
return nil
}
func (c *testVMRunGuestClient) UploadFile(ctx context.Context, remotePath string, mode os.FileMode, data []byte, logWriter io.Writer) error {
copyData := append([]byte(nil), data...)
c.uploads = append(c.uploads, testVMRunUpload{path: remotePath, mode: mode, data: copyData})
c.uploadPath = remotePath
c.uploadMode = mode
c.uploadData = copyData
return c.uploadErr
}
func (c *testVMRunGuestClient) StreamTar(ctx context.Context, sourceDir, remoteCommand string, logWriter io.Writer) error {
c.tarSourceDir = sourceDir
c.tarCommand = remoteCommand
return nil
}
func (c *testVMRunGuestClient) RunScript(ctx context.Context, script string, logWriter io.Writer) error {
c.runScriptCalls++
if c.runScriptCalls == 1 {
c.script = script
c.launchScript = script
if c.checkoutErr != nil {
return c.checkoutErr
}
return c.launchErr
}
c.launchScript = script
return c.launchErr
}
func (c *testVMRunGuestClient) StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error {
c.streamSourceDir = sourceDir
c.streamEntries = append([]string(nil), entries...)
c.streamCommand = remoteCommand
return nil
}
// stubEnsureDaemonForSend isolates XDG dirs and installs a daemon-ping
// fake onto the caller's *deps so `ensureDaemon` short-circuits without
// trying to spawn bangerd.
func stubEnsureDaemonForSend(t *testing.T, d *deps) {
t.Helper()
t.Setenv("XDG_CONFIG_HOME", filepath.Join(t.TempDir(), "config"))
t.Setenv("XDG_STATE_HOME", filepath.Join(t.TempDir(), "state"))
t.Setenv("XDG_RUNTIME_DIR", filepath.Join(t.TempDir(), "run"))
d.daemonPing = func(context.Context, string) (api.PingResult, error) {
return api.PingResult{Status: "ok", PID: os.Getpid()}, nil
}
}
func TestVMWorkspaceExportCommandExists(t *testing.T) {
root := NewBangerCommand()
vm, _, err := root.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
workspace, _, err := vm.Find([]string{"workspace"})
if err != nil {
t.Fatalf("find workspace: %v", err)
}
if _, _, err := workspace.Find([]string{"export"}); err != nil {
t.Fatalf("find workspace export: %v", err)
}
}
func TestVMWorkspaceExportRejectsMissingArg(t *testing.T) {
cmd := NewBangerCommand()
cmd.SetArgs([]string{"vm", "workspace", "export"})
err := cmd.Execute()
if err == nil || !strings.Contains(err.Error(), "usage: banger vm workspace export") {
t.Fatalf("Execute() error = %v, want usage error", err)
}
}
func TestVMWorkspaceExportWritesToStdout(t *testing.T) {
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
patch := []byte("diff --git a/main.go b/main.go\nindex 0000000..1111111 100644\n")
d.vmWorkspaceExport = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return api.WorkspaceExportResult{
GuestPath: params.GuestPath,
Patch: patch,
ChangedFiles: []string{"main.go"},
HasChanges: true,
}, nil
}
cmd := d.newRootCommand()
var out bytes.Buffer
cmd.SetOut(&out)
cmd.SetErr(io.Discard)
cmd.SetArgs([]string{"vm", "workspace", "export", "devbox"})
if err := cmd.Execute(); err != nil {
t.Fatalf("Execute: %v", err)
}
if !bytes.Equal(out.Bytes(), patch) {
t.Fatalf("stdout = %q, want %q", out.Bytes(), patch)
}
}
func TestVMWorkspaceExportWritesToFile(t *testing.T) {
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
patch := []byte("diff --git a/main.go b/main.go\n")
d.vmWorkspaceExport = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return api.WorkspaceExportResult{
GuestPath: "/root/repo",
Patch: patch,
ChangedFiles: []string{"main.go"},
HasChanges: true,
}, nil
}
outFile := filepath.Join(t.TempDir(), "worker.diff")
cmd := d.newRootCommand()
cmd.SetOut(io.Discard)
var stderr bytes.Buffer
cmd.SetErr(&stderr)
cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--output", outFile})
if err := cmd.Execute(); err != nil {
t.Fatalf("Execute: %v", err)
}
got, err := os.ReadFile(outFile)
if err != nil {
t.Fatalf("ReadFile: %v", err)
}
if !bytes.Equal(got, patch) {
t.Fatalf("file content = %q, want %q", got, patch)
}
if !strings.Contains(stderr.String(), "worker.diff") {
t.Fatalf("stderr = %q, want output path mentioned", stderr.String())
}
}
func TestVMWorkspaceExportNoChanges(t *testing.T) {
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
d.vmWorkspaceExport = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return api.WorkspaceExportResult{
GuestPath: "/root/repo",
HasChanges: false,
}, nil
}
cmd := d.newRootCommand()
var out bytes.Buffer
var stderr bytes.Buffer
cmd.SetOut(&out)
cmd.SetErr(&stderr)
cmd.SetArgs([]string{"vm", "workspace", "export", "devbox"})
if err := cmd.Execute(); err != nil {
t.Fatalf("Execute: %v", err)
}
if out.Len() != 0 {
t.Fatalf("stdout = %q, want empty when no changes", out.String())
}
if !strings.Contains(stderr.String(), "no changes") {
t.Fatalf("stderr = %q, want 'no changes'", stderr.String())
}
}
func TestVMWorkspaceExportGuestPathFlag(t *testing.T) {
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
var capturedParams api.WorkspaceExportParams
d.vmWorkspaceExport = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
capturedParams = params
return api.WorkspaceExportResult{HasChanges: false}, nil
}
cmd := d.newRootCommand()
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--guest-path", "/root/project"})
if err := cmd.Execute(); err != nil {
t.Fatalf("Execute: %v", err)
}
if capturedParams.GuestPath != "/root/project" {
t.Fatalf("GuestPath = %q, want /root/project", capturedParams.GuestPath)
}
if capturedParams.IDOrName != "devbox" {
t.Fatalf("IDOrName = %q, want devbox", capturedParams.IDOrName)
}
}
func TestVMWorkspaceExportBaseCommitFlag(t *testing.T) {
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
var capturedParams api.WorkspaceExportParams
d.vmWorkspaceExport = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
capturedParams = params
return api.WorkspaceExportResult{
HasChanges: false,
BaseCommit: params.BaseCommit,
}, nil
}
cmd := d.newRootCommand()
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--base-commit", "abc1234deadbeef"})
if err := cmd.Execute(); err != nil {
t.Fatalf("Execute: %v", err)
}
if capturedParams.BaseCommit != "abc1234deadbeef" {
t.Fatalf("BaseCommit = %q, want abc1234deadbeef", capturedParams.BaseCommit)
}
}