cli + daemon: move test seams off package globals onto injected structs

CLI: introduce internal/cli.deps which owns every RPC/SSH/host-command
seam the tree used to reach through mutable package vars. Command
builders, orchestrators, and the completion helpers become methods on
*deps. Tests construct their own deps per case, so fakes no longer leak
across cases and tests are free to run in parallel.

Daemon: move workspaceInspectRepoFunc + workspaceImportFunc onto the
Daemon struct (workspaceInspectRepo / workspaceImport), mirroring the
existing guestWaitForSSH / guestDial pattern. Workspace-prepare tests
drop t.Parallel() guards now that they no longer mutate process-wide
state.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Thales Maciel 2026-04-19 19:03:55 -03:00
parent d38f580e00
commit c42fcbe012
No known key found for this signature in database
GPG key ID: 33112E6833C34679
19 changed files with 664 additions and 733 deletions

View file

@ -1,125 +1,25 @@
package cli
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"banger/internal/api"
"banger/internal/buildinfo"
"banger/internal/daemon"
"banger/internal/guest"
"banger/internal/paths"
"banger/internal/rpc"
"banger/internal/toolingplan"
"github.com/spf13/cobra"
)
var (
bangerdPathFunc = paths.BangerdPath
daemonExePath = func(pid int) string {
return filepath.Join("/proc", fmt.Sprintf("%d", pid), "exe")
}
doctorFunc = daemon.Doctor
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
sshCmd := exec.CommandContext(ctx, "ssh", args...)
sshCmd.Stdout = stdout
sshCmd.Stderr = stderr
sshCmd.Stdin = stdin
return sshCmd.Run()
}
hostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) {
cmd := exec.CommandContext(ctx, name, args...)
output, err := cmd.CombinedOutput()
if err == nil {
return output, nil
}
command := strings.TrimSpace(strings.Join(append([]string{name}, args...), " "))
detail := strings.TrimSpace(string(output))
if detail == "" {
return output, fmt.Errorf("%s: %w", command, err)
}
return output, fmt.Errorf("%s: %w: %s", command, err, detail)
}
vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
return rpc.Call[api.VMHealthResult](ctx, socketPath, "vm.health", api.VMRefParams{IDOrName: idOrName})
}
vmSSHFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error) {
return rpc.Call[api.VMSSHResult](ctx, socketPath, "vm.ssh", api.VMRefParams{IDOrName: idOrName})
}
vmDeleteFunc = func(ctx context.Context, socketPath, idOrName string) error {
_, err := rpc.Call[api.VMShowResult](ctx, socketPath, "vm.delete", api.VMRefParams{IDOrName: idOrName})
return err
}
vmListFunc = func(ctx context.Context, socketPath string) (api.VMListResult, error) {
return rpc.Call[api.VMListResult](ctx, socketPath, "vm.list", api.Empty{})
}
daemonPingFunc = func(ctx context.Context, socketPath string) (api.PingResult, error) {
return rpc.Call[api.PingResult](ctx, socketPath, "ping", api.Empty{})
}
vmCreateBeginFunc = func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error) {
return rpc.Call[api.VMCreateBeginResult](ctx, socketPath, "vm.create.begin", params)
}
vmCreateStatusFunc = func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error) {
return rpc.Call[api.VMCreateStatusResult](ctx, socketPath, "vm.create.status", api.VMCreateStatusParams{ID: operationID})
}
vmCreateCancelFunc = func(ctx context.Context, socketPath, operationID string) error {
_, err := rpc.Call[api.Empty](ctx, socketPath, "vm.create.cancel", api.VMCreateStatusParams{ID: operationID})
return err
}
vmPortsFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) {
return rpc.Call[api.VMPortsResult](ctx, socketPath, "vm.ports", api.VMRefParams{IDOrName: idOrName})
}
vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
return rpc.Call[api.VMWorkspacePrepareResult](ctx, socketPath, "vm.workspace.prepare", params)
}
vmWorkspaceExportFunc = func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return rpc.Call[api.WorkspaceExportResult](ctx, socketPath, "vm.workspace.export", params)
}
guestSessionStartFunc = func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.start", params)
}
guestSessionGetFunc = func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.get", params)
}
guestSessionListFunc = func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error) {
return rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: idOrName})
}
guestSessionStopFunc = func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.stop", params)
}
guestSessionKillFunc = func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.kill", params)
}
guestSessionLogsFunc = func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) {
return rpc.Call[api.GuestSessionLogsResult](ctx, socketPath, "guest.session.logs", params)
}
guestSessionAttachBeginFunc = func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) {
return rpc.Call[api.GuestSessionAttachBeginResult](ctx, socketPath, "guest.session.attach.begin", params)
}
guestSessionSendFunc = func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params)
}
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
knownHosts, _ := bangerKnownHostsPath()
return guest.WaitForSSH(ctx, address, privateKeyPath, knownHosts, interval)
}
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
knownHosts, _ := bangerKnownHostsPath()
return guest.Dial(ctx, address, privateKeyPath, knownHosts)
}
buildVMRunToolingPlanFunc = toolingplan.Build
cwdFunc = os.Getwd
)
// NewBangerCommand builds the top-level cobra tree with production
// defaults wired into the dependency struct. Tests reach into the
// package directly — see newRootCommand + defaultDeps.
func NewBangerCommand() *cobra.Command {
return defaultDeps().newRootCommand()
}
func (d *deps) newRootCommand() *cobra.Command {
root := &cobra.Command{
Use: "banger",
Short: "Manage development VMs and images",
@ -127,17 +27,26 @@ func NewBangerCommand() *cobra.Command {
SilenceErrors: true,
RunE: helpNoArgs,
}
root.AddCommand(newDaemonCommand(), newDoctorCommand(), newImageCommand(), newInternalCommand(), newKernelCommand(), newVersionCommand(), newPSCommand(), newVMCommand())
root.AddCommand(
d.newDaemonCommand(),
d.newDoctorCommand(),
d.newImageCommand(),
d.newInternalCommand(),
d.newKernelCommand(),
newVersionCommand(),
d.newPSCommand(),
d.newVMCommand(),
)
return root
}
func newDoctorCommand() *cobra.Command {
func (d *deps) newDoctorCommand() *cobra.Command {
return &cobra.Command{
Use: "doctor",
Short: "Check host and runtime readiness",
Args: noArgsUsage("usage: banger doctor"),
RunE: func(cmd *cobra.Command, args []string) error {
report, err := doctorFunc(cmd.Context())
report, err := d.doctor(cmd.Context())
if err != nil {
return err
}

View file

@ -121,11 +121,8 @@ func TestLegacyRemovedCommandIsRejected(t *testing.T) {
}
func TestDoctorCommandPrintsReportAndFailsOnHardFailures(t *testing.T) {
original := doctorFunc
t.Cleanup(func() {
doctorFunc = original
})
doctorFunc = func(context.Context) (system.Report, error) {
d := defaultDeps()
d.doctor = func(context.Context) (system.Report, error) {
return system.Report{
Checks: []system.CheckResult{
{Name: "runtime bundle", Status: system.CheckStatusPass, Details: []string{"runtime dir /tmp/runtime"}},
@ -134,7 +131,7 @@ func TestDoctorCommandPrintsReportAndFailsOnHardFailures(t *testing.T) {
}, nil
}
cmd := NewBangerCommand()
cmd := d.newRootCommand()
var stdout bytes.Buffer
cmd.SetOut(&stdout)
cmd.SetErr(&stdout)
@ -154,15 +151,12 @@ func TestDoctorCommandPrintsReportAndFailsOnHardFailures(t *testing.T) {
}
func TestDoctorCommandReturnsUnderlyingError(t *testing.T) {
original := doctorFunc
t.Cleanup(func() {
doctorFunc = original
})
doctorFunc = func(context.Context) (system.Report, error) {
d := defaultDeps()
d.doctor = func(context.Context) (system.Report, error) {
return system.Report{}, errors.New("load failed")
}
cmd := NewBangerCommand()
cmd := d.newRootCommand()
cmd.SetArgs([]string{"doctor"})
err := cmd.Execute()
if err == nil || !strings.Contains(err.Error(), "load failed") {
@ -509,14 +503,7 @@ func TestVMCreateParamsFromFlagsRejectsNonPositiveCPUAndMemory(t *testing.T) {
}
func TestRunVMCreatePollsUntilDone(t *testing.T) {
origBegin := vmCreateBeginFunc
origStatus := vmCreateStatusFunc
origCancel := vmCreateCancelFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
vmCreateStatusFunc = origStatus
vmCreateCancelFunc = origCancel
})
d := defaultDeps()
vm := model.VMRecord{
ID: "vm-id",
@ -528,7 +515,7 @@ func TestRunVMCreatePollsUntilDone(t *testing.T) {
DNSName: "devbox.vm",
},
}
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{
Operation: api.VMCreateOperation{
ID: "op-1",
@ -538,7 +525,7 @@ func TestRunVMCreatePollsUntilDone(t *testing.T) {
}, nil
}
statusCalls := 0
vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
statusCalls++
if statusCalls == 1 {
return api.VMCreateStatusResult{
@ -560,14 +547,14 @@ func TestRunVMCreatePollsUntilDone(t *testing.T) {
},
}, nil
}
vmCreateCancelFunc = func(context.Context, string, string) error {
d.vmCreateCancel = func(context.Context, string, string) error {
t.Fatal("cancel should not be called")
return nil
}
got, err := runVMCreate(context.Background(), "/tmp/bangerd.sock", &bytes.Buffer{}, api.VMCreateParams{Name: "devbox"})
got, err := d.runVMCreate(context.Background(), "/tmp/bangerd.sock", &bytes.Buffer{}, api.VMCreateParams{Name: "devbox"})
if err != nil {
t.Fatalf("runVMCreate: %v", err)
t.Fatalf("d.runVMCreate: %v", err)
}
if got.Name != vm.Name || got.Runtime.GuestIP != vm.Runtime.GuestIP {
t.Fatalf("vm = %+v, want %+v", got, vm)
@ -878,23 +865,18 @@ func TestPrintVMPortsTableSortsAndRendersURLEndpoints(t *testing.T) {
}
func TestRunSSHSessionPrintsReminderWhenHealthCheckPasses(t *testing.T) {
origSSHExec := sshExecFunc
origHealth := vmHealthFunc
t.Cleanup(func() {
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
d := defaultDeps()
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
return nil
}
vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
d.vmHealth = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
return api.VMHealthResult{Name: "devbox", Healthy: true}, nil
}
var stderr bytes.Buffer
if err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false); err != nil {
t.Fatalf("runSSHSession: %v", err)
if err := d.runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false); err != nil {
t.Fatalf("d.runSSHSession: %v", err)
}
if !strings.Contains(stderr.String(), "devbox is still running") {
t.Fatalf("stderr = %q, want reminder", stderr.String())
@ -902,25 +884,20 @@ func TestRunSSHSessionPrintsReminderWhenHealthCheckPasses(t *testing.T) {
}
func TestRunSSHSessionPreservesSSHExitStatusOnHealthWarning(t *testing.T) {
origSSHExec := sshExecFunc
origHealth := vmHealthFunc
t.Cleanup(func() {
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
d := defaultDeps()
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
return exitErrorWithCode(t, 1)
}
vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
d.vmHealth = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
return api.VMHealthResult{}, errors.New("dial failed")
}
var stderr bytes.Buffer
err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false)
err := d.runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false)
var exitErr *exec.ExitError
if !errors.As(err, &exitErr) {
t.Fatalf("runSSHSession error = %v, want exit error", err)
t.Fatalf("d.runSSHSession error = %v, want exit error", err)
}
if !strings.Contains(stderr.String(), "failed to check whether devbox is still running") {
t.Fatalf("stderr = %q, want warning", stderr.String())
@ -928,27 +905,22 @@ func TestRunSSHSessionPreservesSSHExitStatusOnHealthWarning(t *testing.T) {
}
func TestRunSSHSessionSkipsReminderOnSSHAuthFailure(t *testing.T) {
origSSHExec := sshExecFunc
origHealth := vmHealthFunc
t.Cleanup(func() {
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
d := defaultDeps()
healthCalled := false
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
return exitErrorWithCode(t, 255)
}
vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
d.vmHealth = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
healthCalled = true
return api.VMHealthResult{Name: "devbox", Healthy: true}, nil
}
var stderr bytes.Buffer
err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false)
err := d.runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false)
var exitErr *exec.ExitError
if !errors.As(err, &exitErr) || exitErr.ExitCode() != 255 {
t.Fatalf("runSSHSession error = %v, want exit 255", err)
t.Fatalf("d.runSSHSession error = %v, want exit 255", err)
}
if healthCalled {
t.Fatal("vm health should not run after ssh auth failure")
@ -1141,6 +1113,7 @@ func TestValidateSSHPrereqsFailsForMissingKey(t *testing.T) {
// gets a fast error instead of an orphaned VM.
func TestVMRunPreflightRejectsSubmodules(t *testing.T) {
d := defaultDeps()
repoRoot := t.TempDir()
origHostCommandOutput := workspace.HostCommandOutputFunc
@ -1166,36 +1139,16 @@ func TestVMRunPreflightRejectsSubmodules(t *testing.T) {
}
}
_, err := vmRunPreflightRepo(context.Background(), repoRoot)
_, err := d.vmRunPreflightRepo(context.Background(), repoRoot)
if err == nil || !strings.Contains(err.Error(), "submodules") {
t.Fatalf("vmRunPreflightRepo() error = %v, want submodule rejection", err)
t.Fatalf("d.vmRunPreflightRepo() error = %v, want submodule rejection", err)
}
}
func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) {
d := defaultDeps()
repoRoot := t.TempDir()
origBegin := vmCreateBeginFunc
origStatus := vmCreateStatusFunc
origCancel := vmCreateCancelFunc
origWaitForSSH := guestWaitForSSHFunc
origGuestDial := guestDialFunc
origBuildVMRunToolingPlan := buildVMRunToolingPlanFunc
origVMWorkspacePrepare := vmWorkspacePrepareFunc
origSSHExec := sshExecFunc
origHealth := vmHealthFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
vmCreateStatusFunc = origStatus
vmCreateCancelFunc = origCancel
guestWaitForSSHFunc = origWaitForSSH
guestDialFunc = origGuestDial
buildVMRunToolingPlanFunc = origBuildVMRunToolingPlan
vmWorkspacePrepareFunc = origVMWorkspacePrepare
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
vm := model.VMRecord{
ID: "vm-id",
Name: "devbox",
@ -1205,7 +1158,7 @@ func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) {
DNSName: "devbox.vm",
},
}
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{
Operation: api.VMCreateOperation{
ID: "op-1", Stage: "ready", Detail: "vm is ready",
@ -1213,45 +1166,45 @@ func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) {
},
}, nil
}
vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
t.Fatal("vmCreateStatusFunc should not be called")
d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
t.Fatal("d.vmCreateStatus should not be called")
return api.VMCreateStatusResult{}, nil
}
vmCreateCancelFunc = func(context.Context, string, string) error {
t.Fatal("vmCreateCancelFunc should not be called")
d.vmCreateCancel = func(context.Context, string, string) error {
t.Fatal("d.vmCreateCancel should not be called")
return nil
}
fakeClient := &testVMRunGuestClient{}
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
d.guestWaitForSSH = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
return nil
}
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
d.guestDial = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
return fakeClient, nil
}
var workspaceParams api.VMWorkspacePrepareParams
vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
d.vmWorkspacePrepare = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
workspaceParams = params
return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil
}
buildVMRunToolingPlanFunc = func(context.Context, string) toolingplan.Plan {
d.buildVMRunToolingPlan = func(context.Context, string) toolingplan.Plan {
return toolingplan.Plan{
RepoManagedTools: []string{"go"},
Steps: []toolingplan.InstallStep{{Tool: "go", Version: "1.25.0", Source: "go.mod"}},
}
}
var sshArgsSeen []string
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
sshArgsSeen = args
return nil
}
vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) {
d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Name: "devbox", Healthy: false}, nil
}
repo := vmRunRepo{sourcePath: repoRoot}
var stdout, stderr bytes.Buffer
err := runVMRun(
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1263,7 +1216,7 @@ func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) {
false,
)
if err != nil {
t.Fatalf("runVMRun: %v", err)
t.Fatalf("d.runVMRun: %v", err)
}
if workspaceParams.IDOrName != "devbox" || workspaceParams.SourcePath != repoRoot {
t.Fatalf("workspaceParams = %+v", workspaceParams)
@ -1283,24 +1236,7 @@ func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) {
}
func TestVMRunPrintsPostCreateProgress(t *testing.T) {
origBegin := vmCreateBeginFunc
origStatus := vmCreateStatusFunc
origCancel := vmCreateCancelFunc
origWaitForSSH := guestWaitForSSHFunc
origGuestDial := guestDialFunc
origVMWorkspacePrepare := vmWorkspacePrepareFunc
origSSHExec := sshExecFunc
origHealth := vmHealthFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
vmCreateStatusFunc = origStatus
vmCreateCancelFunc = origCancel
guestWaitForSSHFunc = origWaitForSSH
guestDialFunc = origGuestDial
vmWorkspacePrepareFunc = origVMWorkspacePrepare
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
d := defaultDeps()
vm := model.VMRecord{
ID: "vm-id",
@ -1310,7 +1246,7 @@ func TestVMRunPrintsPostCreateProgress(t *testing.T) {
GuestIP: "172.16.0.2",
},
}
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{
Operation: api.VMCreateOperation{
ID: "op-1", Stage: "ready", Detail: "vm is ready",
@ -1318,33 +1254,33 @@ func TestVMRunPrintsPostCreateProgress(t *testing.T) {
},
}, nil
}
vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
t.Fatal("vmCreateStatusFunc should not be called")
d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
t.Fatal("d.vmCreateStatus should not be called")
return api.VMCreateStatusResult{}, nil
}
vmCreateCancelFunc = func(context.Context, string, string) error {
t.Fatal("vmCreateCancelFunc should not be called")
d.vmCreateCancel = func(context.Context, string, string) error {
t.Fatal("d.vmCreateCancel should not be called")
return nil
}
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
d.guestWaitForSSH = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
return nil
}
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
d.guestDial = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
return &testVMRunGuestClient{}, nil
}
vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
d.vmWorkspacePrepare = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil
}
sshExecFunc = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error {
d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error {
return nil
}
vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) {
d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Name: "devbox", Healthy: false}, nil
}
repo := vmRunRepo{sourcePath: t.TempDir()}
var stdout, stderr bytes.Buffer
err := runVMRun(
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1356,7 +1292,7 @@ func TestVMRunPrintsPostCreateProgress(t *testing.T) {
false,
)
if err != nil {
t.Fatalf("runVMRun: %v", err)
t.Fatalf("d.runVMRun: %v", err)
}
output := stderr.String()
@ -1377,24 +1313,7 @@ func TestVMRunPrintsPostCreateProgress(t *testing.T) {
}
func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) {
origBegin := vmCreateBeginFunc
origStatus := vmCreateStatusFunc
origCancel := vmCreateCancelFunc
origWaitForSSH := guestWaitForSSHFunc
origGuestDial := guestDialFunc
origVMWorkspacePrepare := vmWorkspacePrepareFunc
origSSHExec := sshExecFunc
origHealth := vmHealthFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
vmCreateStatusFunc = origStatus
vmCreateCancelFunc = origCancel
guestWaitForSSHFunc = origWaitForSSH
guestDialFunc = origGuestDial
vmWorkspacePrepareFunc = origVMWorkspacePrepare
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
d := defaultDeps()
vm := model.VMRecord{
ID: "vm-id",
@ -1404,39 +1323,39 @@ func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) {
GuestIP: "172.16.0.2",
},
}
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Detail: "vm is ready", Done: true, Success: true, VM: &vm}}, nil
}
vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
t.Fatal("vmCreateStatusFunc should not be called")
d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
t.Fatal("d.vmCreateStatus should not be called")
return api.VMCreateStatusResult{}, nil
}
vmCreateCancelFunc = func(context.Context, string, string) error {
t.Fatal("vmCreateCancelFunc should not be called")
d.vmCreateCancel = func(context.Context, string, string) error {
t.Fatal("d.vmCreateCancel should not be called")
return nil
}
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
d.guestWaitForSSH = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
return nil
}
fakeClient := &testVMRunGuestClient{launchErr: errors.New("launch failed")}
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
d.guestDial = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
return fakeClient, nil
}
vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
d.vmWorkspacePrepare = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil
}
sshExecCalls := 0
sshExecFunc = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error {
d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error {
sshExecCalls++
return nil
}
vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) {
d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Healthy: false}, nil
}
repo := vmRunRepo{sourcePath: t.TempDir()}
var stdout, stderr bytes.Buffer
err := runVMRun(
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1448,7 +1367,7 @@ func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) {
false,
)
if err != nil {
t.Fatalf("runVMRun: %v", err)
t.Fatalf("d.runVMRun: %v", err)
}
if !strings.Contains(stderr.String(), "[vm run] warning: guest tooling bootstrap start failed: launch guest tooling bootstrap") {
t.Fatalf("stderr = %q, want tooling bootstrap warning", stderr.String())
@ -1459,48 +1378,35 @@ func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) {
}
func TestRunVMRunBareModeSkipsWorkspaceAndTooling(t *testing.T) {
origBegin := vmCreateBeginFunc
origWaitForSSH := guestWaitForSSHFunc
origGuestDial := guestDialFunc
origVMWorkspacePrepare := vmWorkspacePrepareFunc
origSSHExec := sshExecFunc
origHealth := vmHealthFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
guestWaitForSSHFunc = origWaitForSSH
guestDialFunc = origGuestDial
vmWorkspacePrepareFunc = origVMWorkspacePrepare
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
d := defaultDeps()
vm := model.VMRecord{
ID: "vm-id", Name: "bare",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
}
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
}
guestWaitForSSHFunc = func(context.Context, string, string, time.Duration) error { return nil }
guestDialFunc = func(context.Context, string, string) (vmRunGuestClient, error) {
t.Fatal("guestDialFunc should not be called in bare mode")
d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
d.guestDial = func(context.Context, string, string) (vmRunGuestClient, error) {
t.Fatal("d.guestDial should not be called in bare mode")
return nil, nil
}
vmWorkspacePrepareFunc = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
t.Fatal("vmWorkspacePrepareFunc should not be called in bare mode")
d.vmWorkspacePrepare = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
t.Fatal("d.vmWorkspacePrepare should not be called in bare mode")
return api.VMWorkspacePrepareResult{}, nil
}
sshExecCalls := 0
sshExecFunc = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error {
d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error {
sshExecCalls++
return nil
}
vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) {
d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Healthy: false}, nil
}
var stdout, stderr bytes.Buffer
err := runVMRun(
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1512,7 +1418,7 @@ func TestRunVMRunBareModeSkipsWorkspaceAndTooling(t *testing.T) {
false,
)
if err != nil {
t.Fatalf("runVMRun: %v", err)
t.Fatalf("d.runVMRun: %v", err)
}
if sshExecCalls != 1 {
t.Fatalf("sshExec calls = %d, want 1", sshExecCalls)
@ -1523,39 +1429,28 @@ func TestRunVMRunBareModeSkipsWorkspaceAndTooling(t *testing.T) {
}
func TestRunVMRunRMDeletesAfterSessionExits(t *testing.T) {
origBegin := vmCreateBeginFunc
origWaitForSSH := guestWaitForSSHFunc
origSSHExec := sshExecFunc
origHealth := vmHealthFunc
origDelete := vmDeleteFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
guestWaitForSSHFunc = origWaitForSSH
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
vmDeleteFunc = origDelete
})
d := defaultDeps()
vm := model.VMRecord{
ID: "vm-id", Name: "tmpbox",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
}
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
}
guestWaitForSSHFunc = func(context.Context, string, string, time.Duration) error { return nil }
sshExecFunc = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { return nil }
vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) {
d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { return nil }
d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Healthy: false}, nil
}
deletedRef := ""
vmDeleteFunc = func(_ context.Context, _, idOrName string) error {
d.vmDelete = func(_ context.Context, _, idOrName string) error {
deletedRef = idOrName
return nil
}
var stdout, stderr bytes.Buffer
err := runVMRun(
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1567,7 +1462,7 @@ func TestRunVMRunRMDeletesAfterSessionExits(t *testing.T) {
true, // --rm
)
if err != nil {
t.Fatalf("runVMRun: %v", err)
t.Fatalf("d.runVMRun: %v", err)
}
if deletedRef != "tmpbox" {
t.Fatalf("deletedRef = %q, want tmpbox", deletedRef)
@ -1580,15 +1475,10 @@ func TestRunVMRunRMDeletesAfterSessionExits(t *testing.T) {
}
func TestRunVMRunRMSkipsDeleteOnSSHWaitTimeout(t *testing.T) {
origBegin := vmCreateBeginFunc
origWaitForSSH := guestWaitForSSHFunc
origDelete := vmDeleteFunc
d := defaultDeps()
origTimeout := vmRunSSHTimeout
vmRunSSHTimeout = 50 * time.Millisecond
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
guestWaitForSSHFunc = origWaitForSSH
vmDeleteFunc = origDelete
vmRunSSHTimeout = origTimeout
})
@ -1596,21 +1486,21 @@ func TestRunVMRunRMSkipsDeleteOnSSHWaitTimeout(t *testing.T) {
ID: "vm-id", Name: "slowvm",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
}
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
}
guestWaitForSSHFunc = func(ctx context.Context, _, _ string, _ time.Duration) error {
d.guestWaitForSSH = func(ctx context.Context, _, _ string, _ time.Duration) error {
<-ctx.Done()
return ctx.Err()
}
deleteCalled := false
vmDeleteFunc = func(context.Context, string, string) error {
d.vmDelete = func(context.Context, string, string) error {
deleteCalled = true
return nil
}
var stdout, stderr bytes.Buffer
err := runVMRun(
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1630,13 +1520,10 @@ func TestRunVMRunRMSkipsDeleteOnSSHWaitTimeout(t *testing.T) {
}
func TestRunVMRunSSHTimeoutReturnsActionableError(t *testing.T) {
origBegin := vmCreateBeginFunc
origWaitForSSH := guestWaitForSSHFunc
d := defaultDeps()
origTimeout := vmRunSSHTimeout
vmRunSSHTimeout = 50 * time.Millisecond
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
guestWaitForSSHFunc = origWaitForSSH
vmRunSSHTimeout = origTimeout
})
@ -1644,18 +1531,18 @@ func TestRunVMRunSSHTimeoutReturnsActionableError(t *testing.T) {
ID: "vm-id", Name: "slowvm",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
}
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
}
// Simulate the guest never bringing sshd up — the wait-for-ssh
// child context fires its deadline, returning a DeadlineExceeded.
guestWaitForSSHFunc = func(ctx context.Context, _, _ string, _ time.Duration) error {
d.guestWaitForSSH = func(ctx context.Context, _, _ string, _ time.Duration) error {
<-ctx.Done()
return ctx.Err()
}
var stdout, stderr bytes.Buffer
err := runVMRun(
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1683,37 +1570,28 @@ func TestRunVMRunSSHTimeoutReturnsActionableError(t *testing.T) {
}
func TestRunVMRunCommandModePropagatesExitCode(t *testing.T) {
origBegin := vmCreateBeginFunc
origWaitForSSH := guestWaitForSSHFunc
origVMWorkspacePrepare := vmWorkspacePrepareFunc
origSSHExec := sshExecFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
guestWaitForSSHFunc = origWaitForSSH
vmWorkspacePrepareFunc = origVMWorkspacePrepare
sshExecFunc = origSSHExec
})
d := defaultDeps()
vm := model.VMRecord{
ID: "vm-id", Name: "cmdbox",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
}
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
}
guestWaitForSSHFunc = func(context.Context, string, string, time.Duration) error { return nil }
vmWorkspacePrepareFunc = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
d.vmWorkspacePrepare = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
t.Fatal("workspace prepare should not run without spec")
return api.VMWorkspacePrepareResult{}, nil
}
var sshArgsSeen []string
sshExecFunc = func(_ context.Context, _ io.Reader, _, _ io.Writer, args []string) error {
d.sshExec = func(_ context.Context, _ io.Reader, _, _ io.Writer, args []string) error {
sshArgsSeen = args
return exitErrorWithCode(t, 7)
}
var stdout, stderr bytes.Buffer
err := runVMRun(
err := d.runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1726,7 +1604,7 @@ func TestRunVMRunCommandModePropagatesExitCode(t *testing.T) {
)
var exitErr ExitCodeError
if !errors.As(err, &exitErr) || exitErr.Code != 7 {
t.Fatalf("runVMRun error = %v, want ExitCodeError{7}", err)
t.Fatalf("d.runVMRun error = %v, want ExitCodeError{7}", err)
}
if len(sshArgsSeen) == 0 || sshArgsSeen[len(sshArgsSeen)-1] != "false" {
t.Fatalf("sshArgsSeen = %v, want trailing command 'false'", sshArgsSeen)
@ -1843,6 +1721,7 @@ func TestNewBangerdCommandRejectsArgs(t *testing.T) {
}
func TestDaemonOutdated(t *testing.T) {
d := defaultDeps()
dir := t.TempDir()
current := filepath.Join(dir, "bangerd-current")
same := filepath.Join(dir, "bangerd-same")
@ -1857,27 +1736,20 @@ func TestDaemonOutdated(t *testing.T) {
t.Fatalf("write stale: %v", err)
}
origBangerdPath := bangerdPathFunc
origDaemonExePath := daemonExePath
t.Cleanup(func() {
bangerdPathFunc = origBangerdPath
daemonExePath = origDaemonExePath
})
bangerdPathFunc = func() (string, error) {
d.bangerdPath = func() (string, error) {
return current, nil
}
daemonExePath = func(pid int) string {
d.daemonExePath = func(pid int) string {
if pid == 1 {
return same
}
return stale
}
if daemonOutdated(1) {
if d.daemonOutdated(1) {
t.Fatal("expected matching daemon executable to be current")
}
if !daemonOutdated(2) {
if !d.daemonOutdated(2) {
t.Fatal("expected replaced daemon executable to be outdated")
}
}
@ -1912,10 +1784,7 @@ func TestDaemonStatusIncludesLogPathWhenStopped(t *testing.T) {
}
func TestDaemonStatusIncludesDaemonBuildInfoWhenRunning(t *testing.T) {
origDaemonPing := daemonPingFunc
t.Cleanup(func() {
daemonPingFunc = origDaemonPing
})
d := defaultDeps()
configHome := filepath.Join(t.TempDir(), "config")
stateHome := filepath.Join(t.TempDir(), "state")
@ -1924,7 +1793,7 @@ func TestDaemonStatusIncludesDaemonBuildInfoWhenRunning(t *testing.T) {
t.Setenv("XDG_STATE_HOME", stateHome)
t.Setenv("XDG_RUNTIME_DIR", runtimeHome)
daemonPingFunc = func(context.Context, string) (api.PingResult, error) {
d.daemonPing = func(context.Context, string) (api.PingResult, error) {
return api.PingResult{
Status: "ok",
PID: 42,
@ -1934,7 +1803,7 @@ func TestDaemonStatusIncludesDaemonBuildInfoWhenRunning(t *testing.T) {
}, nil
}
cmd := NewBangerCommand()
cmd := d.newRootCommand()
var stdout bytes.Buffer
cmd.SetOut(&stdout)
cmd.SetErr(&stdout)
@ -2073,26 +1942,26 @@ func TestVMSessionSendRejectsWrongArgCount(t *testing.T) {
}
}
func stubEnsureDaemonForSend(t *testing.T) {
// stubEnsureDaemonForSend isolates XDG dirs and installs a daemon-ping
// fake onto the caller's *deps so `ensureDaemon` short-circuits without
// trying to spawn bangerd. `vm session send` uses this to avoid needing
// a built binary on disk.
func stubEnsureDaemonForSend(t *testing.T, d *deps) {
t.Helper()
t.Setenv("XDG_CONFIG_HOME", filepath.Join(t.TempDir(), "config"))
t.Setenv("XDG_STATE_HOME", filepath.Join(t.TempDir(), "state"))
t.Setenv("XDG_RUNTIME_DIR", filepath.Join(t.TempDir(), "run"))
origPing := daemonPingFunc
t.Cleanup(func() { daemonPingFunc = origPing })
daemonPingFunc = func(context.Context, string) (api.PingResult, error) {
d.daemonPing = func(context.Context, string) (api.PingResult, error) {
return api.PingResult{Status: "ok", PID: os.Getpid()}, nil
}
}
func TestVMSessionSendWithMessageFlag(t *testing.T) {
stubEnsureDaemonForSend(t)
original := guestSessionSendFunc
t.Cleanup(func() { guestSessionSendFunc = original })
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
var capturedParams api.GuestSessionSendParams
guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
capturedParams = params
return api.GuestSessionSendResult{
Session: model.GuestSession{ID: "sess-id", Name: "planner"},
@ -2100,7 +1969,7 @@ func TestVMSessionSendWithMessageFlag(t *testing.T) {
}, nil
}
cmd := NewBangerCommand()
cmd := d.newRootCommand()
var out bytes.Buffer
cmd.SetOut(&out)
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner", "--message", `{"type":"abort"}`})
@ -2124,13 +1993,11 @@ func TestVMSessionSendWithMessageFlag(t *testing.T) {
}
func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) {
stubEnsureDaemonForSend(t)
original := guestSessionSendFunc
t.Cleanup(func() { guestSessionSendFunc = original })
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
var capturedPayload []byte
guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
capturedPayload = params.Payload
return api.GuestSessionSendResult{
Session: model.GuestSession{Name: "s"},
@ -2138,7 +2005,7 @@ func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) {
}, nil
}
cmd := NewBangerCommand()
cmd := d.newRootCommand()
cmd.SetOut(io.Discard)
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "s", "--message", "{\"type\":\"abort\"}\n"})
if err := cmd.Execute(); err != nil {
@ -2155,13 +2022,11 @@ func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) {
}
func TestVMSessionSendFromStdin(t *testing.T) {
stubEnsureDaemonForSend(t)
original := guestSessionSendFunc
t.Cleanup(func() { guestSessionSendFunc = original })
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
var capturedPayload []byte
guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
capturedPayload = params.Payload
return api.GuestSessionSendResult{
Session: model.GuestSession{Name: "planner"},
@ -2170,7 +2035,7 @@ func TestVMSessionSendFromStdin(t *testing.T) {
}
stdinPayload := `{"type":"steer","message":"Focus on src/"}` + "\n"
cmd := NewBangerCommand()
cmd := d.newRootCommand()
cmd.SetOut(io.Discard)
cmd.SetIn(strings.NewReader(stdinPayload))
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner"})
@ -2208,13 +2073,11 @@ func TestVMWorkspaceExportRejectsMissingArg(t *testing.T) {
}
func TestVMWorkspaceExportWritesToStdout(t *testing.T) {
stubEnsureDaemonForSend(t)
origExport := vmWorkspaceExportFunc
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
patch := []byte("diff --git a/main.go b/main.go\nindex 0000000..1111111 100644\n")
vmWorkspaceExportFunc = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
d.vmWorkspaceExport = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return api.WorkspaceExportResult{
GuestPath: params.GuestPath,
Patch: patch,
@ -2223,7 +2086,7 @@ func TestVMWorkspaceExportWritesToStdout(t *testing.T) {
}, nil
}
cmd := NewBangerCommand()
cmd := d.newRootCommand()
var out bytes.Buffer
cmd.SetOut(&out)
cmd.SetErr(io.Discard)
@ -2237,13 +2100,11 @@ func TestVMWorkspaceExportWritesToStdout(t *testing.T) {
}
func TestVMWorkspaceExportWritesToFile(t *testing.T) {
stubEnsureDaemonForSend(t)
origExport := vmWorkspaceExportFunc
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
patch := []byte("diff --git a/main.go b/main.go\n")
vmWorkspaceExportFunc = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
d.vmWorkspaceExport = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return api.WorkspaceExportResult{
GuestPath: "/root/repo",
Patch: patch,
@ -2253,7 +2114,7 @@ func TestVMWorkspaceExportWritesToFile(t *testing.T) {
}
outFile := filepath.Join(t.TempDir(), "worker.diff")
cmd := NewBangerCommand()
cmd := d.newRootCommand()
cmd.SetOut(io.Discard)
var stderr bytes.Buffer
cmd.SetErr(&stderr)
@ -2275,19 +2136,17 @@ func TestVMWorkspaceExportWritesToFile(t *testing.T) {
}
func TestVMWorkspaceExportNoChanges(t *testing.T) {
stubEnsureDaemonForSend(t)
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
origExport := vmWorkspaceExportFunc
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
vmWorkspaceExportFunc = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
d.vmWorkspaceExport = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return api.WorkspaceExportResult{
GuestPath: "/root/repo",
HasChanges: false,
}, nil
}
cmd := NewBangerCommand()
cmd := d.newRootCommand()
var out bytes.Buffer
var stderr bytes.Buffer
cmd.SetOut(&out)
@ -2305,18 +2164,16 @@ func TestVMWorkspaceExportNoChanges(t *testing.T) {
}
func TestVMWorkspaceExportGuestPathFlag(t *testing.T) {
stubEnsureDaemonForSend(t)
origExport := vmWorkspaceExportFunc
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
var capturedParams api.WorkspaceExportParams
vmWorkspaceExportFunc = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
d.vmWorkspaceExport = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
capturedParams = params
return api.WorkspaceExportResult{HasChanges: false}, nil
}
cmd := NewBangerCommand()
cmd := d.newRootCommand()
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--guest-path", "/root/project"})
@ -2332,13 +2189,11 @@ func TestVMWorkspaceExportGuestPathFlag(t *testing.T) {
}
func TestVMWorkspaceExportBaseCommitFlag(t *testing.T) {
stubEnsureDaemonForSend(t)
origExport := vmWorkspaceExportFunc
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
var capturedParams api.WorkspaceExportParams
vmWorkspaceExportFunc = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
d.vmWorkspaceExport = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
capturedParams = params
return api.WorkspaceExportResult{
HasChanges: false,
@ -2346,7 +2201,7 @@ func TestVMWorkspaceExportBaseCommitFlag(t *testing.T) {
}, nil
}
cmd := NewBangerCommand()
cmd := d.newRootCommand()
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--base-commit", "abc1234deadbeef"})

View file

@ -15,7 +15,7 @@ import (
"github.com/spf13/cobra"
)
func newDaemonCommand() *cobra.Command {
func (d *deps) newDaemonCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "daemon",
Short: "Manage the banger daemon",
@ -31,7 +31,7 @@ func newDaemonCommand() *cobra.Command {
if err != nil {
return err
}
ping, pingErr := daemonPingFunc(cmd.Context(), layout.SocketPath)
ping, pingErr := d.daemonPing(cmd.Context(), layout.SocketPath)
if pingErr != nil {
_, err = fmt.Fprintf(cmd.OutOrStdout(), "stopped\nsocket: %s\nlog: %s\ndns: %s\n", layout.SocketPath, layout.DaemonLog, vmdns.DefaultListenAddr)
return err

View file

@ -13,24 +13,24 @@ import (
"github.com/spf13/cobra"
)
func newImageCommand() *cobra.Command {
func (d *deps) newImageCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "image",
Short: "Manage images",
RunE: helpNoArgs,
}
cmd.AddCommand(
newImageRegisterCommand(),
newImagePullCommand(),
newImagePromoteCommand(),
newImageListCommand(),
newImageShowCommand(),
newImageDeleteCommand(),
d.newImageRegisterCommand(),
d.newImagePullCommand(),
d.newImagePromoteCommand(),
d.newImageListCommand(),
d.newImageShowCommand(),
d.newImageDeleteCommand(),
)
return cmd
}
func newImageRegisterCommand() *cobra.Command {
func (d *deps) newImageRegisterCommand() *cobra.Command {
var params api.ImageRegisterParams
cmd := &cobra.Command{
Use: "register",
@ -46,7 +46,7 @@ func newImageRegisterCommand() *cobra.Command {
if err := system.EnsureSudo(cmd.Context()); err != nil {
return err
}
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -65,11 +65,11 @@ func newImageRegisterCommand() *cobra.Command {
cmd.Flags().StringVar(&params.ModulesDir, "modules", "", "modules dir")
cmd.Flags().StringVar(&params.KernelRef, "kernel-ref", "", "name of a cataloged kernel (see 'banger kernel list')")
cmd.Flags().BoolVar(&params.Docker, "docker", false, "mark image as docker-prepared")
_ = cmd.RegisterFlagCompletionFunc("kernel-ref", completeKernelNames)
_ = cmd.RegisterFlagCompletionFunc("kernel-ref", d.completeKernelNames)
return cmd
}
func newImagePullCommand() *cobra.Command {
func (d *deps) newImagePullCommand() *cobra.Command {
var (
params api.ImagePullParams
sizeRaw string
@ -117,7 +117,7 @@ subcommand lands).
if err := system.EnsureSudo(cmd.Context()); err != nil {
return err
}
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -139,21 +139,21 @@ subcommand lands).
cmd.Flags().StringVar(&params.ModulesDir, "modules", "", "modules dir")
cmd.Flags().StringVar(&params.KernelRef, "kernel-ref", "", "name of a cataloged kernel (see 'banger kernel list')")
cmd.Flags().StringVar(&sizeRaw, "size", "", "ext4 image size (e.g. 4GiB); defaults to content + 25%, min 1GiB")
_ = cmd.RegisterFlagCompletionFunc("kernel-ref", completeKernelNames)
_ = cmd.RegisterFlagCompletionFunc("kernel-ref", d.completeKernelNames)
return cmd
}
func newImagePromoteCommand() *cobra.Command {
func (d *deps) newImagePromoteCommand() *cobra.Command {
return &cobra.Command{
Use: "promote <id-or-name>",
Short: "Promote an unmanaged image to a managed artifact",
Args: exactArgsUsage(1, "usage: banger image promote <id-or-name>"),
ValidArgsFunction: completeImageNameOnlyAtPos0,
ValidArgsFunction: d.completeImageNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error {
if err := system.EnsureSudo(cmd.Context()); err != nil {
return err
}
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -166,14 +166,14 @@ func newImagePromoteCommand() *cobra.Command {
}
}
func newImageListCommand() *cobra.Command {
func (d *deps) newImageListCommand() *cobra.Command {
return &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List images",
Args: noArgsUsage("usage: banger image list"),
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -186,14 +186,14 @@ func newImageListCommand() *cobra.Command {
}
}
func newImageShowCommand() *cobra.Command {
func (d *deps) newImageShowCommand() *cobra.Command {
return &cobra.Command{
Use: "show <id-or-name>",
Short: "Show image details",
Args: exactArgsUsage(1, "usage: banger image show <id-or-name>"),
ValidArgsFunction: completeImageNameOnlyAtPos0,
ValidArgsFunction: d.completeImageNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -206,18 +206,18 @@ func newImageShowCommand() *cobra.Command {
}
}
func newImageDeleteCommand() *cobra.Command {
func (d *deps) newImageDeleteCommand() *cobra.Command {
return &cobra.Command{
Use: "delete <id-or-name>",
Aliases: []string{"rm"},
Short: "Delete an image",
Args: exactArgsUsage(1, "usage: banger image delete <id-or-name>"),
ValidArgsFunction: completeImageNameOnlyAtPos0,
ValidArgsFunction: d.completeImageNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error {
if err := system.EnsureSudo(cmd.Context()); err != nil {
return err
}
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}

View file

@ -25,7 +25,7 @@ import (
"github.com/spf13/cobra"
)
func newInternalCommand() *cobra.Command {
func (d *deps) newInternalCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "internal",
Hidden: true,

View file

@ -12,30 +12,30 @@ import (
"github.com/spf13/cobra"
)
func newKernelCommand() *cobra.Command {
func (d *deps) newKernelCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "kernel",
Short: "Manage the local kernel catalog",
RunE: helpNoArgs,
}
cmd.AddCommand(
newKernelListCommand(),
newKernelShowCommand(),
newKernelRmCommand(),
newKernelImportCommand(),
newKernelPullCommand(),
d.newKernelListCommand(),
d.newKernelShowCommand(),
d.newKernelRmCommand(),
d.newKernelImportCommand(),
d.newKernelPullCommand(),
)
return cmd
}
func newKernelPullCommand() *cobra.Command {
func (d *deps) newKernelPullCommand() *cobra.Command {
var force bool
cmd := &cobra.Command{
Use: "pull <name>",
Short: "Download a cataloged kernel bundle",
Args: exactArgsUsage(1, "usage: banger kernel pull <name> [--force]"),
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -55,7 +55,7 @@ func newKernelPullCommand() *cobra.Command {
return cmd
}
func newKernelImportCommand() *cobra.Command {
func (d *deps) newKernelImportCommand() *cobra.Command {
var params api.KernelImportParams
cmd := &cobra.Command{
Use: "import <name>",
@ -72,7 +72,7 @@ func newKernelImportCommand() *cobra.Command {
return err
}
params.FromDir = abs
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -89,7 +89,7 @@ func newKernelImportCommand() *cobra.Command {
return cmd
}
func newKernelListCommand() *cobra.Command {
func (d *deps) newKernelListCommand() *cobra.Command {
var available bool
cmd := &cobra.Command{
Use: "list",
@ -97,7 +97,7 @@ func newKernelListCommand() *cobra.Command {
Short: "List kernels (local by default, or --available for the catalog)",
Args: noArgsUsage("usage: banger kernel list [--available]"),
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -119,14 +119,14 @@ func newKernelListCommand() *cobra.Command {
return cmd
}
func newKernelShowCommand() *cobra.Command {
func (d *deps) newKernelShowCommand() *cobra.Command {
return &cobra.Command{
Use: "show <name>",
Short: "Show kernel catalog entry details",
Args: exactArgsUsage(1, "usage: banger kernel show <name>"),
ValidArgsFunction: completeKernelNameOnlyAtPos0,
ValidArgsFunction: d.completeKernelNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -139,15 +139,15 @@ func newKernelShowCommand() *cobra.Command {
}
}
func newKernelRmCommand() *cobra.Command {
func (d *deps) newKernelRmCommand() *cobra.Command {
return &cobra.Command{
Use: "rm <name>",
Aliases: []string{"remove", "delete"},
Short: "Remove a kernel catalog entry",
Args: exactArgsUsage(1, "usage: banger kernel rm <name>"),
ValidArgsFunction: completeKernelNameOnlyAtPos0,
ValidArgsFunction: d.completeKernelNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}

View file

@ -22,35 +22,35 @@ import (
"github.com/spf13/cobra"
)
func newVMCommand() *cobra.Command {
func (d *deps) newVMCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "vm",
Short: "Manage virtual machines",
RunE: helpNoArgs,
}
cmd.AddCommand(
newVMCreateCommand(),
newVMRunCommand(),
newVMListCommand(),
newVMShowCommand(),
newVMActionCommand("start", "Start a VM", "vm.start"),
newVMActionCommand("stop", "Stop a VM", "vm.stop"),
newVMKillCommand(),
newVMActionCommand("restart", "Restart a VM", "vm.restart"),
newVMActionCommand("delete", "Delete a VM", "vm.delete", "rm"),
newVMPruneCommand(),
newVMSetCommand(),
newVMSSHCommand(),
newVMWorkspaceCommand(),
newVMSessionCommand(),
newVMLogsCommand(),
newVMStatsCommand(),
newVMPortsCommand(),
d.newVMCreateCommand(),
d.newVMRunCommand(),
d.newVMListCommand(),
d.newVMShowCommand(),
d.newVMActionCommand("start", "Start a VM", "vm.start"),
d.newVMActionCommand("stop", "Stop a VM", "vm.stop"),
d.newVMKillCommand(),
d.newVMActionCommand("restart", "Restart a VM", "vm.restart"),
d.newVMActionCommand("delete", "Delete a VM", "vm.delete", "rm"),
d.newVMPruneCommand(),
d.newVMSetCommand(),
d.newVMSSHCommand(),
d.newVMWorkspaceCommand(),
d.newVMSessionCommand(),
d.newVMLogsCommand(),
d.newVMStatsCommand(),
d.newVMPortsCommand(),
)
return cmd
}
func newVMRunCommand() *cobra.Command {
func (d *deps) newVMRunCommand() *cobra.Command {
defaults := effectiveVMDefaults()
var (
name string
@ -104,7 +104,7 @@ Three modes:
var repoPtr *vmRunRepo
if sourcePath != "" {
resolved, err := vmRunPreflightRepo(cmd.Context(), sourcePath)
resolved, err := d.vmRunPreflightRepo(cmd.Context(), sourcePath)
if err != nil {
return err
}
@ -135,11 +135,11 @@ Three modes:
if err := system.EnsureSudo(cmd.Context()); err != nil {
return err
}
layout, cfg, err = ensureDaemon(cmd.Context())
layout, cfg, err = d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
return runVMRun(cmd.Context(), layout.SocketPath, cfg, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), params, repoPtr, commandArgs, removeOnExit)
return d.runVMRun(cmd.Context(), layout.SocketPath, cfg, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), params, repoPtr, commandArgs, removeOnExit)
},
}
cmd.Flags().StringVar(&name, "name", "", "vm name")
@ -152,22 +152,22 @@ Three modes:
cmd.Flags().StringVar(&branchName, "branch", "", "create and switch to a new guest branch")
cmd.Flags().StringVar(&fromRef, "from", "HEAD", "base ref for --branch")
cmd.Flags().BoolVar(&removeOnExit, "rm", false, "delete the VM after the ssh session / command exits")
_ = cmd.RegisterFlagCompletionFunc("image", completeImageNames)
_ = cmd.RegisterFlagCompletionFunc("image", d.completeImageNames)
return cmd
}
func newVMKillCommand() *cobra.Command {
func (d *deps) newVMKillCommand() *cobra.Command {
var signal string
cmd := &cobra.Command{
Use: "kill <id-or-name>...",
Short: "Send a signal to a VM process",
Args: minArgsUsage(1, "usage: banger vm kill [--signal SIGTERM|SIGKILL|...] <id-or-name>..."),
ValidArgsFunction: completeVMNames,
ValidArgsFunction: d.completeVMNames,
RunE: func(cmd *cobra.Command, args []string) error {
if err := system.EnsureSudo(cmd.Context()); err != nil {
return err
}
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -201,7 +201,7 @@ func newVMKillCommand() *cobra.Command {
return cmd
}
func newVMPruneCommand() *cobra.Command {
func (d *deps) newVMPruneCommand() *cobra.Command {
var force bool
cmd := &cobra.Command{
Use: "prune",
@ -212,23 +212,23 @@ func newVMPruneCommand() *cobra.Command {
if err := system.EnsureSudo(cmd.Context()); err != nil {
return err
}
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
return runVMPrune(cmd, layout.SocketPath, force)
return d.runVMPrune(cmd, layout.SocketPath, force)
},
}
cmd.Flags().BoolVarP(&force, "force", "f", false, "skip the confirmation prompt")
return cmd
}
func runVMPrune(cmd *cobra.Command, socketPath string, force bool) error {
func (d *deps) runVMPrune(cmd *cobra.Command, socketPath string, force bool) error {
ctx := cmd.Context()
stdout := cmd.OutOrStdout()
stderr := cmd.ErrOrStderr()
list, err := vmListFunc(ctx, socketPath)
list, err := d.vmList(ctx, socketPath)
if err != nil {
return err
}
@ -270,7 +270,7 @@ func runVMPrune(cmd *cobra.Command, socketPath string, force bool) error {
if ref == "" {
ref = shortID(vm.ID)
}
if err := vmDeleteFunc(ctx, socketPath, vm.ID); err != nil {
if err := d.vmDelete(ctx, socketPath, vm.ID); err != nil {
fmt.Fprintf(stderr, "delete %s: %v\n", ref, err)
failed++
continue
@ -299,7 +299,7 @@ func promptYesNo(in io.Reader, out io.Writer, prompt string) (bool, error) {
return answer == "y" || answer == "yes", nil
}
func newVMCreateCommand() *cobra.Command {
func (d *deps) newVMCreateCommand() *cobra.Command {
defaults := effectiveVMDefaults()
var (
name string
@ -323,11 +323,11 @@ func newVMCreateCommand() *cobra.Command {
if err := system.EnsureSudo(cmd.Context()); err != nil {
return err
}
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
vm, err := runVMCreate(cmd.Context(), layout.SocketPath, cmd.ErrOrStderr(), params)
vm, err := d.runVMCreate(cmd.Context(), layout.SocketPath, cmd.ErrOrStderr(), params)
if err != nil {
return err
}
@ -342,7 +342,7 @@ func newVMCreateCommand() *cobra.Command {
cmd.Flags().StringVar(&workDiskSize, "disk-size", model.FormatSizeBytes(defaults.WorkDiskSizeBytes), "work disk size")
cmd.Flags().BoolVar(&natEnabled, "nat", false, "enable NAT")
cmd.Flags().BoolVar(&noStart, "no-start", false, "create without starting")
_ = cmd.RegisterFlagCompletionFunc("image", completeImageNames)
_ = cmd.RegisterFlagCompletionFunc("image", d.completeImageNames)
return cmd
}
@ -352,15 +352,15 @@ type vmListOptions struct {
quiet bool
}
func newPSCommand() *cobra.Command {
return newVMListLikeCommand("ps", nil, "usage: banger ps")
func (d *deps) newPSCommand() *cobra.Command {
return d.newVMListLikeCommand("ps", nil, "usage: banger ps")
}
func newVMListCommand() *cobra.Command {
return newVMListLikeCommand("list", []string{"ls", "ps"}, "usage: banger vm list")
func (d *deps) newVMListCommand() *cobra.Command {
return d.newVMListLikeCommand("list", []string{"ls", "ps"}, "usage: banger vm list")
}
func newVMListLikeCommand(use string, aliases []string, usage string) *cobra.Command {
func (d *deps) newVMListLikeCommand(use string, aliases []string, usage string) *cobra.Command {
var opts vmListOptions
cmd := &cobra.Command{
Use: use,
@ -368,7 +368,7 @@ func newVMListLikeCommand(use string, aliases []string, usage string) *cobra.Com
Short: "List VMs",
Args: noArgsUsage(usage),
RunE: func(cmd *cobra.Command, args []string) error {
return runVMList(cmd, opts)
return d.runVMList(cmd, opts)
},
}
cmd.Flags().BoolVarP(&opts.showAll, "all", "a", false, "show all VMs")
@ -377,8 +377,8 @@ func newVMListLikeCommand(use string, aliases []string, usage string) *cobra.Com
return cmd
}
func runVMList(cmd *cobra.Command, opts vmListOptions) error {
layout, _, err := ensureDaemon(cmd.Context())
func (d *deps) runVMList(cmd *cobra.Command, opts vmListOptions) error {
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -421,14 +421,14 @@ func selectVMListVMs(vms []model.VMRecord, showAll, latest bool) []model.VMRecor
return []model.VMRecord{latestVM}
}
func newVMShowCommand() *cobra.Command {
func (d *deps) newVMShowCommand() *cobra.Command {
return &cobra.Command{
Use: "show <id-or-name>",
Short: "Show VM details",
Args: exactArgsUsage(1, "usage: banger vm show <id-or-name>"),
ValidArgsFunction: completeVMNameOnlyAtPos0,
ValidArgsFunction: d.completeVMNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -441,18 +441,18 @@ func newVMShowCommand() *cobra.Command {
}
}
func newVMActionCommand(use, short, method string, aliases ...string) *cobra.Command {
func (d *deps) newVMActionCommand(use, short, method string, aliases ...string) *cobra.Command {
return &cobra.Command{
Use: use + " <id-or-name>...",
Aliases: aliases,
Short: short,
Args: minArgsUsage(1, fmt.Sprintf("usage: banger vm %s <id-or-name>...", use)),
ValidArgsFunction: completeVMNames,
ValidArgsFunction: d.completeVMNames,
RunE: func(cmd *cobra.Command, args []string) error {
if err := system.EnsureSudo(cmd.Context()); err != nil {
return err
}
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -474,7 +474,7 @@ func newVMActionCommand(use, short, method string, aliases ...string) *cobra.Com
}
}
func newVMSetCommand() *cobra.Command {
func (d *deps) newVMSetCommand() *cobra.Command {
var (
vcpu int
memory int
@ -486,7 +486,7 @@ func newVMSetCommand() *cobra.Command {
Use: "set <id-or-name>...",
Short: "Update stopped VM settings",
Args: minArgsUsage(1, "usage: banger vm set [--vcpu N] [--memory MiB] [--disk-size SIZE] [--nat|--no-nat] <id-or-name>..."),
ValidArgsFunction: completeVMNames,
ValidArgsFunction: d.completeVMNames,
RunE: func(cmd *cobra.Command, args []string) error {
params, err := vmSetParamsFromFlags(args[0], vcpu, memory, diskSize, nat, noNat)
if err != nil {
@ -495,7 +495,7 @@ func newVMSetCommand() *cobra.Command {
if err := system.EnsureSudo(cmd.Context()); err != nil {
return err
}
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -525,21 +525,21 @@ func newVMSetCommand() *cobra.Command {
return cmd
}
func newVMSSHCommand() *cobra.Command {
func (d *deps) newVMSSHCommand() *cobra.Command {
return &cobra.Command{
Use: "ssh <id-or-name> [ssh args...]",
Short: "SSH into a running VM",
Args: minArgsUsage(1, "usage: banger vm ssh <id-or-name> [ssh args...]"),
ValidArgsFunction: completeVMNameOnlyAtPos0,
ValidArgsFunction: d.completeVMNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error {
layout, cfg, err := ensureDaemon(cmd.Context())
layout, cfg, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
if err := validateSSHPrereqs(cfg); err != nil {
return err
}
result, err := vmSSHFunc(cmd.Context(), layout.SocketPath, args[0])
result, err := d.vmSSH(cmd.Context(), layout.SocketPath, args[0])
if err != nil {
return err
}
@ -547,25 +547,25 @@ func newVMSSHCommand() *cobra.Command {
if err != nil {
return err
}
return runSSHSession(cmd.Context(), layout.SocketPath, result.Name, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), sshArgs, false)
return d.runSSHSession(cmd.Context(), layout.SocketPath, result.Name, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), sshArgs, false)
},
}
}
func newVMWorkspaceCommand() *cobra.Command {
func (d *deps) newVMWorkspaceCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "workspace",
Short: "Manage repository workspaces inside a running VM",
RunE: helpNoArgs,
}
cmd.AddCommand(
newVMWorkspacePrepareCommand(),
newVMWorkspaceExportCommand(),
d.newVMWorkspacePrepareCommand(),
d.newVMWorkspaceExportCommand(),
)
return cmd
}
func newVMWorkspacePrepareCommand() *cobra.Command {
func (d *deps) newVMWorkspacePrepareCommand() *cobra.Command {
var guestPath string
var branchName string
var fromRef string
@ -576,14 +576,14 @@ func newVMWorkspacePrepareCommand() *cobra.Command {
Short: "Copy a local repo into a running VM",
Long: "Prepare a repository workspace from a local git checkout into a running VM. The default guest path is /root/repo and the default mode is shallow_overlay. Repositories with git submodules must use --mode full_copy.",
Args: minArgsUsage(1, "usage: banger vm workspace prepare <id-or-name> [path]"),
ValidArgsFunction: completeVMNameOnlyAtPos0,
ValidArgsFunction: d.completeVMNameOnlyAtPos0,
Example: strings.TrimSpace(`
banger vm workspace prepare devbox
banger vm workspace prepare devbox ../repo --guest-path /root/repo --readonly
banger vm workspace prepare devbox ../repo --mode full_copy
`),
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -592,7 +592,7 @@ func newVMWorkspacePrepareCommand() *cobra.Command {
sourcePath = args[1]
}
if strings.TrimSpace(sourcePath) == "" {
wd, err := cwdFunc()
wd, err := d.cwd()
if err != nil {
return err
}
@ -606,7 +606,7 @@ func newVMWorkspacePrepareCommand() *cobra.Command {
if strings.TrimSpace(branchName) != "" {
prepareFrom = fromRef
}
result, err := vmWorkspacePrepareFunc(cmd.Context(), layout.SocketPath, api.VMWorkspacePrepareParams{
result, err := d.vmWorkspacePrepare(cmd.Context(), layout.SocketPath, api.VMWorkspacePrepareParams{
IDOrName: args[0],
SourcePath: resolvedPath,
GuestPath: guestPath,
@ -629,7 +629,7 @@ func newVMWorkspacePrepareCommand() *cobra.Command {
return cmd
}
func newVMWorkspaceExportCommand() *cobra.Command {
func (d *deps) newVMWorkspaceExportCommand() *cobra.Command {
var guestPath string
var outputPath string
var baseCommit string
@ -638,7 +638,7 @@ func newVMWorkspaceExportCommand() *cobra.Command {
Short: "Pull changes from a guest workspace back to the host as a patch",
Long: "Emit a binary-safe unified diff of every change inside the guest workspace (committed since base + uncommitted + untracked, minus .gitignore). Non-mutating — the guest's index and working tree are untouched. Pass --base-commit with the head_commit from workspace prepare to capture changes even when the worker ran git commit inside the VM. Without --base-commit the diff is against the current guest HEAD, which misses committed changes.",
Args: exactArgsUsage(1, "usage: banger vm workspace export <id-or-name>"),
ValidArgsFunction: completeVMNameOnlyAtPos0,
ValidArgsFunction: d.completeVMNameOnlyAtPos0,
Example: strings.TrimSpace(`
banger vm workspace export devbox | git apply
banger vm workspace export devbox --base-commit abc1234 | git apply
@ -646,11 +646,11 @@ func newVMWorkspaceExportCommand() *cobra.Command {
banger vm workspace export devbox --guest-path /root/project --output changes.diff
`),
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := vmWorkspaceExportFunc(cmd.Context(), layout.SocketPath, api.WorkspaceExportParams{
result, err := d.vmWorkspaceExport(cmd.Context(), layout.SocketPath, api.WorkspaceExportParams{
IDOrName: args[0],
GuestPath: guestPath,
BaseCommit: baseCommit,
@ -680,15 +680,15 @@ func newVMWorkspaceExportCommand() *cobra.Command {
return cmd
}
func newVMLogsCommand() *cobra.Command {
func (d *deps) newVMLogsCommand() *cobra.Command {
var follow bool
cmd := &cobra.Command{
Use: "logs <id-or-name>",
Short: "Show VM logs",
Args: exactArgsUsage(1, "usage: banger vm logs [-f] <id-or-name>"),
ValidArgsFunction: completeVMNameOnlyAtPos0,
ValidArgsFunction: d.completeVMNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -706,14 +706,14 @@ func newVMLogsCommand() *cobra.Command {
return cmd
}
func newVMStatsCommand() *cobra.Command {
func (d *deps) newVMStatsCommand() *cobra.Command {
return &cobra.Command{
Use: "stats <id-or-name>",
Short: "Show VM stats",
Args: exactArgsUsage(1, "usage: banger vm stats <id-or-name>"),
ValidArgsFunction: completeVMNameOnlyAtPos0,
ValidArgsFunction: d.completeVMNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -726,18 +726,18 @@ func newVMStatsCommand() *cobra.Command {
}
}
func newVMPortsCommand() *cobra.Command {
func (d *deps) newVMPortsCommand() *cobra.Command {
return &cobra.Command{
Use: "ports <id-or-name>",
Short: "Show host-reachable listening guest ports",
Args: exactArgsUsage(1, "usage: banger vm ports <id-or-name>"),
ValidArgsFunction: completeVMNameOnlyAtPos0,
ValidArgsFunction: d.completeVMNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := vmPortsFunc(cmd.Context(), layout.SocketPath, args[0])
result, err := d.vmPorts(cmd.Context(), layout.SocketPath, args[0])
if err != nil {
return err
}

View file

@ -15,7 +15,7 @@ import (
"github.com/spf13/cobra"
)
func newVMSessionCommand() *cobra.Command {
func (d *deps) newVMSessionCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "session",
Short: "Manage long-lived guest commands inside a VM",
@ -23,19 +23,19 @@ func newVMSessionCommand() *cobra.Command {
RunE: helpNoArgs,
}
cmd.AddCommand(
newVMSessionStartCommand(),
newVMSessionListCommand(),
newVMSessionShowCommand(),
newVMSessionLogsCommand(),
newVMSessionStopCommand(),
newVMSessionKillCommand(),
newVMSessionAttachCommand(),
newVMSessionSendCommand(),
d.newVMSessionStartCommand(),
d.newVMSessionListCommand(),
d.newVMSessionShowCommand(),
d.newVMSessionLogsCommand(),
d.newVMSessionStopCommand(),
d.newVMSessionKillCommand(),
d.newVMSessionAttachCommand(),
d.newVMSessionSendCommand(),
)
return cmd
}
func newVMSessionStartCommand() *cobra.Command {
func (d *deps) newVMSessionStartCommand() *cobra.Command {
var name string
var cwd string
var stdinMode string
@ -47,13 +47,13 @@ func newVMSessionStartCommand() *cobra.Command {
Short: "Start a managed guest command",
Long: "Start a daemon-managed guest command. The daemon verifies that the guest working directory exists and that the requested command is present in guest PATH before launch. Use --stdin-mode pipe when you need live attach.",
Args: minArgsUsage(2, "usage: banger vm session start <id-or-name> [flags] -- <command> [args...]"),
ValidArgsFunction: completeVMNameOnlyAtPos0,
ValidArgsFunction: d.completeVMNameOnlyAtPos0,
Example: strings.TrimSpace(`
banger vm session start devbox --name planner --cwd /root/repo --stdin-mode pipe --require-command git -- pi --mode rpc --no-session
banger vm session start devbox --name shell --stdin-mode pipe -- bash -lc 'exec bash'
`),
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -65,7 +65,7 @@ func newVMSessionStartCommand() *cobra.Command {
if err != nil {
return err
}
result, err := guestSessionStartFunc(cmd.Context(), layout.SocketPath, api.GuestSessionStartParams{
result, err := d.guestSessionStart(cmd.Context(), layout.SocketPath, api.GuestSessionStartParams{
VMIDOrName: args[0],
Name: name,
Command: args[1],
@ -97,19 +97,19 @@ func newVMSessionStartCommand() *cobra.Command {
return cmd
}
func newVMSessionListCommand() *cobra.Command {
func (d *deps) newVMSessionListCommand() *cobra.Command {
return &cobra.Command{
Use: "list <id-or-name>",
Aliases: []string{"ls"},
Short: "List managed guest commands for a VM",
Args: exactArgsUsage(1, "usage: banger vm session list <id-or-name>"),
ValidArgsFunction: completeVMNameOnlyAtPos0,
ValidArgsFunction: d.completeVMNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := guestSessionListFunc(cmd.Context(), layout.SocketPath, args[0])
result, err := d.guestSessionList(cmd.Context(), layout.SocketPath, args[0])
if err != nil {
return err
}
@ -118,18 +118,18 @@ func newVMSessionListCommand() *cobra.Command {
}
}
func newVMSessionShowCommand() *cobra.Command {
func (d *deps) newVMSessionShowCommand() *cobra.Command {
return &cobra.Command{
Use: "show <id-or-name> <session>",
Short: "Show managed guest command details",
Args: exactArgsUsage(2, "usage: banger vm session show <id-or-name> <session>"),
ValidArgsFunction: completeSessionNames,
ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := guestSessionGetFunc(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
result, err := d.guestSessionGet(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
if err != nil {
return err
}
@ -138,20 +138,20 @@ func newVMSessionShowCommand() *cobra.Command {
}
}
func newVMSessionLogsCommand() *cobra.Command {
func (d *deps) newVMSessionLogsCommand() *cobra.Command {
var stream string
var tailLines int
cmd := &cobra.Command{
Use: "logs <id-or-name> <session>",
Short: "Show stdout or stderr for a guest session",
Args: exactArgsUsage(2, "usage: banger vm session logs [--stream stdout|stderr] [-n LINES] <id-or-name> <session>"),
ValidArgsFunction: completeSessionNames,
ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := guestSessionLogsFunc(cmd.Context(), layout.SocketPath, api.GuestSessionLogsParams{VMIDOrName: args[0], SessionIDOrName: args[1], Stream: stream, TailLines: tailLines})
result, err := d.guestSessionLogs(cmd.Context(), layout.SocketPath, api.GuestSessionLogsParams{VMIDOrName: args[0], SessionIDOrName: args[1], Stream: stream, TailLines: tailLines})
if err != nil {
return err
}
@ -164,18 +164,18 @@ func newVMSessionLogsCommand() *cobra.Command {
return cmd
}
func newVMSessionStopCommand() *cobra.Command {
func (d *deps) newVMSessionStopCommand() *cobra.Command {
return &cobra.Command{
Use: "stop <id-or-name> <session>",
Short: "Send SIGTERM to a guest session",
Args: exactArgsUsage(2, "usage: banger vm session stop <id-or-name> <session>"),
ValidArgsFunction: completeSessionNames,
ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := guestSessionStopFunc(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
result, err := d.guestSessionStop(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
if err != nil {
return err
}
@ -184,18 +184,18 @@ func newVMSessionStopCommand() *cobra.Command {
}
}
func newVMSessionKillCommand() *cobra.Command {
func (d *deps) newVMSessionKillCommand() *cobra.Command {
return &cobra.Command{
Use: "kill <id-or-name> <session>",
Short: "Send SIGKILL to a guest session",
Args: exactArgsUsage(2, "usage: banger vm session kill <id-or-name> <session>"),
ValidArgsFunction: completeSessionNames,
ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := guestSessionKillFunc(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
result, err := d.guestSessionKill(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
if err != nil {
return err
}
@ -204,19 +204,19 @@ func newVMSessionKillCommand() *cobra.Command {
}
}
func newVMSessionAttachCommand() *cobra.Command {
func (d *deps) newVMSessionAttachCommand() *cobra.Command {
return &cobra.Command{
Use: "attach <id-or-name> <session>",
Short: "Attach local stdio to an attachable guest session",
Long: "Attach local stdio to a pipe-mode session through a daemon-created local Unix socket bridge. Only one active attach is allowed at a time, and the client must run on the same host as the daemon.",
Args: exactArgsUsage(2, "usage: banger vm session attach <id-or-name> <session>"),
ValidArgsFunction: completeSessionNames,
ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := guestSessionAttachBeginFunc(cmd.Context(), layout.SocketPath, api.GuestSessionAttachBeginParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
result, err := d.guestSessionAttachBegin(cmd.Context(), layout.SocketPath, api.GuestSessionAttachBeginParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
if err != nil {
return err
}
@ -229,21 +229,21 @@ func newVMSessionAttachCommand() *cobra.Command {
}
}
func newVMSessionSendCommand() *cobra.Command {
func (d *deps) newVMSessionSendCommand() *cobra.Command {
var message string
cmd := &cobra.Command{
Use: "send <id-or-name> <session>",
Short: "Write bytes to a running guest session's stdin pipe",
Long: "Write a payload to the stdin pipe of a running pipe-mode guest session without holding the exclusive attach. Use --message for an inline JSONL string, or pipe bytes via stdin when --message is omitted. A trailing newline is appended to --message values that lack one.",
Args: exactArgsUsage(2, "usage: banger vm session send <id-or-name> <session> [--message '<json>']"),
ValidArgsFunction: completeSessionNames,
ValidArgsFunction: d.completeSessionNames,
Example: strings.TrimSpace(`
banger vm session send devbox planner --message '{"type":"abort"}'
banger vm session send devbox planner --message '{"type":"steer","message":"Focus on src/"}'
echo '{"type":"prompt","prompt":"Summarize."}' | banger vm session send devbox planner
`),
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
@ -259,7 +259,7 @@ func newVMSessionSendCommand() *cobra.Command {
return fmt.Errorf("read stdin: %w", err)
}
}
result, err := guestSessionSendFunc(cmd.Context(), layout.SocketPath, api.GuestSessionSendParams{
result, err := d.guestSessionSend(cmd.Context(), layout.SocketPath, api.GuestSessionSendParams{
VMIDOrName: args[0],
SessionIDOrName: args[1],
Payload: payload,

View file

@ -21,9 +21,10 @@ import (
// - Fail silently. Completion is advisory; any error path returns an
// empty suggestion list rather than propagating to the user.
// completionListerFunc is the seam used by tests to avoid touching a
// real daemon socket.
var completionListerFunc = func(ctx context.Context, socketPath, method string) ([]string, error) {
// defaultCompletionLister + defaultCompletionSessionLister back the
// corresponding *deps fields; tests inject their own fakes via the
// struct instead of mutating package-level vars.
func defaultCompletionLister(ctx context.Context, socketPath, method string) ([]string, error) {
switch method {
case "vm.list":
result, err := rpc.Call[api.VMListResult](ctx, socketPath, method, api.Empty{})
@ -65,9 +66,7 @@ var completionListerFunc = func(ctx context.Context, socketPath, method string)
return nil, nil
}
// completionSessionListerFunc is the seam for guest-session name lookups
// scoped to a VM.
var completionSessionListerFunc = func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) {
func defaultCompletionSessionLister(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) {
result, err := rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: vmIDOrName})
if err != nil {
return nil, err
@ -84,12 +83,12 @@ var completionSessionListerFunc = func(ctx context.Context, socketPath, vmIDOrNa
// daemonSocketForCompletion returns the socket path IFF the daemon is
// already running. Returns "", false when no daemon is up — completion
// callers use this as the bail signal.
func daemonSocketForCompletion(ctx context.Context) (string, bool) {
func (d *deps) daemonSocketForCompletion(ctx context.Context) (string, bool) {
layout, err := paths.Resolve()
if err != nil {
return "", false
}
if _, err := daemonPingFunc(ctx, layout.SocketPath); err != nil {
if _, err := d.daemonPing(ctx, layout.SocketPath); err != nil {
return "", false
}
return layout.SocketPath, true
@ -119,12 +118,12 @@ func hasPrefix(s, prefix string) bool {
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
}
func completeVMNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
socket, ok := daemonSocketForCompletion(cmd.Context())
func (d *deps) completeVMNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
socket, ok := d.daemonSocketForCompletion(cmd.Context())
if !ok {
return nil, cobra.ShellCompDirectiveNoFileComp
}
names, err := completionListerFunc(cmd.Context(), socket, "vm.list")
names, err := d.completionLister(cmd.Context(), socket, "vm.list")
if err != nil {
return nil, cobra.ShellCompDirectiveNoFileComp
}
@ -134,45 +133,45 @@ func completeVMNames(cmd *cobra.Command, args []string, toComplete string) ([]st
// completeVMNameOnlyAtPos0 restricts VM-name completion to the first
// positional argument. Used by commands like `vm ssh <vm> [ssh args...]`
// where args after pos 0 are free-form.
func completeVMNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
func (d *deps) completeVMNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
if len(args) > 0 {
return nil, cobra.ShellCompDirectiveNoFileComp
}
return completeVMNames(cmd, args, toComplete)
return d.completeVMNames(cmd, args, toComplete)
}
func completeImageNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
func (d *deps) completeImageNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
if len(args) > 0 {
return nil, cobra.ShellCompDirectiveNoFileComp
}
return completeImageNames(cmd, args, toComplete)
return d.completeImageNames(cmd, args, toComplete)
}
func completeKernelNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
func (d *deps) completeKernelNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
if len(args) > 0 {
return nil, cobra.ShellCompDirectiveNoFileComp
}
return completeKernelNames(cmd, args, toComplete)
return d.completeKernelNames(cmd, args, toComplete)
}
func completeImageNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
socket, ok := daemonSocketForCompletion(cmd.Context())
func (d *deps) completeImageNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
socket, ok := d.daemonSocketForCompletion(cmd.Context())
if !ok {
return nil, cobra.ShellCompDirectiveNoFileComp
}
names, err := completionListerFunc(cmd.Context(), socket, "image.list")
names, err := d.completionLister(cmd.Context(), socket, "image.list")
if err != nil {
return nil, cobra.ShellCompDirectiveNoFileComp
}
return filterPrefix(names, args, toComplete), cobra.ShellCompDirectiveNoFileComp
}
func completeKernelNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
socket, ok := daemonSocketForCompletion(cmd.Context())
func (d *deps) completeKernelNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
socket, ok := d.daemonSocketForCompletion(cmd.Context())
if !ok {
return nil, cobra.ShellCompDirectiveNoFileComp
}
names, err := completionListerFunc(cmd.Context(), socket, "kernel.list")
names, err := d.completionLister(cmd.Context(), socket, "kernel.list")
if err != nil {
return nil, cobra.ShellCompDirectiveNoFileComp
}
@ -182,16 +181,16 @@ func completeKernelNames(cmd *cobra.Command, args []string, toComplete string) (
// completeSessionNames handles `... <vm> <session>` commands: pos 0
// completes VMs, pos 1 completes sessions owned by args[0], pos 2+ is
// silent.
func completeSessionNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
func (d *deps) completeSessionNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) {
case 0:
return completeVMNames(cmd, args, toComplete)
return d.completeVMNames(cmd, args, toComplete)
case 1:
socket, ok := daemonSocketForCompletion(cmd.Context())
socket, ok := d.daemonSocketForCompletion(cmd.Context())
if !ok {
return nil, cobra.ShellCompDirectiveNoFileComp
}
names, err := completionSessionListerFunc(cmd.Context(), socket, args[0])
names, err := d.completionSessionLister(cmd.Context(), socket, args[0])
if err != nil {
return nil, cobra.ShellCompDirectiveNoFileComp
}

View file

@ -12,10 +12,11 @@ import (
)
// stubCompletionSeams installs test doubles for the daemon ping + lister
// seams and restores the originals on cleanup. Tests opt into the
// sub-functions they actually need.
// seams on the caller's *deps. Tests opt into the sub-functions they
// actually need.
func stubCompletionSeams(
t *testing.T,
d *deps,
pingErr error,
names map[string][]string,
listErr error,
@ -24,28 +25,19 @@ func stubCompletionSeams(
) {
t.Helper()
origPing := daemonPingFunc
origLister := completionListerFunc
origSessionLister := completionSessionListerFunc
t.Cleanup(func() {
daemonPingFunc = origPing
completionListerFunc = origLister
completionSessionListerFunc = origSessionLister
})
daemonPingFunc = func(ctx context.Context, socketPath string) (api.PingResult, error) {
d.daemonPing = func(ctx context.Context, socketPath string) (api.PingResult, error) {
if pingErr != nil {
return api.PingResult{}, pingErr
}
return api.PingResult{}, nil
}
completionListerFunc = func(ctx context.Context, socketPath, method string) ([]string, error) {
d.completionLister = func(ctx context.Context, socketPath, method string) ([]string, error) {
if listErr != nil {
return nil, listErr
}
return names[method], nil
}
completionSessionListerFunc = func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) {
d.completionSessionLister = func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) {
if sessionErr != nil {
return nil, sessionErr
}
@ -89,9 +81,10 @@ func testCmdWithCtx() *cobra.Command {
}
func TestCompleteVMNamesHappyPath(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil)
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil)
got, directive := completeVMNames(testCmdWithCtx(), nil, "")
got, directive := d.completeVMNames(testCmdWithCtx(), nil, "")
if directive != cobra.ShellCompDirectiveNoFileComp {
t.Errorf("directive = %d, want NoFileComp", directive)
}
@ -101,9 +94,10 @@ func TestCompleteVMNamesHappyPath(t *testing.T) {
}
func TestCompleteVMNamesDaemonDown(t *testing.T) {
stubCompletionSeams(t, errors.New("connection refused"), nil, nil, nil, nil)
d := defaultDeps()
stubCompletionSeams(t, d, errors.New("connection refused"), nil, nil, nil, nil)
got, directive := completeVMNames(testCmdWithCtx(), nil, "")
got, directive := d.completeVMNames(testCmdWithCtx(), nil, "")
if len(got) != 0 {
t.Errorf("daemon-down should return no suggestions, got %v", got)
}
@ -113,18 +107,20 @@ func TestCompleteVMNamesDaemonDown(t *testing.T) {
}
func TestCompleteVMNamesRPCError(t *testing.T) {
stubCompletionSeams(t, nil, nil, errors.New("rpc failed"), nil, nil)
d := defaultDeps()
stubCompletionSeams(t, d, nil, nil, errors.New("rpc failed"), nil, nil)
got, _ := completeVMNames(testCmdWithCtx(), nil, "")
got, _ := d.completeVMNames(testCmdWithCtx(), nil, "")
if len(got) != 0 {
t.Errorf("rpc error should return no suggestions, got %v", got)
}
}
func TestCompleteVMNamesExcludesAlreadyEntered(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil)
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil)
got, _ := completeVMNames(testCmdWithCtx(), []string{"alpha"}, "")
got, _ := d.completeVMNames(testCmdWithCtx(), []string{"alpha"}, "")
want := []string{"beta", "gamma"}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
@ -132,9 +128,10 @@ func TestCompleteVMNamesExcludesAlreadyEntered(t *testing.T) {
}
func TestCompleteVMNamesPrefixFilter(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"vm.list": {"alpha", "beta", "alphabet"}}, nil, nil, nil)
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "alphabet"}}, nil, nil, nil)
got, _ := completeVMNames(testCmdWithCtx(), nil, "alp")
got, _ := d.completeVMNames(testCmdWithCtx(), nil, "alp")
want := []string{"alpha", "alphabet"}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
@ -142,49 +139,53 @@ func TestCompleteVMNamesPrefixFilter(t *testing.T) {
}
func TestCompleteVMNameOnlyAtPos0(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"vm.list": {"alpha"}}, nil, nil, nil)
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha"}}, nil, nil, nil)
atPos0, _ := completeVMNameOnlyAtPos0(testCmdWithCtx(), nil, "")
atPos0, _ := d.completeVMNameOnlyAtPos0(testCmdWithCtx(), nil, "")
if len(atPos0) != 1 || atPos0[0] != "alpha" {
t.Errorf("pos 0: got %v", atPos0)
}
atPos1, _ := completeVMNameOnlyAtPos0(testCmdWithCtx(), []string{"alpha"}, "")
atPos1, _ := d.completeVMNameOnlyAtPos0(testCmdWithCtx(), []string{"alpha"}, "")
if len(atPos1) != 0 {
t.Errorf("pos 1+ should be silent, got %v", atPos1)
}
}
func TestCompleteImageNames(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"image.list": {"debian-bookworm", "alpine"}}, nil, nil, nil)
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"debian-bookworm", "alpine"}}, nil, nil, nil)
got, _ := completeImageNames(testCmdWithCtx(), nil, "")
got, _ := d.completeImageNames(testCmdWithCtx(), nil, "")
if !reflect.DeepEqual(got, []string{"debian-bookworm", "alpine"}) {
t.Errorf("got %v", got)
}
}
func TestCompleteKernelNames(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"kernel.list": {"generic-6.12"}}, nil, nil, nil)
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"kernel.list": {"generic-6.12"}}, nil, nil, nil)
got, _ := completeKernelNames(testCmdWithCtx(), nil, "")
got, _ := d.completeKernelNames(testCmdWithCtx(), nil, "")
if len(got) != 1 || got[0] != "generic-6.12" {
t.Errorf("got %v", got)
}
}
func TestCompleteImageNameOnlyAtPos0SilentAfterFirst(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"image.list": {"alpine"}}, nil, nil, nil)
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"alpine"}}, nil, nil, nil)
after, _ := completeImageNameOnlyAtPos0(testCmdWithCtx(), []string{"alpine"}, "")
after, _ := d.completeImageNameOnlyAtPos0(testCmdWithCtx(), []string{"alpine"}, "")
if len(after) != 0 {
t.Errorf("expected silence at pos 1+, got %v", after)
}
}
func TestCompleteSessionNames(t *testing.T) {
stubCompletionSeams(
t,
d := defaultDeps()
stubCompletionSeams(t, d,
nil,
map[string][]string{"vm.list": {"devbox"}},
nil,
@ -193,34 +194,35 @@ func TestCompleteSessionNames(t *testing.T) {
)
// Position 0 → VMs.
vms, _ := completeSessionNames(testCmdWithCtx(), nil, "")
vms, _ := d.completeSessionNames(testCmdWithCtx(), nil, "")
if len(vms) != 1 || vms[0] != "devbox" {
t.Errorf("pos 0: got %v", vms)
}
// Position 1 → sessions scoped to args[0].
sessions, _ := completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "")
sessions, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "")
if !reflect.DeepEqual(sessions, []string{"planner", "worker"}) {
t.Errorf("pos 1: got %v", sessions)
}
// Position 1 with prefix filter.
filtered, _ := completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "wor")
filtered, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "wor")
if len(filtered) != 1 || filtered[0] != "worker" {
t.Errorf("pos 1 prefix: got %v", filtered)
}
// Position 2+ silent.
past, _ := completeSessionNames(testCmdWithCtx(), []string{"devbox", "planner"}, "")
past, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox", "planner"}, "")
if len(past) != 0 {
t.Errorf("pos 2+: got %v", past)
}
}
func TestCompleteSessionNamesDaemonDown(t *testing.T) {
stubCompletionSeams(t, errors.New("down"), nil, nil, nil, nil)
d := defaultDeps()
stubCompletionSeams(t, d, errors.New("down"), nil, nil, nil, nil)
got, directive := completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "")
got, directive := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "")
if len(got) != 0 {
t.Errorf("expected no suggestions when daemon down, got %v", got)
}

View file

@ -18,7 +18,7 @@ import (
// ensureDaemon pings the socket; on miss it auto-starts bangerd, on
// version mismatch it restarts. Every CLI command that needs to talk
// to the daemon routes through here.
func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error) {
func (d *deps) ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error) {
layout, err := paths.Resolve()
if err != nil {
return paths.Layout{}, model.DaemonConfig{}, err
@ -27,16 +27,16 @@ func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error)
if err != nil {
return paths.Layout{}, model.DaemonConfig{}, err
}
if ping, err := daemonPingFunc(ctx, layout.SocketPath); err == nil {
if daemonOutdated(ping.PID) {
if err := restartDaemon(ctx, layout, ping.PID); err != nil {
if ping, err := d.daemonPing(ctx, layout.SocketPath); err == nil {
if d.daemonOutdated(ping.PID) {
if err := d.restartDaemon(ctx, layout, ping.PID); err != nil {
return paths.Layout{}, model.DaemonConfig{}, err
}
return layout, cfg, nil
}
return layout, cfg, nil
}
if err := startDaemon(ctx, layout); err != nil {
if err := d.startDaemon(ctx, layout); err != nil {
return paths.Layout{}, model.DaemonConfig{}, err
}
return layout, cfg, nil
@ -47,11 +47,11 @@ func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error)
// session still holds a handle to an old daemon. os.SameFile compares
// inode + dev, so a fresh binary at the same path registers as
// different.
func daemonOutdated(pid int) bool {
func (d *deps) daemonOutdated(pid int) bool {
if pid <= 0 {
return false
}
daemonBin, err := bangerdPathFunc()
daemonBin, err := d.bangerdPath()
if err != nil {
return false
}
@ -59,20 +59,20 @@ func daemonOutdated(pid int) bool {
if err != nil {
return false
}
runningInfo, err := os.Stat(daemonExePath(pid))
runningInfo, err := os.Stat(d.daemonExePath(pid))
if err != nil {
return false
}
return !os.SameFile(currentInfo, runningInfo)
}
func restartDaemon(ctx context.Context, layout paths.Layout, pid int) error {
func (d *deps) restartDaemon(ctx context.Context, layout paths.Layout, pid int) error {
stopCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
_, _ = rpc.Call[api.ShutdownResult](stopCtx, layout.SocketPath, "shutdown", api.Empty{})
if waitForPIDExit(pid, 2*time.Second) {
return startDaemon(ctx, layout)
return d.startDaemon(ctx, layout)
}
if proc, err := os.FindProcess(pid); err == nil {
_ = proc.Signal(syscall.SIGTERM)
@ -80,7 +80,7 @@ func restartDaemon(ctx context.Context, layout paths.Layout, pid int) error {
if !waitForPIDExit(pid, 2*time.Second) {
return fmt.Errorf("timed out restarting stale daemon pid %d", pid)
}
return startDaemon(ctx, layout)
return d.startDaemon(ctx, layout)
}
func waitForPIDExit(pid int, timeout time.Duration) bool {
@ -105,7 +105,7 @@ func pidRunning(pid int) bool {
return proc.Signal(syscall.Signal(0)) == nil
}
func startDaemon(ctx context.Context, layout paths.Layout) error {
func (d *deps) startDaemon(ctx context.Context, layout paths.Layout) error {
if err := paths.Ensure(layout); err != nil {
return err
}

165
internal/cli/deps.go Normal file
View file

@ -0,0 +1,165 @@
package cli
import (
"context"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"banger/internal/api"
"banger/internal/daemon"
"banger/internal/guest"
"banger/internal/paths"
"banger/internal/rpc"
"banger/internal/system"
"banger/internal/toolingplan"
)
// deps holds the function seams production code dispatches through and
// tests replace with fakes. Keeping these on a per-invocation struct
// (instead of package-level mutable vars) makes the CLI's external
// surface explicit and lets tests run in parallel without leaking fakes
// across test cases.
//
// Every command builder, orchestrator, and helper that touches the RPC
// socket, spawns a subprocess, or reads host state hangs off a *deps
// receiver. Pure helpers (formatters, path resolvers, arg-count
// validators) stay package-level because they hold no references to
// external systems.
type deps struct {
bangerdPath func() (string, error)
daemonExePath func(pid int) string
doctor func(ctx context.Context) (system.Report, error)
sshExec func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error
hostCommandOutput func(ctx context.Context, name string, args ...string) ([]byte, error)
vmHealth func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error)
vmSSH func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error)
vmDelete func(ctx context.Context, socketPath, idOrName string) error
vmList func(ctx context.Context, socketPath string) (api.VMListResult, error)
daemonPing func(ctx context.Context, socketPath string) (api.PingResult, error)
vmCreateBegin func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error)
vmCreateStatus func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error)
vmCreateCancel func(ctx context.Context, socketPath, operationID string) error
vmPorts func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error)
vmWorkspacePrepare func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error)
vmWorkspaceExport func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error)
guestSessionStart func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error)
guestSessionGet func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error)
guestSessionList func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error)
guestSessionStop func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error)
guestSessionKill func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error)
guestSessionLogs func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error)
guestSessionAttachBegin func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error)
guestSessionSend func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error)
guestWaitForSSH func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error
guestDial func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error)
buildVMRunToolingPlan func(ctx context.Context, repoRoot string) toolingplan.Plan
cwd func() (string, error)
completionLister func(ctx context.Context, socketPath, method string) ([]string, error)
completionSessionLister func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error)
}
func defaultDeps() *deps {
return &deps{
bangerdPath: paths.BangerdPath,
daemonExePath: func(pid int) string {
return filepath.Join("/proc", fmt.Sprintf("%d", pid), "exe")
},
doctor: daemon.Doctor,
sshExec: func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
sshCmd := exec.CommandContext(ctx, "ssh", args...)
sshCmd.Stdout = stdout
sshCmd.Stderr = stderr
sshCmd.Stdin = stdin
return sshCmd.Run()
},
hostCommandOutput: func(ctx context.Context, name string, args ...string) ([]byte, error) {
cmd := exec.CommandContext(ctx, name, args...)
output, err := cmd.CombinedOutput()
if err == nil {
return output, nil
}
command := strings.TrimSpace(strings.Join(append([]string{name}, args...), " "))
detail := strings.TrimSpace(string(output))
if detail == "" {
return output, fmt.Errorf("%s: %w", command, err)
}
return output, fmt.Errorf("%s: %w: %s", command, err, detail)
},
vmHealth: func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
return rpc.Call[api.VMHealthResult](ctx, socketPath, "vm.health", api.VMRefParams{IDOrName: idOrName})
},
vmSSH: func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error) {
return rpc.Call[api.VMSSHResult](ctx, socketPath, "vm.ssh", api.VMRefParams{IDOrName: idOrName})
},
vmDelete: func(ctx context.Context, socketPath, idOrName string) error {
_, err := rpc.Call[api.VMShowResult](ctx, socketPath, "vm.delete", api.VMRefParams{IDOrName: idOrName})
return err
},
vmList: func(ctx context.Context, socketPath string) (api.VMListResult, error) {
return rpc.Call[api.VMListResult](ctx, socketPath, "vm.list", api.Empty{})
},
daemonPing: func(ctx context.Context, socketPath string) (api.PingResult, error) {
return rpc.Call[api.PingResult](ctx, socketPath, "ping", api.Empty{})
},
vmCreateBegin: func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error) {
return rpc.Call[api.VMCreateBeginResult](ctx, socketPath, "vm.create.begin", params)
},
vmCreateStatus: func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error) {
return rpc.Call[api.VMCreateStatusResult](ctx, socketPath, "vm.create.status", api.VMCreateStatusParams{ID: operationID})
},
vmCreateCancel: func(ctx context.Context, socketPath, operationID string) error {
_, err := rpc.Call[api.Empty](ctx, socketPath, "vm.create.cancel", api.VMCreateStatusParams{ID: operationID})
return err
},
vmPorts: func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) {
return rpc.Call[api.VMPortsResult](ctx, socketPath, "vm.ports", api.VMRefParams{IDOrName: idOrName})
},
vmWorkspacePrepare: func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
return rpc.Call[api.VMWorkspacePrepareResult](ctx, socketPath, "vm.workspace.prepare", params)
},
vmWorkspaceExport: func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return rpc.Call[api.WorkspaceExportResult](ctx, socketPath, "vm.workspace.export", params)
},
guestSessionStart: func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.start", params)
},
guestSessionGet: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.get", params)
},
guestSessionList: func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error) {
return rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: idOrName})
},
guestSessionStop: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.stop", params)
},
guestSessionKill: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.kill", params)
},
guestSessionLogs: func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) {
return rpc.Call[api.GuestSessionLogsResult](ctx, socketPath, "guest.session.logs", params)
},
guestSessionAttachBegin: func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) {
return rpc.Call[api.GuestSessionAttachBeginResult](ctx, socketPath, "guest.session.attach.begin", params)
},
guestSessionSend: func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params)
},
guestWaitForSSH: func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
knownHosts, _ := bangerKnownHostsPath()
return guest.WaitForSSH(ctx, address, privateKeyPath, knownHosts, interval)
},
guestDial: func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
knownHosts, _ := bangerKnownHostsPath()
return guest.Dial(ctx, address, privateKeyPath, knownHosts)
},
buildVMRunToolingPlan: toolingplan.Build,
cwd: os.Getwd,
completionLister: defaultCompletionLister,
completionSessionLister: defaultCompletionSessionLister,
}
}

View file

@ -14,22 +14,17 @@ import (
"github.com/spf13/cobra"
)
// stubPruneSeams installs fakes for vmListFunc and vmDeleteFunc, and
// restores originals on cleanup.
func stubPruneSeams(t *testing.T, vms []model.VMRecord, listErr error, deleteErr map[string]error) *[]string {
// stubPruneSeams installs list + delete fakes onto the caller's *deps
// and returns a pointer to a slice that records every ID passed to the
// delete fake.
func stubPruneSeams(t *testing.T, d *deps, vms []model.VMRecord, listErr error, deleteErr map[string]error) *[]string {
t.Helper()
origList := vmListFunc
origDelete := vmDeleteFunc
t.Cleanup(func() {
vmListFunc = origList
vmDeleteFunc = origDelete
})
var deleted []string
vmListFunc = func(ctx context.Context, socketPath string) (api.VMListResult, error) {
d.vmList = func(ctx context.Context, socketPath string) (api.VMListResult, error) {
return api.VMListResult{VMs: vms}, listErr
}
vmDeleteFunc = func(ctx context.Context, socketPath, idOrName string) error {
d.vmDelete = func(ctx context.Context, socketPath, idOrName string) error {
if err, ok := deleteErr[idOrName]; ok {
return err
}
@ -89,13 +84,14 @@ func TestPromptYesNoEOF(t *testing.T) {
}
func TestRunVMPruneNoVictims(t *testing.T) {
stubPruneSeams(t, []model.VMRecord{
d := defaultDeps()
stubPruneSeams(t, d, []model.VMRecord{
{ID: "id-1", Name: "running-vm", State: model.VMStateRunning},
}, nil, nil)
cmd, stdout, _ := newPruneTestCmd("")
if err := runVMPrune(cmd, "sock", false); err != nil {
t.Fatalf("runVMPrune: %v", err)
if err := d.runVMPrune(cmd, "sock", false); err != nil {
t.Fatalf("d.runVMPrune: %v", err)
}
if !strings.Contains(stdout.String(), "no non-running VMs") {
t.Errorf("expected no-op message, got %q", stdout.String())
@ -103,13 +99,14 @@ func TestRunVMPruneNoVictims(t *testing.T) {
}
func TestRunVMPruneAbortedByUser(t *testing.T) {
deleted := stubPruneSeams(t, []model.VMRecord{
d := defaultDeps()
deleted := stubPruneSeams(t, d, []model.VMRecord{
{ID: "id-1", Name: "stale", State: model.VMStateStopped},
}, nil, nil)
cmd, stdout, _ := newPruneTestCmd("n\n")
if err := runVMPrune(cmd, "sock", false); err != nil {
t.Fatalf("runVMPrune: %v", err)
if err := d.runVMPrune(cmd, "sock", false); err != nil {
t.Fatalf("d.runVMPrune: %v", err)
}
if !strings.Contains(stdout.String(), "aborted") {
t.Errorf("expected 'aborted' output, got %q", stdout.String())
@ -120,7 +117,8 @@ func TestRunVMPruneAbortedByUser(t *testing.T) {
}
func TestRunVMPruneConfirmedDeletesNonRunning(t *testing.T) {
deleted := stubPruneSeams(t, []model.VMRecord{
d := defaultDeps()
deleted := stubPruneSeams(t, d, []model.VMRecord{
{ID: "id-run", Name: "keeper", State: model.VMStateRunning},
{ID: "id-stop", Name: "stale", State: model.VMStateStopped},
{ID: "id-err", Name: "broken", State: model.VMStateError},
@ -128,8 +126,8 @@ func TestRunVMPruneConfirmedDeletesNonRunning(t *testing.T) {
}, nil, nil)
cmd, stdout, _ := newPruneTestCmd("y\n")
if err := runVMPrune(cmd, "sock", false); err != nil {
t.Fatalf("runVMPrune: %v", err)
if err := d.runVMPrune(cmd, "sock", false); err != nil {
t.Fatalf("d.runVMPrune: %v", err)
}
// Deleted must be exactly the three non-running IDs, in list order.
want := []string{"id-stop", "id-err", "id-created"}
@ -152,14 +150,15 @@ func TestRunVMPruneConfirmedDeletesNonRunning(t *testing.T) {
}
func TestRunVMPruneForceSkipsPrompt(t *testing.T) {
deleted := stubPruneSeams(t, []model.VMRecord{
d := defaultDeps()
deleted := stubPruneSeams(t, d, []model.VMRecord{
{ID: "id-1", Name: "stale", State: model.VMStateStopped},
}, nil, nil)
// Empty stdin + force=true: must not block on prompt.
cmd, stdout, _ := newPruneTestCmd("")
if err := runVMPrune(cmd, "sock", true); err != nil {
t.Fatalf("runVMPrune: %v", err)
if err := d.runVMPrune(cmd, "sock", true); err != nil {
t.Fatalf("d.runVMPrune: %v", err)
}
if len(*deleted) != 1 || (*deleted)[0] != "id-1" {
t.Errorf("deleted = %v, want [id-1]", *deleted)
@ -171,7 +170,8 @@ func TestRunVMPruneForceSkipsPrompt(t *testing.T) {
}
func TestRunVMPruneReportsPartialFailure(t *testing.T) {
stubPruneSeams(t,
d := defaultDeps()
stubPruneSeams(t, d,
[]model.VMRecord{
{ID: "id-a", Name: "a", State: model.VMStateStopped},
{ID: "id-b", Name: "b", State: model.VMStateStopped},
@ -181,7 +181,7 @@ func TestRunVMPruneReportsPartialFailure(t *testing.T) {
)
cmd, _, stderr := newPruneTestCmd("")
err := runVMPrune(cmd, "sock", true)
err := d.runVMPrune(cmd, "sock", true)
if err == nil {
t.Fatal("expected non-zero exit when any delete fails")
}
@ -194,10 +194,11 @@ func TestRunVMPruneReportsPartialFailure(t *testing.T) {
}
func TestRunVMPruneListErrorPropagates(t *testing.T) {
stubPruneSeams(t, nil, fmt.Errorf("rpc failed"), nil)
d := defaultDeps()
stubPruneSeams(t, d, nil, fmt.Errorf("rpc failed"), nil)
cmd, _, _ := newPruneTestCmd("")
err := runVMPrune(cmd, "sock", true)
err := d.runVMPrune(cmd, "sock", true)
if err == nil || !strings.Contains(err.Error(), "rpc failed") {
t.Fatalf("expected rpc error to propagate, got %v", err)
}

View file

@ -20,14 +20,14 @@ import (
// the caller asked (e.g. --rm is about to delete the VM), if the
// ctx is already done, or if the ssh error isn't the one that
// typically means "user disconnected cleanly".
func runSSHSession(ctx context.Context, socketPath, vmRef string, stdin io.Reader, stdout, stderr io.Writer, sshArgs []string, skipReminder bool) error {
sshErr := sshExecFunc(ctx, stdin, stdout, stderr, sshArgs)
func (d *deps) runSSHSession(ctx context.Context, socketPath, vmRef string, stdin io.Reader, stdout, stderr io.Writer, sshArgs []string, skipReminder bool) error {
sshErr := d.sshExec(ctx, stdin, stdout, stderr, sshArgs)
if skipReminder || !shouldCheckSSHReminder(sshErr) || ctx.Err() != nil {
return sshErr
}
pingCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
health, err := vmHealthFunc(pingCtx, socketPath, vmRef)
health, err := d.vmHealth(pingCtx, socketPath, vmRef)
if err != nil {
_, _ = fmt.Fprintln(stderr, vsockagent.WarningMessage(vmRef, err))
return sshErr

View file

@ -60,9 +60,9 @@ func printVMSpecLine(out io.Writer, params api.VMCreateParams) {
// gets the spec line up front and the progress renderer thereafter.
// On context cancel we cooperate with the daemon to cancel the
// in-flight op so it doesn't leak partially-created VM state.
func runVMCreate(ctx context.Context, socketPath string, stderr io.Writer, params api.VMCreateParams) (model.VMRecord, error) {
func (d *deps) runVMCreate(ctx context.Context, socketPath string, stderr io.Writer, params api.VMCreateParams) (model.VMRecord, error) {
printVMSpecLine(stderr, params)
begin, err := vmCreateBeginFunc(ctx, socketPath, params)
begin, err := d.vmCreateBegin(ctx, socketPath, params)
if err != nil {
return model.VMRecord{}, err
}
@ -86,17 +86,17 @@ func runVMCreate(ctx context.Context, socketPath string, stderr io.Writer, param
case <-ctx.Done():
cancelCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_ = vmCreateCancelFunc(cancelCtx, socketPath, op.ID)
_ = d.vmCreateCancel(cancelCtx, socketPath, op.ID)
return model.VMRecord{}, ctx.Err()
case <-time.After(200 * time.Millisecond):
}
status, err := vmCreateStatusFunc(ctx, socketPath, op.ID)
status, err := d.vmCreateStatus(ctx, socketPath, op.ID)
if err != nil {
if ctx.Err() != nil {
cancelCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_ = vmCreateCancelFunc(cancelCtx, socketPath, op.ID)
_ = d.vmCreateCancel(cancelCtx, socketPath, op.ID)
return model.VMRecord{}, ctx.Err()
}
return model.VMRecord{}, err

View file

@ -80,9 +80,9 @@ func (e ExitCodeError) Error() string {
// - it sits inside a non-bare git repository,
// - the repository has no submodules (unsupported in the shallow
// overlay mode vm run uses).
func vmRunPreflightRepo(ctx context.Context, rawPath string) (string, error) {
func (d *deps) vmRunPreflightRepo(ctx context.Context, rawPath string) (string, error) {
if strings.TrimSpace(rawPath) == "" {
wd, err := cwdFunc()
wd, err := d.cwd()
if err != nil {
return "", err
}
@ -131,9 +131,9 @@ func splitVMRunArgs(cmd *cobra.Command, args []string) (pathArgs, commandArgs []
// for guest ssh, optionally materialise a workspace and kick off the
// tooling bootstrap, then either attach interactively or run the
// user's command and propagate its exit status.
func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, stdin io.Reader, stdout, stderr io.Writer, params api.VMCreateParams, repo *vmRunRepo, command []string, removeOnExit bool) error {
func (d *deps) runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, stdin io.Reader, stdout, stderr io.Writer, params api.VMCreateParams, repo *vmRunRepo, command []string, removeOnExit bool) error {
progress := newVMRunProgressRenderer(stderr)
vm, err := runVMCreate(ctx, socketPath, stderr, params)
vm, err := d.runVMCreate(ctx, socketPath, stderr, params)
if err != nil {
return err
}
@ -155,7 +155,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st
// doesn't abort the delete RPC.
cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := vmDeleteFunc(cleanupCtx, socketPath, vmRef); err != nil {
if err := d.vmDelete(cleanupCtx, socketPath, vmRef); err != nil {
printVMRunWarning(stderr, fmt.Sprintf("--rm cleanup failed: %v (leaked vm %q; delete manually)", err, vmRef))
}
}()
@ -163,7 +163,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st
sshAddress := net.JoinHostPort(vm.Runtime.GuestIP, "22")
progress.render("waiting for guest ssh")
sshCtx, cancelSSH := context.WithTimeout(ctx, vmRunSSHTimeout)
if err := guestWaitForSSHFunc(sshCtx, sshAddress, cfg.SSHKeyPath, 250*time.Millisecond); err != nil {
if err := d.guestWaitForSSH(sshCtx, sshAddress, cfg.SSHKeyPath, 250*time.Millisecond); err != nil {
cancelSSH()
// Surface parent-context cancellation (Ctrl-C, caller
// timeout) as-is. Only the guest-side timeout needs the
@ -193,7 +193,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st
if strings.TrimSpace(repo.branchName) != "" {
fromRef = repo.fromRef
}
prepared, err := vmWorkspacePrepareFunc(ctx, socketPath, api.VMWorkspacePrepareParams{
prepared, err := d.vmWorkspacePrepare(ctx, socketPath, api.VMWorkspacePrepareParams{
IDOrName: vmRef,
SourcePath: repo.sourcePath,
GuestPath: vmRunGuestDir(),
@ -208,11 +208,11 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st
// daemon side; grab what the tooling harness needs from its
// result instead of re-inspecting here.
if len(command) == 0 {
client, err := guestDialFunc(ctx, sshAddress, cfg.SSHKeyPath)
client, err := d.guestDial(ctx, sshAddress, cfg.SSHKeyPath)
if err != nil {
return fmt.Errorf("vm %q is running but guest ssh is unavailable: %w", vmRef, err)
}
if err := startVMRunToolingHarness(ctx, client, prepared.Workspace.RepoRoot, prepared.Workspace.RepoName, progress); err != nil {
if err := d.startVMRunToolingHarness(ctx, client, prepared.Workspace.RepoRoot, prepared.Workspace.RepoName, progress); err != nil {
printVMRunWarning(stderr, fmt.Sprintf("guest tooling bootstrap start failed: %v", err))
}
_ = client.Close()
@ -224,7 +224,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st
}
if len(command) > 0 {
progress.render("running command in guest")
if err := sshExecFunc(ctx, stdin, stdout, stderr, sshArgs); err != nil {
if err := d.sshExec(ctx, stdin, stdout, stderr, sshArgs); err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
return ExitCodeError{Code: exitErr.ExitCode()}
@ -234,7 +234,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st
return nil
}
progress.render("attaching to guest")
return runSSHSession(ctx, socketPath, vmRef, stdin, stdout, stderr, sshArgs, removeOnExit)
return d.runSSHSession(ctx, socketPath, vmRef, stdin, stdout, stderr, sshArgs, removeOnExit)
}
func vmRunGuestDir() string {
@ -253,11 +253,11 @@ func vmRunToolingHarnessLogPath(repoName string) string {
// script inside the guest. repoRoot / repoName both come from the
// daemon's workspace.prepare RPC response — the CLI no longer does
// its own git inspection.
func startVMRunToolingHarness(ctx context.Context, client vmRunGuestClient, repoRoot, repoName string, progress *vmRunProgressRenderer) error {
func (d *deps) startVMRunToolingHarness(ctx context.Context, client vmRunGuestClient, repoRoot, repoName string, progress *vmRunProgressRenderer) error {
if progress != nil {
progress.render("starting guest tooling bootstrap")
}
plan := buildVMRunToolingPlanFunc(ctx, repoRoot)
plan := d.buildVMRunToolingPlan(ctx, repoRoot)
var uploadLog bytes.Buffer
if err := client.UploadFile(ctx, vmRunToolingHarnessPath(repoName), 0o755, []byte(vmRunToolingHarnessScript(plan)), &uploadLog); err != nil {
return formatVMRunStepError("upload guest tooling bootstrap", err, uploadLog.String())