From c42fcbe01207e0f4f918dbceff25c24b96a8397f Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Sun, 19 Apr 2026 19:03:55 -0300 Subject: [PATCH] cli + daemon: move test seams off package globals onto injected structs CLI: introduce internal/cli.deps which owns every RPC/SSH/host-command seam the tree used to reach through mutable package vars. Command builders, orchestrators, and the completion helpers become methods on *deps. Tests construct their own deps per case, so fakes no longer leak across cases and tests are free to run in parallel. Daemon: move workspaceInspectRepoFunc + workspaceImportFunc onto the Daemon struct (workspaceInspectRepo / workspaceImport), mirroring the existing guestWaitForSSH / guestDial pattern. Workspace-prepare tests drop t.Parallel() guards now that they no longer mutate process-wide state. Co-Authored-By: Claude Opus 4.7 (1M context) --- internal/cli/banger.go | 129 ++------ internal/cli/cli_test.go | 447 ++++++++++------------------ internal/cli/commands_daemon.go | 4 +- internal/cli/commands_image.go | 48 +-- internal/cli/commands_internal.go | 2 +- internal/cli/commands_kernel.go | 36 +-- internal/cli/commands_vm.go | 160 +++++----- internal/cli/commands_vm_session.go | 82 ++--- internal/cli/completion.go | 53 ++-- internal/cli/completion_test.go | 84 +++--- internal/cli/daemon_lifecycle.go | 24 +- internal/cli/deps.go | 165 ++++++++++ internal/cli/prune_test.go | 55 ++-- internal/cli/ssh.go | 6 +- internal/cli/vm_create.go | 10 +- internal/cli/vm_run.go | 26 +- internal/daemon/daemon.go | 3 + internal/daemon/workspace.go | 30 +- internal/daemon/workspace_test.go | 33 +- 19 files changed, 664 insertions(+), 733 deletions(-) create mode 100644 internal/cli/deps.go diff --git a/internal/cli/banger.go b/internal/cli/banger.go index f278b13..fd87775 100644 --- a/internal/cli/banger.go +++ b/internal/cli/banger.go @@ -1,125 +1,25 @@ package cli import ( - "context" "errors" "fmt" - "io" - "os" - "os/exec" "path/filepath" "strings" - "time" "banger/internal/api" "banger/internal/buildinfo" - "banger/internal/daemon" - "banger/internal/guest" - "banger/internal/paths" - "banger/internal/rpc" - "banger/internal/toolingplan" "github.com/spf13/cobra" ) -var ( - bangerdPathFunc = paths.BangerdPath - daemonExePath = func(pid int) string { - return filepath.Join("/proc", fmt.Sprintf("%d", pid), "exe") - } - doctorFunc = daemon.Doctor - sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { - sshCmd := exec.CommandContext(ctx, "ssh", args...) - sshCmd.Stdout = stdout - sshCmd.Stderr = stderr - sshCmd.Stdin = stdin - return sshCmd.Run() - } - hostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) { - cmd := exec.CommandContext(ctx, name, args...) - output, err := cmd.CombinedOutput() - if err == nil { - return output, nil - } - command := strings.TrimSpace(strings.Join(append([]string{name}, args...), " ")) - detail := strings.TrimSpace(string(output)) - if detail == "" { - return output, fmt.Errorf("%s: %w", command, err) - } - return output, fmt.Errorf("%s: %w: %s", command, err, detail) - } - vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { - return rpc.Call[api.VMHealthResult](ctx, socketPath, "vm.health", api.VMRefParams{IDOrName: idOrName}) - } - vmSSHFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error) { - return rpc.Call[api.VMSSHResult](ctx, socketPath, "vm.ssh", api.VMRefParams{IDOrName: idOrName}) - } - vmDeleteFunc = func(ctx context.Context, socketPath, idOrName string) error { - _, err := rpc.Call[api.VMShowResult](ctx, socketPath, "vm.delete", api.VMRefParams{IDOrName: idOrName}) - return err - } - vmListFunc = func(ctx context.Context, socketPath string) (api.VMListResult, error) { - return rpc.Call[api.VMListResult](ctx, socketPath, "vm.list", api.Empty{}) - } - daemonPingFunc = func(ctx context.Context, socketPath string) (api.PingResult, error) { - return rpc.Call[api.PingResult](ctx, socketPath, "ping", api.Empty{}) - } - vmCreateBeginFunc = func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error) { - return rpc.Call[api.VMCreateBeginResult](ctx, socketPath, "vm.create.begin", params) - } - vmCreateStatusFunc = func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error) { - return rpc.Call[api.VMCreateStatusResult](ctx, socketPath, "vm.create.status", api.VMCreateStatusParams{ID: operationID}) - } - vmCreateCancelFunc = func(ctx context.Context, socketPath, operationID string) error { - _, err := rpc.Call[api.Empty](ctx, socketPath, "vm.create.cancel", api.VMCreateStatusParams{ID: operationID}) - return err - } - vmPortsFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) { - return rpc.Call[api.VMPortsResult](ctx, socketPath, "vm.ports", api.VMRefParams{IDOrName: idOrName}) - } - vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { - return rpc.Call[api.VMWorkspacePrepareResult](ctx, socketPath, "vm.workspace.prepare", params) - } - vmWorkspaceExportFunc = func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { - return rpc.Call[api.WorkspaceExportResult](ctx, socketPath, "vm.workspace.export", params) - } - guestSessionStartFunc = func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error) { - return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.start", params) - } - guestSessionGetFunc = func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) { - return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.get", params) - } - guestSessionListFunc = func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error) { - return rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: idOrName}) - } - guestSessionStopFunc = func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) { - return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.stop", params) - } - guestSessionKillFunc = func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) { - return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.kill", params) - } - guestSessionLogsFunc = func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) { - return rpc.Call[api.GuestSessionLogsResult](ctx, socketPath, "guest.session.logs", params) - } - guestSessionAttachBeginFunc = func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) { - return rpc.Call[api.GuestSessionAttachBeginResult](ctx, socketPath, "guest.session.attach.begin", params) - } - guestSessionSendFunc = func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { - return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params) - } - guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { - knownHosts, _ := bangerKnownHostsPath() - return guest.WaitForSSH(ctx, address, privateKeyPath, knownHosts, interval) - } - guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { - knownHosts, _ := bangerKnownHostsPath() - return guest.Dial(ctx, address, privateKeyPath, knownHosts) - } - buildVMRunToolingPlanFunc = toolingplan.Build - cwdFunc = os.Getwd -) - +// NewBangerCommand builds the top-level cobra tree with production +// defaults wired into the dependency struct. Tests reach into the +// package directly — see newRootCommand + defaultDeps. func NewBangerCommand() *cobra.Command { + return defaultDeps().newRootCommand() +} + +func (d *deps) newRootCommand() *cobra.Command { root := &cobra.Command{ Use: "banger", Short: "Manage development VMs and images", @@ -127,17 +27,26 @@ func NewBangerCommand() *cobra.Command { SilenceErrors: true, RunE: helpNoArgs, } - root.AddCommand(newDaemonCommand(), newDoctorCommand(), newImageCommand(), newInternalCommand(), newKernelCommand(), newVersionCommand(), newPSCommand(), newVMCommand()) + root.AddCommand( + d.newDaemonCommand(), + d.newDoctorCommand(), + d.newImageCommand(), + d.newInternalCommand(), + d.newKernelCommand(), + newVersionCommand(), + d.newPSCommand(), + d.newVMCommand(), + ) return root } -func newDoctorCommand() *cobra.Command { +func (d *deps) newDoctorCommand() *cobra.Command { return &cobra.Command{ Use: "doctor", Short: "Check host and runtime readiness", Args: noArgsUsage("usage: banger doctor"), RunE: func(cmd *cobra.Command, args []string) error { - report, err := doctorFunc(cmd.Context()) + report, err := d.doctor(cmd.Context()) if err != nil { return err } diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 07a3412..d9030ca 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -121,11 +121,8 @@ func TestLegacyRemovedCommandIsRejected(t *testing.T) { } func TestDoctorCommandPrintsReportAndFailsOnHardFailures(t *testing.T) { - original := doctorFunc - t.Cleanup(func() { - doctorFunc = original - }) - doctorFunc = func(context.Context) (system.Report, error) { + d := defaultDeps() + d.doctor = func(context.Context) (system.Report, error) { return system.Report{ Checks: []system.CheckResult{ {Name: "runtime bundle", Status: system.CheckStatusPass, Details: []string{"runtime dir /tmp/runtime"}}, @@ -134,7 +131,7 @@ func TestDoctorCommandPrintsReportAndFailsOnHardFailures(t *testing.T) { }, nil } - cmd := NewBangerCommand() + cmd := d.newRootCommand() var stdout bytes.Buffer cmd.SetOut(&stdout) cmd.SetErr(&stdout) @@ -154,15 +151,12 @@ func TestDoctorCommandPrintsReportAndFailsOnHardFailures(t *testing.T) { } func TestDoctorCommandReturnsUnderlyingError(t *testing.T) { - original := doctorFunc - t.Cleanup(func() { - doctorFunc = original - }) - doctorFunc = func(context.Context) (system.Report, error) { + d := defaultDeps() + d.doctor = func(context.Context) (system.Report, error) { return system.Report{}, errors.New("load failed") } - cmd := NewBangerCommand() + cmd := d.newRootCommand() cmd.SetArgs([]string{"doctor"}) err := cmd.Execute() if err == nil || !strings.Contains(err.Error(), "load failed") { @@ -509,14 +503,7 @@ func TestVMCreateParamsFromFlagsRejectsNonPositiveCPUAndMemory(t *testing.T) { } func TestRunVMCreatePollsUntilDone(t *testing.T) { - origBegin := vmCreateBeginFunc - origStatus := vmCreateStatusFunc - origCancel := vmCreateCancelFunc - t.Cleanup(func() { - vmCreateBeginFunc = origBegin - vmCreateStatusFunc = origStatus - vmCreateCancelFunc = origCancel - }) + d := defaultDeps() vm := model.VMRecord{ ID: "vm-id", @@ -528,7 +515,7 @@ func TestRunVMCreatePollsUntilDone(t *testing.T) { DNSName: "devbox.vm", }, } - vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { + d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { return api.VMCreateBeginResult{ Operation: api.VMCreateOperation{ ID: "op-1", @@ -538,7 +525,7 @@ func TestRunVMCreatePollsUntilDone(t *testing.T) { }, nil } statusCalls := 0 - vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) { + d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) { statusCalls++ if statusCalls == 1 { return api.VMCreateStatusResult{ @@ -560,14 +547,14 @@ func TestRunVMCreatePollsUntilDone(t *testing.T) { }, }, nil } - vmCreateCancelFunc = func(context.Context, string, string) error { + d.vmCreateCancel = func(context.Context, string, string) error { t.Fatal("cancel should not be called") return nil } - got, err := runVMCreate(context.Background(), "/tmp/bangerd.sock", &bytes.Buffer{}, api.VMCreateParams{Name: "devbox"}) + got, err := d.runVMCreate(context.Background(), "/tmp/bangerd.sock", &bytes.Buffer{}, api.VMCreateParams{Name: "devbox"}) if err != nil { - t.Fatalf("runVMCreate: %v", err) + t.Fatalf("d.runVMCreate: %v", err) } if got.Name != vm.Name || got.Runtime.GuestIP != vm.Runtime.GuestIP { t.Fatalf("vm = %+v, want %+v", got, vm) @@ -878,23 +865,18 @@ func TestPrintVMPortsTableSortsAndRendersURLEndpoints(t *testing.T) { } func TestRunSSHSessionPrintsReminderWhenHealthCheckPasses(t *testing.T) { - origSSHExec := sshExecFunc - origHealth := vmHealthFunc - t.Cleanup(func() { - sshExecFunc = origSSHExec - vmHealthFunc = origHealth - }) + d := defaultDeps() - sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { + d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { return nil } - vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { + d.vmHealth = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { return api.VMHealthResult{Name: "devbox", Healthy: true}, nil } var stderr bytes.Buffer - if err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false); err != nil { - t.Fatalf("runSSHSession: %v", err) + if err := d.runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false); err != nil { + t.Fatalf("d.runSSHSession: %v", err) } if !strings.Contains(stderr.String(), "devbox is still running") { t.Fatalf("stderr = %q, want reminder", stderr.String()) @@ -902,25 +884,20 @@ func TestRunSSHSessionPrintsReminderWhenHealthCheckPasses(t *testing.T) { } func TestRunSSHSessionPreservesSSHExitStatusOnHealthWarning(t *testing.T) { - origSSHExec := sshExecFunc - origHealth := vmHealthFunc - t.Cleanup(func() { - sshExecFunc = origSSHExec - vmHealthFunc = origHealth - }) + d := defaultDeps() - sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { + d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { return exitErrorWithCode(t, 1) } - vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { + d.vmHealth = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { return api.VMHealthResult{}, errors.New("dial failed") } var stderr bytes.Buffer - err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false) + err := d.runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false) var exitErr *exec.ExitError if !errors.As(err, &exitErr) { - t.Fatalf("runSSHSession error = %v, want exit error", err) + t.Fatalf("d.runSSHSession error = %v, want exit error", err) } if !strings.Contains(stderr.String(), "failed to check whether devbox is still running") { t.Fatalf("stderr = %q, want warning", stderr.String()) @@ -928,27 +905,22 @@ func TestRunSSHSessionPreservesSSHExitStatusOnHealthWarning(t *testing.T) { } func TestRunSSHSessionSkipsReminderOnSSHAuthFailure(t *testing.T) { - origSSHExec := sshExecFunc - origHealth := vmHealthFunc - t.Cleanup(func() { - sshExecFunc = origSSHExec - vmHealthFunc = origHealth - }) + d := defaultDeps() healthCalled := false - sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { + d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { return exitErrorWithCode(t, 255) } - vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { + d.vmHealth = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { healthCalled = true return api.VMHealthResult{Name: "devbox", Healthy: true}, nil } var stderr bytes.Buffer - err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false) + err := d.runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false) var exitErr *exec.ExitError if !errors.As(err, &exitErr) || exitErr.ExitCode() != 255 { - t.Fatalf("runSSHSession error = %v, want exit 255", err) + t.Fatalf("d.runSSHSession error = %v, want exit 255", err) } if healthCalled { t.Fatal("vm health should not run after ssh auth failure") @@ -1141,6 +1113,7 @@ func TestValidateSSHPrereqsFailsForMissingKey(t *testing.T) { // gets a fast error instead of an orphaned VM. func TestVMRunPreflightRejectsSubmodules(t *testing.T) { + d := defaultDeps() repoRoot := t.TempDir() origHostCommandOutput := workspace.HostCommandOutputFunc @@ -1166,36 +1139,16 @@ func TestVMRunPreflightRejectsSubmodules(t *testing.T) { } } - _, err := vmRunPreflightRepo(context.Background(), repoRoot) + _, err := d.vmRunPreflightRepo(context.Background(), repoRoot) if err == nil || !strings.Contains(err.Error(), "submodules") { - t.Fatalf("vmRunPreflightRepo() error = %v, want submodule rejection", err) + t.Fatalf("d.vmRunPreflightRepo() error = %v, want submodule rejection", err) } } func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) { + d := defaultDeps() repoRoot := t.TempDir() - origBegin := vmCreateBeginFunc - origStatus := vmCreateStatusFunc - origCancel := vmCreateCancelFunc - origWaitForSSH := guestWaitForSSHFunc - origGuestDial := guestDialFunc - origBuildVMRunToolingPlan := buildVMRunToolingPlanFunc - origVMWorkspacePrepare := vmWorkspacePrepareFunc - origSSHExec := sshExecFunc - origHealth := vmHealthFunc - t.Cleanup(func() { - vmCreateBeginFunc = origBegin - vmCreateStatusFunc = origStatus - vmCreateCancelFunc = origCancel - guestWaitForSSHFunc = origWaitForSSH - guestDialFunc = origGuestDial - buildVMRunToolingPlanFunc = origBuildVMRunToolingPlan - vmWorkspacePrepareFunc = origVMWorkspacePrepare - sshExecFunc = origSSHExec - vmHealthFunc = origHealth - }) - vm := model.VMRecord{ ID: "vm-id", Name: "devbox", @@ -1205,7 +1158,7 @@ func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) { DNSName: "devbox.vm", }, } - vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { + d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { return api.VMCreateBeginResult{ Operation: api.VMCreateOperation{ ID: "op-1", Stage: "ready", Detail: "vm is ready", @@ -1213,45 +1166,45 @@ func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) { }, }, nil } - vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) { - t.Fatal("vmCreateStatusFunc should not be called") + d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) { + t.Fatal("d.vmCreateStatus should not be called") return api.VMCreateStatusResult{}, nil } - vmCreateCancelFunc = func(context.Context, string, string) error { - t.Fatal("vmCreateCancelFunc should not be called") + d.vmCreateCancel = func(context.Context, string, string) error { + t.Fatal("d.vmCreateCancel should not be called") return nil } fakeClient := &testVMRunGuestClient{} - guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { + d.guestWaitForSSH = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { return nil } - guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { + d.guestDial = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { return fakeClient, nil } var workspaceParams api.VMWorkspacePrepareParams - vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { + d.vmWorkspacePrepare = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { workspaceParams = params return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil } - buildVMRunToolingPlanFunc = func(context.Context, string) toolingplan.Plan { + d.buildVMRunToolingPlan = func(context.Context, string) toolingplan.Plan { return toolingplan.Plan{ RepoManagedTools: []string{"go"}, Steps: []toolingplan.InstallStep{{Tool: "go", Version: "1.25.0", Source: "go.mod"}}, } } var sshArgsSeen []string - sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { + d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { sshArgsSeen = args return nil } - vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) { + d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) { return api.VMHealthResult{Name: "devbox", Healthy: false}, nil } repo := vmRunRepo{sourcePath: repoRoot} var stdout, stderr bytes.Buffer - err := runVMRun( + err := d.runVMRun( context.Background(), "/tmp/bangerd.sock", model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, @@ -1263,7 +1216,7 @@ func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) { false, ) if err != nil { - t.Fatalf("runVMRun: %v", err) + t.Fatalf("d.runVMRun: %v", err) } if workspaceParams.IDOrName != "devbox" || workspaceParams.SourcePath != repoRoot { t.Fatalf("workspaceParams = %+v", workspaceParams) @@ -1283,24 +1236,7 @@ func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) { } func TestVMRunPrintsPostCreateProgress(t *testing.T) { - origBegin := vmCreateBeginFunc - origStatus := vmCreateStatusFunc - origCancel := vmCreateCancelFunc - origWaitForSSH := guestWaitForSSHFunc - origGuestDial := guestDialFunc - origVMWorkspacePrepare := vmWorkspacePrepareFunc - origSSHExec := sshExecFunc - origHealth := vmHealthFunc - t.Cleanup(func() { - vmCreateBeginFunc = origBegin - vmCreateStatusFunc = origStatus - vmCreateCancelFunc = origCancel - guestWaitForSSHFunc = origWaitForSSH - guestDialFunc = origGuestDial - vmWorkspacePrepareFunc = origVMWorkspacePrepare - sshExecFunc = origSSHExec - vmHealthFunc = origHealth - }) + d := defaultDeps() vm := model.VMRecord{ ID: "vm-id", @@ -1310,7 +1246,7 @@ func TestVMRunPrintsPostCreateProgress(t *testing.T) { GuestIP: "172.16.0.2", }, } - vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { + d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { return api.VMCreateBeginResult{ Operation: api.VMCreateOperation{ ID: "op-1", Stage: "ready", Detail: "vm is ready", @@ -1318,33 +1254,33 @@ func TestVMRunPrintsPostCreateProgress(t *testing.T) { }, }, nil } - vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) { - t.Fatal("vmCreateStatusFunc should not be called") + d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) { + t.Fatal("d.vmCreateStatus should not be called") return api.VMCreateStatusResult{}, nil } - vmCreateCancelFunc = func(context.Context, string, string) error { - t.Fatal("vmCreateCancelFunc should not be called") + d.vmCreateCancel = func(context.Context, string, string) error { + t.Fatal("d.vmCreateCancel should not be called") return nil } - guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { + d.guestWaitForSSH = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { return nil } - guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { + d.guestDial = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { return &testVMRunGuestClient{}, nil } - vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { + d.vmWorkspacePrepare = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil } - sshExecFunc = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { + d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { return nil } - vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) { + d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) { return api.VMHealthResult{Name: "devbox", Healthy: false}, nil } repo := vmRunRepo{sourcePath: t.TempDir()} var stdout, stderr bytes.Buffer - err := runVMRun( + err := d.runVMRun( context.Background(), "/tmp/bangerd.sock", model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, @@ -1356,7 +1292,7 @@ func TestVMRunPrintsPostCreateProgress(t *testing.T) { false, ) if err != nil { - t.Fatalf("runVMRun: %v", err) + t.Fatalf("d.runVMRun: %v", err) } output := stderr.String() @@ -1377,24 +1313,7 @@ func TestVMRunPrintsPostCreateProgress(t *testing.T) { } func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) { - origBegin := vmCreateBeginFunc - origStatus := vmCreateStatusFunc - origCancel := vmCreateCancelFunc - origWaitForSSH := guestWaitForSSHFunc - origGuestDial := guestDialFunc - origVMWorkspacePrepare := vmWorkspacePrepareFunc - origSSHExec := sshExecFunc - origHealth := vmHealthFunc - t.Cleanup(func() { - vmCreateBeginFunc = origBegin - vmCreateStatusFunc = origStatus - vmCreateCancelFunc = origCancel - guestWaitForSSHFunc = origWaitForSSH - guestDialFunc = origGuestDial - vmWorkspacePrepareFunc = origVMWorkspacePrepare - sshExecFunc = origSSHExec - vmHealthFunc = origHealth - }) + d := defaultDeps() vm := model.VMRecord{ ID: "vm-id", @@ -1404,39 +1323,39 @@ func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) { GuestIP: "172.16.0.2", }, } - vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { + d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Detail: "vm is ready", Done: true, Success: true, VM: &vm}}, nil } - vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) { - t.Fatal("vmCreateStatusFunc should not be called") + d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) { + t.Fatal("d.vmCreateStatus should not be called") return api.VMCreateStatusResult{}, nil } - vmCreateCancelFunc = func(context.Context, string, string) error { - t.Fatal("vmCreateCancelFunc should not be called") + d.vmCreateCancel = func(context.Context, string, string) error { + t.Fatal("d.vmCreateCancel should not be called") return nil } - guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { + d.guestWaitForSSH = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { return nil } fakeClient := &testVMRunGuestClient{launchErr: errors.New("launch failed")} - guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { + d.guestDial = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { return fakeClient, nil } - vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { + d.vmWorkspacePrepare = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil } sshExecCalls := 0 - sshExecFunc = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { + d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { sshExecCalls++ return nil } - vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) { + d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) { return api.VMHealthResult{Healthy: false}, nil } repo := vmRunRepo{sourcePath: t.TempDir()} var stdout, stderr bytes.Buffer - err := runVMRun( + err := d.runVMRun( context.Background(), "/tmp/bangerd.sock", model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, @@ -1448,7 +1367,7 @@ func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) { false, ) if err != nil { - t.Fatalf("runVMRun: %v", err) + t.Fatalf("d.runVMRun: %v", err) } if !strings.Contains(stderr.String(), "[vm run] warning: guest tooling bootstrap start failed: launch guest tooling bootstrap") { t.Fatalf("stderr = %q, want tooling bootstrap warning", stderr.String()) @@ -1459,48 +1378,35 @@ func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) { } func TestRunVMRunBareModeSkipsWorkspaceAndTooling(t *testing.T) { - origBegin := vmCreateBeginFunc - origWaitForSSH := guestWaitForSSHFunc - origGuestDial := guestDialFunc - origVMWorkspacePrepare := vmWorkspacePrepareFunc - origSSHExec := sshExecFunc - origHealth := vmHealthFunc - t.Cleanup(func() { - vmCreateBeginFunc = origBegin - guestWaitForSSHFunc = origWaitForSSH - guestDialFunc = origGuestDial - vmWorkspacePrepareFunc = origVMWorkspacePrepare - sshExecFunc = origSSHExec - vmHealthFunc = origHealth - }) + d := defaultDeps() vm := model.VMRecord{ ID: "vm-id", Name: "bare", Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"}, } - vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { + d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil } - guestWaitForSSHFunc = func(context.Context, string, string, time.Duration) error { return nil } - guestDialFunc = func(context.Context, string, string) (vmRunGuestClient, error) { - t.Fatal("guestDialFunc should not be called in bare mode") + d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil } + d.guestDial = func(context.Context, string, string) (vmRunGuestClient, error) { + t.Fatal("d.guestDial should not be called in bare mode") return nil, nil } - vmWorkspacePrepareFunc = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { - t.Fatal("vmWorkspacePrepareFunc should not be called in bare mode") + d.vmWorkspacePrepare = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { + t.Fatal("d.vmWorkspacePrepare should not be called in bare mode") return api.VMWorkspacePrepareResult{}, nil } sshExecCalls := 0 - sshExecFunc = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { + d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { sshExecCalls++ return nil } - vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) { + d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) { return api.VMHealthResult{Healthy: false}, nil } var stdout, stderr bytes.Buffer - err := runVMRun( + err := d.runVMRun( context.Background(), "/tmp/bangerd.sock", model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, @@ -1512,7 +1418,7 @@ func TestRunVMRunBareModeSkipsWorkspaceAndTooling(t *testing.T) { false, ) if err != nil { - t.Fatalf("runVMRun: %v", err) + t.Fatalf("d.runVMRun: %v", err) } if sshExecCalls != 1 { t.Fatalf("sshExec calls = %d, want 1", sshExecCalls) @@ -1523,39 +1429,28 @@ func TestRunVMRunBareModeSkipsWorkspaceAndTooling(t *testing.T) { } func TestRunVMRunRMDeletesAfterSessionExits(t *testing.T) { - origBegin := vmCreateBeginFunc - origWaitForSSH := guestWaitForSSHFunc - origSSHExec := sshExecFunc - origHealth := vmHealthFunc - origDelete := vmDeleteFunc - t.Cleanup(func() { - vmCreateBeginFunc = origBegin - guestWaitForSSHFunc = origWaitForSSH - sshExecFunc = origSSHExec - vmHealthFunc = origHealth - vmDeleteFunc = origDelete - }) + d := defaultDeps() vm := model.VMRecord{ ID: "vm-id", Name: "tmpbox", Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"}, } - vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { + d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil } - guestWaitForSSHFunc = func(context.Context, string, string, time.Duration) error { return nil } - sshExecFunc = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { return nil } - vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) { + d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil } + d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { return nil } + d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) { return api.VMHealthResult{Healthy: false}, nil } deletedRef := "" - vmDeleteFunc = func(_ context.Context, _, idOrName string) error { + d.vmDelete = func(_ context.Context, _, idOrName string) error { deletedRef = idOrName return nil } var stdout, stderr bytes.Buffer - err := runVMRun( + err := d.runVMRun( context.Background(), "/tmp/bangerd.sock", model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, @@ -1567,7 +1462,7 @@ func TestRunVMRunRMDeletesAfterSessionExits(t *testing.T) { true, // --rm ) if err != nil { - t.Fatalf("runVMRun: %v", err) + t.Fatalf("d.runVMRun: %v", err) } if deletedRef != "tmpbox" { t.Fatalf("deletedRef = %q, want tmpbox", deletedRef) @@ -1580,15 +1475,10 @@ func TestRunVMRunRMDeletesAfterSessionExits(t *testing.T) { } func TestRunVMRunRMSkipsDeleteOnSSHWaitTimeout(t *testing.T) { - origBegin := vmCreateBeginFunc - origWaitForSSH := guestWaitForSSHFunc - origDelete := vmDeleteFunc + d := defaultDeps() origTimeout := vmRunSSHTimeout vmRunSSHTimeout = 50 * time.Millisecond t.Cleanup(func() { - vmCreateBeginFunc = origBegin - guestWaitForSSHFunc = origWaitForSSH - vmDeleteFunc = origDelete vmRunSSHTimeout = origTimeout }) @@ -1596,21 +1486,21 @@ func TestRunVMRunRMSkipsDeleteOnSSHWaitTimeout(t *testing.T) { ID: "vm-id", Name: "slowvm", Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"}, } - vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { + d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil } - guestWaitForSSHFunc = func(ctx context.Context, _, _ string, _ time.Duration) error { + d.guestWaitForSSH = func(ctx context.Context, _, _ string, _ time.Duration) error { <-ctx.Done() return ctx.Err() } deleteCalled := false - vmDeleteFunc = func(context.Context, string, string) error { + d.vmDelete = func(context.Context, string, string) error { deleteCalled = true return nil } var stdout, stderr bytes.Buffer - err := runVMRun( + err := d.runVMRun( context.Background(), "/tmp/bangerd.sock", model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, @@ -1630,13 +1520,10 @@ func TestRunVMRunRMSkipsDeleteOnSSHWaitTimeout(t *testing.T) { } func TestRunVMRunSSHTimeoutReturnsActionableError(t *testing.T) { - origBegin := vmCreateBeginFunc - origWaitForSSH := guestWaitForSSHFunc + d := defaultDeps() origTimeout := vmRunSSHTimeout vmRunSSHTimeout = 50 * time.Millisecond t.Cleanup(func() { - vmCreateBeginFunc = origBegin - guestWaitForSSHFunc = origWaitForSSH vmRunSSHTimeout = origTimeout }) @@ -1644,18 +1531,18 @@ func TestRunVMRunSSHTimeoutReturnsActionableError(t *testing.T) { ID: "vm-id", Name: "slowvm", Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"}, } - vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { + d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil } // Simulate the guest never bringing sshd up — the wait-for-ssh // child context fires its deadline, returning a DeadlineExceeded. - guestWaitForSSHFunc = func(ctx context.Context, _, _ string, _ time.Duration) error { + d.guestWaitForSSH = func(ctx context.Context, _, _ string, _ time.Duration) error { <-ctx.Done() return ctx.Err() } var stdout, stderr bytes.Buffer - err := runVMRun( + err := d.runVMRun( context.Background(), "/tmp/bangerd.sock", model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, @@ -1683,37 +1570,28 @@ func TestRunVMRunSSHTimeoutReturnsActionableError(t *testing.T) { } func TestRunVMRunCommandModePropagatesExitCode(t *testing.T) { - origBegin := vmCreateBeginFunc - origWaitForSSH := guestWaitForSSHFunc - origVMWorkspacePrepare := vmWorkspacePrepareFunc - origSSHExec := sshExecFunc - t.Cleanup(func() { - vmCreateBeginFunc = origBegin - guestWaitForSSHFunc = origWaitForSSH - vmWorkspacePrepareFunc = origVMWorkspacePrepare - sshExecFunc = origSSHExec - }) + d := defaultDeps() vm := model.VMRecord{ ID: "vm-id", Name: "cmdbox", Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"}, } - vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { + d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil } - guestWaitForSSHFunc = func(context.Context, string, string, time.Duration) error { return nil } - vmWorkspacePrepareFunc = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { + d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil } + d.vmWorkspacePrepare = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { t.Fatal("workspace prepare should not run without spec") return api.VMWorkspacePrepareResult{}, nil } var sshArgsSeen []string - sshExecFunc = func(_ context.Context, _ io.Reader, _, _ io.Writer, args []string) error { + d.sshExec = func(_ context.Context, _ io.Reader, _, _ io.Writer, args []string) error { sshArgsSeen = args return exitErrorWithCode(t, 7) } var stdout, stderr bytes.Buffer - err := runVMRun( + err := d.runVMRun( context.Background(), "/tmp/bangerd.sock", model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, @@ -1726,7 +1604,7 @@ func TestRunVMRunCommandModePropagatesExitCode(t *testing.T) { ) var exitErr ExitCodeError if !errors.As(err, &exitErr) || exitErr.Code != 7 { - t.Fatalf("runVMRun error = %v, want ExitCodeError{7}", err) + t.Fatalf("d.runVMRun error = %v, want ExitCodeError{7}", err) } if len(sshArgsSeen) == 0 || sshArgsSeen[len(sshArgsSeen)-1] != "false" { t.Fatalf("sshArgsSeen = %v, want trailing command 'false'", sshArgsSeen) @@ -1843,6 +1721,7 @@ func TestNewBangerdCommandRejectsArgs(t *testing.T) { } func TestDaemonOutdated(t *testing.T) { + d := defaultDeps() dir := t.TempDir() current := filepath.Join(dir, "bangerd-current") same := filepath.Join(dir, "bangerd-same") @@ -1857,27 +1736,20 @@ func TestDaemonOutdated(t *testing.T) { t.Fatalf("write stale: %v", err) } - origBangerdPath := bangerdPathFunc - origDaemonExePath := daemonExePath - t.Cleanup(func() { - bangerdPathFunc = origBangerdPath - daemonExePath = origDaemonExePath - }) - - bangerdPathFunc = func() (string, error) { + d.bangerdPath = func() (string, error) { return current, nil } - daemonExePath = func(pid int) string { + d.daemonExePath = func(pid int) string { if pid == 1 { return same } return stale } - if daemonOutdated(1) { + if d.daemonOutdated(1) { t.Fatal("expected matching daemon executable to be current") } - if !daemonOutdated(2) { + if !d.daemonOutdated(2) { t.Fatal("expected replaced daemon executable to be outdated") } } @@ -1912,10 +1784,7 @@ func TestDaemonStatusIncludesLogPathWhenStopped(t *testing.T) { } func TestDaemonStatusIncludesDaemonBuildInfoWhenRunning(t *testing.T) { - origDaemonPing := daemonPingFunc - t.Cleanup(func() { - daemonPingFunc = origDaemonPing - }) + d := defaultDeps() configHome := filepath.Join(t.TempDir(), "config") stateHome := filepath.Join(t.TempDir(), "state") @@ -1924,7 +1793,7 @@ func TestDaemonStatusIncludesDaemonBuildInfoWhenRunning(t *testing.T) { t.Setenv("XDG_STATE_HOME", stateHome) t.Setenv("XDG_RUNTIME_DIR", runtimeHome) - daemonPingFunc = func(context.Context, string) (api.PingResult, error) { + d.daemonPing = func(context.Context, string) (api.PingResult, error) { return api.PingResult{ Status: "ok", PID: 42, @@ -1934,7 +1803,7 @@ func TestDaemonStatusIncludesDaemonBuildInfoWhenRunning(t *testing.T) { }, nil } - cmd := NewBangerCommand() + cmd := d.newRootCommand() var stdout bytes.Buffer cmd.SetOut(&stdout) cmd.SetErr(&stdout) @@ -2073,26 +1942,26 @@ func TestVMSessionSendRejectsWrongArgCount(t *testing.T) { } } -func stubEnsureDaemonForSend(t *testing.T) { +// stubEnsureDaemonForSend isolates XDG dirs and installs a daemon-ping +// fake onto the caller's *deps so `ensureDaemon` short-circuits without +// trying to spawn bangerd. `vm session send` uses this to avoid needing +// a built binary on disk. +func stubEnsureDaemonForSend(t *testing.T, d *deps) { t.Helper() t.Setenv("XDG_CONFIG_HOME", filepath.Join(t.TempDir(), "config")) t.Setenv("XDG_STATE_HOME", filepath.Join(t.TempDir(), "state")) t.Setenv("XDG_RUNTIME_DIR", filepath.Join(t.TempDir(), "run")) - origPing := daemonPingFunc - t.Cleanup(func() { daemonPingFunc = origPing }) - daemonPingFunc = func(context.Context, string) (api.PingResult, error) { + d.daemonPing = func(context.Context, string) (api.PingResult, error) { return api.PingResult{Status: "ok", PID: os.Getpid()}, nil } } func TestVMSessionSendWithMessageFlag(t *testing.T) { - stubEnsureDaemonForSend(t) - - original := guestSessionSendFunc - t.Cleanup(func() { guestSessionSendFunc = original }) + d := defaultDeps() + stubEnsureDaemonForSend(t, d) var capturedParams api.GuestSessionSendParams - guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { + d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { capturedParams = params return api.GuestSessionSendResult{ Session: model.GuestSession{ID: "sess-id", Name: "planner"}, @@ -2100,7 +1969,7 @@ func TestVMSessionSendWithMessageFlag(t *testing.T) { }, nil } - cmd := NewBangerCommand() + cmd := d.newRootCommand() var out bytes.Buffer cmd.SetOut(&out) cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner", "--message", `{"type":"abort"}`}) @@ -2124,13 +1993,11 @@ func TestVMSessionSendWithMessageFlag(t *testing.T) { } func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) { - stubEnsureDaemonForSend(t) - - original := guestSessionSendFunc - t.Cleanup(func() { guestSessionSendFunc = original }) + d := defaultDeps() + stubEnsureDaemonForSend(t, d) var capturedPayload []byte - guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { + d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { capturedPayload = params.Payload return api.GuestSessionSendResult{ Session: model.GuestSession{Name: "s"}, @@ -2138,7 +2005,7 @@ func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) { }, nil } - cmd := NewBangerCommand() + cmd := d.newRootCommand() cmd.SetOut(io.Discard) cmd.SetArgs([]string{"vm", "session", "send", "devbox", "s", "--message", "{\"type\":\"abort\"}\n"}) if err := cmd.Execute(); err != nil { @@ -2155,13 +2022,11 @@ func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) { } func TestVMSessionSendFromStdin(t *testing.T) { - stubEnsureDaemonForSend(t) - - original := guestSessionSendFunc - t.Cleanup(func() { guestSessionSendFunc = original }) + d := defaultDeps() + stubEnsureDaemonForSend(t, d) var capturedPayload []byte - guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { + d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { capturedPayload = params.Payload return api.GuestSessionSendResult{ Session: model.GuestSession{Name: "planner"}, @@ -2170,7 +2035,7 @@ func TestVMSessionSendFromStdin(t *testing.T) { } stdinPayload := `{"type":"steer","message":"Focus on src/"}` + "\n" - cmd := NewBangerCommand() + cmd := d.newRootCommand() cmd.SetOut(io.Discard) cmd.SetIn(strings.NewReader(stdinPayload)) cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner"}) @@ -2208,13 +2073,11 @@ func TestVMWorkspaceExportRejectsMissingArg(t *testing.T) { } func TestVMWorkspaceExportWritesToStdout(t *testing.T) { - stubEnsureDaemonForSend(t) - - origExport := vmWorkspaceExportFunc - t.Cleanup(func() { vmWorkspaceExportFunc = origExport }) + d := defaultDeps() + stubEnsureDaemonForSend(t, d) patch := []byte("diff --git a/main.go b/main.go\nindex 0000000..1111111 100644\n") - vmWorkspaceExportFunc = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { + d.vmWorkspaceExport = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { return api.WorkspaceExportResult{ GuestPath: params.GuestPath, Patch: patch, @@ -2223,7 +2086,7 @@ func TestVMWorkspaceExportWritesToStdout(t *testing.T) { }, nil } - cmd := NewBangerCommand() + cmd := d.newRootCommand() var out bytes.Buffer cmd.SetOut(&out) cmd.SetErr(io.Discard) @@ -2237,13 +2100,11 @@ func TestVMWorkspaceExportWritesToStdout(t *testing.T) { } func TestVMWorkspaceExportWritesToFile(t *testing.T) { - stubEnsureDaemonForSend(t) - - origExport := vmWorkspaceExportFunc - t.Cleanup(func() { vmWorkspaceExportFunc = origExport }) + d := defaultDeps() + stubEnsureDaemonForSend(t, d) patch := []byte("diff --git a/main.go b/main.go\n") - vmWorkspaceExportFunc = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { + d.vmWorkspaceExport = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { return api.WorkspaceExportResult{ GuestPath: "/root/repo", Patch: patch, @@ -2253,7 +2114,7 @@ func TestVMWorkspaceExportWritesToFile(t *testing.T) { } outFile := filepath.Join(t.TempDir(), "worker.diff") - cmd := NewBangerCommand() + cmd := d.newRootCommand() cmd.SetOut(io.Discard) var stderr bytes.Buffer cmd.SetErr(&stderr) @@ -2275,19 +2136,17 @@ func TestVMWorkspaceExportWritesToFile(t *testing.T) { } func TestVMWorkspaceExportNoChanges(t *testing.T) { - stubEnsureDaemonForSend(t) + d := defaultDeps() + stubEnsureDaemonForSend(t, d) - origExport := vmWorkspaceExportFunc - t.Cleanup(func() { vmWorkspaceExportFunc = origExport }) - - vmWorkspaceExportFunc = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { + d.vmWorkspaceExport = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { return api.WorkspaceExportResult{ GuestPath: "/root/repo", HasChanges: false, }, nil } - cmd := NewBangerCommand() + cmd := d.newRootCommand() var out bytes.Buffer var stderr bytes.Buffer cmd.SetOut(&out) @@ -2305,18 +2164,16 @@ func TestVMWorkspaceExportNoChanges(t *testing.T) { } func TestVMWorkspaceExportGuestPathFlag(t *testing.T) { - stubEnsureDaemonForSend(t) - - origExport := vmWorkspaceExportFunc - t.Cleanup(func() { vmWorkspaceExportFunc = origExport }) + d := defaultDeps() + stubEnsureDaemonForSend(t, d) var capturedParams api.WorkspaceExportParams - vmWorkspaceExportFunc = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { + d.vmWorkspaceExport = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { capturedParams = params return api.WorkspaceExportResult{HasChanges: false}, nil } - cmd := NewBangerCommand() + cmd := d.newRootCommand() cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--guest-path", "/root/project"}) @@ -2332,13 +2189,11 @@ func TestVMWorkspaceExportGuestPathFlag(t *testing.T) { } func TestVMWorkspaceExportBaseCommitFlag(t *testing.T) { - stubEnsureDaemonForSend(t) - - origExport := vmWorkspaceExportFunc - t.Cleanup(func() { vmWorkspaceExportFunc = origExport }) + d := defaultDeps() + stubEnsureDaemonForSend(t, d) var capturedParams api.WorkspaceExportParams - vmWorkspaceExportFunc = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { + d.vmWorkspaceExport = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { capturedParams = params return api.WorkspaceExportResult{ HasChanges: false, @@ -2346,7 +2201,7 @@ func TestVMWorkspaceExportBaseCommitFlag(t *testing.T) { }, nil } - cmd := NewBangerCommand() + cmd := d.newRootCommand() cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--base-commit", "abc1234deadbeef"}) diff --git a/internal/cli/commands_daemon.go b/internal/cli/commands_daemon.go index 4342816..f2f1d86 100644 --- a/internal/cli/commands_daemon.go +++ b/internal/cli/commands_daemon.go @@ -15,7 +15,7 @@ import ( "github.com/spf13/cobra" ) -func newDaemonCommand() *cobra.Command { +func (d *deps) newDaemonCommand() *cobra.Command { cmd := &cobra.Command{ Use: "daemon", Short: "Manage the banger daemon", @@ -31,7 +31,7 @@ func newDaemonCommand() *cobra.Command { if err != nil { return err } - ping, pingErr := daemonPingFunc(cmd.Context(), layout.SocketPath) + ping, pingErr := d.daemonPing(cmd.Context(), layout.SocketPath) if pingErr != nil { _, err = fmt.Fprintf(cmd.OutOrStdout(), "stopped\nsocket: %s\nlog: %s\ndns: %s\n", layout.SocketPath, layout.DaemonLog, vmdns.DefaultListenAddr) return err diff --git a/internal/cli/commands_image.go b/internal/cli/commands_image.go index 6a30ee9..46e29fe 100644 --- a/internal/cli/commands_image.go +++ b/internal/cli/commands_image.go @@ -13,24 +13,24 @@ import ( "github.com/spf13/cobra" ) -func newImageCommand() *cobra.Command { +func (d *deps) newImageCommand() *cobra.Command { cmd := &cobra.Command{ Use: "image", Short: "Manage images", RunE: helpNoArgs, } cmd.AddCommand( - newImageRegisterCommand(), - newImagePullCommand(), - newImagePromoteCommand(), - newImageListCommand(), - newImageShowCommand(), - newImageDeleteCommand(), + d.newImageRegisterCommand(), + d.newImagePullCommand(), + d.newImagePromoteCommand(), + d.newImageListCommand(), + d.newImageShowCommand(), + d.newImageDeleteCommand(), ) return cmd } -func newImageRegisterCommand() *cobra.Command { +func (d *deps) newImageRegisterCommand() *cobra.Command { var params api.ImageRegisterParams cmd := &cobra.Command{ Use: "register", @@ -46,7 +46,7 @@ func newImageRegisterCommand() *cobra.Command { if err := system.EnsureSudo(cmd.Context()); err != nil { return err } - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -65,11 +65,11 @@ func newImageRegisterCommand() *cobra.Command { cmd.Flags().StringVar(¶ms.ModulesDir, "modules", "", "modules dir") cmd.Flags().StringVar(¶ms.KernelRef, "kernel-ref", "", "name of a cataloged kernel (see 'banger kernel list')") cmd.Flags().BoolVar(¶ms.Docker, "docker", false, "mark image as docker-prepared") - _ = cmd.RegisterFlagCompletionFunc("kernel-ref", completeKernelNames) + _ = cmd.RegisterFlagCompletionFunc("kernel-ref", d.completeKernelNames) return cmd } -func newImagePullCommand() *cobra.Command { +func (d *deps) newImagePullCommand() *cobra.Command { var ( params api.ImagePullParams sizeRaw string @@ -117,7 +117,7 @@ subcommand lands). if err := system.EnsureSudo(cmd.Context()); err != nil { return err } - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -139,21 +139,21 @@ subcommand lands). cmd.Flags().StringVar(¶ms.ModulesDir, "modules", "", "modules dir") cmd.Flags().StringVar(¶ms.KernelRef, "kernel-ref", "", "name of a cataloged kernel (see 'banger kernel list')") cmd.Flags().StringVar(&sizeRaw, "size", "", "ext4 image size (e.g. 4GiB); defaults to content + 25%, min 1GiB") - _ = cmd.RegisterFlagCompletionFunc("kernel-ref", completeKernelNames) + _ = cmd.RegisterFlagCompletionFunc("kernel-ref", d.completeKernelNames) return cmd } -func newImagePromoteCommand() *cobra.Command { +func (d *deps) newImagePromoteCommand() *cobra.Command { return &cobra.Command{ Use: "promote ", Short: "Promote an unmanaged image to a managed artifact", Args: exactArgsUsage(1, "usage: banger image promote "), - ValidArgsFunction: completeImageNameOnlyAtPos0, + ValidArgsFunction: d.completeImageNameOnlyAtPos0, RunE: func(cmd *cobra.Command, args []string) error { if err := system.EnsureSudo(cmd.Context()); err != nil { return err } - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -166,14 +166,14 @@ func newImagePromoteCommand() *cobra.Command { } } -func newImageListCommand() *cobra.Command { +func (d *deps) newImageListCommand() *cobra.Command { return &cobra.Command{ Use: "list", Aliases: []string{"ls"}, Short: "List images", Args: noArgsUsage("usage: banger image list"), RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -186,14 +186,14 @@ func newImageListCommand() *cobra.Command { } } -func newImageShowCommand() *cobra.Command { +func (d *deps) newImageShowCommand() *cobra.Command { return &cobra.Command{ Use: "show ", Short: "Show image details", Args: exactArgsUsage(1, "usage: banger image show "), - ValidArgsFunction: completeImageNameOnlyAtPos0, + ValidArgsFunction: d.completeImageNameOnlyAtPos0, RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -206,18 +206,18 @@ func newImageShowCommand() *cobra.Command { } } -func newImageDeleteCommand() *cobra.Command { +func (d *deps) newImageDeleteCommand() *cobra.Command { return &cobra.Command{ Use: "delete ", Aliases: []string{"rm"}, Short: "Delete an image", Args: exactArgsUsage(1, "usage: banger image delete "), - ValidArgsFunction: completeImageNameOnlyAtPos0, + ValidArgsFunction: d.completeImageNameOnlyAtPos0, RunE: func(cmd *cobra.Command, args []string) error { if err := system.EnsureSudo(cmd.Context()); err != nil { return err } - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } diff --git a/internal/cli/commands_internal.go b/internal/cli/commands_internal.go index 3902aa2..2201b21 100644 --- a/internal/cli/commands_internal.go +++ b/internal/cli/commands_internal.go @@ -25,7 +25,7 @@ import ( "github.com/spf13/cobra" ) -func newInternalCommand() *cobra.Command { +func (d *deps) newInternalCommand() *cobra.Command { cmd := &cobra.Command{ Use: "internal", Hidden: true, diff --git a/internal/cli/commands_kernel.go b/internal/cli/commands_kernel.go index 27bd13b..5f7acbc 100644 --- a/internal/cli/commands_kernel.go +++ b/internal/cli/commands_kernel.go @@ -12,30 +12,30 @@ import ( "github.com/spf13/cobra" ) -func newKernelCommand() *cobra.Command { +func (d *deps) newKernelCommand() *cobra.Command { cmd := &cobra.Command{ Use: "kernel", Short: "Manage the local kernel catalog", RunE: helpNoArgs, } cmd.AddCommand( - newKernelListCommand(), - newKernelShowCommand(), - newKernelRmCommand(), - newKernelImportCommand(), - newKernelPullCommand(), + d.newKernelListCommand(), + d.newKernelShowCommand(), + d.newKernelRmCommand(), + d.newKernelImportCommand(), + d.newKernelPullCommand(), ) return cmd } -func newKernelPullCommand() *cobra.Command { +func (d *deps) newKernelPullCommand() *cobra.Command { var force bool cmd := &cobra.Command{ Use: "pull ", Short: "Download a cataloged kernel bundle", Args: exactArgsUsage(1, "usage: banger kernel pull [--force]"), RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -55,7 +55,7 @@ func newKernelPullCommand() *cobra.Command { return cmd } -func newKernelImportCommand() *cobra.Command { +func (d *deps) newKernelImportCommand() *cobra.Command { var params api.KernelImportParams cmd := &cobra.Command{ Use: "import ", @@ -72,7 +72,7 @@ func newKernelImportCommand() *cobra.Command { return err } params.FromDir = abs - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -89,7 +89,7 @@ func newKernelImportCommand() *cobra.Command { return cmd } -func newKernelListCommand() *cobra.Command { +func (d *deps) newKernelListCommand() *cobra.Command { var available bool cmd := &cobra.Command{ Use: "list", @@ -97,7 +97,7 @@ func newKernelListCommand() *cobra.Command { Short: "List kernels (local by default, or --available for the catalog)", Args: noArgsUsage("usage: banger kernel list [--available]"), RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -119,14 +119,14 @@ func newKernelListCommand() *cobra.Command { return cmd } -func newKernelShowCommand() *cobra.Command { +func (d *deps) newKernelShowCommand() *cobra.Command { return &cobra.Command{ Use: "show ", Short: "Show kernel catalog entry details", Args: exactArgsUsage(1, "usage: banger kernel show "), - ValidArgsFunction: completeKernelNameOnlyAtPos0, + ValidArgsFunction: d.completeKernelNameOnlyAtPos0, RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -139,15 +139,15 @@ func newKernelShowCommand() *cobra.Command { } } -func newKernelRmCommand() *cobra.Command { +func (d *deps) newKernelRmCommand() *cobra.Command { return &cobra.Command{ Use: "rm ", Aliases: []string{"remove", "delete"}, Short: "Remove a kernel catalog entry", Args: exactArgsUsage(1, "usage: banger kernel rm "), - ValidArgsFunction: completeKernelNameOnlyAtPos0, + ValidArgsFunction: d.completeKernelNameOnlyAtPos0, RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } diff --git a/internal/cli/commands_vm.go b/internal/cli/commands_vm.go index a642978..f85198b 100644 --- a/internal/cli/commands_vm.go +++ b/internal/cli/commands_vm.go @@ -22,35 +22,35 @@ import ( "github.com/spf13/cobra" ) -func newVMCommand() *cobra.Command { +func (d *deps) newVMCommand() *cobra.Command { cmd := &cobra.Command{ Use: "vm", Short: "Manage virtual machines", RunE: helpNoArgs, } cmd.AddCommand( - newVMCreateCommand(), - newVMRunCommand(), - newVMListCommand(), - newVMShowCommand(), - newVMActionCommand("start", "Start a VM", "vm.start"), - newVMActionCommand("stop", "Stop a VM", "vm.stop"), - newVMKillCommand(), - newVMActionCommand("restart", "Restart a VM", "vm.restart"), - newVMActionCommand("delete", "Delete a VM", "vm.delete", "rm"), - newVMPruneCommand(), - newVMSetCommand(), - newVMSSHCommand(), - newVMWorkspaceCommand(), - newVMSessionCommand(), - newVMLogsCommand(), - newVMStatsCommand(), - newVMPortsCommand(), + d.newVMCreateCommand(), + d.newVMRunCommand(), + d.newVMListCommand(), + d.newVMShowCommand(), + d.newVMActionCommand("start", "Start a VM", "vm.start"), + d.newVMActionCommand("stop", "Stop a VM", "vm.stop"), + d.newVMKillCommand(), + d.newVMActionCommand("restart", "Restart a VM", "vm.restart"), + d.newVMActionCommand("delete", "Delete a VM", "vm.delete", "rm"), + d.newVMPruneCommand(), + d.newVMSetCommand(), + d.newVMSSHCommand(), + d.newVMWorkspaceCommand(), + d.newVMSessionCommand(), + d.newVMLogsCommand(), + d.newVMStatsCommand(), + d.newVMPortsCommand(), ) return cmd } -func newVMRunCommand() *cobra.Command { +func (d *deps) newVMRunCommand() *cobra.Command { defaults := effectiveVMDefaults() var ( name string @@ -104,7 +104,7 @@ Three modes: var repoPtr *vmRunRepo if sourcePath != "" { - resolved, err := vmRunPreflightRepo(cmd.Context(), sourcePath) + resolved, err := d.vmRunPreflightRepo(cmd.Context(), sourcePath) if err != nil { return err } @@ -135,11 +135,11 @@ Three modes: if err := system.EnsureSudo(cmd.Context()); err != nil { return err } - layout, cfg, err = ensureDaemon(cmd.Context()) + layout, cfg, err = d.ensureDaemon(cmd.Context()) if err != nil { return err } - return runVMRun(cmd.Context(), layout.SocketPath, cfg, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), params, repoPtr, commandArgs, removeOnExit) + return d.runVMRun(cmd.Context(), layout.SocketPath, cfg, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), params, repoPtr, commandArgs, removeOnExit) }, } cmd.Flags().StringVar(&name, "name", "", "vm name") @@ -152,22 +152,22 @@ Three modes: cmd.Flags().StringVar(&branchName, "branch", "", "create and switch to a new guest branch") cmd.Flags().StringVar(&fromRef, "from", "HEAD", "base ref for --branch") cmd.Flags().BoolVar(&removeOnExit, "rm", false, "delete the VM after the ssh session / command exits") - _ = cmd.RegisterFlagCompletionFunc("image", completeImageNames) + _ = cmd.RegisterFlagCompletionFunc("image", d.completeImageNames) return cmd } -func newVMKillCommand() *cobra.Command { +func (d *deps) newVMKillCommand() *cobra.Command { var signal string cmd := &cobra.Command{ Use: "kill ...", Short: "Send a signal to a VM process", Args: minArgsUsage(1, "usage: banger vm kill [--signal SIGTERM|SIGKILL|...] ..."), - ValidArgsFunction: completeVMNames, + ValidArgsFunction: d.completeVMNames, RunE: func(cmd *cobra.Command, args []string) error { if err := system.EnsureSudo(cmd.Context()); err != nil { return err } - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -201,7 +201,7 @@ func newVMKillCommand() *cobra.Command { return cmd } -func newVMPruneCommand() *cobra.Command { +func (d *deps) newVMPruneCommand() *cobra.Command { var force bool cmd := &cobra.Command{ Use: "prune", @@ -212,23 +212,23 @@ func newVMPruneCommand() *cobra.Command { if err := system.EnsureSudo(cmd.Context()); err != nil { return err } - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } - return runVMPrune(cmd, layout.SocketPath, force) + return d.runVMPrune(cmd, layout.SocketPath, force) }, } cmd.Flags().BoolVarP(&force, "force", "f", false, "skip the confirmation prompt") return cmd } -func runVMPrune(cmd *cobra.Command, socketPath string, force bool) error { +func (d *deps) runVMPrune(cmd *cobra.Command, socketPath string, force bool) error { ctx := cmd.Context() stdout := cmd.OutOrStdout() stderr := cmd.ErrOrStderr() - list, err := vmListFunc(ctx, socketPath) + list, err := d.vmList(ctx, socketPath) if err != nil { return err } @@ -270,7 +270,7 @@ func runVMPrune(cmd *cobra.Command, socketPath string, force bool) error { if ref == "" { ref = shortID(vm.ID) } - if err := vmDeleteFunc(ctx, socketPath, vm.ID); err != nil { + if err := d.vmDelete(ctx, socketPath, vm.ID); err != nil { fmt.Fprintf(stderr, "delete %s: %v\n", ref, err) failed++ continue @@ -299,7 +299,7 @@ func promptYesNo(in io.Reader, out io.Writer, prompt string) (bool, error) { return answer == "y" || answer == "yes", nil } -func newVMCreateCommand() *cobra.Command { +func (d *deps) newVMCreateCommand() *cobra.Command { defaults := effectiveVMDefaults() var ( name string @@ -323,11 +323,11 @@ func newVMCreateCommand() *cobra.Command { if err := system.EnsureSudo(cmd.Context()); err != nil { return err } - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } - vm, err := runVMCreate(cmd.Context(), layout.SocketPath, cmd.ErrOrStderr(), params) + vm, err := d.runVMCreate(cmd.Context(), layout.SocketPath, cmd.ErrOrStderr(), params) if err != nil { return err } @@ -342,7 +342,7 @@ func newVMCreateCommand() *cobra.Command { cmd.Flags().StringVar(&workDiskSize, "disk-size", model.FormatSizeBytes(defaults.WorkDiskSizeBytes), "work disk size") cmd.Flags().BoolVar(&natEnabled, "nat", false, "enable NAT") cmd.Flags().BoolVar(&noStart, "no-start", false, "create without starting") - _ = cmd.RegisterFlagCompletionFunc("image", completeImageNames) + _ = cmd.RegisterFlagCompletionFunc("image", d.completeImageNames) return cmd } @@ -352,15 +352,15 @@ type vmListOptions struct { quiet bool } -func newPSCommand() *cobra.Command { - return newVMListLikeCommand("ps", nil, "usage: banger ps") +func (d *deps) newPSCommand() *cobra.Command { + return d.newVMListLikeCommand("ps", nil, "usage: banger ps") } -func newVMListCommand() *cobra.Command { - return newVMListLikeCommand("list", []string{"ls", "ps"}, "usage: banger vm list") +func (d *deps) newVMListCommand() *cobra.Command { + return d.newVMListLikeCommand("list", []string{"ls", "ps"}, "usage: banger vm list") } -func newVMListLikeCommand(use string, aliases []string, usage string) *cobra.Command { +func (d *deps) newVMListLikeCommand(use string, aliases []string, usage string) *cobra.Command { var opts vmListOptions cmd := &cobra.Command{ Use: use, @@ -368,7 +368,7 @@ func newVMListLikeCommand(use string, aliases []string, usage string) *cobra.Com Short: "List VMs", Args: noArgsUsage(usage), RunE: func(cmd *cobra.Command, args []string) error { - return runVMList(cmd, opts) + return d.runVMList(cmd, opts) }, } cmd.Flags().BoolVarP(&opts.showAll, "all", "a", false, "show all VMs") @@ -377,8 +377,8 @@ func newVMListLikeCommand(use string, aliases []string, usage string) *cobra.Com return cmd } -func runVMList(cmd *cobra.Command, opts vmListOptions) error { - layout, _, err := ensureDaemon(cmd.Context()) +func (d *deps) runVMList(cmd *cobra.Command, opts vmListOptions) error { + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -421,14 +421,14 @@ func selectVMListVMs(vms []model.VMRecord, showAll, latest bool) []model.VMRecor return []model.VMRecord{latestVM} } -func newVMShowCommand() *cobra.Command { +func (d *deps) newVMShowCommand() *cobra.Command { return &cobra.Command{ Use: "show ", Short: "Show VM details", Args: exactArgsUsage(1, "usage: banger vm show "), - ValidArgsFunction: completeVMNameOnlyAtPos0, + ValidArgsFunction: d.completeVMNameOnlyAtPos0, RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -441,18 +441,18 @@ func newVMShowCommand() *cobra.Command { } } -func newVMActionCommand(use, short, method string, aliases ...string) *cobra.Command { +func (d *deps) newVMActionCommand(use, short, method string, aliases ...string) *cobra.Command { return &cobra.Command{ Use: use + " ...", Aliases: aliases, Short: short, Args: minArgsUsage(1, fmt.Sprintf("usage: banger vm %s ...", use)), - ValidArgsFunction: completeVMNames, + ValidArgsFunction: d.completeVMNames, RunE: func(cmd *cobra.Command, args []string) error { if err := system.EnsureSudo(cmd.Context()); err != nil { return err } - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -474,7 +474,7 @@ func newVMActionCommand(use, short, method string, aliases ...string) *cobra.Com } } -func newVMSetCommand() *cobra.Command { +func (d *deps) newVMSetCommand() *cobra.Command { var ( vcpu int memory int @@ -486,7 +486,7 @@ func newVMSetCommand() *cobra.Command { Use: "set ...", Short: "Update stopped VM settings", Args: minArgsUsage(1, "usage: banger vm set [--vcpu N] [--memory MiB] [--disk-size SIZE] [--nat|--no-nat] ..."), - ValidArgsFunction: completeVMNames, + ValidArgsFunction: d.completeVMNames, RunE: func(cmd *cobra.Command, args []string) error { params, err := vmSetParamsFromFlags(args[0], vcpu, memory, diskSize, nat, noNat) if err != nil { @@ -495,7 +495,7 @@ func newVMSetCommand() *cobra.Command { if err := system.EnsureSudo(cmd.Context()); err != nil { return err } - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -525,21 +525,21 @@ func newVMSetCommand() *cobra.Command { return cmd } -func newVMSSHCommand() *cobra.Command { +func (d *deps) newVMSSHCommand() *cobra.Command { return &cobra.Command{ Use: "ssh [ssh args...]", Short: "SSH into a running VM", Args: minArgsUsage(1, "usage: banger vm ssh [ssh args...]"), - ValidArgsFunction: completeVMNameOnlyAtPos0, + ValidArgsFunction: d.completeVMNameOnlyAtPos0, RunE: func(cmd *cobra.Command, args []string) error { - layout, cfg, err := ensureDaemon(cmd.Context()) + layout, cfg, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } if err := validateSSHPrereqs(cfg); err != nil { return err } - result, err := vmSSHFunc(cmd.Context(), layout.SocketPath, args[0]) + result, err := d.vmSSH(cmd.Context(), layout.SocketPath, args[0]) if err != nil { return err } @@ -547,25 +547,25 @@ func newVMSSHCommand() *cobra.Command { if err != nil { return err } - return runSSHSession(cmd.Context(), layout.SocketPath, result.Name, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), sshArgs, false) + return d.runSSHSession(cmd.Context(), layout.SocketPath, result.Name, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), sshArgs, false) }, } } -func newVMWorkspaceCommand() *cobra.Command { +func (d *deps) newVMWorkspaceCommand() *cobra.Command { cmd := &cobra.Command{ Use: "workspace", Short: "Manage repository workspaces inside a running VM", RunE: helpNoArgs, } cmd.AddCommand( - newVMWorkspacePrepareCommand(), - newVMWorkspaceExportCommand(), + d.newVMWorkspacePrepareCommand(), + d.newVMWorkspaceExportCommand(), ) return cmd } -func newVMWorkspacePrepareCommand() *cobra.Command { +func (d *deps) newVMWorkspacePrepareCommand() *cobra.Command { var guestPath string var branchName string var fromRef string @@ -576,14 +576,14 @@ func newVMWorkspacePrepareCommand() *cobra.Command { Short: "Copy a local repo into a running VM", Long: "Prepare a repository workspace from a local git checkout into a running VM. The default guest path is /root/repo and the default mode is shallow_overlay. Repositories with git submodules must use --mode full_copy.", Args: minArgsUsage(1, "usage: banger vm workspace prepare [path]"), - ValidArgsFunction: completeVMNameOnlyAtPos0, + ValidArgsFunction: d.completeVMNameOnlyAtPos0, Example: strings.TrimSpace(` banger vm workspace prepare devbox banger vm workspace prepare devbox ../repo --guest-path /root/repo --readonly banger vm workspace prepare devbox ../repo --mode full_copy `), RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -592,7 +592,7 @@ func newVMWorkspacePrepareCommand() *cobra.Command { sourcePath = args[1] } if strings.TrimSpace(sourcePath) == "" { - wd, err := cwdFunc() + wd, err := d.cwd() if err != nil { return err } @@ -606,7 +606,7 @@ func newVMWorkspacePrepareCommand() *cobra.Command { if strings.TrimSpace(branchName) != "" { prepareFrom = fromRef } - result, err := vmWorkspacePrepareFunc(cmd.Context(), layout.SocketPath, api.VMWorkspacePrepareParams{ + result, err := d.vmWorkspacePrepare(cmd.Context(), layout.SocketPath, api.VMWorkspacePrepareParams{ IDOrName: args[0], SourcePath: resolvedPath, GuestPath: guestPath, @@ -629,7 +629,7 @@ func newVMWorkspacePrepareCommand() *cobra.Command { return cmd } -func newVMWorkspaceExportCommand() *cobra.Command { +func (d *deps) newVMWorkspaceExportCommand() *cobra.Command { var guestPath string var outputPath string var baseCommit string @@ -638,7 +638,7 @@ func newVMWorkspaceExportCommand() *cobra.Command { Short: "Pull changes from a guest workspace back to the host as a patch", Long: "Emit a binary-safe unified diff of every change inside the guest workspace (committed since base + uncommitted + untracked, minus .gitignore). Non-mutating — the guest's index and working tree are untouched. Pass --base-commit with the head_commit from workspace prepare to capture changes even when the worker ran git commit inside the VM. Without --base-commit the diff is against the current guest HEAD, which misses committed changes.", Args: exactArgsUsage(1, "usage: banger vm workspace export "), - ValidArgsFunction: completeVMNameOnlyAtPos0, + ValidArgsFunction: d.completeVMNameOnlyAtPos0, Example: strings.TrimSpace(` banger vm workspace export devbox | git apply banger vm workspace export devbox --base-commit abc1234 | git apply @@ -646,11 +646,11 @@ func newVMWorkspaceExportCommand() *cobra.Command { banger vm workspace export devbox --guest-path /root/project --output changes.diff `), RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } - result, err := vmWorkspaceExportFunc(cmd.Context(), layout.SocketPath, api.WorkspaceExportParams{ + result, err := d.vmWorkspaceExport(cmd.Context(), layout.SocketPath, api.WorkspaceExportParams{ IDOrName: args[0], GuestPath: guestPath, BaseCommit: baseCommit, @@ -680,15 +680,15 @@ func newVMWorkspaceExportCommand() *cobra.Command { return cmd } -func newVMLogsCommand() *cobra.Command { +func (d *deps) newVMLogsCommand() *cobra.Command { var follow bool cmd := &cobra.Command{ Use: "logs ", Short: "Show VM logs", Args: exactArgsUsage(1, "usage: banger vm logs [-f] "), - ValidArgsFunction: completeVMNameOnlyAtPos0, + ValidArgsFunction: d.completeVMNameOnlyAtPos0, RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -706,14 +706,14 @@ func newVMLogsCommand() *cobra.Command { return cmd } -func newVMStatsCommand() *cobra.Command { +func (d *deps) newVMStatsCommand() *cobra.Command { return &cobra.Command{ Use: "stats ", Short: "Show VM stats", Args: exactArgsUsage(1, "usage: banger vm stats "), - ValidArgsFunction: completeVMNameOnlyAtPos0, + ValidArgsFunction: d.completeVMNameOnlyAtPos0, RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -726,18 +726,18 @@ func newVMStatsCommand() *cobra.Command { } } -func newVMPortsCommand() *cobra.Command { +func (d *deps) newVMPortsCommand() *cobra.Command { return &cobra.Command{ Use: "ports ", Short: "Show host-reachable listening guest ports", Args: exactArgsUsage(1, "usage: banger vm ports "), - ValidArgsFunction: completeVMNameOnlyAtPos0, + ValidArgsFunction: d.completeVMNameOnlyAtPos0, RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } - result, err := vmPortsFunc(cmd.Context(), layout.SocketPath, args[0]) + result, err := d.vmPorts(cmd.Context(), layout.SocketPath, args[0]) if err != nil { return err } diff --git a/internal/cli/commands_vm_session.go b/internal/cli/commands_vm_session.go index 16d8cb4..d539445 100644 --- a/internal/cli/commands_vm_session.go +++ b/internal/cli/commands_vm_session.go @@ -15,7 +15,7 @@ import ( "github.com/spf13/cobra" ) -func newVMSessionCommand() *cobra.Command { +func (d *deps) newVMSessionCommand() *cobra.Command { cmd := &cobra.Command{ Use: "session", Short: "Manage long-lived guest commands inside a VM", @@ -23,19 +23,19 @@ func newVMSessionCommand() *cobra.Command { RunE: helpNoArgs, } cmd.AddCommand( - newVMSessionStartCommand(), - newVMSessionListCommand(), - newVMSessionShowCommand(), - newVMSessionLogsCommand(), - newVMSessionStopCommand(), - newVMSessionKillCommand(), - newVMSessionAttachCommand(), - newVMSessionSendCommand(), + d.newVMSessionStartCommand(), + d.newVMSessionListCommand(), + d.newVMSessionShowCommand(), + d.newVMSessionLogsCommand(), + d.newVMSessionStopCommand(), + d.newVMSessionKillCommand(), + d.newVMSessionAttachCommand(), + d.newVMSessionSendCommand(), ) return cmd } -func newVMSessionStartCommand() *cobra.Command { +func (d *deps) newVMSessionStartCommand() *cobra.Command { var name string var cwd string var stdinMode string @@ -47,13 +47,13 @@ func newVMSessionStartCommand() *cobra.Command { Short: "Start a managed guest command", Long: "Start a daemon-managed guest command. The daemon verifies that the guest working directory exists and that the requested command is present in guest PATH before launch. Use --stdin-mode pipe when you need live attach.", Args: minArgsUsage(2, "usage: banger vm session start [flags] -- [args...]"), - ValidArgsFunction: completeVMNameOnlyAtPos0, + ValidArgsFunction: d.completeVMNameOnlyAtPos0, Example: strings.TrimSpace(` banger vm session start devbox --name planner --cwd /root/repo --stdin-mode pipe --require-command git -- pi --mode rpc --no-session banger vm session start devbox --name shell --stdin-mode pipe -- bash -lc 'exec bash' `), RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -65,7 +65,7 @@ func newVMSessionStartCommand() *cobra.Command { if err != nil { return err } - result, err := guestSessionStartFunc(cmd.Context(), layout.SocketPath, api.GuestSessionStartParams{ + result, err := d.guestSessionStart(cmd.Context(), layout.SocketPath, api.GuestSessionStartParams{ VMIDOrName: args[0], Name: name, Command: args[1], @@ -97,19 +97,19 @@ func newVMSessionStartCommand() *cobra.Command { return cmd } -func newVMSessionListCommand() *cobra.Command { +func (d *deps) newVMSessionListCommand() *cobra.Command { return &cobra.Command{ Use: "list ", Aliases: []string{"ls"}, Short: "List managed guest commands for a VM", Args: exactArgsUsage(1, "usage: banger vm session list "), - ValidArgsFunction: completeVMNameOnlyAtPos0, + ValidArgsFunction: d.completeVMNameOnlyAtPos0, RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } - result, err := guestSessionListFunc(cmd.Context(), layout.SocketPath, args[0]) + result, err := d.guestSessionList(cmd.Context(), layout.SocketPath, args[0]) if err != nil { return err } @@ -118,18 +118,18 @@ func newVMSessionListCommand() *cobra.Command { } } -func newVMSessionShowCommand() *cobra.Command { +func (d *deps) newVMSessionShowCommand() *cobra.Command { return &cobra.Command{ Use: "show ", Short: "Show managed guest command details", Args: exactArgsUsage(2, "usage: banger vm session show "), - ValidArgsFunction: completeSessionNames, + ValidArgsFunction: d.completeSessionNames, RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } - result, err := guestSessionGetFunc(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) + result, err := d.guestSessionGet(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) if err != nil { return err } @@ -138,20 +138,20 @@ func newVMSessionShowCommand() *cobra.Command { } } -func newVMSessionLogsCommand() *cobra.Command { +func (d *deps) newVMSessionLogsCommand() *cobra.Command { var stream string var tailLines int cmd := &cobra.Command{ Use: "logs ", Short: "Show stdout or stderr for a guest session", Args: exactArgsUsage(2, "usage: banger vm session logs [--stream stdout|stderr] [-n LINES] "), - ValidArgsFunction: completeSessionNames, + ValidArgsFunction: d.completeSessionNames, RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } - result, err := guestSessionLogsFunc(cmd.Context(), layout.SocketPath, api.GuestSessionLogsParams{VMIDOrName: args[0], SessionIDOrName: args[1], Stream: stream, TailLines: tailLines}) + result, err := d.guestSessionLogs(cmd.Context(), layout.SocketPath, api.GuestSessionLogsParams{VMIDOrName: args[0], SessionIDOrName: args[1], Stream: stream, TailLines: tailLines}) if err != nil { return err } @@ -164,18 +164,18 @@ func newVMSessionLogsCommand() *cobra.Command { return cmd } -func newVMSessionStopCommand() *cobra.Command { +func (d *deps) newVMSessionStopCommand() *cobra.Command { return &cobra.Command{ Use: "stop ", Short: "Send SIGTERM to a guest session", Args: exactArgsUsage(2, "usage: banger vm session stop "), - ValidArgsFunction: completeSessionNames, + ValidArgsFunction: d.completeSessionNames, RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } - result, err := guestSessionStopFunc(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) + result, err := d.guestSessionStop(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) if err != nil { return err } @@ -184,18 +184,18 @@ func newVMSessionStopCommand() *cobra.Command { } } -func newVMSessionKillCommand() *cobra.Command { +func (d *deps) newVMSessionKillCommand() *cobra.Command { return &cobra.Command{ Use: "kill ", Short: "Send SIGKILL to a guest session", Args: exactArgsUsage(2, "usage: banger vm session kill "), - ValidArgsFunction: completeSessionNames, + ValidArgsFunction: d.completeSessionNames, RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } - result, err := guestSessionKillFunc(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) + result, err := d.guestSessionKill(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) if err != nil { return err } @@ -204,19 +204,19 @@ func newVMSessionKillCommand() *cobra.Command { } } -func newVMSessionAttachCommand() *cobra.Command { +func (d *deps) newVMSessionAttachCommand() *cobra.Command { return &cobra.Command{ Use: "attach ", Short: "Attach local stdio to an attachable guest session", Long: "Attach local stdio to a pipe-mode session through a daemon-created local Unix socket bridge. Only one active attach is allowed at a time, and the client must run on the same host as the daemon.", Args: exactArgsUsage(2, "usage: banger vm session attach "), - ValidArgsFunction: completeSessionNames, + ValidArgsFunction: d.completeSessionNames, RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } - result, err := guestSessionAttachBeginFunc(cmd.Context(), layout.SocketPath, api.GuestSessionAttachBeginParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) + result, err := d.guestSessionAttachBegin(cmd.Context(), layout.SocketPath, api.GuestSessionAttachBeginParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) if err != nil { return err } @@ -229,21 +229,21 @@ func newVMSessionAttachCommand() *cobra.Command { } } -func newVMSessionSendCommand() *cobra.Command { +func (d *deps) newVMSessionSendCommand() *cobra.Command { var message string cmd := &cobra.Command{ Use: "send ", Short: "Write bytes to a running guest session's stdin pipe", Long: "Write a payload to the stdin pipe of a running pipe-mode guest session without holding the exclusive attach. Use --message for an inline JSONL string, or pipe bytes via stdin when --message is omitted. A trailing newline is appended to --message values that lack one.", Args: exactArgsUsage(2, "usage: banger vm session send [--message '']"), - ValidArgsFunction: completeSessionNames, + ValidArgsFunction: d.completeSessionNames, Example: strings.TrimSpace(` banger vm session send devbox planner --message '{"type":"abort"}' banger vm session send devbox planner --message '{"type":"steer","message":"Focus on src/"}' echo '{"type":"prompt","prompt":"Summarize."}' | banger vm session send devbox planner `), RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := ensureDaemon(cmd.Context()) + layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } @@ -259,7 +259,7 @@ func newVMSessionSendCommand() *cobra.Command { return fmt.Errorf("read stdin: %w", err) } } - result, err := guestSessionSendFunc(cmd.Context(), layout.SocketPath, api.GuestSessionSendParams{ + result, err := d.guestSessionSend(cmd.Context(), layout.SocketPath, api.GuestSessionSendParams{ VMIDOrName: args[0], SessionIDOrName: args[1], Payload: payload, diff --git a/internal/cli/completion.go b/internal/cli/completion.go index 871ce85..db42627 100644 --- a/internal/cli/completion.go +++ b/internal/cli/completion.go @@ -21,9 +21,10 @@ import ( // - Fail silently. Completion is advisory; any error path returns an // empty suggestion list rather than propagating to the user. -// completionListerFunc is the seam used by tests to avoid touching a -// real daemon socket. -var completionListerFunc = func(ctx context.Context, socketPath, method string) ([]string, error) { +// defaultCompletionLister + defaultCompletionSessionLister back the +// corresponding *deps fields; tests inject their own fakes via the +// struct instead of mutating package-level vars. +func defaultCompletionLister(ctx context.Context, socketPath, method string) ([]string, error) { switch method { case "vm.list": result, err := rpc.Call[api.VMListResult](ctx, socketPath, method, api.Empty{}) @@ -65,9 +66,7 @@ var completionListerFunc = func(ctx context.Context, socketPath, method string) return nil, nil } -// completionSessionListerFunc is the seam for guest-session name lookups -// scoped to a VM. -var completionSessionListerFunc = func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) { +func defaultCompletionSessionLister(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) { result, err := rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: vmIDOrName}) if err != nil { return nil, err @@ -84,12 +83,12 @@ var completionSessionListerFunc = func(ctx context.Context, socketPath, vmIDOrNa // daemonSocketForCompletion returns the socket path IFF the daemon is // already running. Returns "", false when no daemon is up — completion // callers use this as the bail signal. -func daemonSocketForCompletion(ctx context.Context) (string, bool) { +func (d *deps) daemonSocketForCompletion(ctx context.Context) (string, bool) { layout, err := paths.Resolve() if err != nil { return "", false } - if _, err := daemonPingFunc(ctx, layout.SocketPath); err != nil { + if _, err := d.daemonPing(ctx, layout.SocketPath); err != nil { return "", false } return layout.SocketPath, true @@ -119,12 +118,12 @@ func hasPrefix(s, prefix string) bool { return len(s) >= len(prefix) && s[:len(prefix)] == prefix } -func completeVMNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - socket, ok := daemonSocketForCompletion(cmd.Context()) +func (d *deps) completeVMNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + socket, ok := d.daemonSocketForCompletion(cmd.Context()) if !ok { return nil, cobra.ShellCompDirectiveNoFileComp } - names, err := completionListerFunc(cmd.Context(), socket, "vm.list") + names, err := d.completionLister(cmd.Context(), socket, "vm.list") if err != nil { return nil, cobra.ShellCompDirectiveNoFileComp } @@ -134,45 +133,45 @@ func completeVMNames(cmd *cobra.Command, args []string, toComplete string) ([]st // completeVMNameOnlyAtPos0 restricts VM-name completion to the first // positional argument. Used by commands like `vm ssh [ssh args...]` // where args after pos 0 are free-form. -func completeVMNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { +func (d *deps) completeVMNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { if len(args) > 0 { return nil, cobra.ShellCompDirectiveNoFileComp } - return completeVMNames(cmd, args, toComplete) + return d.completeVMNames(cmd, args, toComplete) } -func completeImageNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { +func (d *deps) completeImageNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { if len(args) > 0 { return nil, cobra.ShellCompDirectiveNoFileComp } - return completeImageNames(cmd, args, toComplete) + return d.completeImageNames(cmd, args, toComplete) } -func completeKernelNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { +func (d *deps) completeKernelNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { if len(args) > 0 { return nil, cobra.ShellCompDirectiveNoFileComp } - return completeKernelNames(cmd, args, toComplete) + return d.completeKernelNames(cmd, args, toComplete) } -func completeImageNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - socket, ok := daemonSocketForCompletion(cmd.Context()) +func (d *deps) completeImageNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + socket, ok := d.daemonSocketForCompletion(cmd.Context()) if !ok { return nil, cobra.ShellCompDirectiveNoFileComp } - names, err := completionListerFunc(cmd.Context(), socket, "image.list") + names, err := d.completionLister(cmd.Context(), socket, "image.list") if err != nil { return nil, cobra.ShellCompDirectiveNoFileComp } return filterPrefix(names, args, toComplete), cobra.ShellCompDirectiveNoFileComp } -func completeKernelNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - socket, ok := daemonSocketForCompletion(cmd.Context()) +func (d *deps) completeKernelNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + socket, ok := d.daemonSocketForCompletion(cmd.Context()) if !ok { return nil, cobra.ShellCompDirectiveNoFileComp } - names, err := completionListerFunc(cmd.Context(), socket, "kernel.list") + names, err := d.completionLister(cmd.Context(), socket, "kernel.list") if err != nil { return nil, cobra.ShellCompDirectiveNoFileComp } @@ -182,16 +181,16 @@ func completeKernelNames(cmd *cobra.Command, args []string, toComplete string) ( // completeSessionNames handles `... ` commands: pos 0 // completes VMs, pos 1 completes sessions owned by args[0], pos 2+ is // silent. -func completeSessionNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { +func (d *deps) completeSessionNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { switch len(args) { case 0: - return completeVMNames(cmd, args, toComplete) + return d.completeVMNames(cmd, args, toComplete) case 1: - socket, ok := daemonSocketForCompletion(cmd.Context()) + socket, ok := d.daemonSocketForCompletion(cmd.Context()) if !ok { return nil, cobra.ShellCompDirectiveNoFileComp } - names, err := completionSessionListerFunc(cmd.Context(), socket, args[0]) + names, err := d.completionSessionLister(cmd.Context(), socket, args[0]) if err != nil { return nil, cobra.ShellCompDirectiveNoFileComp } diff --git a/internal/cli/completion_test.go b/internal/cli/completion_test.go index 6ef2dee..e552732 100644 --- a/internal/cli/completion_test.go +++ b/internal/cli/completion_test.go @@ -12,10 +12,11 @@ import ( ) // stubCompletionSeams installs test doubles for the daemon ping + lister -// seams and restores the originals on cleanup. Tests opt into the -// sub-functions they actually need. +// seams on the caller's *deps. Tests opt into the sub-functions they +// actually need. func stubCompletionSeams( t *testing.T, + d *deps, pingErr error, names map[string][]string, listErr error, @@ -24,28 +25,19 @@ func stubCompletionSeams( ) { t.Helper() - origPing := daemonPingFunc - origLister := completionListerFunc - origSessionLister := completionSessionListerFunc - t.Cleanup(func() { - daemonPingFunc = origPing - completionListerFunc = origLister - completionSessionListerFunc = origSessionLister - }) - - daemonPingFunc = func(ctx context.Context, socketPath string) (api.PingResult, error) { + d.daemonPing = func(ctx context.Context, socketPath string) (api.PingResult, error) { if pingErr != nil { return api.PingResult{}, pingErr } return api.PingResult{}, nil } - completionListerFunc = func(ctx context.Context, socketPath, method string) ([]string, error) { + d.completionLister = func(ctx context.Context, socketPath, method string) ([]string, error) { if listErr != nil { return nil, listErr } return names[method], nil } - completionSessionListerFunc = func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) { + d.completionSessionLister = func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) { if sessionErr != nil { return nil, sessionErr } @@ -89,9 +81,10 @@ func testCmdWithCtx() *cobra.Command { } func TestCompleteVMNamesHappyPath(t *testing.T) { - stubCompletionSeams(t, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil) + d := defaultDeps() + stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil) - got, directive := completeVMNames(testCmdWithCtx(), nil, "") + got, directive := d.completeVMNames(testCmdWithCtx(), nil, "") if directive != cobra.ShellCompDirectiveNoFileComp { t.Errorf("directive = %d, want NoFileComp", directive) } @@ -101,9 +94,10 @@ func TestCompleteVMNamesHappyPath(t *testing.T) { } func TestCompleteVMNamesDaemonDown(t *testing.T) { - stubCompletionSeams(t, errors.New("connection refused"), nil, nil, nil, nil) + d := defaultDeps() + stubCompletionSeams(t, d, errors.New("connection refused"), nil, nil, nil, nil) - got, directive := completeVMNames(testCmdWithCtx(), nil, "") + got, directive := d.completeVMNames(testCmdWithCtx(), nil, "") if len(got) != 0 { t.Errorf("daemon-down should return no suggestions, got %v", got) } @@ -113,18 +107,20 @@ func TestCompleteVMNamesDaemonDown(t *testing.T) { } func TestCompleteVMNamesRPCError(t *testing.T) { - stubCompletionSeams(t, nil, nil, errors.New("rpc failed"), nil, nil) + d := defaultDeps() + stubCompletionSeams(t, d, nil, nil, errors.New("rpc failed"), nil, nil) - got, _ := completeVMNames(testCmdWithCtx(), nil, "") + got, _ := d.completeVMNames(testCmdWithCtx(), nil, "") if len(got) != 0 { t.Errorf("rpc error should return no suggestions, got %v", got) } } func TestCompleteVMNamesExcludesAlreadyEntered(t *testing.T) { - stubCompletionSeams(t, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil) + d := defaultDeps() + stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil) - got, _ := completeVMNames(testCmdWithCtx(), []string{"alpha"}, "") + got, _ := d.completeVMNames(testCmdWithCtx(), []string{"alpha"}, "") want := []string{"beta", "gamma"} if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) @@ -132,9 +128,10 @@ func TestCompleteVMNamesExcludesAlreadyEntered(t *testing.T) { } func TestCompleteVMNamesPrefixFilter(t *testing.T) { - stubCompletionSeams(t, nil, map[string][]string{"vm.list": {"alpha", "beta", "alphabet"}}, nil, nil, nil) + d := defaultDeps() + stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "alphabet"}}, nil, nil, nil) - got, _ := completeVMNames(testCmdWithCtx(), nil, "alp") + got, _ := d.completeVMNames(testCmdWithCtx(), nil, "alp") want := []string{"alpha", "alphabet"} if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) @@ -142,49 +139,53 @@ func TestCompleteVMNamesPrefixFilter(t *testing.T) { } func TestCompleteVMNameOnlyAtPos0(t *testing.T) { - stubCompletionSeams(t, nil, map[string][]string{"vm.list": {"alpha"}}, nil, nil, nil) + d := defaultDeps() + stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha"}}, nil, nil, nil) - atPos0, _ := completeVMNameOnlyAtPos0(testCmdWithCtx(), nil, "") + atPos0, _ := d.completeVMNameOnlyAtPos0(testCmdWithCtx(), nil, "") if len(atPos0) != 1 || atPos0[0] != "alpha" { t.Errorf("pos 0: got %v", atPos0) } - atPos1, _ := completeVMNameOnlyAtPos0(testCmdWithCtx(), []string{"alpha"}, "") + atPos1, _ := d.completeVMNameOnlyAtPos0(testCmdWithCtx(), []string{"alpha"}, "") if len(atPos1) != 0 { t.Errorf("pos 1+ should be silent, got %v", atPos1) } } func TestCompleteImageNames(t *testing.T) { - stubCompletionSeams(t, nil, map[string][]string{"image.list": {"debian-bookworm", "alpine"}}, nil, nil, nil) + d := defaultDeps() + stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"debian-bookworm", "alpine"}}, nil, nil, nil) - got, _ := completeImageNames(testCmdWithCtx(), nil, "") + got, _ := d.completeImageNames(testCmdWithCtx(), nil, "") if !reflect.DeepEqual(got, []string{"debian-bookworm", "alpine"}) { t.Errorf("got %v", got) } } func TestCompleteKernelNames(t *testing.T) { - stubCompletionSeams(t, nil, map[string][]string{"kernel.list": {"generic-6.12"}}, nil, nil, nil) + d := defaultDeps() + stubCompletionSeams(t, d, nil, map[string][]string{"kernel.list": {"generic-6.12"}}, nil, nil, nil) - got, _ := completeKernelNames(testCmdWithCtx(), nil, "") + got, _ := d.completeKernelNames(testCmdWithCtx(), nil, "") if len(got) != 1 || got[0] != "generic-6.12" { t.Errorf("got %v", got) } } func TestCompleteImageNameOnlyAtPos0SilentAfterFirst(t *testing.T) { - stubCompletionSeams(t, nil, map[string][]string{"image.list": {"alpine"}}, nil, nil, nil) + d := defaultDeps() + stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"alpine"}}, nil, nil, nil) - after, _ := completeImageNameOnlyAtPos0(testCmdWithCtx(), []string{"alpine"}, "") + after, _ := d.completeImageNameOnlyAtPos0(testCmdWithCtx(), []string{"alpine"}, "") if len(after) != 0 { t.Errorf("expected silence at pos 1+, got %v", after) } } func TestCompleteSessionNames(t *testing.T) { - stubCompletionSeams( - t, + d := defaultDeps() + stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"devbox"}}, nil, @@ -193,34 +194,35 @@ func TestCompleteSessionNames(t *testing.T) { ) // Position 0 → VMs. - vms, _ := completeSessionNames(testCmdWithCtx(), nil, "") + vms, _ := d.completeSessionNames(testCmdWithCtx(), nil, "") if len(vms) != 1 || vms[0] != "devbox" { t.Errorf("pos 0: got %v", vms) } // Position 1 → sessions scoped to args[0]. - sessions, _ := completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "") + sessions, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "") if !reflect.DeepEqual(sessions, []string{"planner", "worker"}) { t.Errorf("pos 1: got %v", sessions) } // Position 1 with prefix filter. - filtered, _ := completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "wor") + filtered, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "wor") if len(filtered) != 1 || filtered[0] != "worker" { t.Errorf("pos 1 prefix: got %v", filtered) } // Position 2+ silent. - past, _ := completeSessionNames(testCmdWithCtx(), []string{"devbox", "planner"}, "") + past, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox", "planner"}, "") if len(past) != 0 { t.Errorf("pos 2+: got %v", past) } } func TestCompleteSessionNamesDaemonDown(t *testing.T) { - stubCompletionSeams(t, errors.New("down"), nil, nil, nil, nil) + d := defaultDeps() + stubCompletionSeams(t, d, errors.New("down"), nil, nil, nil, nil) - got, directive := completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "") + got, directive := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "") if len(got) != 0 { t.Errorf("expected no suggestions when daemon down, got %v", got) } diff --git a/internal/cli/daemon_lifecycle.go b/internal/cli/daemon_lifecycle.go index 70d5910..5b8822b 100644 --- a/internal/cli/daemon_lifecycle.go +++ b/internal/cli/daemon_lifecycle.go @@ -18,7 +18,7 @@ import ( // ensureDaemon pings the socket; on miss it auto-starts bangerd, on // version mismatch it restarts. Every CLI command that needs to talk // to the daemon routes through here. -func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error) { +func (d *deps) ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error) { layout, err := paths.Resolve() if err != nil { return paths.Layout{}, model.DaemonConfig{}, err @@ -27,16 +27,16 @@ func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error) if err != nil { return paths.Layout{}, model.DaemonConfig{}, err } - if ping, err := daemonPingFunc(ctx, layout.SocketPath); err == nil { - if daemonOutdated(ping.PID) { - if err := restartDaemon(ctx, layout, ping.PID); err != nil { + if ping, err := d.daemonPing(ctx, layout.SocketPath); err == nil { + if d.daemonOutdated(ping.PID) { + if err := d.restartDaemon(ctx, layout, ping.PID); err != nil { return paths.Layout{}, model.DaemonConfig{}, err } return layout, cfg, nil } return layout, cfg, nil } - if err := startDaemon(ctx, layout); err != nil { + if err := d.startDaemon(ctx, layout); err != nil { return paths.Layout{}, model.DaemonConfig{}, err } return layout, cfg, nil @@ -47,11 +47,11 @@ func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error) // session still holds a handle to an old daemon. os.SameFile compares // inode + dev, so a fresh binary at the same path registers as // different. -func daemonOutdated(pid int) bool { +func (d *deps) daemonOutdated(pid int) bool { if pid <= 0 { return false } - daemonBin, err := bangerdPathFunc() + daemonBin, err := d.bangerdPath() if err != nil { return false } @@ -59,20 +59,20 @@ func daemonOutdated(pid int) bool { if err != nil { return false } - runningInfo, err := os.Stat(daemonExePath(pid)) + runningInfo, err := os.Stat(d.daemonExePath(pid)) if err != nil { return false } return !os.SameFile(currentInfo, runningInfo) } -func restartDaemon(ctx context.Context, layout paths.Layout, pid int) error { +func (d *deps) restartDaemon(ctx context.Context, layout paths.Layout, pid int) error { stopCtx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() _, _ = rpc.Call[api.ShutdownResult](stopCtx, layout.SocketPath, "shutdown", api.Empty{}) if waitForPIDExit(pid, 2*time.Second) { - return startDaemon(ctx, layout) + return d.startDaemon(ctx, layout) } if proc, err := os.FindProcess(pid); err == nil { _ = proc.Signal(syscall.SIGTERM) @@ -80,7 +80,7 @@ func restartDaemon(ctx context.Context, layout paths.Layout, pid int) error { if !waitForPIDExit(pid, 2*time.Second) { return fmt.Errorf("timed out restarting stale daemon pid %d", pid) } - return startDaemon(ctx, layout) + return d.startDaemon(ctx, layout) } func waitForPIDExit(pid int, timeout time.Duration) bool { @@ -105,7 +105,7 @@ func pidRunning(pid int) bool { return proc.Signal(syscall.Signal(0)) == nil } -func startDaemon(ctx context.Context, layout paths.Layout) error { +func (d *deps) startDaemon(ctx context.Context, layout paths.Layout) error { if err := paths.Ensure(layout); err != nil { return err } diff --git a/internal/cli/deps.go b/internal/cli/deps.go new file mode 100644 index 0000000..e18bff3 --- /dev/null +++ b/internal/cli/deps.go @@ -0,0 +1,165 @@ +package cli + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "banger/internal/api" + "banger/internal/daemon" + "banger/internal/guest" + "banger/internal/paths" + "banger/internal/rpc" + "banger/internal/system" + "banger/internal/toolingplan" +) + +// deps holds the function seams production code dispatches through and +// tests replace with fakes. Keeping these on a per-invocation struct +// (instead of package-level mutable vars) makes the CLI's external +// surface explicit and lets tests run in parallel without leaking fakes +// across test cases. +// +// Every command builder, orchestrator, and helper that touches the RPC +// socket, spawns a subprocess, or reads host state hangs off a *deps +// receiver. Pure helpers (formatters, path resolvers, arg-count +// validators) stay package-level because they hold no references to +// external systems. +type deps struct { + bangerdPath func() (string, error) + daemonExePath func(pid int) string + doctor func(ctx context.Context) (system.Report, error) + sshExec func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error + hostCommandOutput func(ctx context.Context, name string, args ...string) ([]byte, error) + vmHealth func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) + vmSSH func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error) + vmDelete func(ctx context.Context, socketPath, idOrName string) error + vmList func(ctx context.Context, socketPath string) (api.VMListResult, error) + daemonPing func(ctx context.Context, socketPath string) (api.PingResult, error) + vmCreateBegin func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error) + vmCreateStatus func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error) + vmCreateCancel func(ctx context.Context, socketPath, operationID string) error + vmPorts func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) + vmWorkspacePrepare func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) + vmWorkspaceExport func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) + guestSessionStart func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error) + guestSessionGet func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) + guestSessionList func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error) + guestSessionStop func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) + guestSessionKill func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) + guestSessionLogs func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) + guestSessionAttachBegin func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) + guestSessionSend func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) + guestWaitForSSH func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error + guestDial func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) + buildVMRunToolingPlan func(ctx context.Context, repoRoot string) toolingplan.Plan + cwd func() (string, error) + completionLister func(ctx context.Context, socketPath, method string) ([]string, error) + completionSessionLister func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) +} + +func defaultDeps() *deps { + return &deps{ + bangerdPath: paths.BangerdPath, + daemonExePath: func(pid int) string { + return filepath.Join("/proc", fmt.Sprintf("%d", pid), "exe") + }, + doctor: daemon.Doctor, + sshExec: func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { + sshCmd := exec.CommandContext(ctx, "ssh", args...) + sshCmd.Stdout = stdout + sshCmd.Stderr = stderr + sshCmd.Stdin = stdin + return sshCmd.Run() + }, + hostCommandOutput: func(ctx context.Context, name string, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, name, args...) + output, err := cmd.CombinedOutput() + if err == nil { + return output, nil + } + command := strings.TrimSpace(strings.Join(append([]string{name}, args...), " ")) + detail := strings.TrimSpace(string(output)) + if detail == "" { + return output, fmt.Errorf("%s: %w", command, err) + } + return output, fmt.Errorf("%s: %w: %s", command, err, detail) + }, + vmHealth: func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { + return rpc.Call[api.VMHealthResult](ctx, socketPath, "vm.health", api.VMRefParams{IDOrName: idOrName}) + }, + vmSSH: func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error) { + return rpc.Call[api.VMSSHResult](ctx, socketPath, "vm.ssh", api.VMRefParams{IDOrName: idOrName}) + }, + vmDelete: func(ctx context.Context, socketPath, idOrName string) error { + _, err := rpc.Call[api.VMShowResult](ctx, socketPath, "vm.delete", api.VMRefParams{IDOrName: idOrName}) + return err + }, + vmList: func(ctx context.Context, socketPath string) (api.VMListResult, error) { + return rpc.Call[api.VMListResult](ctx, socketPath, "vm.list", api.Empty{}) + }, + daemonPing: func(ctx context.Context, socketPath string) (api.PingResult, error) { + return rpc.Call[api.PingResult](ctx, socketPath, "ping", api.Empty{}) + }, + vmCreateBegin: func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error) { + return rpc.Call[api.VMCreateBeginResult](ctx, socketPath, "vm.create.begin", params) + }, + vmCreateStatus: func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error) { + return rpc.Call[api.VMCreateStatusResult](ctx, socketPath, "vm.create.status", api.VMCreateStatusParams{ID: operationID}) + }, + vmCreateCancel: func(ctx context.Context, socketPath, operationID string) error { + _, err := rpc.Call[api.Empty](ctx, socketPath, "vm.create.cancel", api.VMCreateStatusParams{ID: operationID}) + return err + }, + vmPorts: func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) { + return rpc.Call[api.VMPortsResult](ctx, socketPath, "vm.ports", api.VMRefParams{IDOrName: idOrName}) + }, + vmWorkspacePrepare: func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { + return rpc.Call[api.VMWorkspacePrepareResult](ctx, socketPath, "vm.workspace.prepare", params) + }, + vmWorkspaceExport: func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { + return rpc.Call[api.WorkspaceExportResult](ctx, socketPath, "vm.workspace.export", params) + }, + guestSessionStart: func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error) { + return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.start", params) + }, + guestSessionGet: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) { + return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.get", params) + }, + guestSessionList: func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error) { + return rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: idOrName}) + }, + guestSessionStop: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) { + return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.stop", params) + }, + guestSessionKill: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) { + return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.kill", params) + }, + guestSessionLogs: func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) { + return rpc.Call[api.GuestSessionLogsResult](ctx, socketPath, "guest.session.logs", params) + }, + guestSessionAttachBegin: func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) { + return rpc.Call[api.GuestSessionAttachBeginResult](ctx, socketPath, "guest.session.attach.begin", params) + }, + guestSessionSend: func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { + return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params) + }, + guestWaitForSSH: func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { + knownHosts, _ := bangerKnownHostsPath() + return guest.WaitForSSH(ctx, address, privateKeyPath, knownHosts, interval) + }, + guestDial: func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { + knownHosts, _ := bangerKnownHostsPath() + return guest.Dial(ctx, address, privateKeyPath, knownHosts) + }, + buildVMRunToolingPlan: toolingplan.Build, + cwd: os.Getwd, + completionLister: defaultCompletionLister, + completionSessionLister: defaultCompletionSessionLister, + } +} diff --git a/internal/cli/prune_test.go b/internal/cli/prune_test.go index 32372cb..cdf86c8 100644 --- a/internal/cli/prune_test.go +++ b/internal/cli/prune_test.go @@ -14,22 +14,17 @@ import ( "github.com/spf13/cobra" ) -// stubPruneSeams installs fakes for vmListFunc and vmDeleteFunc, and -// restores originals on cleanup. -func stubPruneSeams(t *testing.T, vms []model.VMRecord, listErr error, deleteErr map[string]error) *[]string { +// stubPruneSeams installs list + delete fakes onto the caller's *deps +// and returns a pointer to a slice that records every ID passed to the +// delete fake. +func stubPruneSeams(t *testing.T, d *deps, vms []model.VMRecord, listErr error, deleteErr map[string]error) *[]string { t.Helper() - origList := vmListFunc - origDelete := vmDeleteFunc - t.Cleanup(func() { - vmListFunc = origList - vmDeleteFunc = origDelete - }) var deleted []string - vmListFunc = func(ctx context.Context, socketPath string) (api.VMListResult, error) { + d.vmList = func(ctx context.Context, socketPath string) (api.VMListResult, error) { return api.VMListResult{VMs: vms}, listErr } - vmDeleteFunc = func(ctx context.Context, socketPath, idOrName string) error { + d.vmDelete = func(ctx context.Context, socketPath, idOrName string) error { if err, ok := deleteErr[idOrName]; ok { return err } @@ -89,13 +84,14 @@ func TestPromptYesNoEOF(t *testing.T) { } func TestRunVMPruneNoVictims(t *testing.T) { - stubPruneSeams(t, []model.VMRecord{ + d := defaultDeps() + stubPruneSeams(t, d, []model.VMRecord{ {ID: "id-1", Name: "running-vm", State: model.VMStateRunning}, }, nil, nil) cmd, stdout, _ := newPruneTestCmd("") - if err := runVMPrune(cmd, "sock", false); err != nil { - t.Fatalf("runVMPrune: %v", err) + if err := d.runVMPrune(cmd, "sock", false); err != nil { + t.Fatalf("d.runVMPrune: %v", err) } if !strings.Contains(stdout.String(), "no non-running VMs") { t.Errorf("expected no-op message, got %q", stdout.String()) @@ -103,13 +99,14 @@ func TestRunVMPruneNoVictims(t *testing.T) { } func TestRunVMPruneAbortedByUser(t *testing.T) { - deleted := stubPruneSeams(t, []model.VMRecord{ + d := defaultDeps() + deleted := stubPruneSeams(t, d, []model.VMRecord{ {ID: "id-1", Name: "stale", State: model.VMStateStopped}, }, nil, nil) cmd, stdout, _ := newPruneTestCmd("n\n") - if err := runVMPrune(cmd, "sock", false); err != nil { - t.Fatalf("runVMPrune: %v", err) + if err := d.runVMPrune(cmd, "sock", false); err != nil { + t.Fatalf("d.runVMPrune: %v", err) } if !strings.Contains(stdout.String(), "aborted") { t.Errorf("expected 'aborted' output, got %q", stdout.String()) @@ -120,7 +117,8 @@ func TestRunVMPruneAbortedByUser(t *testing.T) { } func TestRunVMPruneConfirmedDeletesNonRunning(t *testing.T) { - deleted := stubPruneSeams(t, []model.VMRecord{ + d := defaultDeps() + deleted := stubPruneSeams(t, d, []model.VMRecord{ {ID: "id-run", Name: "keeper", State: model.VMStateRunning}, {ID: "id-stop", Name: "stale", State: model.VMStateStopped}, {ID: "id-err", Name: "broken", State: model.VMStateError}, @@ -128,8 +126,8 @@ func TestRunVMPruneConfirmedDeletesNonRunning(t *testing.T) { }, nil, nil) cmd, stdout, _ := newPruneTestCmd("y\n") - if err := runVMPrune(cmd, "sock", false); err != nil { - t.Fatalf("runVMPrune: %v", err) + if err := d.runVMPrune(cmd, "sock", false); err != nil { + t.Fatalf("d.runVMPrune: %v", err) } // Deleted must be exactly the three non-running IDs, in list order. want := []string{"id-stop", "id-err", "id-created"} @@ -152,14 +150,15 @@ func TestRunVMPruneConfirmedDeletesNonRunning(t *testing.T) { } func TestRunVMPruneForceSkipsPrompt(t *testing.T) { - deleted := stubPruneSeams(t, []model.VMRecord{ + d := defaultDeps() + deleted := stubPruneSeams(t, d, []model.VMRecord{ {ID: "id-1", Name: "stale", State: model.VMStateStopped}, }, nil, nil) // Empty stdin + force=true: must not block on prompt. cmd, stdout, _ := newPruneTestCmd("") - if err := runVMPrune(cmd, "sock", true); err != nil { - t.Fatalf("runVMPrune: %v", err) + if err := d.runVMPrune(cmd, "sock", true); err != nil { + t.Fatalf("d.runVMPrune: %v", err) } if len(*deleted) != 1 || (*deleted)[0] != "id-1" { t.Errorf("deleted = %v, want [id-1]", *deleted) @@ -171,7 +170,8 @@ func TestRunVMPruneForceSkipsPrompt(t *testing.T) { } func TestRunVMPruneReportsPartialFailure(t *testing.T) { - stubPruneSeams(t, + d := defaultDeps() + stubPruneSeams(t, d, []model.VMRecord{ {ID: "id-a", Name: "a", State: model.VMStateStopped}, {ID: "id-b", Name: "b", State: model.VMStateStopped}, @@ -181,7 +181,7 @@ func TestRunVMPruneReportsPartialFailure(t *testing.T) { ) cmd, _, stderr := newPruneTestCmd("") - err := runVMPrune(cmd, "sock", true) + err := d.runVMPrune(cmd, "sock", true) if err == nil { t.Fatal("expected non-zero exit when any delete fails") } @@ -194,10 +194,11 @@ func TestRunVMPruneReportsPartialFailure(t *testing.T) { } func TestRunVMPruneListErrorPropagates(t *testing.T) { - stubPruneSeams(t, nil, fmt.Errorf("rpc failed"), nil) + d := defaultDeps() + stubPruneSeams(t, d, nil, fmt.Errorf("rpc failed"), nil) cmd, _, _ := newPruneTestCmd("") - err := runVMPrune(cmd, "sock", true) + err := d.runVMPrune(cmd, "sock", true) if err == nil || !strings.Contains(err.Error(), "rpc failed") { t.Fatalf("expected rpc error to propagate, got %v", err) } diff --git a/internal/cli/ssh.go b/internal/cli/ssh.go index a17bbb3..436ef8a 100644 --- a/internal/cli/ssh.go +++ b/internal/cli/ssh.go @@ -20,14 +20,14 @@ import ( // the caller asked (e.g. --rm is about to delete the VM), if the // ctx is already done, or if the ssh error isn't the one that // typically means "user disconnected cleanly". -func runSSHSession(ctx context.Context, socketPath, vmRef string, stdin io.Reader, stdout, stderr io.Writer, sshArgs []string, skipReminder bool) error { - sshErr := sshExecFunc(ctx, stdin, stdout, stderr, sshArgs) +func (d *deps) runSSHSession(ctx context.Context, socketPath, vmRef string, stdin io.Reader, stdout, stderr io.Writer, sshArgs []string, skipReminder bool) error { + sshErr := d.sshExec(ctx, stdin, stdout, stderr, sshArgs) if skipReminder || !shouldCheckSSHReminder(sshErr) || ctx.Err() != nil { return sshErr } pingCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - health, err := vmHealthFunc(pingCtx, socketPath, vmRef) + health, err := d.vmHealth(pingCtx, socketPath, vmRef) if err != nil { _, _ = fmt.Fprintln(stderr, vsockagent.WarningMessage(vmRef, err)) return sshErr diff --git a/internal/cli/vm_create.go b/internal/cli/vm_create.go index af52f0a..a1e8238 100644 --- a/internal/cli/vm_create.go +++ b/internal/cli/vm_create.go @@ -60,9 +60,9 @@ func printVMSpecLine(out io.Writer, params api.VMCreateParams) { // gets the spec line up front and the progress renderer thereafter. // On context cancel we cooperate with the daemon to cancel the // in-flight op so it doesn't leak partially-created VM state. -func runVMCreate(ctx context.Context, socketPath string, stderr io.Writer, params api.VMCreateParams) (model.VMRecord, error) { +func (d *deps) runVMCreate(ctx context.Context, socketPath string, stderr io.Writer, params api.VMCreateParams) (model.VMRecord, error) { printVMSpecLine(stderr, params) - begin, err := vmCreateBeginFunc(ctx, socketPath, params) + begin, err := d.vmCreateBegin(ctx, socketPath, params) if err != nil { return model.VMRecord{}, err } @@ -86,17 +86,17 @@ func runVMCreate(ctx context.Context, socketPath string, stderr io.Writer, param case <-ctx.Done(): cancelCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - _ = vmCreateCancelFunc(cancelCtx, socketPath, op.ID) + _ = d.vmCreateCancel(cancelCtx, socketPath, op.ID) return model.VMRecord{}, ctx.Err() case <-time.After(200 * time.Millisecond): } - status, err := vmCreateStatusFunc(ctx, socketPath, op.ID) + status, err := d.vmCreateStatus(ctx, socketPath, op.ID) if err != nil { if ctx.Err() != nil { cancelCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - _ = vmCreateCancelFunc(cancelCtx, socketPath, op.ID) + _ = d.vmCreateCancel(cancelCtx, socketPath, op.ID) return model.VMRecord{}, ctx.Err() } return model.VMRecord{}, err diff --git a/internal/cli/vm_run.go b/internal/cli/vm_run.go index cab039e..2dce5cd 100644 --- a/internal/cli/vm_run.go +++ b/internal/cli/vm_run.go @@ -80,9 +80,9 @@ func (e ExitCodeError) Error() string { // - it sits inside a non-bare git repository, // - the repository has no submodules (unsupported in the shallow // overlay mode vm run uses). -func vmRunPreflightRepo(ctx context.Context, rawPath string) (string, error) { +func (d *deps) vmRunPreflightRepo(ctx context.Context, rawPath string) (string, error) { if strings.TrimSpace(rawPath) == "" { - wd, err := cwdFunc() + wd, err := d.cwd() if err != nil { return "", err } @@ -131,9 +131,9 @@ func splitVMRunArgs(cmd *cobra.Command, args []string) (pathArgs, commandArgs [] // for guest ssh, optionally materialise a workspace and kick off the // tooling bootstrap, then either attach interactively or run the // user's command and propagate its exit status. -func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, stdin io.Reader, stdout, stderr io.Writer, params api.VMCreateParams, repo *vmRunRepo, command []string, removeOnExit bool) error { +func (d *deps) runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, stdin io.Reader, stdout, stderr io.Writer, params api.VMCreateParams, repo *vmRunRepo, command []string, removeOnExit bool) error { progress := newVMRunProgressRenderer(stderr) - vm, err := runVMCreate(ctx, socketPath, stderr, params) + vm, err := d.runVMCreate(ctx, socketPath, stderr, params) if err != nil { return err } @@ -155,7 +155,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st // doesn't abort the delete RPC. cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if err := vmDeleteFunc(cleanupCtx, socketPath, vmRef); err != nil { + if err := d.vmDelete(cleanupCtx, socketPath, vmRef); err != nil { printVMRunWarning(stderr, fmt.Sprintf("--rm cleanup failed: %v (leaked vm %q; delete manually)", err, vmRef)) } }() @@ -163,7 +163,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st sshAddress := net.JoinHostPort(vm.Runtime.GuestIP, "22") progress.render("waiting for guest ssh") sshCtx, cancelSSH := context.WithTimeout(ctx, vmRunSSHTimeout) - if err := guestWaitForSSHFunc(sshCtx, sshAddress, cfg.SSHKeyPath, 250*time.Millisecond); err != nil { + if err := d.guestWaitForSSH(sshCtx, sshAddress, cfg.SSHKeyPath, 250*time.Millisecond); err != nil { cancelSSH() // Surface parent-context cancellation (Ctrl-C, caller // timeout) as-is. Only the guest-side timeout needs the @@ -193,7 +193,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st if strings.TrimSpace(repo.branchName) != "" { fromRef = repo.fromRef } - prepared, err := vmWorkspacePrepareFunc(ctx, socketPath, api.VMWorkspacePrepareParams{ + prepared, err := d.vmWorkspacePrepare(ctx, socketPath, api.VMWorkspacePrepareParams{ IDOrName: vmRef, SourcePath: repo.sourcePath, GuestPath: vmRunGuestDir(), @@ -208,11 +208,11 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st // daemon side; grab what the tooling harness needs from its // result instead of re-inspecting here. if len(command) == 0 { - client, err := guestDialFunc(ctx, sshAddress, cfg.SSHKeyPath) + client, err := d.guestDial(ctx, sshAddress, cfg.SSHKeyPath) if err != nil { return fmt.Errorf("vm %q is running but guest ssh is unavailable: %w", vmRef, err) } - if err := startVMRunToolingHarness(ctx, client, prepared.Workspace.RepoRoot, prepared.Workspace.RepoName, progress); err != nil { + if err := d.startVMRunToolingHarness(ctx, client, prepared.Workspace.RepoRoot, prepared.Workspace.RepoName, progress); err != nil { printVMRunWarning(stderr, fmt.Sprintf("guest tooling bootstrap start failed: %v", err)) } _ = client.Close() @@ -224,7 +224,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st } if len(command) > 0 { progress.render("running command in guest") - if err := sshExecFunc(ctx, stdin, stdout, stderr, sshArgs); err != nil { + if err := d.sshExec(ctx, stdin, stdout, stderr, sshArgs); err != nil { var exitErr *exec.ExitError if errors.As(err, &exitErr) { return ExitCodeError{Code: exitErr.ExitCode()} @@ -234,7 +234,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st return nil } progress.render("attaching to guest") - return runSSHSession(ctx, socketPath, vmRef, stdin, stdout, stderr, sshArgs, removeOnExit) + return d.runSSHSession(ctx, socketPath, vmRef, stdin, stdout, stderr, sshArgs, removeOnExit) } func vmRunGuestDir() string { @@ -253,11 +253,11 @@ func vmRunToolingHarnessLogPath(repoName string) string { // script inside the guest. repoRoot / repoName both come from the // daemon's workspace.prepare RPC response — the CLI no longer does // its own git inspection. -func startVMRunToolingHarness(ctx context.Context, client vmRunGuestClient, repoRoot, repoName string, progress *vmRunProgressRenderer) error { +func (d *deps) startVMRunToolingHarness(ctx context.Context, client vmRunGuestClient, repoRoot, repoName string, progress *vmRunProgressRenderer) error { if progress != nil { progress.render("starting guest tooling bootstrap") } - plan := buildVMRunToolingPlanFunc(ctx, repoRoot) + plan := d.buildVMRunToolingPlan(ctx, repoRoot) var uploadLog bytes.Buffer if err := client.UploadFile(ctx, vmRunToolingHarnessPath(repoName), 0o755, []byte(vmRunToolingHarnessScript(plan)), &uploadLog); err != nil { return formatVMRunStepError("upload guest tooling bootstrap", err, uploadLog.String()) diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index ea3f0fc..b06ee80 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -18,6 +18,7 @@ import ( "banger/internal/buildinfo" "banger/internal/config" "banger/internal/daemon/opstate" + ws "banger/internal/daemon/workspace" "banger/internal/imagecat" "banger/internal/imagepull" "banger/internal/model" @@ -66,6 +67,8 @@ type Daemon struct { guestWaitForSSH func(context.Context, string, string, time.Duration) error guestDial func(context.Context, string, string) (guestSSHClient, error) waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error) + workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error) + workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error } func Open(ctx context.Context) (d *Daemon, err error) { diff --git a/internal/daemon/workspace.go b/internal/daemon/workspace.go index 531e98c..e285c94 100644 --- a/internal/daemon/workspace.go +++ b/internal/daemon/workspace.go @@ -15,13 +15,25 @@ import ( "banger/internal/model" ) -// Test seams. Tests swap these to observe or stall the guest-I/O -// phase without needing a real git repo or SSH server. Production -// callers see the real implementations from the workspace package. -var ( - workspaceInspectRepoFunc = ws.InspectRepo - workspaceImportFunc = ws.ImportRepoToGuest -) +// workspaceInspectRepoHook + workspaceImportHook dispatch through the +// per-instance Daemon seams when set, falling back to the real +// workspace package implementations. Keeping the fallbacks here (as +// opposed to always requiring callers to populate d.workspaceInspectRepo +// in a constructor) lets tests selectively override one hook without +// having to wire both. +func (d *Daemon) workspaceInspectRepoHook(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error) { + if d != nil && d.workspaceInspectRepo != nil { + return d.workspaceInspectRepo(ctx, sourcePath, branchName, fromRef) + } + return ws.InspectRepo(ctx, sourcePath, branchName, fromRef) +} + +func (d *Daemon) workspaceImportHook(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error { + if d != nil && d.workspaceImport != nil { + return d.workspaceImport(ctx, client, spec, guestPath, mode) + } + return ws.ImportRepoToGuest(ctx, client, spec, guestPath, mode) +} func (d *Daemon) ExportVMWorkspace(ctx context.Context, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { guestPath := strings.TrimSpace(params.GuestPath) @@ -156,7 +168,7 @@ func (d *Daemon) PrepareVMWorkspace(ctx context.Context, params api.VMWorkspaceP // inspect the local repo, dial SSH, stream the tar, optionally chmod // readonly. It is called without holding the VM mutex. func (d *Daemon) prepareVMWorkspaceGuestIO(ctx context.Context, vm model.VMRecord, sourcePath, guestPath, branchName, fromRef string, mode model.WorkspacePrepareMode, readOnly bool) (model.WorkspacePrepareResult, error) { - spec, err := workspaceInspectRepoFunc(ctx, sourcePath, branchName, fromRef) + spec, err := d.workspaceInspectRepoHook(ctx, sourcePath, branchName, fromRef) if err != nil { return model.WorkspacePrepareResult{}, err } @@ -172,7 +184,7 @@ func (d *Daemon) prepareVMWorkspaceGuestIO(ctx context.Context, vm model.VMRecor return model.WorkspacePrepareResult{}, fmt.Errorf("dial guest ssh: %w", err) } defer client.Close() - if err := workspaceImportFunc(ctx, client, spec, guestPath, mode); err != nil { + if err := d.workspaceImportHook(ctx, client, spec, guestPath, mode); err != nil { return model.WorkspacePrepareResult{}, err } if readOnly { diff --git a/internal/daemon/workspace_test.go b/internal/daemon/workspace_test.go index 2194dce..cfe92ff 100644 --- a/internal/daemon/workspace_test.go +++ b/internal/daemon/workspace_test.go @@ -370,9 +370,7 @@ func TestExportVMWorkspace_MultipleChangedFiles(t *testing.T) { // inside the import step and then asserts the VM mutex is acquirable // while the prepare is mid-flight. func TestPrepareVMWorkspace_ReleasesVMLockDuringGuestIO(t *testing.T) { - // Not parallel: mutates package-level workspaceInspectRepoFunc / - // workspaceImportFunc seams, which the other prepare-concurrency - // test also swaps. + t.Parallel() ctx := context.Background() apiSock := filepath.Join(t.TempDir(), "fc.sock") @@ -395,21 +393,15 @@ func TestPrepareVMWorkspace_ReleasesVMLockDuringGuestIO(t *testing.T) { upsertDaemonVM(t, ctx, d.store, vm) d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) - // Replace the seams. InspectRepo returns a trivial spec so the - // real filesystem isn't touched; Import blocks until we say go. - origInspect := workspaceInspectRepoFunc - origImport := workspaceImportFunc - t.Cleanup(func() { - workspaceInspectRepoFunc = origInspect - workspaceImportFunc = origImport - }) - + // Install the workspace seams on this daemon instance. InspectRepo + // returns a trivial spec so the real filesystem isn't touched; + // Import blocks until we say go. importStarted := make(chan struct{}) releaseImport := make(chan struct{}) - workspaceInspectRepoFunc = func(context.Context, string, string, string) (workspace.RepoSpec, error) { + d.workspaceInspectRepo = func(context.Context, string, string, string) (workspace.RepoSpec, error) { return workspace.RepoSpec{RepoName: "fake", RepoRoot: "/tmp/fake"}, nil } - workspaceImportFunc = func(context.Context, workspace.GuestClient, workspace.RepoSpec, string, model.WorkspacePrepareMode) error { + d.workspaceImport = func(context.Context, workspace.GuestClient, workspace.RepoSpec, string, model.WorkspacePrepareMode) error { close(importStarted) <-releaseImport return nil @@ -465,7 +457,7 @@ func TestPrepareVMWorkspace_ReleasesVMLockDuringGuestIO(t *testing.T) { // the workspaceLocks scope: two concurrent prepares on the same VM do // NOT interleave, even though they no longer take the core VM mutex. func TestPrepareVMWorkspace_SerialisesConcurrentPreparesOnSameVM(t *testing.T) { - // Not parallel: see note on ReleasesVMLockDuringGuestIO. + t.Parallel() ctx := context.Background() apiSock := filepath.Join(t.TempDir(), "fc.sock") @@ -488,14 +480,7 @@ func TestPrepareVMWorkspace_SerialisesConcurrentPreparesOnSameVM(t *testing.T) { upsertDaemonVM(t, ctx, d.store, vm) d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) - origInspect := workspaceInspectRepoFunc - origImport := workspaceImportFunc - t.Cleanup(func() { - workspaceInspectRepoFunc = origInspect - workspaceImportFunc = origImport - }) - - workspaceInspectRepoFunc = func(context.Context, string, string, string) (workspace.RepoSpec, error) { + d.workspaceInspectRepo = func(context.Context, string, string, string) (workspace.RepoSpec, error) { return workspace.RepoSpec{RepoName: "fake", RepoRoot: "/tmp/fake"}, nil } @@ -503,7 +488,7 @@ func TestPrepareVMWorkspace_SerialisesConcurrentPreparesOnSameVM(t *testing.T) { var active int32 var maxObserved int32 release := make(chan struct{}) - workspaceImportFunc = func(context.Context, workspace.GuestClient, workspace.RepoSpec, string, model.WorkspacePrepareMode) error { + d.workspaceImport = func(context.Context, workspace.GuestClient, workspace.RepoSpec, string, model.WorkspacePrepareMode) error { n := atomic.AddInt32(&active, 1) for { prev := atomic.LoadInt32(&maxObserved)