cli + daemon: move test seams off package globals onto injected structs

CLI: introduce internal/cli.deps which owns every RPC/SSH/host-command
seam the tree used to reach through mutable package vars. Command
builders, orchestrators, and the completion helpers become methods on
*deps. Tests construct their own deps per case, so fakes no longer leak
across cases and tests are free to run in parallel.

Daemon: move workspaceInspectRepoFunc + workspaceImportFunc onto the
Daemon struct (workspaceInspectRepo / workspaceImport), mirroring the
existing guestWaitForSSH / guestDial pattern. Workspace-prepare tests
drop t.Parallel() guards now that they no longer mutate process-wide
state.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Thales Maciel 2026-04-19 19:03:55 -03:00
parent d38f580e00
commit c42fcbe012
No known key found for this signature in database
GPG key ID: 33112E6833C34679
19 changed files with 664 additions and 733 deletions

View file

@ -1,125 +1,25 @@
package cli package cli
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io"
"os"
"os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
"time"
"banger/internal/api" "banger/internal/api"
"banger/internal/buildinfo" "banger/internal/buildinfo"
"banger/internal/daemon"
"banger/internal/guest"
"banger/internal/paths"
"banger/internal/rpc"
"banger/internal/toolingplan"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var ( // NewBangerCommand builds the top-level cobra tree with production
bangerdPathFunc = paths.BangerdPath // defaults wired into the dependency struct. Tests reach into the
daemonExePath = func(pid int) string { // package directly — see newRootCommand + defaultDeps.
return filepath.Join("/proc", fmt.Sprintf("%d", pid), "exe")
}
doctorFunc = daemon.Doctor
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
sshCmd := exec.CommandContext(ctx, "ssh", args...)
sshCmd.Stdout = stdout
sshCmd.Stderr = stderr
sshCmd.Stdin = stdin
return sshCmd.Run()
}
hostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) {
cmd := exec.CommandContext(ctx, name, args...)
output, err := cmd.CombinedOutput()
if err == nil {
return output, nil
}
command := strings.TrimSpace(strings.Join(append([]string{name}, args...), " "))
detail := strings.TrimSpace(string(output))
if detail == "" {
return output, fmt.Errorf("%s: %w", command, err)
}
return output, fmt.Errorf("%s: %w: %s", command, err, detail)
}
vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
return rpc.Call[api.VMHealthResult](ctx, socketPath, "vm.health", api.VMRefParams{IDOrName: idOrName})
}
vmSSHFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error) {
return rpc.Call[api.VMSSHResult](ctx, socketPath, "vm.ssh", api.VMRefParams{IDOrName: idOrName})
}
vmDeleteFunc = func(ctx context.Context, socketPath, idOrName string) error {
_, err := rpc.Call[api.VMShowResult](ctx, socketPath, "vm.delete", api.VMRefParams{IDOrName: idOrName})
return err
}
vmListFunc = func(ctx context.Context, socketPath string) (api.VMListResult, error) {
return rpc.Call[api.VMListResult](ctx, socketPath, "vm.list", api.Empty{})
}
daemonPingFunc = func(ctx context.Context, socketPath string) (api.PingResult, error) {
return rpc.Call[api.PingResult](ctx, socketPath, "ping", api.Empty{})
}
vmCreateBeginFunc = func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error) {
return rpc.Call[api.VMCreateBeginResult](ctx, socketPath, "vm.create.begin", params)
}
vmCreateStatusFunc = func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error) {
return rpc.Call[api.VMCreateStatusResult](ctx, socketPath, "vm.create.status", api.VMCreateStatusParams{ID: operationID})
}
vmCreateCancelFunc = func(ctx context.Context, socketPath, operationID string) error {
_, err := rpc.Call[api.Empty](ctx, socketPath, "vm.create.cancel", api.VMCreateStatusParams{ID: operationID})
return err
}
vmPortsFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) {
return rpc.Call[api.VMPortsResult](ctx, socketPath, "vm.ports", api.VMRefParams{IDOrName: idOrName})
}
vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
return rpc.Call[api.VMWorkspacePrepareResult](ctx, socketPath, "vm.workspace.prepare", params)
}
vmWorkspaceExportFunc = func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return rpc.Call[api.WorkspaceExportResult](ctx, socketPath, "vm.workspace.export", params)
}
guestSessionStartFunc = func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.start", params)
}
guestSessionGetFunc = func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.get", params)
}
guestSessionListFunc = func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error) {
return rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: idOrName})
}
guestSessionStopFunc = func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.stop", params)
}
guestSessionKillFunc = func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.kill", params)
}
guestSessionLogsFunc = func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) {
return rpc.Call[api.GuestSessionLogsResult](ctx, socketPath, "guest.session.logs", params)
}
guestSessionAttachBeginFunc = func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) {
return rpc.Call[api.GuestSessionAttachBeginResult](ctx, socketPath, "guest.session.attach.begin", params)
}
guestSessionSendFunc = func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params)
}
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
knownHosts, _ := bangerKnownHostsPath()
return guest.WaitForSSH(ctx, address, privateKeyPath, knownHosts, interval)
}
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
knownHosts, _ := bangerKnownHostsPath()
return guest.Dial(ctx, address, privateKeyPath, knownHosts)
}
buildVMRunToolingPlanFunc = toolingplan.Build
cwdFunc = os.Getwd
)
func NewBangerCommand() *cobra.Command { func NewBangerCommand() *cobra.Command {
return defaultDeps().newRootCommand()
}
func (d *deps) newRootCommand() *cobra.Command {
root := &cobra.Command{ root := &cobra.Command{
Use: "banger", Use: "banger",
Short: "Manage development VMs and images", Short: "Manage development VMs and images",
@ -127,17 +27,26 @@ func NewBangerCommand() *cobra.Command {
SilenceErrors: true, SilenceErrors: true,
RunE: helpNoArgs, RunE: helpNoArgs,
} }
root.AddCommand(newDaemonCommand(), newDoctorCommand(), newImageCommand(), newInternalCommand(), newKernelCommand(), newVersionCommand(), newPSCommand(), newVMCommand()) root.AddCommand(
d.newDaemonCommand(),
d.newDoctorCommand(),
d.newImageCommand(),
d.newInternalCommand(),
d.newKernelCommand(),
newVersionCommand(),
d.newPSCommand(),
d.newVMCommand(),
)
return root return root
} }
func newDoctorCommand() *cobra.Command { func (d *deps) newDoctorCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "doctor", Use: "doctor",
Short: "Check host and runtime readiness", Short: "Check host and runtime readiness",
Args: noArgsUsage("usage: banger doctor"), Args: noArgsUsage("usage: banger doctor"),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
report, err := doctorFunc(cmd.Context()) report, err := d.doctor(cmd.Context())
if err != nil { if err != nil {
return err return err
} }

View file

@ -121,11 +121,8 @@ func TestLegacyRemovedCommandIsRejected(t *testing.T) {
} }
func TestDoctorCommandPrintsReportAndFailsOnHardFailures(t *testing.T) { func TestDoctorCommandPrintsReportAndFailsOnHardFailures(t *testing.T) {
original := doctorFunc d := defaultDeps()
t.Cleanup(func() { d.doctor = func(context.Context) (system.Report, error) {
doctorFunc = original
})
doctorFunc = func(context.Context) (system.Report, error) {
return system.Report{ return system.Report{
Checks: []system.CheckResult{ Checks: []system.CheckResult{
{Name: "runtime bundle", Status: system.CheckStatusPass, Details: []string{"runtime dir /tmp/runtime"}}, {Name: "runtime bundle", Status: system.CheckStatusPass, Details: []string{"runtime dir /tmp/runtime"}},
@ -134,7 +131,7 @@ func TestDoctorCommandPrintsReportAndFailsOnHardFailures(t *testing.T) {
}, nil }, nil
} }
cmd := NewBangerCommand() cmd := d.newRootCommand()
var stdout bytes.Buffer var stdout bytes.Buffer
cmd.SetOut(&stdout) cmd.SetOut(&stdout)
cmd.SetErr(&stdout) cmd.SetErr(&stdout)
@ -154,15 +151,12 @@ func TestDoctorCommandPrintsReportAndFailsOnHardFailures(t *testing.T) {
} }
func TestDoctorCommandReturnsUnderlyingError(t *testing.T) { func TestDoctorCommandReturnsUnderlyingError(t *testing.T) {
original := doctorFunc d := defaultDeps()
t.Cleanup(func() { d.doctor = func(context.Context) (system.Report, error) {
doctorFunc = original
})
doctorFunc = func(context.Context) (system.Report, error) {
return system.Report{}, errors.New("load failed") return system.Report{}, errors.New("load failed")
} }
cmd := NewBangerCommand() cmd := d.newRootCommand()
cmd.SetArgs([]string{"doctor"}) cmd.SetArgs([]string{"doctor"})
err := cmd.Execute() err := cmd.Execute()
if err == nil || !strings.Contains(err.Error(), "load failed") { if err == nil || !strings.Contains(err.Error(), "load failed") {
@ -509,14 +503,7 @@ func TestVMCreateParamsFromFlagsRejectsNonPositiveCPUAndMemory(t *testing.T) {
} }
func TestRunVMCreatePollsUntilDone(t *testing.T) { func TestRunVMCreatePollsUntilDone(t *testing.T) {
origBegin := vmCreateBeginFunc d := defaultDeps()
origStatus := vmCreateStatusFunc
origCancel := vmCreateCancelFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
vmCreateStatusFunc = origStatus
vmCreateCancelFunc = origCancel
})
vm := model.VMRecord{ vm := model.VMRecord{
ID: "vm-id", ID: "vm-id",
@ -528,7 +515,7 @@ func TestRunVMCreatePollsUntilDone(t *testing.T) {
DNSName: "devbox.vm", DNSName: "devbox.vm",
}, },
} }
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{ return api.VMCreateBeginResult{
Operation: api.VMCreateOperation{ Operation: api.VMCreateOperation{
ID: "op-1", ID: "op-1",
@ -538,7 +525,7 @@ func TestRunVMCreatePollsUntilDone(t *testing.T) {
}, nil }, nil
} }
statusCalls := 0 statusCalls := 0
vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) { d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
statusCalls++ statusCalls++
if statusCalls == 1 { if statusCalls == 1 {
return api.VMCreateStatusResult{ return api.VMCreateStatusResult{
@ -560,14 +547,14 @@ func TestRunVMCreatePollsUntilDone(t *testing.T) {
}, },
}, nil }, nil
} }
vmCreateCancelFunc = func(context.Context, string, string) error { d.vmCreateCancel = func(context.Context, string, string) error {
t.Fatal("cancel should not be called") t.Fatal("cancel should not be called")
return nil return nil
} }
got, err := runVMCreate(context.Background(), "/tmp/bangerd.sock", &bytes.Buffer{}, api.VMCreateParams{Name: "devbox"}) got, err := d.runVMCreate(context.Background(), "/tmp/bangerd.sock", &bytes.Buffer{}, api.VMCreateParams{Name: "devbox"})
if err != nil { if err != nil {
t.Fatalf("runVMCreate: %v", err) t.Fatalf("d.runVMCreate: %v", err)
} }
if got.Name != vm.Name || got.Runtime.GuestIP != vm.Runtime.GuestIP { if got.Name != vm.Name || got.Runtime.GuestIP != vm.Runtime.GuestIP {
t.Fatalf("vm = %+v, want %+v", got, vm) t.Fatalf("vm = %+v, want %+v", got, vm)
@ -878,23 +865,18 @@ func TestPrintVMPortsTableSortsAndRendersURLEndpoints(t *testing.T) {
} }
func TestRunSSHSessionPrintsReminderWhenHealthCheckPasses(t *testing.T) { func TestRunSSHSessionPrintsReminderWhenHealthCheckPasses(t *testing.T) {
origSSHExec := sshExecFunc d := defaultDeps()
origHealth := vmHealthFunc
t.Cleanup(func() {
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
return nil return nil
} }
vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { d.vmHealth = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
return api.VMHealthResult{Name: "devbox", Healthy: true}, nil return api.VMHealthResult{Name: "devbox", Healthy: true}, nil
} }
var stderr bytes.Buffer var stderr bytes.Buffer
if err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false); err != nil { if err := d.runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false); err != nil {
t.Fatalf("runSSHSession: %v", err) t.Fatalf("d.runSSHSession: %v", err)
} }
if !strings.Contains(stderr.String(), "devbox is still running") { if !strings.Contains(stderr.String(), "devbox is still running") {
t.Fatalf("stderr = %q, want reminder", stderr.String()) t.Fatalf("stderr = %q, want reminder", stderr.String())
@ -902,25 +884,20 @@ func TestRunSSHSessionPrintsReminderWhenHealthCheckPasses(t *testing.T) {
} }
func TestRunSSHSessionPreservesSSHExitStatusOnHealthWarning(t *testing.T) { func TestRunSSHSessionPreservesSSHExitStatusOnHealthWarning(t *testing.T) {
origSSHExec := sshExecFunc d := defaultDeps()
origHealth := vmHealthFunc
t.Cleanup(func() {
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
return exitErrorWithCode(t, 1) return exitErrorWithCode(t, 1)
} }
vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { d.vmHealth = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
return api.VMHealthResult{}, errors.New("dial failed") return api.VMHealthResult{}, errors.New("dial failed")
} }
var stderr bytes.Buffer var stderr bytes.Buffer
err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false) err := d.runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false)
var exitErr *exec.ExitError var exitErr *exec.ExitError
if !errors.As(err, &exitErr) { if !errors.As(err, &exitErr) {
t.Fatalf("runSSHSession error = %v, want exit error", err) t.Fatalf("d.runSSHSession error = %v, want exit error", err)
} }
if !strings.Contains(stderr.String(), "failed to check whether devbox is still running") { if !strings.Contains(stderr.String(), "failed to check whether devbox is still running") {
t.Fatalf("stderr = %q, want warning", stderr.String()) t.Fatalf("stderr = %q, want warning", stderr.String())
@ -928,27 +905,22 @@ func TestRunSSHSessionPreservesSSHExitStatusOnHealthWarning(t *testing.T) {
} }
func TestRunSSHSessionSkipsReminderOnSSHAuthFailure(t *testing.T) { func TestRunSSHSessionSkipsReminderOnSSHAuthFailure(t *testing.T) {
origSSHExec := sshExecFunc d := defaultDeps()
origHealth := vmHealthFunc
t.Cleanup(func() {
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
healthCalled := false healthCalled := false
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
return exitErrorWithCode(t, 255) return exitErrorWithCode(t, 255)
} }
vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { d.vmHealth = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
healthCalled = true healthCalled = true
return api.VMHealthResult{Name: "devbox", Healthy: true}, nil return api.VMHealthResult{Name: "devbox", Healthy: true}, nil
} }
var stderr bytes.Buffer var stderr bytes.Buffer
err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false) err := d.runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}, false)
var exitErr *exec.ExitError var exitErr *exec.ExitError
if !errors.As(err, &exitErr) || exitErr.ExitCode() != 255 { if !errors.As(err, &exitErr) || exitErr.ExitCode() != 255 {
t.Fatalf("runSSHSession error = %v, want exit 255", err) t.Fatalf("d.runSSHSession error = %v, want exit 255", err)
} }
if healthCalled { if healthCalled {
t.Fatal("vm health should not run after ssh auth failure") t.Fatal("vm health should not run after ssh auth failure")
@ -1141,6 +1113,7 @@ func TestValidateSSHPrereqsFailsForMissingKey(t *testing.T) {
// gets a fast error instead of an orphaned VM. // gets a fast error instead of an orphaned VM.
func TestVMRunPreflightRejectsSubmodules(t *testing.T) { func TestVMRunPreflightRejectsSubmodules(t *testing.T) {
d := defaultDeps()
repoRoot := t.TempDir() repoRoot := t.TempDir()
origHostCommandOutput := workspace.HostCommandOutputFunc origHostCommandOutput := workspace.HostCommandOutputFunc
@ -1166,36 +1139,16 @@ func TestVMRunPreflightRejectsSubmodules(t *testing.T) {
} }
} }
_, err := vmRunPreflightRepo(context.Background(), repoRoot) _, err := d.vmRunPreflightRepo(context.Background(), repoRoot)
if err == nil || !strings.Contains(err.Error(), "submodules") { if err == nil || !strings.Contains(err.Error(), "submodules") {
t.Fatalf("vmRunPreflightRepo() error = %v, want submodule rejection", err) t.Fatalf("d.vmRunPreflightRepo() error = %v, want submodule rejection", err)
} }
} }
func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) { func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) {
d := defaultDeps()
repoRoot := t.TempDir() repoRoot := t.TempDir()
origBegin := vmCreateBeginFunc
origStatus := vmCreateStatusFunc
origCancel := vmCreateCancelFunc
origWaitForSSH := guestWaitForSSHFunc
origGuestDial := guestDialFunc
origBuildVMRunToolingPlan := buildVMRunToolingPlanFunc
origVMWorkspacePrepare := vmWorkspacePrepareFunc
origSSHExec := sshExecFunc
origHealth := vmHealthFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
vmCreateStatusFunc = origStatus
vmCreateCancelFunc = origCancel
guestWaitForSSHFunc = origWaitForSSH
guestDialFunc = origGuestDial
buildVMRunToolingPlanFunc = origBuildVMRunToolingPlan
vmWorkspacePrepareFunc = origVMWorkspacePrepare
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
vm := model.VMRecord{ vm := model.VMRecord{
ID: "vm-id", ID: "vm-id",
Name: "devbox", Name: "devbox",
@ -1205,7 +1158,7 @@ func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) {
DNSName: "devbox.vm", DNSName: "devbox.vm",
}, },
} }
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{ return api.VMCreateBeginResult{
Operation: api.VMCreateOperation{ Operation: api.VMCreateOperation{
ID: "op-1", Stage: "ready", Detail: "vm is ready", ID: "op-1", Stage: "ready", Detail: "vm is ready",
@ -1213,45 +1166,45 @@ func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) {
}, },
}, nil }, nil
} }
vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) { d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
t.Fatal("vmCreateStatusFunc should not be called") t.Fatal("d.vmCreateStatus should not be called")
return api.VMCreateStatusResult{}, nil return api.VMCreateStatusResult{}, nil
} }
vmCreateCancelFunc = func(context.Context, string, string) error { d.vmCreateCancel = func(context.Context, string, string) error {
t.Fatal("vmCreateCancelFunc should not be called") t.Fatal("d.vmCreateCancel should not be called")
return nil return nil
} }
fakeClient := &testVMRunGuestClient{} fakeClient := &testVMRunGuestClient{}
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { d.guestWaitForSSH = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
return nil return nil
} }
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { d.guestDial = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
return fakeClient, nil return fakeClient, nil
} }
var workspaceParams api.VMWorkspacePrepareParams var workspaceParams api.VMWorkspacePrepareParams
vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { d.vmWorkspacePrepare = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
workspaceParams = params workspaceParams = params
return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil
} }
buildVMRunToolingPlanFunc = func(context.Context, string) toolingplan.Plan { d.buildVMRunToolingPlan = func(context.Context, string) toolingplan.Plan {
return toolingplan.Plan{ return toolingplan.Plan{
RepoManagedTools: []string{"go"}, RepoManagedTools: []string{"go"},
Steps: []toolingplan.InstallStep{{Tool: "go", Version: "1.25.0", Source: "go.mod"}}, Steps: []toolingplan.InstallStep{{Tool: "go", Version: "1.25.0", Source: "go.mod"}},
} }
} }
var sshArgsSeen []string var sshArgsSeen []string
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { d.sshExec = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
sshArgsSeen = args sshArgsSeen = args
return nil return nil
} }
vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) { d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Name: "devbox", Healthy: false}, nil return api.VMHealthResult{Name: "devbox", Healthy: false}, nil
} }
repo := vmRunRepo{sourcePath: repoRoot} repo := vmRunRepo{sourcePath: repoRoot}
var stdout, stderr bytes.Buffer var stdout, stderr bytes.Buffer
err := runVMRun( err := d.runVMRun(
context.Background(), context.Background(),
"/tmp/bangerd.sock", "/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1263,7 +1216,7 @@ func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) {
false, false,
) )
if err != nil { if err != nil {
t.Fatalf("runVMRun: %v", err) t.Fatalf("d.runVMRun: %v", err)
} }
if workspaceParams.IDOrName != "devbox" || workspaceParams.SourcePath != repoRoot { if workspaceParams.IDOrName != "devbox" || workspaceParams.SourcePath != repoRoot {
t.Fatalf("workspaceParams = %+v", workspaceParams) t.Fatalf("workspaceParams = %+v", workspaceParams)
@ -1283,24 +1236,7 @@ func TestRunVMRunWorkspacePreparesAndAttaches(t *testing.T) {
} }
func TestVMRunPrintsPostCreateProgress(t *testing.T) { func TestVMRunPrintsPostCreateProgress(t *testing.T) {
origBegin := vmCreateBeginFunc d := defaultDeps()
origStatus := vmCreateStatusFunc
origCancel := vmCreateCancelFunc
origWaitForSSH := guestWaitForSSHFunc
origGuestDial := guestDialFunc
origVMWorkspacePrepare := vmWorkspacePrepareFunc
origSSHExec := sshExecFunc
origHealth := vmHealthFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
vmCreateStatusFunc = origStatus
vmCreateCancelFunc = origCancel
guestWaitForSSHFunc = origWaitForSSH
guestDialFunc = origGuestDial
vmWorkspacePrepareFunc = origVMWorkspacePrepare
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
vm := model.VMRecord{ vm := model.VMRecord{
ID: "vm-id", ID: "vm-id",
@ -1310,7 +1246,7 @@ func TestVMRunPrintsPostCreateProgress(t *testing.T) {
GuestIP: "172.16.0.2", GuestIP: "172.16.0.2",
}, },
} }
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{ return api.VMCreateBeginResult{
Operation: api.VMCreateOperation{ Operation: api.VMCreateOperation{
ID: "op-1", Stage: "ready", Detail: "vm is ready", ID: "op-1", Stage: "ready", Detail: "vm is ready",
@ -1318,33 +1254,33 @@ func TestVMRunPrintsPostCreateProgress(t *testing.T) {
}, },
}, nil }, nil
} }
vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) { d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
t.Fatal("vmCreateStatusFunc should not be called") t.Fatal("d.vmCreateStatus should not be called")
return api.VMCreateStatusResult{}, nil return api.VMCreateStatusResult{}, nil
} }
vmCreateCancelFunc = func(context.Context, string, string) error { d.vmCreateCancel = func(context.Context, string, string) error {
t.Fatal("vmCreateCancelFunc should not be called") t.Fatal("d.vmCreateCancel should not be called")
return nil return nil
} }
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { d.guestWaitForSSH = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
return nil return nil
} }
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { d.guestDial = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
return &testVMRunGuestClient{}, nil return &testVMRunGuestClient{}, nil
} }
vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { d.vmWorkspacePrepare = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil
} }
sshExecFunc = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error {
return nil return nil
} }
vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) { d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Name: "devbox", Healthy: false}, nil return api.VMHealthResult{Name: "devbox", Healthy: false}, nil
} }
repo := vmRunRepo{sourcePath: t.TempDir()} repo := vmRunRepo{sourcePath: t.TempDir()}
var stdout, stderr bytes.Buffer var stdout, stderr bytes.Buffer
err := runVMRun( err := d.runVMRun(
context.Background(), context.Background(),
"/tmp/bangerd.sock", "/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1356,7 +1292,7 @@ func TestVMRunPrintsPostCreateProgress(t *testing.T) {
false, false,
) )
if err != nil { if err != nil {
t.Fatalf("runVMRun: %v", err) t.Fatalf("d.runVMRun: %v", err)
} }
output := stderr.String() output := stderr.String()
@ -1377,24 +1313,7 @@ func TestVMRunPrintsPostCreateProgress(t *testing.T) {
} }
func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) { func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) {
origBegin := vmCreateBeginFunc d := defaultDeps()
origStatus := vmCreateStatusFunc
origCancel := vmCreateCancelFunc
origWaitForSSH := guestWaitForSSHFunc
origGuestDial := guestDialFunc
origVMWorkspacePrepare := vmWorkspacePrepareFunc
origSSHExec := sshExecFunc
origHealth := vmHealthFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
vmCreateStatusFunc = origStatus
vmCreateCancelFunc = origCancel
guestWaitForSSHFunc = origWaitForSSH
guestDialFunc = origGuestDial
vmWorkspacePrepareFunc = origVMWorkspacePrepare
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
vm := model.VMRecord{ vm := model.VMRecord{
ID: "vm-id", ID: "vm-id",
@ -1404,39 +1323,39 @@ func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) {
GuestIP: "172.16.0.2", GuestIP: "172.16.0.2",
}, },
} }
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Detail: "vm is ready", Done: true, Success: true, VM: &vm}}, nil return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Detail: "vm is ready", Done: true, Success: true, VM: &vm}}, nil
} }
vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) { d.vmCreateStatus = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
t.Fatal("vmCreateStatusFunc should not be called") t.Fatal("d.vmCreateStatus should not be called")
return api.VMCreateStatusResult{}, nil return api.VMCreateStatusResult{}, nil
} }
vmCreateCancelFunc = func(context.Context, string, string) error { d.vmCreateCancel = func(context.Context, string, string) error {
t.Fatal("vmCreateCancelFunc should not be called") t.Fatal("d.vmCreateCancel should not be called")
return nil return nil
} }
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { d.guestWaitForSSH = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
return nil return nil
} }
fakeClient := &testVMRunGuestClient{launchErr: errors.New("launch failed")} fakeClient := &testVMRunGuestClient{launchErr: errors.New("launch failed")}
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { d.guestDial = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
return fakeClient, nil return fakeClient, nil
} }
vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { d.vmWorkspacePrepare = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil return api.VMWorkspacePrepareResult{Workspace: model.WorkspacePrepareResult{VMID: vm.ID, GuestPath: "/root/repo", RepoName: "repo", RepoRoot: "/tmp/repo"}}, nil
} }
sshExecCalls := 0 sshExecCalls := 0
sshExecFunc = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error {
sshExecCalls++ sshExecCalls++
return nil return nil
} }
vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) { d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Healthy: false}, nil return api.VMHealthResult{Healthy: false}, nil
} }
repo := vmRunRepo{sourcePath: t.TempDir()} repo := vmRunRepo{sourcePath: t.TempDir()}
var stdout, stderr bytes.Buffer var stdout, stderr bytes.Buffer
err := runVMRun( err := d.runVMRun(
context.Background(), context.Background(),
"/tmp/bangerd.sock", "/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1448,7 +1367,7 @@ func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) {
false, false,
) )
if err != nil { if err != nil {
t.Fatalf("runVMRun: %v", err) t.Fatalf("d.runVMRun: %v", err)
} }
if !strings.Contains(stderr.String(), "[vm run] warning: guest tooling bootstrap start failed: launch guest tooling bootstrap") { if !strings.Contains(stderr.String(), "[vm run] warning: guest tooling bootstrap start failed: launch guest tooling bootstrap") {
t.Fatalf("stderr = %q, want tooling bootstrap warning", stderr.String()) t.Fatalf("stderr = %q, want tooling bootstrap warning", stderr.String())
@ -1459,48 +1378,35 @@ func TestRunVMRunWarnsWhenToolingHarnessStartFails(t *testing.T) {
} }
func TestRunVMRunBareModeSkipsWorkspaceAndTooling(t *testing.T) { func TestRunVMRunBareModeSkipsWorkspaceAndTooling(t *testing.T) {
origBegin := vmCreateBeginFunc d := defaultDeps()
origWaitForSSH := guestWaitForSSHFunc
origGuestDial := guestDialFunc
origVMWorkspacePrepare := vmWorkspacePrepareFunc
origSSHExec := sshExecFunc
origHealth := vmHealthFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
guestWaitForSSHFunc = origWaitForSSH
guestDialFunc = origGuestDial
vmWorkspacePrepareFunc = origVMWorkspacePrepare
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
})
vm := model.VMRecord{ vm := model.VMRecord{
ID: "vm-id", Name: "bare", ID: "vm-id", Name: "bare",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"}, Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
} }
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
} }
guestWaitForSSHFunc = func(context.Context, string, string, time.Duration) error { return nil } d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
guestDialFunc = func(context.Context, string, string) (vmRunGuestClient, error) { d.guestDial = func(context.Context, string, string) (vmRunGuestClient, error) {
t.Fatal("guestDialFunc should not be called in bare mode") t.Fatal("d.guestDial should not be called in bare mode")
return nil, nil return nil, nil
} }
vmWorkspacePrepareFunc = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { d.vmWorkspacePrepare = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
t.Fatal("vmWorkspacePrepareFunc should not be called in bare mode") t.Fatal("d.vmWorkspacePrepare should not be called in bare mode")
return api.VMWorkspacePrepareResult{}, nil return api.VMWorkspacePrepareResult{}, nil
} }
sshExecCalls := 0 sshExecCalls := 0
sshExecFunc = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error {
sshExecCalls++ sshExecCalls++
return nil return nil
} }
vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) { d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Healthy: false}, nil return api.VMHealthResult{Healthy: false}, nil
} }
var stdout, stderr bytes.Buffer var stdout, stderr bytes.Buffer
err := runVMRun( err := d.runVMRun(
context.Background(), context.Background(),
"/tmp/bangerd.sock", "/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1512,7 +1418,7 @@ func TestRunVMRunBareModeSkipsWorkspaceAndTooling(t *testing.T) {
false, false,
) )
if err != nil { if err != nil {
t.Fatalf("runVMRun: %v", err) t.Fatalf("d.runVMRun: %v", err)
} }
if sshExecCalls != 1 { if sshExecCalls != 1 {
t.Fatalf("sshExec calls = %d, want 1", sshExecCalls) t.Fatalf("sshExec calls = %d, want 1", sshExecCalls)
@ -1523,39 +1429,28 @@ func TestRunVMRunBareModeSkipsWorkspaceAndTooling(t *testing.T) {
} }
func TestRunVMRunRMDeletesAfterSessionExits(t *testing.T) { func TestRunVMRunRMDeletesAfterSessionExits(t *testing.T) {
origBegin := vmCreateBeginFunc d := defaultDeps()
origWaitForSSH := guestWaitForSSHFunc
origSSHExec := sshExecFunc
origHealth := vmHealthFunc
origDelete := vmDeleteFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
guestWaitForSSHFunc = origWaitForSSH
sshExecFunc = origSSHExec
vmHealthFunc = origHealth
vmDeleteFunc = origDelete
})
vm := model.VMRecord{ vm := model.VMRecord{
ID: "vm-id", Name: "tmpbox", ID: "vm-id", Name: "tmpbox",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"}, Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
} }
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
} }
guestWaitForSSHFunc = func(context.Context, string, string, time.Duration) error { return nil } d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
sshExecFunc = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { return nil } d.sshExec = func(context.Context, io.Reader, io.Writer, io.Writer, []string) error { return nil }
vmHealthFunc = func(context.Context, string, string) (api.VMHealthResult, error) { d.vmHealth = func(context.Context, string, string) (api.VMHealthResult, error) {
return api.VMHealthResult{Healthy: false}, nil return api.VMHealthResult{Healthy: false}, nil
} }
deletedRef := "" deletedRef := ""
vmDeleteFunc = func(_ context.Context, _, idOrName string) error { d.vmDelete = func(_ context.Context, _, idOrName string) error {
deletedRef = idOrName deletedRef = idOrName
return nil return nil
} }
var stdout, stderr bytes.Buffer var stdout, stderr bytes.Buffer
err := runVMRun( err := d.runVMRun(
context.Background(), context.Background(),
"/tmp/bangerd.sock", "/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1567,7 +1462,7 @@ func TestRunVMRunRMDeletesAfterSessionExits(t *testing.T) {
true, // --rm true, // --rm
) )
if err != nil { if err != nil {
t.Fatalf("runVMRun: %v", err) t.Fatalf("d.runVMRun: %v", err)
} }
if deletedRef != "tmpbox" { if deletedRef != "tmpbox" {
t.Fatalf("deletedRef = %q, want tmpbox", deletedRef) t.Fatalf("deletedRef = %q, want tmpbox", deletedRef)
@ -1580,15 +1475,10 @@ func TestRunVMRunRMDeletesAfterSessionExits(t *testing.T) {
} }
func TestRunVMRunRMSkipsDeleteOnSSHWaitTimeout(t *testing.T) { func TestRunVMRunRMSkipsDeleteOnSSHWaitTimeout(t *testing.T) {
origBegin := vmCreateBeginFunc d := defaultDeps()
origWaitForSSH := guestWaitForSSHFunc
origDelete := vmDeleteFunc
origTimeout := vmRunSSHTimeout origTimeout := vmRunSSHTimeout
vmRunSSHTimeout = 50 * time.Millisecond vmRunSSHTimeout = 50 * time.Millisecond
t.Cleanup(func() { t.Cleanup(func() {
vmCreateBeginFunc = origBegin
guestWaitForSSHFunc = origWaitForSSH
vmDeleteFunc = origDelete
vmRunSSHTimeout = origTimeout vmRunSSHTimeout = origTimeout
}) })
@ -1596,21 +1486,21 @@ func TestRunVMRunRMSkipsDeleteOnSSHWaitTimeout(t *testing.T) {
ID: "vm-id", Name: "slowvm", ID: "vm-id", Name: "slowvm",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"}, Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
} }
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
} }
guestWaitForSSHFunc = func(ctx context.Context, _, _ string, _ time.Duration) error { d.guestWaitForSSH = func(ctx context.Context, _, _ string, _ time.Duration) error {
<-ctx.Done() <-ctx.Done()
return ctx.Err() return ctx.Err()
} }
deleteCalled := false deleteCalled := false
vmDeleteFunc = func(context.Context, string, string) error { d.vmDelete = func(context.Context, string, string) error {
deleteCalled = true deleteCalled = true
return nil return nil
} }
var stdout, stderr bytes.Buffer var stdout, stderr bytes.Buffer
err := runVMRun( err := d.runVMRun(
context.Background(), context.Background(),
"/tmp/bangerd.sock", "/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1630,13 +1520,10 @@ func TestRunVMRunRMSkipsDeleteOnSSHWaitTimeout(t *testing.T) {
} }
func TestRunVMRunSSHTimeoutReturnsActionableError(t *testing.T) { func TestRunVMRunSSHTimeoutReturnsActionableError(t *testing.T) {
origBegin := vmCreateBeginFunc d := defaultDeps()
origWaitForSSH := guestWaitForSSHFunc
origTimeout := vmRunSSHTimeout origTimeout := vmRunSSHTimeout
vmRunSSHTimeout = 50 * time.Millisecond vmRunSSHTimeout = 50 * time.Millisecond
t.Cleanup(func() { t.Cleanup(func() {
vmCreateBeginFunc = origBegin
guestWaitForSSHFunc = origWaitForSSH
vmRunSSHTimeout = origTimeout vmRunSSHTimeout = origTimeout
}) })
@ -1644,18 +1531,18 @@ func TestRunVMRunSSHTimeoutReturnsActionableError(t *testing.T) {
ID: "vm-id", Name: "slowvm", ID: "vm-id", Name: "slowvm",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"}, Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
} }
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
} }
// Simulate the guest never bringing sshd up — the wait-for-ssh // Simulate the guest never bringing sshd up — the wait-for-ssh
// child context fires its deadline, returning a DeadlineExceeded. // child context fires its deadline, returning a DeadlineExceeded.
guestWaitForSSHFunc = func(ctx context.Context, _, _ string, _ time.Duration) error { d.guestWaitForSSH = func(ctx context.Context, _, _ string, _ time.Duration) error {
<-ctx.Done() <-ctx.Done()
return ctx.Err() return ctx.Err()
} }
var stdout, stderr bytes.Buffer var stdout, stderr bytes.Buffer
err := runVMRun( err := d.runVMRun(
context.Background(), context.Background(),
"/tmp/bangerd.sock", "/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1683,37 +1570,28 @@ func TestRunVMRunSSHTimeoutReturnsActionableError(t *testing.T) {
} }
func TestRunVMRunCommandModePropagatesExitCode(t *testing.T) { func TestRunVMRunCommandModePropagatesExitCode(t *testing.T) {
origBegin := vmCreateBeginFunc d := defaultDeps()
origWaitForSSH := guestWaitForSSHFunc
origVMWorkspacePrepare := vmWorkspacePrepareFunc
origSSHExec := sshExecFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
guestWaitForSSHFunc = origWaitForSSH
vmWorkspacePrepareFunc = origVMWorkspacePrepare
sshExecFunc = origSSHExec
})
vm := model.VMRecord{ vm := model.VMRecord{
ID: "vm-id", Name: "cmdbox", ID: "vm-id", Name: "cmdbox",
Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"}, Runtime: model.VMRuntime{State: model.VMStateRunning, GuestIP: "172.16.0.2"},
} }
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { d.vmCreateBegin = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil return api.VMCreateBeginResult{Operation: api.VMCreateOperation{ID: "op-1", Stage: "ready", Done: true, Success: true, VM: &vm}}, nil
} }
guestWaitForSSHFunc = func(context.Context, string, string, time.Duration) error { return nil } d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
vmWorkspacePrepareFunc = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { d.vmWorkspacePrepare = func(context.Context, string, api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
t.Fatal("workspace prepare should not run without spec") t.Fatal("workspace prepare should not run without spec")
return api.VMWorkspacePrepareResult{}, nil return api.VMWorkspacePrepareResult{}, nil
} }
var sshArgsSeen []string var sshArgsSeen []string
sshExecFunc = func(_ context.Context, _ io.Reader, _, _ io.Writer, args []string) error { d.sshExec = func(_ context.Context, _ io.Reader, _, _ io.Writer, args []string) error {
sshArgsSeen = args sshArgsSeen = args
return exitErrorWithCode(t, 7) return exitErrorWithCode(t, 7)
} }
var stdout, stderr bytes.Buffer var stdout, stderr bytes.Buffer
err := runVMRun( err := d.runVMRun(
context.Background(), context.Background(),
"/tmp/bangerd.sock", "/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
@ -1726,7 +1604,7 @@ func TestRunVMRunCommandModePropagatesExitCode(t *testing.T) {
) )
var exitErr ExitCodeError var exitErr ExitCodeError
if !errors.As(err, &exitErr) || exitErr.Code != 7 { if !errors.As(err, &exitErr) || exitErr.Code != 7 {
t.Fatalf("runVMRun error = %v, want ExitCodeError{7}", err) t.Fatalf("d.runVMRun error = %v, want ExitCodeError{7}", err)
} }
if len(sshArgsSeen) == 0 || sshArgsSeen[len(sshArgsSeen)-1] != "false" { if len(sshArgsSeen) == 0 || sshArgsSeen[len(sshArgsSeen)-1] != "false" {
t.Fatalf("sshArgsSeen = %v, want trailing command 'false'", sshArgsSeen) t.Fatalf("sshArgsSeen = %v, want trailing command 'false'", sshArgsSeen)
@ -1843,6 +1721,7 @@ func TestNewBangerdCommandRejectsArgs(t *testing.T) {
} }
func TestDaemonOutdated(t *testing.T) { func TestDaemonOutdated(t *testing.T) {
d := defaultDeps()
dir := t.TempDir() dir := t.TempDir()
current := filepath.Join(dir, "bangerd-current") current := filepath.Join(dir, "bangerd-current")
same := filepath.Join(dir, "bangerd-same") same := filepath.Join(dir, "bangerd-same")
@ -1857,27 +1736,20 @@ func TestDaemonOutdated(t *testing.T) {
t.Fatalf("write stale: %v", err) t.Fatalf("write stale: %v", err)
} }
origBangerdPath := bangerdPathFunc d.bangerdPath = func() (string, error) {
origDaemonExePath := daemonExePath
t.Cleanup(func() {
bangerdPathFunc = origBangerdPath
daemonExePath = origDaemonExePath
})
bangerdPathFunc = func() (string, error) {
return current, nil return current, nil
} }
daemonExePath = func(pid int) string { d.daemonExePath = func(pid int) string {
if pid == 1 { if pid == 1 {
return same return same
} }
return stale return stale
} }
if daemonOutdated(1) { if d.daemonOutdated(1) {
t.Fatal("expected matching daemon executable to be current") t.Fatal("expected matching daemon executable to be current")
} }
if !daemonOutdated(2) { if !d.daemonOutdated(2) {
t.Fatal("expected replaced daemon executable to be outdated") t.Fatal("expected replaced daemon executable to be outdated")
} }
} }
@ -1912,10 +1784,7 @@ func TestDaemonStatusIncludesLogPathWhenStopped(t *testing.T) {
} }
func TestDaemonStatusIncludesDaemonBuildInfoWhenRunning(t *testing.T) { func TestDaemonStatusIncludesDaemonBuildInfoWhenRunning(t *testing.T) {
origDaemonPing := daemonPingFunc d := defaultDeps()
t.Cleanup(func() {
daemonPingFunc = origDaemonPing
})
configHome := filepath.Join(t.TempDir(), "config") configHome := filepath.Join(t.TempDir(), "config")
stateHome := filepath.Join(t.TempDir(), "state") stateHome := filepath.Join(t.TempDir(), "state")
@ -1924,7 +1793,7 @@ func TestDaemonStatusIncludesDaemonBuildInfoWhenRunning(t *testing.T) {
t.Setenv("XDG_STATE_HOME", stateHome) t.Setenv("XDG_STATE_HOME", stateHome)
t.Setenv("XDG_RUNTIME_DIR", runtimeHome) t.Setenv("XDG_RUNTIME_DIR", runtimeHome)
daemonPingFunc = func(context.Context, string) (api.PingResult, error) { d.daemonPing = func(context.Context, string) (api.PingResult, error) {
return api.PingResult{ return api.PingResult{
Status: "ok", Status: "ok",
PID: 42, PID: 42,
@ -1934,7 +1803,7 @@ func TestDaemonStatusIncludesDaemonBuildInfoWhenRunning(t *testing.T) {
}, nil }, nil
} }
cmd := NewBangerCommand() cmd := d.newRootCommand()
var stdout bytes.Buffer var stdout bytes.Buffer
cmd.SetOut(&stdout) cmd.SetOut(&stdout)
cmd.SetErr(&stdout) cmd.SetErr(&stdout)
@ -2073,26 +1942,26 @@ func TestVMSessionSendRejectsWrongArgCount(t *testing.T) {
} }
} }
func stubEnsureDaemonForSend(t *testing.T) { // stubEnsureDaemonForSend isolates XDG dirs and installs a daemon-ping
// fake onto the caller's *deps so `ensureDaemon` short-circuits without
// trying to spawn bangerd. `vm session send` uses this to avoid needing
// a built binary on disk.
func stubEnsureDaemonForSend(t *testing.T, d *deps) {
t.Helper() t.Helper()
t.Setenv("XDG_CONFIG_HOME", filepath.Join(t.TempDir(), "config")) t.Setenv("XDG_CONFIG_HOME", filepath.Join(t.TempDir(), "config"))
t.Setenv("XDG_STATE_HOME", filepath.Join(t.TempDir(), "state")) t.Setenv("XDG_STATE_HOME", filepath.Join(t.TempDir(), "state"))
t.Setenv("XDG_RUNTIME_DIR", filepath.Join(t.TempDir(), "run")) t.Setenv("XDG_RUNTIME_DIR", filepath.Join(t.TempDir(), "run"))
origPing := daemonPingFunc d.daemonPing = func(context.Context, string) (api.PingResult, error) {
t.Cleanup(func() { daemonPingFunc = origPing })
daemonPingFunc = func(context.Context, string) (api.PingResult, error) {
return api.PingResult{Status: "ok", PID: os.Getpid()}, nil return api.PingResult{Status: "ok", PID: os.Getpid()}, nil
} }
} }
func TestVMSessionSendWithMessageFlag(t *testing.T) { func TestVMSessionSendWithMessageFlag(t *testing.T) {
stubEnsureDaemonForSend(t) d := defaultDeps()
stubEnsureDaemonForSend(t, d)
original := guestSessionSendFunc
t.Cleanup(func() { guestSessionSendFunc = original })
var capturedParams api.GuestSessionSendParams var capturedParams api.GuestSessionSendParams
guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
capturedParams = params capturedParams = params
return api.GuestSessionSendResult{ return api.GuestSessionSendResult{
Session: model.GuestSession{ID: "sess-id", Name: "planner"}, Session: model.GuestSession{ID: "sess-id", Name: "planner"},
@ -2100,7 +1969,7 @@ func TestVMSessionSendWithMessageFlag(t *testing.T) {
}, nil }, nil
} }
cmd := NewBangerCommand() cmd := d.newRootCommand()
var out bytes.Buffer var out bytes.Buffer
cmd.SetOut(&out) cmd.SetOut(&out)
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner", "--message", `{"type":"abort"}`}) cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner", "--message", `{"type":"abort"}`})
@ -2124,13 +1993,11 @@ func TestVMSessionSendWithMessageFlag(t *testing.T) {
} }
func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) { func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) {
stubEnsureDaemonForSend(t) d := defaultDeps()
stubEnsureDaemonForSend(t, d)
original := guestSessionSendFunc
t.Cleanup(func() { guestSessionSendFunc = original })
var capturedPayload []byte var capturedPayload []byte
guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
capturedPayload = params.Payload capturedPayload = params.Payload
return api.GuestSessionSendResult{ return api.GuestSessionSendResult{
Session: model.GuestSession{Name: "s"}, Session: model.GuestSession{Name: "s"},
@ -2138,7 +2005,7 @@ func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) {
}, nil }, nil
} }
cmd := NewBangerCommand() cmd := d.newRootCommand()
cmd.SetOut(io.Discard) cmd.SetOut(io.Discard)
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "s", "--message", "{\"type\":\"abort\"}\n"}) cmd.SetArgs([]string{"vm", "session", "send", "devbox", "s", "--message", "{\"type\":\"abort\"}\n"})
if err := cmd.Execute(); err != nil { if err := cmd.Execute(); err != nil {
@ -2155,13 +2022,11 @@ func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) {
} }
func TestVMSessionSendFromStdin(t *testing.T) { func TestVMSessionSendFromStdin(t *testing.T) {
stubEnsureDaemonForSend(t) d := defaultDeps()
stubEnsureDaemonForSend(t, d)
original := guestSessionSendFunc
t.Cleanup(func() { guestSessionSendFunc = original })
var capturedPayload []byte var capturedPayload []byte
guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
capturedPayload = params.Payload capturedPayload = params.Payload
return api.GuestSessionSendResult{ return api.GuestSessionSendResult{
Session: model.GuestSession{Name: "planner"}, Session: model.GuestSession{Name: "planner"},
@ -2170,7 +2035,7 @@ func TestVMSessionSendFromStdin(t *testing.T) {
} }
stdinPayload := `{"type":"steer","message":"Focus on src/"}` + "\n" stdinPayload := `{"type":"steer","message":"Focus on src/"}` + "\n"
cmd := NewBangerCommand() cmd := d.newRootCommand()
cmd.SetOut(io.Discard) cmd.SetOut(io.Discard)
cmd.SetIn(strings.NewReader(stdinPayload)) cmd.SetIn(strings.NewReader(stdinPayload))
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner"}) cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner"})
@ -2208,13 +2073,11 @@ func TestVMWorkspaceExportRejectsMissingArg(t *testing.T) {
} }
func TestVMWorkspaceExportWritesToStdout(t *testing.T) { func TestVMWorkspaceExportWritesToStdout(t *testing.T) {
stubEnsureDaemonForSend(t) d := defaultDeps()
stubEnsureDaemonForSend(t, d)
origExport := vmWorkspaceExportFunc
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
patch := []byte("diff --git a/main.go b/main.go\nindex 0000000..1111111 100644\n") patch := []byte("diff --git a/main.go b/main.go\nindex 0000000..1111111 100644\n")
vmWorkspaceExportFunc = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { d.vmWorkspaceExport = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return api.WorkspaceExportResult{ return api.WorkspaceExportResult{
GuestPath: params.GuestPath, GuestPath: params.GuestPath,
Patch: patch, Patch: patch,
@ -2223,7 +2086,7 @@ func TestVMWorkspaceExportWritesToStdout(t *testing.T) {
}, nil }, nil
} }
cmd := NewBangerCommand() cmd := d.newRootCommand()
var out bytes.Buffer var out bytes.Buffer
cmd.SetOut(&out) cmd.SetOut(&out)
cmd.SetErr(io.Discard) cmd.SetErr(io.Discard)
@ -2237,13 +2100,11 @@ func TestVMWorkspaceExportWritesToStdout(t *testing.T) {
} }
func TestVMWorkspaceExportWritesToFile(t *testing.T) { func TestVMWorkspaceExportWritesToFile(t *testing.T) {
stubEnsureDaemonForSend(t) d := defaultDeps()
stubEnsureDaemonForSend(t, d)
origExport := vmWorkspaceExportFunc
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
patch := []byte("diff --git a/main.go b/main.go\n") patch := []byte("diff --git a/main.go b/main.go\n")
vmWorkspaceExportFunc = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { d.vmWorkspaceExport = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return api.WorkspaceExportResult{ return api.WorkspaceExportResult{
GuestPath: "/root/repo", GuestPath: "/root/repo",
Patch: patch, Patch: patch,
@ -2253,7 +2114,7 @@ func TestVMWorkspaceExportWritesToFile(t *testing.T) {
} }
outFile := filepath.Join(t.TempDir(), "worker.diff") outFile := filepath.Join(t.TempDir(), "worker.diff")
cmd := NewBangerCommand() cmd := d.newRootCommand()
cmd.SetOut(io.Discard) cmd.SetOut(io.Discard)
var stderr bytes.Buffer var stderr bytes.Buffer
cmd.SetErr(&stderr) cmd.SetErr(&stderr)
@ -2275,19 +2136,17 @@ func TestVMWorkspaceExportWritesToFile(t *testing.T) {
} }
func TestVMWorkspaceExportNoChanges(t *testing.T) { func TestVMWorkspaceExportNoChanges(t *testing.T) {
stubEnsureDaemonForSend(t) d := defaultDeps()
stubEnsureDaemonForSend(t, d)
origExport := vmWorkspaceExportFunc d.vmWorkspaceExport = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
vmWorkspaceExportFunc = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return api.WorkspaceExportResult{ return api.WorkspaceExportResult{
GuestPath: "/root/repo", GuestPath: "/root/repo",
HasChanges: false, HasChanges: false,
}, nil }, nil
} }
cmd := NewBangerCommand() cmd := d.newRootCommand()
var out bytes.Buffer var out bytes.Buffer
var stderr bytes.Buffer var stderr bytes.Buffer
cmd.SetOut(&out) cmd.SetOut(&out)
@ -2305,18 +2164,16 @@ func TestVMWorkspaceExportNoChanges(t *testing.T) {
} }
func TestVMWorkspaceExportGuestPathFlag(t *testing.T) { func TestVMWorkspaceExportGuestPathFlag(t *testing.T) {
stubEnsureDaemonForSend(t) d := defaultDeps()
stubEnsureDaemonForSend(t, d)
origExport := vmWorkspaceExportFunc
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
var capturedParams api.WorkspaceExportParams var capturedParams api.WorkspaceExportParams
vmWorkspaceExportFunc = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { d.vmWorkspaceExport = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
capturedParams = params capturedParams = params
return api.WorkspaceExportResult{HasChanges: false}, nil return api.WorkspaceExportResult{HasChanges: false}, nil
} }
cmd := NewBangerCommand() cmd := d.newRootCommand()
cmd.SetOut(io.Discard) cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard) cmd.SetErr(io.Discard)
cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--guest-path", "/root/project"}) cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--guest-path", "/root/project"})
@ -2332,13 +2189,11 @@ func TestVMWorkspaceExportGuestPathFlag(t *testing.T) {
} }
func TestVMWorkspaceExportBaseCommitFlag(t *testing.T) { func TestVMWorkspaceExportBaseCommitFlag(t *testing.T) {
stubEnsureDaemonForSend(t) d := defaultDeps()
stubEnsureDaemonForSend(t, d)
origExport := vmWorkspaceExportFunc
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
var capturedParams api.WorkspaceExportParams var capturedParams api.WorkspaceExportParams
vmWorkspaceExportFunc = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { d.vmWorkspaceExport = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
capturedParams = params capturedParams = params
return api.WorkspaceExportResult{ return api.WorkspaceExportResult{
HasChanges: false, HasChanges: false,
@ -2346,7 +2201,7 @@ func TestVMWorkspaceExportBaseCommitFlag(t *testing.T) {
}, nil }, nil
} }
cmd := NewBangerCommand() cmd := d.newRootCommand()
cmd.SetOut(io.Discard) cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard) cmd.SetErr(io.Discard)
cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--base-commit", "abc1234deadbeef"}) cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--base-commit", "abc1234deadbeef"})

View file

@ -15,7 +15,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func newDaemonCommand() *cobra.Command { func (d *deps) newDaemonCommand() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "daemon", Use: "daemon",
Short: "Manage the banger daemon", Short: "Manage the banger daemon",
@ -31,7 +31,7 @@ func newDaemonCommand() *cobra.Command {
if err != nil { if err != nil {
return err return err
} }
ping, pingErr := daemonPingFunc(cmd.Context(), layout.SocketPath) ping, pingErr := d.daemonPing(cmd.Context(), layout.SocketPath)
if pingErr != nil { if pingErr != nil {
_, err = fmt.Fprintf(cmd.OutOrStdout(), "stopped\nsocket: %s\nlog: %s\ndns: %s\n", layout.SocketPath, layout.DaemonLog, vmdns.DefaultListenAddr) _, err = fmt.Fprintf(cmd.OutOrStdout(), "stopped\nsocket: %s\nlog: %s\ndns: %s\n", layout.SocketPath, layout.DaemonLog, vmdns.DefaultListenAddr)
return err return err

View file

@ -13,24 +13,24 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func newImageCommand() *cobra.Command { func (d *deps) newImageCommand() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "image", Use: "image",
Short: "Manage images", Short: "Manage images",
RunE: helpNoArgs, RunE: helpNoArgs,
} }
cmd.AddCommand( cmd.AddCommand(
newImageRegisterCommand(), d.newImageRegisterCommand(),
newImagePullCommand(), d.newImagePullCommand(),
newImagePromoteCommand(), d.newImagePromoteCommand(),
newImageListCommand(), d.newImageListCommand(),
newImageShowCommand(), d.newImageShowCommand(),
newImageDeleteCommand(), d.newImageDeleteCommand(),
) )
return cmd return cmd
} }
func newImageRegisterCommand() *cobra.Command { func (d *deps) newImageRegisterCommand() *cobra.Command {
var params api.ImageRegisterParams var params api.ImageRegisterParams
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "register", Use: "register",
@ -46,7 +46,7 @@ func newImageRegisterCommand() *cobra.Command {
if err := system.EnsureSudo(cmd.Context()); err != nil { if err := system.EnsureSudo(cmd.Context()); err != nil {
return err return err
} }
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -65,11 +65,11 @@ func newImageRegisterCommand() *cobra.Command {
cmd.Flags().StringVar(&params.ModulesDir, "modules", "", "modules dir") cmd.Flags().StringVar(&params.ModulesDir, "modules", "", "modules dir")
cmd.Flags().StringVar(&params.KernelRef, "kernel-ref", "", "name of a cataloged kernel (see 'banger kernel list')") cmd.Flags().StringVar(&params.KernelRef, "kernel-ref", "", "name of a cataloged kernel (see 'banger kernel list')")
cmd.Flags().BoolVar(&params.Docker, "docker", false, "mark image as docker-prepared") cmd.Flags().BoolVar(&params.Docker, "docker", false, "mark image as docker-prepared")
_ = cmd.RegisterFlagCompletionFunc("kernel-ref", completeKernelNames) _ = cmd.RegisterFlagCompletionFunc("kernel-ref", d.completeKernelNames)
return cmd return cmd
} }
func newImagePullCommand() *cobra.Command { func (d *deps) newImagePullCommand() *cobra.Command {
var ( var (
params api.ImagePullParams params api.ImagePullParams
sizeRaw string sizeRaw string
@ -117,7 +117,7 @@ subcommand lands).
if err := system.EnsureSudo(cmd.Context()); err != nil { if err := system.EnsureSudo(cmd.Context()); err != nil {
return err return err
} }
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -139,21 +139,21 @@ subcommand lands).
cmd.Flags().StringVar(&params.ModulesDir, "modules", "", "modules dir") cmd.Flags().StringVar(&params.ModulesDir, "modules", "", "modules dir")
cmd.Flags().StringVar(&params.KernelRef, "kernel-ref", "", "name of a cataloged kernel (see 'banger kernel list')") cmd.Flags().StringVar(&params.KernelRef, "kernel-ref", "", "name of a cataloged kernel (see 'banger kernel list')")
cmd.Flags().StringVar(&sizeRaw, "size", "", "ext4 image size (e.g. 4GiB); defaults to content + 25%, min 1GiB") cmd.Flags().StringVar(&sizeRaw, "size", "", "ext4 image size (e.g. 4GiB); defaults to content + 25%, min 1GiB")
_ = cmd.RegisterFlagCompletionFunc("kernel-ref", completeKernelNames) _ = cmd.RegisterFlagCompletionFunc("kernel-ref", d.completeKernelNames)
return cmd return cmd
} }
func newImagePromoteCommand() *cobra.Command { func (d *deps) newImagePromoteCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "promote <id-or-name>", Use: "promote <id-or-name>",
Short: "Promote an unmanaged image to a managed artifact", Short: "Promote an unmanaged image to a managed artifact",
Args: exactArgsUsage(1, "usage: banger image promote <id-or-name>"), Args: exactArgsUsage(1, "usage: banger image promote <id-or-name>"),
ValidArgsFunction: completeImageNameOnlyAtPos0, ValidArgsFunction: d.completeImageNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := system.EnsureSudo(cmd.Context()); err != nil { if err := system.EnsureSudo(cmd.Context()); err != nil {
return err return err
} }
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -166,14 +166,14 @@ func newImagePromoteCommand() *cobra.Command {
} }
} }
func newImageListCommand() *cobra.Command { func (d *deps) newImageListCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "list", Use: "list",
Aliases: []string{"ls"}, Aliases: []string{"ls"},
Short: "List images", Short: "List images",
Args: noArgsUsage("usage: banger image list"), Args: noArgsUsage("usage: banger image list"),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -186,14 +186,14 @@ func newImageListCommand() *cobra.Command {
} }
} }
func newImageShowCommand() *cobra.Command { func (d *deps) newImageShowCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "show <id-or-name>", Use: "show <id-or-name>",
Short: "Show image details", Short: "Show image details",
Args: exactArgsUsage(1, "usage: banger image show <id-or-name>"), Args: exactArgsUsage(1, "usage: banger image show <id-or-name>"),
ValidArgsFunction: completeImageNameOnlyAtPos0, ValidArgsFunction: d.completeImageNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -206,18 +206,18 @@ func newImageShowCommand() *cobra.Command {
} }
} }
func newImageDeleteCommand() *cobra.Command { func (d *deps) newImageDeleteCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "delete <id-or-name>", Use: "delete <id-or-name>",
Aliases: []string{"rm"}, Aliases: []string{"rm"},
Short: "Delete an image", Short: "Delete an image",
Args: exactArgsUsage(1, "usage: banger image delete <id-or-name>"), Args: exactArgsUsage(1, "usage: banger image delete <id-or-name>"),
ValidArgsFunction: completeImageNameOnlyAtPos0, ValidArgsFunction: d.completeImageNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := system.EnsureSudo(cmd.Context()); err != nil { if err := system.EnsureSudo(cmd.Context()); err != nil {
return err return err
} }
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }

View file

@ -25,7 +25,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func newInternalCommand() *cobra.Command { func (d *deps) newInternalCommand() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "internal", Use: "internal",
Hidden: true, Hidden: true,

View file

@ -12,30 +12,30 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func newKernelCommand() *cobra.Command { func (d *deps) newKernelCommand() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "kernel", Use: "kernel",
Short: "Manage the local kernel catalog", Short: "Manage the local kernel catalog",
RunE: helpNoArgs, RunE: helpNoArgs,
} }
cmd.AddCommand( cmd.AddCommand(
newKernelListCommand(), d.newKernelListCommand(),
newKernelShowCommand(), d.newKernelShowCommand(),
newKernelRmCommand(), d.newKernelRmCommand(),
newKernelImportCommand(), d.newKernelImportCommand(),
newKernelPullCommand(), d.newKernelPullCommand(),
) )
return cmd return cmd
} }
func newKernelPullCommand() *cobra.Command { func (d *deps) newKernelPullCommand() *cobra.Command {
var force bool var force bool
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "pull <name>", Use: "pull <name>",
Short: "Download a cataloged kernel bundle", Short: "Download a cataloged kernel bundle",
Args: exactArgsUsage(1, "usage: banger kernel pull <name> [--force]"), Args: exactArgsUsage(1, "usage: banger kernel pull <name> [--force]"),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -55,7 +55,7 @@ func newKernelPullCommand() *cobra.Command {
return cmd return cmd
} }
func newKernelImportCommand() *cobra.Command { func (d *deps) newKernelImportCommand() *cobra.Command {
var params api.KernelImportParams var params api.KernelImportParams
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "import <name>", Use: "import <name>",
@ -72,7 +72,7 @@ func newKernelImportCommand() *cobra.Command {
return err return err
} }
params.FromDir = abs params.FromDir = abs
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -89,7 +89,7 @@ func newKernelImportCommand() *cobra.Command {
return cmd return cmd
} }
func newKernelListCommand() *cobra.Command { func (d *deps) newKernelListCommand() *cobra.Command {
var available bool var available bool
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "list", Use: "list",
@ -97,7 +97,7 @@ func newKernelListCommand() *cobra.Command {
Short: "List kernels (local by default, or --available for the catalog)", Short: "List kernels (local by default, or --available for the catalog)",
Args: noArgsUsage("usage: banger kernel list [--available]"), Args: noArgsUsage("usage: banger kernel list [--available]"),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -119,14 +119,14 @@ func newKernelListCommand() *cobra.Command {
return cmd return cmd
} }
func newKernelShowCommand() *cobra.Command { func (d *deps) newKernelShowCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "show <name>", Use: "show <name>",
Short: "Show kernel catalog entry details", Short: "Show kernel catalog entry details",
Args: exactArgsUsage(1, "usage: banger kernel show <name>"), Args: exactArgsUsage(1, "usage: banger kernel show <name>"),
ValidArgsFunction: completeKernelNameOnlyAtPos0, ValidArgsFunction: d.completeKernelNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -139,15 +139,15 @@ func newKernelShowCommand() *cobra.Command {
} }
} }
func newKernelRmCommand() *cobra.Command { func (d *deps) newKernelRmCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "rm <name>", Use: "rm <name>",
Aliases: []string{"remove", "delete"}, Aliases: []string{"remove", "delete"},
Short: "Remove a kernel catalog entry", Short: "Remove a kernel catalog entry",
Args: exactArgsUsage(1, "usage: banger kernel rm <name>"), Args: exactArgsUsage(1, "usage: banger kernel rm <name>"),
ValidArgsFunction: completeKernelNameOnlyAtPos0, ValidArgsFunction: d.completeKernelNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }

View file

@ -22,35 +22,35 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func newVMCommand() *cobra.Command { func (d *deps) newVMCommand() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "vm", Use: "vm",
Short: "Manage virtual machines", Short: "Manage virtual machines",
RunE: helpNoArgs, RunE: helpNoArgs,
} }
cmd.AddCommand( cmd.AddCommand(
newVMCreateCommand(), d.newVMCreateCommand(),
newVMRunCommand(), d.newVMRunCommand(),
newVMListCommand(), d.newVMListCommand(),
newVMShowCommand(), d.newVMShowCommand(),
newVMActionCommand("start", "Start a VM", "vm.start"), d.newVMActionCommand("start", "Start a VM", "vm.start"),
newVMActionCommand("stop", "Stop a VM", "vm.stop"), d.newVMActionCommand("stop", "Stop a VM", "vm.stop"),
newVMKillCommand(), d.newVMKillCommand(),
newVMActionCommand("restart", "Restart a VM", "vm.restart"), d.newVMActionCommand("restart", "Restart a VM", "vm.restart"),
newVMActionCommand("delete", "Delete a VM", "vm.delete", "rm"), d.newVMActionCommand("delete", "Delete a VM", "vm.delete", "rm"),
newVMPruneCommand(), d.newVMPruneCommand(),
newVMSetCommand(), d.newVMSetCommand(),
newVMSSHCommand(), d.newVMSSHCommand(),
newVMWorkspaceCommand(), d.newVMWorkspaceCommand(),
newVMSessionCommand(), d.newVMSessionCommand(),
newVMLogsCommand(), d.newVMLogsCommand(),
newVMStatsCommand(), d.newVMStatsCommand(),
newVMPortsCommand(), d.newVMPortsCommand(),
) )
return cmd return cmd
} }
func newVMRunCommand() *cobra.Command { func (d *deps) newVMRunCommand() *cobra.Command {
defaults := effectiveVMDefaults() defaults := effectiveVMDefaults()
var ( var (
name string name string
@ -104,7 +104,7 @@ Three modes:
var repoPtr *vmRunRepo var repoPtr *vmRunRepo
if sourcePath != "" { if sourcePath != "" {
resolved, err := vmRunPreflightRepo(cmd.Context(), sourcePath) resolved, err := d.vmRunPreflightRepo(cmd.Context(), sourcePath)
if err != nil { if err != nil {
return err return err
} }
@ -135,11 +135,11 @@ Three modes:
if err := system.EnsureSudo(cmd.Context()); err != nil { if err := system.EnsureSudo(cmd.Context()); err != nil {
return err return err
} }
layout, cfg, err = ensureDaemon(cmd.Context()) layout, cfg, err = d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
return runVMRun(cmd.Context(), layout.SocketPath, cfg, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), params, repoPtr, commandArgs, removeOnExit) return d.runVMRun(cmd.Context(), layout.SocketPath, cfg, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), params, repoPtr, commandArgs, removeOnExit)
}, },
} }
cmd.Flags().StringVar(&name, "name", "", "vm name") cmd.Flags().StringVar(&name, "name", "", "vm name")
@ -152,22 +152,22 @@ Three modes:
cmd.Flags().StringVar(&branchName, "branch", "", "create and switch to a new guest branch") cmd.Flags().StringVar(&branchName, "branch", "", "create and switch to a new guest branch")
cmd.Flags().StringVar(&fromRef, "from", "HEAD", "base ref for --branch") cmd.Flags().StringVar(&fromRef, "from", "HEAD", "base ref for --branch")
cmd.Flags().BoolVar(&removeOnExit, "rm", false, "delete the VM after the ssh session / command exits") cmd.Flags().BoolVar(&removeOnExit, "rm", false, "delete the VM after the ssh session / command exits")
_ = cmd.RegisterFlagCompletionFunc("image", completeImageNames) _ = cmd.RegisterFlagCompletionFunc("image", d.completeImageNames)
return cmd return cmd
} }
func newVMKillCommand() *cobra.Command { func (d *deps) newVMKillCommand() *cobra.Command {
var signal string var signal string
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "kill <id-or-name>...", Use: "kill <id-or-name>...",
Short: "Send a signal to a VM process", Short: "Send a signal to a VM process",
Args: minArgsUsage(1, "usage: banger vm kill [--signal SIGTERM|SIGKILL|...] <id-or-name>..."), Args: minArgsUsage(1, "usage: banger vm kill [--signal SIGTERM|SIGKILL|...] <id-or-name>..."),
ValidArgsFunction: completeVMNames, ValidArgsFunction: d.completeVMNames,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := system.EnsureSudo(cmd.Context()); err != nil { if err := system.EnsureSudo(cmd.Context()); err != nil {
return err return err
} }
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -201,7 +201,7 @@ func newVMKillCommand() *cobra.Command {
return cmd return cmd
} }
func newVMPruneCommand() *cobra.Command { func (d *deps) newVMPruneCommand() *cobra.Command {
var force bool var force bool
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "prune", Use: "prune",
@ -212,23 +212,23 @@ func newVMPruneCommand() *cobra.Command {
if err := system.EnsureSudo(cmd.Context()); err != nil { if err := system.EnsureSudo(cmd.Context()); err != nil {
return err return err
} }
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
return runVMPrune(cmd, layout.SocketPath, force) return d.runVMPrune(cmd, layout.SocketPath, force)
}, },
} }
cmd.Flags().BoolVarP(&force, "force", "f", false, "skip the confirmation prompt") cmd.Flags().BoolVarP(&force, "force", "f", false, "skip the confirmation prompt")
return cmd return cmd
} }
func runVMPrune(cmd *cobra.Command, socketPath string, force bool) error { func (d *deps) runVMPrune(cmd *cobra.Command, socketPath string, force bool) error {
ctx := cmd.Context() ctx := cmd.Context()
stdout := cmd.OutOrStdout() stdout := cmd.OutOrStdout()
stderr := cmd.ErrOrStderr() stderr := cmd.ErrOrStderr()
list, err := vmListFunc(ctx, socketPath) list, err := d.vmList(ctx, socketPath)
if err != nil { if err != nil {
return err return err
} }
@ -270,7 +270,7 @@ func runVMPrune(cmd *cobra.Command, socketPath string, force bool) error {
if ref == "" { if ref == "" {
ref = shortID(vm.ID) ref = shortID(vm.ID)
} }
if err := vmDeleteFunc(ctx, socketPath, vm.ID); err != nil { if err := d.vmDelete(ctx, socketPath, vm.ID); err != nil {
fmt.Fprintf(stderr, "delete %s: %v\n", ref, err) fmt.Fprintf(stderr, "delete %s: %v\n", ref, err)
failed++ failed++
continue continue
@ -299,7 +299,7 @@ func promptYesNo(in io.Reader, out io.Writer, prompt string) (bool, error) {
return answer == "y" || answer == "yes", nil return answer == "y" || answer == "yes", nil
} }
func newVMCreateCommand() *cobra.Command { func (d *deps) newVMCreateCommand() *cobra.Command {
defaults := effectiveVMDefaults() defaults := effectiveVMDefaults()
var ( var (
name string name string
@ -323,11 +323,11 @@ func newVMCreateCommand() *cobra.Command {
if err := system.EnsureSudo(cmd.Context()); err != nil { if err := system.EnsureSudo(cmd.Context()); err != nil {
return err return err
} }
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
vm, err := runVMCreate(cmd.Context(), layout.SocketPath, cmd.ErrOrStderr(), params) vm, err := d.runVMCreate(cmd.Context(), layout.SocketPath, cmd.ErrOrStderr(), params)
if err != nil { if err != nil {
return err return err
} }
@ -342,7 +342,7 @@ func newVMCreateCommand() *cobra.Command {
cmd.Flags().StringVar(&workDiskSize, "disk-size", model.FormatSizeBytes(defaults.WorkDiskSizeBytes), "work disk size") cmd.Flags().StringVar(&workDiskSize, "disk-size", model.FormatSizeBytes(defaults.WorkDiskSizeBytes), "work disk size")
cmd.Flags().BoolVar(&natEnabled, "nat", false, "enable NAT") cmd.Flags().BoolVar(&natEnabled, "nat", false, "enable NAT")
cmd.Flags().BoolVar(&noStart, "no-start", false, "create without starting") cmd.Flags().BoolVar(&noStart, "no-start", false, "create without starting")
_ = cmd.RegisterFlagCompletionFunc("image", completeImageNames) _ = cmd.RegisterFlagCompletionFunc("image", d.completeImageNames)
return cmd return cmd
} }
@ -352,15 +352,15 @@ type vmListOptions struct {
quiet bool quiet bool
} }
func newPSCommand() *cobra.Command { func (d *deps) newPSCommand() *cobra.Command {
return newVMListLikeCommand("ps", nil, "usage: banger ps") return d.newVMListLikeCommand("ps", nil, "usage: banger ps")
} }
func newVMListCommand() *cobra.Command { func (d *deps) newVMListCommand() *cobra.Command {
return newVMListLikeCommand("list", []string{"ls", "ps"}, "usage: banger vm list") return d.newVMListLikeCommand("list", []string{"ls", "ps"}, "usage: banger vm list")
} }
func newVMListLikeCommand(use string, aliases []string, usage string) *cobra.Command { func (d *deps) newVMListLikeCommand(use string, aliases []string, usage string) *cobra.Command {
var opts vmListOptions var opts vmListOptions
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: use, Use: use,
@ -368,7 +368,7 @@ func newVMListLikeCommand(use string, aliases []string, usage string) *cobra.Com
Short: "List VMs", Short: "List VMs",
Args: noArgsUsage(usage), Args: noArgsUsage(usage),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
return runVMList(cmd, opts) return d.runVMList(cmd, opts)
}, },
} }
cmd.Flags().BoolVarP(&opts.showAll, "all", "a", false, "show all VMs") cmd.Flags().BoolVarP(&opts.showAll, "all", "a", false, "show all VMs")
@ -377,8 +377,8 @@ func newVMListLikeCommand(use string, aliases []string, usage string) *cobra.Com
return cmd return cmd
} }
func runVMList(cmd *cobra.Command, opts vmListOptions) error { func (d *deps) runVMList(cmd *cobra.Command, opts vmListOptions) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -421,14 +421,14 @@ func selectVMListVMs(vms []model.VMRecord, showAll, latest bool) []model.VMRecor
return []model.VMRecord{latestVM} return []model.VMRecord{latestVM}
} }
func newVMShowCommand() *cobra.Command { func (d *deps) newVMShowCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "show <id-or-name>", Use: "show <id-or-name>",
Short: "Show VM details", Short: "Show VM details",
Args: exactArgsUsage(1, "usage: banger vm show <id-or-name>"), Args: exactArgsUsage(1, "usage: banger vm show <id-or-name>"),
ValidArgsFunction: completeVMNameOnlyAtPos0, ValidArgsFunction: d.completeVMNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -441,18 +441,18 @@ func newVMShowCommand() *cobra.Command {
} }
} }
func newVMActionCommand(use, short, method string, aliases ...string) *cobra.Command { func (d *deps) newVMActionCommand(use, short, method string, aliases ...string) *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: use + " <id-or-name>...", Use: use + " <id-or-name>...",
Aliases: aliases, Aliases: aliases,
Short: short, Short: short,
Args: minArgsUsage(1, fmt.Sprintf("usage: banger vm %s <id-or-name>...", use)), Args: minArgsUsage(1, fmt.Sprintf("usage: banger vm %s <id-or-name>...", use)),
ValidArgsFunction: completeVMNames, ValidArgsFunction: d.completeVMNames,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := system.EnsureSudo(cmd.Context()); err != nil { if err := system.EnsureSudo(cmd.Context()); err != nil {
return err return err
} }
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -474,7 +474,7 @@ func newVMActionCommand(use, short, method string, aliases ...string) *cobra.Com
} }
} }
func newVMSetCommand() *cobra.Command { func (d *deps) newVMSetCommand() *cobra.Command {
var ( var (
vcpu int vcpu int
memory int memory int
@ -486,7 +486,7 @@ func newVMSetCommand() *cobra.Command {
Use: "set <id-or-name>...", Use: "set <id-or-name>...",
Short: "Update stopped VM settings", Short: "Update stopped VM settings",
Args: minArgsUsage(1, "usage: banger vm set [--vcpu N] [--memory MiB] [--disk-size SIZE] [--nat|--no-nat] <id-or-name>..."), Args: minArgsUsage(1, "usage: banger vm set [--vcpu N] [--memory MiB] [--disk-size SIZE] [--nat|--no-nat] <id-or-name>..."),
ValidArgsFunction: completeVMNames, ValidArgsFunction: d.completeVMNames,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
params, err := vmSetParamsFromFlags(args[0], vcpu, memory, diskSize, nat, noNat) params, err := vmSetParamsFromFlags(args[0], vcpu, memory, diskSize, nat, noNat)
if err != nil { if err != nil {
@ -495,7 +495,7 @@ func newVMSetCommand() *cobra.Command {
if err := system.EnsureSudo(cmd.Context()); err != nil { if err := system.EnsureSudo(cmd.Context()); err != nil {
return err return err
} }
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -525,21 +525,21 @@ func newVMSetCommand() *cobra.Command {
return cmd return cmd
} }
func newVMSSHCommand() *cobra.Command { func (d *deps) newVMSSHCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "ssh <id-or-name> [ssh args...]", Use: "ssh <id-or-name> [ssh args...]",
Short: "SSH into a running VM", Short: "SSH into a running VM",
Args: minArgsUsage(1, "usage: banger vm ssh <id-or-name> [ssh args...]"), Args: minArgsUsage(1, "usage: banger vm ssh <id-or-name> [ssh args...]"),
ValidArgsFunction: completeVMNameOnlyAtPos0, ValidArgsFunction: d.completeVMNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, cfg, err := ensureDaemon(cmd.Context()) layout, cfg, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
if err := validateSSHPrereqs(cfg); err != nil { if err := validateSSHPrereqs(cfg); err != nil {
return err return err
} }
result, err := vmSSHFunc(cmd.Context(), layout.SocketPath, args[0]) result, err := d.vmSSH(cmd.Context(), layout.SocketPath, args[0])
if err != nil { if err != nil {
return err return err
} }
@ -547,25 +547,25 @@ func newVMSSHCommand() *cobra.Command {
if err != nil { if err != nil {
return err return err
} }
return runSSHSession(cmd.Context(), layout.SocketPath, result.Name, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), sshArgs, false) return d.runSSHSession(cmd.Context(), layout.SocketPath, result.Name, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), sshArgs, false)
}, },
} }
} }
func newVMWorkspaceCommand() *cobra.Command { func (d *deps) newVMWorkspaceCommand() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "workspace", Use: "workspace",
Short: "Manage repository workspaces inside a running VM", Short: "Manage repository workspaces inside a running VM",
RunE: helpNoArgs, RunE: helpNoArgs,
} }
cmd.AddCommand( cmd.AddCommand(
newVMWorkspacePrepareCommand(), d.newVMWorkspacePrepareCommand(),
newVMWorkspaceExportCommand(), d.newVMWorkspaceExportCommand(),
) )
return cmd return cmd
} }
func newVMWorkspacePrepareCommand() *cobra.Command { func (d *deps) newVMWorkspacePrepareCommand() *cobra.Command {
var guestPath string var guestPath string
var branchName string var branchName string
var fromRef string var fromRef string
@ -576,14 +576,14 @@ func newVMWorkspacePrepareCommand() *cobra.Command {
Short: "Copy a local repo into a running VM", Short: "Copy a local repo into a running VM",
Long: "Prepare a repository workspace from a local git checkout into a running VM. The default guest path is /root/repo and the default mode is shallow_overlay. Repositories with git submodules must use --mode full_copy.", Long: "Prepare a repository workspace from a local git checkout into a running VM. The default guest path is /root/repo and the default mode is shallow_overlay. Repositories with git submodules must use --mode full_copy.",
Args: minArgsUsage(1, "usage: banger vm workspace prepare <id-or-name> [path]"), Args: minArgsUsage(1, "usage: banger vm workspace prepare <id-or-name> [path]"),
ValidArgsFunction: completeVMNameOnlyAtPos0, ValidArgsFunction: d.completeVMNameOnlyAtPos0,
Example: strings.TrimSpace(` Example: strings.TrimSpace(`
banger vm workspace prepare devbox banger vm workspace prepare devbox
banger vm workspace prepare devbox ../repo --guest-path /root/repo --readonly banger vm workspace prepare devbox ../repo --guest-path /root/repo --readonly
banger vm workspace prepare devbox ../repo --mode full_copy banger vm workspace prepare devbox ../repo --mode full_copy
`), `),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -592,7 +592,7 @@ func newVMWorkspacePrepareCommand() *cobra.Command {
sourcePath = args[1] sourcePath = args[1]
} }
if strings.TrimSpace(sourcePath) == "" { if strings.TrimSpace(sourcePath) == "" {
wd, err := cwdFunc() wd, err := d.cwd()
if err != nil { if err != nil {
return err return err
} }
@ -606,7 +606,7 @@ func newVMWorkspacePrepareCommand() *cobra.Command {
if strings.TrimSpace(branchName) != "" { if strings.TrimSpace(branchName) != "" {
prepareFrom = fromRef prepareFrom = fromRef
} }
result, err := vmWorkspacePrepareFunc(cmd.Context(), layout.SocketPath, api.VMWorkspacePrepareParams{ result, err := d.vmWorkspacePrepare(cmd.Context(), layout.SocketPath, api.VMWorkspacePrepareParams{
IDOrName: args[0], IDOrName: args[0],
SourcePath: resolvedPath, SourcePath: resolvedPath,
GuestPath: guestPath, GuestPath: guestPath,
@ -629,7 +629,7 @@ func newVMWorkspacePrepareCommand() *cobra.Command {
return cmd return cmd
} }
func newVMWorkspaceExportCommand() *cobra.Command { func (d *deps) newVMWorkspaceExportCommand() *cobra.Command {
var guestPath string var guestPath string
var outputPath string var outputPath string
var baseCommit string var baseCommit string
@ -638,7 +638,7 @@ func newVMWorkspaceExportCommand() *cobra.Command {
Short: "Pull changes from a guest workspace back to the host as a patch", Short: "Pull changes from a guest workspace back to the host as a patch",
Long: "Emit a binary-safe unified diff of every change inside the guest workspace (committed since base + uncommitted + untracked, minus .gitignore). Non-mutating — the guest's index and working tree are untouched. Pass --base-commit with the head_commit from workspace prepare to capture changes even when the worker ran git commit inside the VM. Without --base-commit the diff is against the current guest HEAD, which misses committed changes.", Long: "Emit a binary-safe unified diff of every change inside the guest workspace (committed since base + uncommitted + untracked, minus .gitignore). Non-mutating — the guest's index and working tree are untouched. Pass --base-commit with the head_commit from workspace prepare to capture changes even when the worker ran git commit inside the VM. Without --base-commit the diff is against the current guest HEAD, which misses committed changes.",
Args: exactArgsUsage(1, "usage: banger vm workspace export <id-or-name>"), Args: exactArgsUsage(1, "usage: banger vm workspace export <id-or-name>"),
ValidArgsFunction: completeVMNameOnlyAtPos0, ValidArgsFunction: d.completeVMNameOnlyAtPos0,
Example: strings.TrimSpace(` Example: strings.TrimSpace(`
banger vm workspace export devbox | git apply banger vm workspace export devbox | git apply
banger vm workspace export devbox --base-commit abc1234 | git apply banger vm workspace export devbox --base-commit abc1234 | git apply
@ -646,11 +646,11 @@ func newVMWorkspaceExportCommand() *cobra.Command {
banger vm workspace export devbox --guest-path /root/project --output changes.diff banger vm workspace export devbox --guest-path /root/project --output changes.diff
`), `),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
result, err := vmWorkspaceExportFunc(cmd.Context(), layout.SocketPath, api.WorkspaceExportParams{ result, err := d.vmWorkspaceExport(cmd.Context(), layout.SocketPath, api.WorkspaceExportParams{
IDOrName: args[0], IDOrName: args[0],
GuestPath: guestPath, GuestPath: guestPath,
BaseCommit: baseCommit, BaseCommit: baseCommit,
@ -680,15 +680,15 @@ func newVMWorkspaceExportCommand() *cobra.Command {
return cmd return cmd
} }
func newVMLogsCommand() *cobra.Command { func (d *deps) newVMLogsCommand() *cobra.Command {
var follow bool var follow bool
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "logs <id-or-name>", Use: "logs <id-or-name>",
Short: "Show VM logs", Short: "Show VM logs",
Args: exactArgsUsage(1, "usage: banger vm logs [-f] <id-or-name>"), Args: exactArgsUsage(1, "usage: banger vm logs [-f] <id-or-name>"),
ValidArgsFunction: completeVMNameOnlyAtPos0, ValidArgsFunction: d.completeVMNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -706,14 +706,14 @@ func newVMLogsCommand() *cobra.Command {
return cmd return cmd
} }
func newVMStatsCommand() *cobra.Command { func (d *deps) newVMStatsCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "stats <id-or-name>", Use: "stats <id-or-name>",
Short: "Show VM stats", Short: "Show VM stats",
Args: exactArgsUsage(1, "usage: banger vm stats <id-or-name>"), Args: exactArgsUsage(1, "usage: banger vm stats <id-or-name>"),
ValidArgsFunction: completeVMNameOnlyAtPos0, ValidArgsFunction: d.completeVMNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -726,18 +726,18 @@ func newVMStatsCommand() *cobra.Command {
} }
} }
func newVMPortsCommand() *cobra.Command { func (d *deps) newVMPortsCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "ports <id-or-name>", Use: "ports <id-or-name>",
Short: "Show host-reachable listening guest ports", Short: "Show host-reachable listening guest ports",
Args: exactArgsUsage(1, "usage: banger vm ports <id-or-name>"), Args: exactArgsUsage(1, "usage: banger vm ports <id-or-name>"),
ValidArgsFunction: completeVMNameOnlyAtPos0, ValidArgsFunction: d.completeVMNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
result, err := vmPortsFunc(cmd.Context(), layout.SocketPath, args[0]) result, err := d.vmPorts(cmd.Context(), layout.SocketPath, args[0])
if err != nil { if err != nil {
return err return err
} }

View file

@ -15,7 +15,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func newVMSessionCommand() *cobra.Command { func (d *deps) newVMSessionCommand() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "session", Use: "session",
Short: "Manage long-lived guest commands inside a VM", Short: "Manage long-lived guest commands inside a VM",
@ -23,19 +23,19 @@ func newVMSessionCommand() *cobra.Command {
RunE: helpNoArgs, RunE: helpNoArgs,
} }
cmd.AddCommand( cmd.AddCommand(
newVMSessionStartCommand(), d.newVMSessionStartCommand(),
newVMSessionListCommand(), d.newVMSessionListCommand(),
newVMSessionShowCommand(), d.newVMSessionShowCommand(),
newVMSessionLogsCommand(), d.newVMSessionLogsCommand(),
newVMSessionStopCommand(), d.newVMSessionStopCommand(),
newVMSessionKillCommand(), d.newVMSessionKillCommand(),
newVMSessionAttachCommand(), d.newVMSessionAttachCommand(),
newVMSessionSendCommand(), d.newVMSessionSendCommand(),
) )
return cmd return cmd
} }
func newVMSessionStartCommand() *cobra.Command { func (d *deps) newVMSessionStartCommand() *cobra.Command {
var name string var name string
var cwd string var cwd string
var stdinMode string var stdinMode string
@ -47,13 +47,13 @@ func newVMSessionStartCommand() *cobra.Command {
Short: "Start a managed guest command", Short: "Start a managed guest command",
Long: "Start a daemon-managed guest command. The daemon verifies that the guest working directory exists and that the requested command is present in guest PATH before launch. Use --stdin-mode pipe when you need live attach.", Long: "Start a daemon-managed guest command. The daemon verifies that the guest working directory exists and that the requested command is present in guest PATH before launch. Use --stdin-mode pipe when you need live attach.",
Args: minArgsUsage(2, "usage: banger vm session start <id-or-name> [flags] -- <command> [args...]"), Args: minArgsUsage(2, "usage: banger vm session start <id-or-name> [flags] -- <command> [args...]"),
ValidArgsFunction: completeVMNameOnlyAtPos0, ValidArgsFunction: d.completeVMNameOnlyAtPos0,
Example: strings.TrimSpace(` Example: strings.TrimSpace(`
banger vm session start devbox --name planner --cwd /root/repo --stdin-mode pipe --require-command git -- pi --mode rpc --no-session banger vm session start devbox --name planner --cwd /root/repo --stdin-mode pipe --require-command git -- pi --mode rpc --no-session
banger vm session start devbox --name shell --stdin-mode pipe -- bash -lc 'exec bash' banger vm session start devbox --name shell --stdin-mode pipe -- bash -lc 'exec bash'
`), `),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -65,7 +65,7 @@ func newVMSessionStartCommand() *cobra.Command {
if err != nil { if err != nil {
return err return err
} }
result, err := guestSessionStartFunc(cmd.Context(), layout.SocketPath, api.GuestSessionStartParams{ result, err := d.guestSessionStart(cmd.Context(), layout.SocketPath, api.GuestSessionStartParams{
VMIDOrName: args[0], VMIDOrName: args[0],
Name: name, Name: name,
Command: args[1], Command: args[1],
@ -97,19 +97,19 @@ func newVMSessionStartCommand() *cobra.Command {
return cmd return cmd
} }
func newVMSessionListCommand() *cobra.Command { func (d *deps) newVMSessionListCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "list <id-or-name>", Use: "list <id-or-name>",
Aliases: []string{"ls"}, Aliases: []string{"ls"},
Short: "List managed guest commands for a VM", Short: "List managed guest commands for a VM",
Args: exactArgsUsage(1, "usage: banger vm session list <id-or-name>"), Args: exactArgsUsage(1, "usage: banger vm session list <id-or-name>"),
ValidArgsFunction: completeVMNameOnlyAtPos0, ValidArgsFunction: d.completeVMNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
result, err := guestSessionListFunc(cmd.Context(), layout.SocketPath, args[0]) result, err := d.guestSessionList(cmd.Context(), layout.SocketPath, args[0])
if err != nil { if err != nil {
return err return err
} }
@ -118,18 +118,18 @@ func newVMSessionListCommand() *cobra.Command {
} }
} }
func newVMSessionShowCommand() *cobra.Command { func (d *deps) newVMSessionShowCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "show <id-or-name> <session>", Use: "show <id-or-name> <session>",
Short: "Show managed guest command details", Short: "Show managed guest command details",
Args: exactArgsUsage(2, "usage: banger vm session show <id-or-name> <session>"), Args: exactArgsUsage(2, "usage: banger vm session show <id-or-name> <session>"),
ValidArgsFunction: completeSessionNames, ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
result, err := guestSessionGetFunc(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) result, err := d.guestSessionGet(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
if err != nil { if err != nil {
return err return err
} }
@ -138,20 +138,20 @@ func newVMSessionShowCommand() *cobra.Command {
} }
} }
func newVMSessionLogsCommand() *cobra.Command { func (d *deps) newVMSessionLogsCommand() *cobra.Command {
var stream string var stream string
var tailLines int var tailLines int
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "logs <id-or-name> <session>", Use: "logs <id-or-name> <session>",
Short: "Show stdout or stderr for a guest session", Short: "Show stdout or stderr for a guest session",
Args: exactArgsUsage(2, "usage: banger vm session logs [--stream stdout|stderr] [-n LINES] <id-or-name> <session>"), Args: exactArgsUsage(2, "usage: banger vm session logs [--stream stdout|stderr] [-n LINES] <id-or-name> <session>"),
ValidArgsFunction: completeSessionNames, ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
result, err := guestSessionLogsFunc(cmd.Context(), layout.SocketPath, api.GuestSessionLogsParams{VMIDOrName: args[0], SessionIDOrName: args[1], Stream: stream, TailLines: tailLines}) result, err := d.guestSessionLogs(cmd.Context(), layout.SocketPath, api.GuestSessionLogsParams{VMIDOrName: args[0], SessionIDOrName: args[1], Stream: stream, TailLines: tailLines})
if err != nil { if err != nil {
return err return err
} }
@ -164,18 +164,18 @@ func newVMSessionLogsCommand() *cobra.Command {
return cmd return cmd
} }
func newVMSessionStopCommand() *cobra.Command { func (d *deps) newVMSessionStopCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "stop <id-or-name> <session>", Use: "stop <id-or-name> <session>",
Short: "Send SIGTERM to a guest session", Short: "Send SIGTERM to a guest session",
Args: exactArgsUsage(2, "usage: banger vm session stop <id-or-name> <session>"), Args: exactArgsUsage(2, "usage: banger vm session stop <id-or-name> <session>"),
ValidArgsFunction: completeSessionNames, ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
result, err := guestSessionStopFunc(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) result, err := d.guestSessionStop(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
if err != nil { if err != nil {
return err return err
} }
@ -184,18 +184,18 @@ func newVMSessionStopCommand() *cobra.Command {
} }
} }
func newVMSessionKillCommand() *cobra.Command { func (d *deps) newVMSessionKillCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "kill <id-or-name> <session>", Use: "kill <id-or-name> <session>",
Short: "Send SIGKILL to a guest session", Short: "Send SIGKILL to a guest session",
Args: exactArgsUsage(2, "usage: banger vm session kill <id-or-name> <session>"), Args: exactArgsUsage(2, "usage: banger vm session kill <id-or-name> <session>"),
ValidArgsFunction: completeSessionNames, ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
result, err := guestSessionKillFunc(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) result, err := d.guestSessionKill(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
if err != nil { if err != nil {
return err return err
} }
@ -204,19 +204,19 @@ func newVMSessionKillCommand() *cobra.Command {
} }
} }
func newVMSessionAttachCommand() *cobra.Command { func (d *deps) newVMSessionAttachCommand() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "attach <id-or-name> <session>", Use: "attach <id-or-name> <session>",
Short: "Attach local stdio to an attachable guest session", Short: "Attach local stdio to an attachable guest session",
Long: "Attach local stdio to a pipe-mode session through a daemon-created local Unix socket bridge. Only one active attach is allowed at a time, and the client must run on the same host as the daemon.", Long: "Attach local stdio to a pipe-mode session through a daemon-created local Unix socket bridge. Only one active attach is allowed at a time, and the client must run on the same host as the daemon.",
Args: exactArgsUsage(2, "usage: banger vm session attach <id-or-name> <session>"), Args: exactArgsUsage(2, "usage: banger vm session attach <id-or-name> <session>"),
ValidArgsFunction: completeSessionNames, ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
result, err := guestSessionAttachBeginFunc(cmd.Context(), layout.SocketPath, api.GuestSessionAttachBeginParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) result, err := d.guestSessionAttachBegin(cmd.Context(), layout.SocketPath, api.GuestSessionAttachBeginParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
if err != nil { if err != nil {
return err return err
} }
@ -229,21 +229,21 @@ func newVMSessionAttachCommand() *cobra.Command {
} }
} }
func newVMSessionSendCommand() *cobra.Command { func (d *deps) newVMSessionSendCommand() *cobra.Command {
var message string var message string
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "send <id-or-name> <session>", Use: "send <id-or-name> <session>",
Short: "Write bytes to a running guest session's stdin pipe", Short: "Write bytes to a running guest session's stdin pipe",
Long: "Write a payload to the stdin pipe of a running pipe-mode guest session without holding the exclusive attach. Use --message for an inline JSONL string, or pipe bytes via stdin when --message is omitted. A trailing newline is appended to --message values that lack one.", Long: "Write a payload to the stdin pipe of a running pipe-mode guest session without holding the exclusive attach. Use --message for an inline JSONL string, or pipe bytes via stdin when --message is omitted. A trailing newline is appended to --message values that lack one.",
Args: exactArgsUsage(2, "usage: banger vm session send <id-or-name> <session> [--message '<json>']"), Args: exactArgsUsage(2, "usage: banger vm session send <id-or-name> <session> [--message '<json>']"),
ValidArgsFunction: completeSessionNames, ValidArgsFunction: d.completeSessionNames,
Example: strings.TrimSpace(` Example: strings.TrimSpace(`
banger vm session send devbox planner --message '{"type":"abort"}' banger vm session send devbox planner --message '{"type":"abort"}'
banger vm session send devbox planner --message '{"type":"steer","message":"Focus on src/"}' banger vm session send devbox planner --message '{"type":"steer","message":"Focus on src/"}'
echo '{"type":"prompt","prompt":"Summarize."}' | banger vm session send devbox planner echo '{"type":"prompt","prompt":"Summarize."}' | banger vm session send devbox planner
`), `),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@ -259,7 +259,7 @@ func newVMSessionSendCommand() *cobra.Command {
return fmt.Errorf("read stdin: %w", err) return fmt.Errorf("read stdin: %w", err)
} }
} }
result, err := guestSessionSendFunc(cmd.Context(), layout.SocketPath, api.GuestSessionSendParams{ result, err := d.guestSessionSend(cmd.Context(), layout.SocketPath, api.GuestSessionSendParams{
VMIDOrName: args[0], VMIDOrName: args[0],
SessionIDOrName: args[1], SessionIDOrName: args[1],
Payload: payload, Payload: payload,

View file

@ -21,9 +21,10 @@ import (
// - Fail silently. Completion is advisory; any error path returns an // - Fail silently. Completion is advisory; any error path returns an
// empty suggestion list rather than propagating to the user. // empty suggestion list rather than propagating to the user.
// completionListerFunc is the seam used by tests to avoid touching a // defaultCompletionLister + defaultCompletionSessionLister back the
// real daemon socket. // corresponding *deps fields; tests inject their own fakes via the
var completionListerFunc = func(ctx context.Context, socketPath, method string) ([]string, error) { // struct instead of mutating package-level vars.
func defaultCompletionLister(ctx context.Context, socketPath, method string) ([]string, error) {
switch method { switch method {
case "vm.list": case "vm.list":
result, err := rpc.Call[api.VMListResult](ctx, socketPath, method, api.Empty{}) result, err := rpc.Call[api.VMListResult](ctx, socketPath, method, api.Empty{})
@ -65,9 +66,7 @@ var completionListerFunc = func(ctx context.Context, socketPath, method string)
return nil, nil return nil, nil
} }
// completionSessionListerFunc is the seam for guest-session name lookups func defaultCompletionSessionLister(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) {
// scoped to a VM.
var completionSessionListerFunc = func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) {
result, err := rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: vmIDOrName}) result, err := rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: vmIDOrName})
if err != nil { if err != nil {
return nil, err return nil, err
@ -84,12 +83,12 @@ var completionSessionListerFunc = func(ctx context.Context, socketPath, vmIDOrNa
// daemonSocketForCompletion returns the socket path IFF the daemon is // daemonSocketForCompletion returns the socket path IFF the daemon is
// already running. Returns "", false when no daemon is up — completion // already running. Returns "", false when no daemon is up — completion
// callers use this as the bail signal. // callers use this as the bail signal.
func daemonSocketForCompletion(ctx context.Context) (string, bool) { func (d *deps) daemonSocketForCompletion(ctx context.Context) (string, bool) {
layout, err := paths.Resolve() layout, err := paths.Resolve()
if err != nil { if err != nil {
return "", false return "", false
} }
if _, err := daemonPingFunc(ctx, layout.SocketPath); err != nil { if _, err := d.daemonPing(ctx, layout.SocketPath); err != nil {
return "", false return "", false
} }
return layout.SocketPath, true return layout.SocketPath, true
@ -119,12 +118,12 @@ func hasPrefix(s, prefix string) bool {
return len(s) >= len(prefix) && s[:len(prefix)] == prefix return len(s) >= len(prefix) && s[:len(prefix)] == prefix
} }
func completeVMNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { func (d *deps) completeVMNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
socket, ok := daemonSocketForCompletion(cmd.Context()) socket, ok := d.daemonSocketForCompletion(cmd.Context())
if !ok { if !ok {
return nil, cobra.ShellCompDirectiveNoFileComp return nil, cobra.ShellCompDirectiveNoFileComp
} }
names, err := completionListerFunc(cmd.Context(), socket, "vm.list") names, err := d.completionLister(cmd.Context(), socket, "vm.list")
if err != nil { if err != nil {
return nil, cobra.ShellCompDirectiveNoFileComp return nil, cobra.ShellCompDirectiveNoFileComp
} }
@ -134,45 +133,45 @@ func completeVMNames(cmd *cobra.Command, args []string, toComplete string) ([]st
// completeVMNameOnlyAtPos0 restricts VM-name completion to the first // completeVMNameOnlyAtPos0 restricts VM-name completion to the first
// positional argument. Used by commands like `vm ssh <vm> [ssh args...]` // positional argument. Used by commands like `vm ssh <vm> [ssh args...]`
// where args after pos 0 are free-form. // where args after pos 0 are free-form.
func completeVMNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { func (d *deps) completeVMNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
if len(args) > 0 { if len(args) > 0 {
return nil, cobra.ShellCompDirectiveNoFileComp return nil, cobra.ShellCompDirectiveNoFileComp
} }
return completeVMNames(cmd, args, toComplete) return d.completeVMNames(cmd, args, toComplete)
} }
func completeImageNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { func (d *deps) completeImageNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
if len(args) > 0 { if len(args) > 0 {
return nil, cobra.ShellCompDirectiveNoFileComp return nil, cobra.ShellCompDirectiveNoFileComp
} }
return completeImageNames(cmd, args, toComplete) return d.completeImageNames(cmd, args, toComplete)
} }
func completeKernelNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { func (d *deps) completeKernelNameOnlyAtPos0(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
if len(args) > 0 { if len(args) > 0 {
return nil, cobra.ShellCompDirectiveNoFileComp return nil, cobra.ShellCompDirectiveNoFileComp
} }
return completeKernelNames(cmd, args, toComplete) return d.completeKernelNames(cmd, args, toComplete)
} }
func completeImageNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { func (d *deps) completeImageNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
socket, ok := daemonSocketForCompletion(cmd.Context()) socket, ok := d.daemonSocketForCompletion(cmd.Context())
if !ok { if !ok {
return nil, cobra.ShellCompDirectiveNoFileComp return nil, cobra.ShellCompDirectiveNoFileComp
} }
names, err := completionListerFunc(cmd.Context(), socket, "image.list") names, err := d.completionLister(cmd.Context(), socket, "image.list")
if err != nil { if err != nil {
return nil, cobra.ShellCompDirectiveNoFileComp return nil, cobra.ShellCompDirectiveNoFileComp
} }
return filterPrefix(names, args, toComplete), cobra.ShellCompDirectiveNoFileComp return filterPrefix(names, args, toComplete), cobra.ShellCompDirectiveNoFileComp
} }
func completeKernelNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { func (d *deps) completeKernelNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
socket, ok := daemonSocketForCompletion(cmd.Context()) socket, ok := d.daemonSocketForCompletion(cmd.Context())
if !ok { if !ok {
return nil, cobra.ShellCompDirectiveNoFileComp return nil, cobra.ShellCompDirectiveNoFileComp
} }
names, err := completionListerFunc(cmd.Context(), socket, "kernel.list") names, err := d.completionLister(cmd.Context(), socket, "kernel.list")
if err != nil { if err != nil {
return nil, cobra.ShellCompDirectiveNoFileComp return nil, cobra.ShellCompDirectiveNoFileComp
} }
@ -182,16 +181,16 @@ func completeKernelNames(cmd *cobra.Command, args []string, toComplete string) (
// completeSessionNames handles `... <vm> <session>` commands: pos 0 // completeSessionNames handles `... <vm> <session>` commands: pos 0
// completes VMs, pos 1 completes sessions owned by args[0], pos 2+ is // completes VMs, pos 1 completes sessions owned by args[0], pos 2+ is
// silent. // silent.
func completeSessionNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { func (d *deps) completeSessionNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) { switch len(args) {
case 0: case 0:
return completeVMNames(cmd, args, toComplete) return d.completeVMNames(cmd, args, toComplete)
case 1: case 1:
socket, ok := daemonSocketForCompletion(cmd.Context()) socket, ok := d.daemonSocketForCompletion(cmd.Context())
if !ok { if !ok {
return nil, cobra.ShellCompDirectiveNoFileComp return nil, cobra.ShellCompDirectiveNoFileComp
} }
names, err := completionSessionListerFunc(cmd.Context(), socket, args[0]) names, err := d.completionSessionLister(cmd.Context(), socket, args[0])
if err != nil { if err != nil {
return nil, cobra.ShellCompDirectiveNoFileComp return nil, cobra.ShellCompDirectiveNoFileComp
} }

View file

@ -12,10 +12,11 @@ import (
) )
// stubCompletionSeams installs test doubles for the daemon ping + lister // stubCompletionSeams installs test doubles for the daemon ping + lister
// seams and restores the originals on cleanup. Tests opt into the // seams on the caller's *deps. Tests opt into the sub-functions they
// sub-functions they actually need. // actually need.
func stubCompletionSeams( func stubCompletionSeams(
t *testing.T, t *testing.T,
d *deps,
pingErr error, pingErr error,
names map[string][]string, names map[string][]string,
listErr error, listErr error,
@ -24,28 +25,19 @@ func stubCompletionSeams(
) { ) {
t.Helper() t.Helper()
origPing := daemonPingFunc d.daemonPing = func(ctx context.Context, socketPath string) (api.PingResult, error) {
origLister := completionListerFunc
origSessionLister := completionSessionListerFunc
t.Cleanup(func() {
daemonPingFunc = origPing
completionListerFunc = origLister
completionSessionListerFunc = origSessionLister
})
daemonPingFunc = func(ctx context.Context, socketPath string) (api.PingResult, error) {
if pingErr != nil { if pingErr != nil {
return api.PingResult{}, pingErr return api.PingResult{}, pingErr
} }
return api.PingResult{}, nil return api.PingResult{}, nil
} }
completionListerFunc = func(ctx context.Context, socketPath, method string) ([]string, error) { d.completionLister = func(ctx context.Context, socketPath, method string) ([]string, error) {
if listErr != nil { if listErr != nil {
return nil, listErr return nil, listErr
} }
return names[method], nil return names[method], nil
} }
completionSessionListerFunc = func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) { d.completionSessionLister = func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) {
if sessionErr != nil { if sessionErr != nil {
return nil, sessionErr return nil, sessionErr
} }
@ -89,9 +81,10 @@ func testCmdWithCtx() *cobra.Command {
} }
func TestCompleteVMNamesHappyPath(t *testing.T) { func TestCompleteVMNamesHappyPath(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil) d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil)
got, directive := completeVMNames(testCmdWithCtx(), nil, "") got, directive := d.completeVMNames(testCmdWithCtx(), nil, "")
if directive != cobra.ShellCompDirectiveNoFileComp { if directive != cobra.ShellCompDirectiveNoFileComp {
t.Errorf("directive = %d, want NoFileComp", directive) t.Errorf("directive = %d, want NoFileComp", directive)
} }
@ -101,9 +94,10 @@ func TestCompleteVMNamesHappyPath(t *testing.T) {
} }
func TestCompleteVMNamesDaemonDown(t *testing.T) { func TestCompleteVMNamesDaemonDown(t *testing.T) {
stubCompletionSeams(t, errors.New("connection refused"), nil, nil, nil, nil) d := defaultDeps()
stubCompletionSeams(t, d, errors.New("connection refused"), nil, nil, nil, nil)
got, directive := completeVMNames(testCmdWithCtx(), nil, "") got, directive := d.completeVMNames(testCmdWithCtx(), nil, "")
if len(got) != 0 { if len(got) != 0 {
t.Errorf("daemon-down should return no suggestions, got %v", got) t.Errorf("daemon-down should return no suggestions, got %v", got)
} }
@ -113,18 +107,20 @@ func TestCompleteVMNamesDaemonDown(t *testing.T) {
} }
func TestCompleteVMNamesRPCError(t *testing.T) { func TestCompleteVMNamesRPCError(t *testing.T) {
stubCompletionSeams(t, nil, nil, errors.New("rpc failed"), nil, nil) d := defaultDeps()
stubCompletionSeams(t, d, nil, nil, errors.New("rpc failed"), nil, nil)
got, _ := completeVMNames(testCmdWithCtx(), nil, "") got, _ := d.completeVMNames(testCmdWithCtx(), nil, "")
if len(got) != 0 { if len(got) != 0 {
t.Errorf("rpc error should return no suggestions, got %v", got) t.Errorf("rpc error should return no suggestions, got %v", got)
} }
} }
func TestCompleteVMNamesExcludesAlreadyEntered(t *testing.T) { func TestCompleteVMNamesExcludesAlreadyEntered(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil) d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil)
got, _ := completeVMNames(testCmdWithCtx(), []string{"alpha"}, "") got, _ := d.completeVMNames(testCmdWithCtx(), []string{"alpha"}, "")
want := []string{"beta", "gamma"} want := []string{"beta", "gamma"}
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want) t.Errorf("got %v, want %v", got, want)
@ -132,9 +128,10 @@ func TestCompleteVMNamesExcludesAlreadyEntered(t *testing.T) {
} }
func TestCompleteVMNamesPrefixFilter(t *testing.T) { func TestCompleteVMNamesPrefixFilter(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"vm.list": {"alpha", "beta", "alphabet"}}, nil, nil, nil) d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "alphabet"}}, nil, nil, nil)
got, _ := completeVMNames(testCmdWithCtx(), nil, "alp") got, _ := d.completeVMNames(testCmdWithCtx(), nil, "alp")
want := []string{"alpha", "alphabet"} want := []string{"alpha", "alphabet"}
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want) t.Errorf("got %v, want %v", got, want)
@ -142,49 +139,53 @@ func TestCompleteVMNamesPrefixFilter(t *testing.T) {
} }
func TestCompleteVMNameOnlyAtPos0(t *testing.T) { func TestCompleteVMNameOnlyAtPos0(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"vm.list": {"alpha"}}, nil, nil, nil) d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha"}}, nil, nil, nil)
atPos0, _ := completeVMNameOnlyAtPos0(testCmdWithCtx(), nil, "") atPos0, _ := d.completeVMNameOnlyAtPos0(testCmdWithCtx(), nil, "")
if len(atPos0) != 1 || atPos0[0] != "alpha" { if len(atPos0) != 1 || atPos0[0] != "alpha" {
t.Errorf("pos 0: got %v", atPos0) t.Errorf("pos 0: got %v", atPos0)
} }
atPos1, _ := completeVMNameOnlyAtPos0(testCmdWithCtx(), []string{"alpha"}, "") atPos1, _ := d.completeVMNameOnlyAtPos0(testCmdWithCtx(), []string{"alpha"}, "")
if len(atPos1) != 0 { if len(atPos1) != 0 {
t.Errorf("pos 1+ should be silent, got %v", atPos1) t.Errorf("pos 1+ should be silent, got %v", atPos1)
} }
} }
func TestCompleteImageNames(t *testing.T) { func TestCompleteImageNames(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"image.list": {"debian-bookworm", "alpine"}}, nil, nil, nil) d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"debian-bookworm", "alpine"}}, nil, nil, nil)
got, _ := completeImageNames(testCmdWithCtx(), nil, "") got, _ := d.completeImageNames(testCmdWithCtx(), nil, "")
if !reflect.DeepEqual(got, []string{"debian-bookworm", "alpine"}) { if !reflect.DeepEqual(got, []string{"debian-bookworm", "alpine"}) {
t.Errorf("got %v", got) t.Errorf("got %v", got)
} }
} }
func TestCompleteKernelNames(t *testing.T) { func TestCompleteKernelNames(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"kernel.list": {"generic-6.12"}}, nil, nil, nil) d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"kernel.list": {"generic-6.12"}}, nil, nil, nil)
got, _ := completeKernelNames(testCmdWithCtx(), nil, "") got, _ := d.completeKernelNames(testCmdWithCtx(), nil, "")
if len(got) != 1 || got[0] != "generic-6.12" { if len(got) != 1 || got[0] != "generic-6.12" {
t.Errorf("got %v", got) t.Errorf("got %v", got)
} }
} }
func TestCompleteImageNameOnlyAtPos0SilentAfterFirst(t *testing.T) { func TestCompleteImageNameOnlyAtPos0SilentAfterFirst(t *testing.T) {
stubCompletionSeams(t, nil, map[string][]string{"image.list": {"alpine"}}, nil, nil, nil) d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"alpine"}}, nil, nil, nil)
after, _ := completeImageNameOnlyAtPos0(testCmdWithCtx(), []string{"alpine"}, "") after, _ := d.completeImageNameOnlyAtPos0(testCmdWithCtx(), []string{"alpine"}, "")
if len(after) != 0 { if len(after) != 0 {
t.Errorf("expected silence at pos 1+, got %v", after) t.Errorf("expected silence at pos 1+, got %v", after)
} }
} }
func TestCompleteSessionNames(t *testing.T) { func TestCompleteSessionNames(t *testing.T) {
stubCompletionSeams( d := defaultDeps()
t, stubCompletionSeams(t, d,
nil, nil,
map[string][]string{"vm.list": {"devbox"}}, map[string][]string{"vm.list": {"devbox"}},
nil, nil,
@ -193,34 +194,35 @@ func TestCompleteSessionNames(t *testing.T) {
) )
// Position 0 → VMs. // Position 0 → VMs.
vms, _ := completeSessionNames(testCmdWithCtx(), nil, "") vms, _ := d.completeSessionNames(testCmdWithCtx(), nil, "")
if len(vms) != 1 || vms[0] != "devbox" { if len(vms) != 1 || vms[0] != "devbox" {
t.Errorf("pos 0: got %v", vms) t.Errorf("pos 0: got %v", vms)
} }
// Position 1 → sessions scoped to args[0]. // Position 1 → sessions scoped to args[0].
sessions, _ := completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "") sessions, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "")
if !reflect.DeepEqual(sessions, []string{"planner", "worker"}) { if !reflect.DeepEqual(sessions, []string{"planner", "worker"}) {
t.Errorf("pos 1: got %v", sessions) t.Errorf("pos 1: got %v", sessions)
} }
// Position 1 with prefix filter. // Position 1 with prefix filter.
filtered, _ := completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "wor") filtered, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "wor")
if len(filtered) != 1 || filtered[0] != "worker" { if len(filtered) != 1 || filtered[0] != "worker" {
t.Errorf("pos 1 prefix: got %v", filtered) t.Errorf("pos 1 prefix: got %v", filtered)
} }
// Position 2+ silent. // Position 2+ silent.
past, _ := completeSessionNames(testCmdWithCtx(), []string{"devbox", "planner"}, "") past, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox", "planner"}, "")
if len(past) != 0 { if len(past) != 0 {
t.Errorf("pos 2+: got %v", past) t.Errorf("pos 2+: got %v", past)
} }
} }
func TestCompleteSessionNamesDaemonDown(t *testing.T) { func TestCompleteSessionNamesDaemonDown(t *testing.T) {
stubCompletionSeams(t, errors.New("down"), nil, nil, nil, nil) d := defaultDeps()
stubCompletionSeams(t, d, errors.New("down"), nil, nil, nil, nil)
got, directive := completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "") got, directive := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "")
if len(got) != 0 { if len(got) != 0 {
t.Errorf("expected no suggestions when daemon down, got %v", got) t.Errorf("expected no suggestions when daemon down, got %v", got)
} }

View file

@ -18,7 +18,7 @@ import (
// ensureDaemon pings the socket; on miss it auto-starts bangerd, on // ensureDaemon pings the socket; on miss it auto-starts bangerd, on
// version mismatch it restarts. Every CLI command that needs to talk // version mismatch it restarts. Every CLI command that needs to talk
// to the daemon routes through here. // to the daemon routes through here.
func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error) { func (d *deps) ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error) {
layout, err := paths.Resolve() layout, err := paths.Resolve()
if err != nil { if err != nil {
return paths.Layout{}, model.DaemonConfig{}, err return paths.Layout{}, model.DaemonConfig{}, err
@ -27,16 +27,16 @@ func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error)
if err != nil { if err != nil {
return paths.Layout{}, model.DaemonConfig{}, err return paths.Layout{}, model.DaemonConfig{}, err
} }
if ping, err := daemonPingFunc(ctx, layout.SocketPath); err == nil { if ping, err := d.daemonPing(ctx, layout.SocketPath); err == nil {
if daemonOutdated(ping.PID) { if d.daemonOutdated(ping.PID) {
if err := restartDaemon(ctx, layout, ping.PID); err != nil { if err := d.restartDaemon(ctx, layout, ping.PID); err != nil {
return paths.Layout{}, model.DaemonConfig{}, err return paths.Layout{}, model.DaemonConfig{}, err
} }
return layout, cfg, nil return layout, cfg, nil
} }
return layout, cfg, nil return layout, cfg, nil
} }
if err := startDaemon(ctx, layout); err != nil { if err := d.startDaemon(ctx, layout); err != nil {
return paths.Layout{}, model.DaemonConfig{}, err return paths.Layout{}, model.DaemonConfig{}, err
} }
return layout, cfg, nil return layout, cfg, nil
@ -47,11 +47,11 @@ func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error)
// session still holds a handle to an old daemon. os.SameFile compares // session still holds a handle to an old daemon. os.SameFile compares
// inode + dev, so a fresh binary at the same path registers as // inode + dev, so a fresh binary at the same path registers as
// different. // different.
func daemonOutdated(pid int) bool { func (d *deps) daemonOutdated(pid int) bool {
if pid <= 0 { if pid <= 0 {
return false return false
} }
daemonBin, err := bangerdPathFunc() daemonBin, err := d.bangerdPath()
if err != nil { if err != nil {
return false return false
} }
@ -59,20 +59,20 @@ func daemonOutdated(pid int) bool {
if err != nil { if err != nil {
return false return false
} }
runningInfo, err := os.Stat(daemonExePath(pid)) runningInfo, err := os.Stat(d.daemonExePath(pid))
if err != nil { if err != nil {
return false return false
} }
return !os.SameFile(currentInfo, runningInfo) return !os.SameFile(currentInfo, runningInfo)
} }
func restartDaemon(ctx context.Context, layout paths.Layout, pid int) error { func (d *deps) restartDaemon(ctx context.Context, layout paths.Layout, pid int) error {
stopCtx, cancel := context.WithTimeout(ctx, 2*time.Second) stopCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel() defer cancel()
_, _ = rpc.Call[api.ShutdownResult](stopCtx, layout.SocketPath, "shutdown", api.Empty{}) _, _ = rpc.Call[api.ShutdownResult](stopCtx, layout.SocketPath, "shutdown", api.Empty{})
if waitForPIDExit(pid, 2*time.Second) { if waitForPIDExit(pid, 2*time.Second) {
return startDaemon(ctx, layout) return d.startDaemon(ctx, layout)
} }
if proc, err := os.FindProcess(pid); err == nil { if proc, err := os.FindProcess(pid); err == nil {
_ = proc.Signal(syscall.SIGTERM) _ = proc.Signal(syscall.SIGTERM)
@ -80,7 +80,7 @@ func restartDaemon(ctx context.Context, layout paths.Layout, pid int) error {
if !waitForPIDExit(pid, 2*time.Second) { if !waitForPIDExit(pid, 2*time.Second) {
return fmt.Errorf("timed out restarting stale daemon pid %d", pid) return fmt.Errorf("timed out restarting stale daemon pid %d", pid)
} }
return startDaemon(ctx, layout) return d.startDaemon(ctx, layout)
} }
func waitForPIDExit(pid int, timeout time.Duration) bool { func waitForPIDExit(pid int, timeout time.Duration) bool {
@ -105,7 +105,7 @@ func pidRunning(pid int) bool {
return proc.Signal(syscall.Signal(0)) == nil return proc.Signal(syscall.Signal(0)) == nil
} }
func startDaemon(ctx context.Context, layout paths.Layout) error { func (d *deps) startDaemon(ctx context.Context, layout paths.Layout) error {
if err := paths.Ensure(layout); err != nil { if err := paths.Ensure(layout); err != nil {
return err return err
} }

165
internal/cli/deps.go Normal file
View file

@ -0,0 +1,165 @@
package cli
import (
"context"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"banger/internal/api"
"banger/internal/daemon"
"banger/internal/guest"
"banger/internal/paths"
"banger/internal/rpc"
"banger/internal/system"
"banger/internal/toolingplan"
)
// deps holds the function seams production code dispatches through and
// tests replace with fakes. Keeping these on a per-invocation struct
// (instead of package-level mutable vars) makes the CLI's external
// surface explicit and lets tests run in parallel without leaking fakes
// across test cases.
//
// Every command builder, orchestrator, and helper that touches the RPC
// socket, spawns a subprocess, or reads host state hangs off a *deps
// receiver. Pure helpers (formatters, path resolvers, arg-count
// validators) stay package-level because they hold no references to
// external systems.
type deps struct {
bangerdPath func() (string, error)
daemonExePath func(pid int) string
doctor func(ctx context.Context) (system.Report, error)
sshExec func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error
hostCommandOutput func(ctx context.Context, name string, args ...string) ([]byte, error)
vmHealth func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error)
vmSSH func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error)
vmDelete func(ctx context.Context, socketPath, idOrName string) error
vmList func(ctx context.Context, socketPath string) (api.VMListResult, error)
daemonPing func(ctx context.Context, socketPath string) (api.PingResult, error)
vmCreateBegin func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error)
vmCreateStatus func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error)
vmCreateCancel func(ctx context.Context, socketPath, operationID string) error
vmPorts func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error)
vmWorkspacePrepare func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error)
vmWorkspaceExport func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error)
guestSessionStart func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error)
guestSessionGet func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error)
guestSessionList func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error)
guestSessionStop func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error)
guestSessionKill func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error)
guestSessionLogs func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error)
guestSessionAttachBegin func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error)
guestSessionSend func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error)
guestWaitForSSH func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error
guestDial func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error)
buildVMRunToolingPlan func(ctx context.Context, repoRoot string) toolingplan.Plan
cwd func() (string, error)
completionLister func(ctx context.Context, socketPath, method string) ([]string, error)
completionSessionLister func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error)
}
func defaultDeps() *deps {
return &deps{
bangerdPath: paths.BangerdPath,
daemonExePath: func(pid int) string {
return filepath.Join("/proc", fmt.Sprintf("%d", pid), "exe")
},
doctor: daemon.Doctor,
sshExec: func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
sshCmd := exec.CommandContext(ctx, "ssh", args...)
sshCmd.Stdout = stdout
sshCmd.Stderr = stderr
sshCmd.Stdin = stdin
return sshCmd.Run()
},
hostCommandOutput: func(ctx context.Context, name string, args ...string) ([]byte, error) {
cmd := exec.CommandContext(ctx, name, args...)
output, err := cmd.CombinedOutput()
if err == nil {
return output, nil
}
command := strings.TrimSpace(strings.Join(append([]string{name}, args...), " "))
detail := strings.TrimSpace(string(output))
if detail == "" {
return output, fmt.Errorf("%s: %w", command, err)
}
return output, fmt.Errorf("%s: %w: %s", command, err, detail)
},
vmHealth: func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
return rpc.Call[api.VMHealthResult](ctx, socketPath, "vm.health", api.VMRefParams{IDOrName: idOrName})
},
vmSSH: func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error) {
return rpc.Call[api.VMSSHResult](ctx, socketPath, "vm.ssh", api.VMRefParams{IDOrName: idOrName})
},
vmDelete: func(ctx context.Context, socketPath, idOrName string) error {
_, err := rpc.Call[api.VMShowResult](ctx, socketPath, "vm.delete", api.VMRefParams{IDOrName: idOrName})
return err
},
vmList: func(ctx context.Context, socketPath string) (api.VMListResult, error) {
return rpc.Call[api.VMListResult](ctx, socketPath, "vm.list", api.Empty{})
},
daemonPing: func(ctx context.Context, socketPath string) (api.PingResult, error) {
return rpc.Call[api.PingResult](ctx, socketPath, "ping", api.Empty{})
},
vmCreateBegin: func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error) {
return rpc.Call[api.VMCreateBeginResult](ctx, socketPath, "vm.create.begin", params)
},
vmCreateStatus: func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error) {
return rpc.Call[api.VMCreateStatusResult](ctx, socketPath, "vm.create.status", api.VMCreateStatusParams{ID: operationID})
},
vmCreateCancel: func(ctx context.Context, socketPath, operationID string) error {
_, err := rpc.Call[api.Empty](ctx, socketPath, "vm.create.cancel", api.VMCreateStatusParams{ID: operationID})
return err
},
vmPorts: func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) {
return rpc.Call[api.VMPortsResult](ctx, socketPath, "vm.ports", api.VMRefParams{IDOrName: idOrName})
},
vmWorkspacePrepare: func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
return rpc.Call[api.VMWorkspacePrepareResult](ctx, socketPath, "vm.workspace.prepare", params)
},
vmWorkspaceExport: func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return rpc.Call[api.WorkspaceExportResult](ctx, socketPath, "vm.workspace.export", params)
},
guestSessionStart: func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.start", params)
},
guestSessionGet: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.get", params)
},
guestSessionList: func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error) {
return rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: idOrName})
},
guestSessionStop: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.stop", params)
},
guestSessionKill: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.kill", params)
},
guestSessionLogs: func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) {
return rpc.Call[api.GuestSessionLogsResult](ctx, socketPath, "guest.session.logs", params)
},
guestSessionAttachBegin: func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) {
return rpc.Call[api.GuestSessionAttachBeginResult](ctx, socketPath, "guest.session.attach.begin", params)
},
guestSessionSend: func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params)
},
guestWaitForSSH: func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
knownHosts, _ := bangerKnownHostsPath()
return guest.WaitForSSH(ctx, address, privateKeyPath, knownHosts, interval)
},
guestDial: func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
knownHosts, _ := bangerKnownHostsPath()
return guest.Dial(ctx, address, privateKeyPath, knownHosts)
},
buildVMRunToolingPlan: toolingplan.Build,
cwd: os.Getwd,
completionLister: defaultCompletionLister,
completionSessionLister: defaultCompletionSessionLister,
}
}

View file

@ -14,22 +14,17 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
// stubPruneSeams installs fakes for vmListFunc and vmDeleteFunc, and // stubPruneSeams installs list + delete fakes onto the caller's *deps
// restores originals on cleanup. // and returns a pointer to a slice that records every ID passed to the
func stubPruneSeams(t *testing.T, vms []model.VMRecord, listErr error, deleteErr map[string]error) *[]string { // delete fake.
func stubPruneSeams(t *testing.T, d *deps, vms []model.VMRecord, listErr error, deleteErr map[string]error) *[]string {
t.Helper() t.Helper()
origList := vmListFunc
origDelete := vmDeleteFunc
t.Cleanup(func() {
vmListFunc = origList
vmDeleteFunc = origDelete
})
var deleted []string var deleted []string
vmListFunc = func(ctx context.Context, socketPath string) (api.VMListResult, error) { d.vmList = func(ctx context.Context, socketPath string) (api.VMListResult, error) {
return api.VMListResult{VMs: vms}, listErr return api.VMListResult{VMs: vms}, listErr
} }
vmDeleteFunc = func(ctx context.Context, socketPath, idOrName string) error { d.vmDelete = func(ctx context.Context, socketPath, idOrName string) error {
if err, ok := deleteErr[idOrName]; ok { if err, ok := deleteErr[idOrName]; ok {
return err return err
} }
@ -89,13 +84,14 @@ func TestPromptYesNoEOF(t *testing.T) {
} }
func TestRunVMPruneNoVictims(t *testing.T) { func TestRunVMPruneNoVictims(t *testing.T) {
stubPruneSeams(t, []model.VMRecord{ d := defaultDeps()
stubPruneSeams(t, d, []model.VMRecord{
{ID: "id-1", Name: "running-vm", State: model.VMStateRunning}, {ID: "id-1", Name: "running-vm", State: model.VMStateRunning},
}, nil, nil) }, nil, nil)
cmd, stdout, _ := newPruneTestCmd("") cmd, stdout, _ := newPruneTestCmd("")
if err := runVMPrune(cmd, "sock", false); err != nil { if err := d.runVMPrune(cmd, "sock", false); err != nil {
t.Fatalf("runVMPrune: %v", err) t.Fatalf("d.runVMPrune: %v", err)
} }
if !strings.Contains(stdout.String(), "no non-running VMs") { if !strings.Contains(stdout.String(), "no non-running VMs") {
t.Errorf("expected no-op message, got %q", stdout.String()) t.Errorf("expected no-op message, got %q", stdout.String())
@ -103,13 +99,14 @@ func TestRunVMPruneNoVictims(t *testing.T) {
} }
func TestRunVMPruneAbortedByUser(t *testing.T) { func TestRunVMPruneAbortedByUser(t *testing.T) {
deleted := stubPruneSeams(t, []model.VMRecord{ d := defaultDeps()
deleted := stubPruneSeams(t, d, []model.VMRecord{
{ID: "id-1", Name: "stale", State: model.VMStateStopped}, {ID: "id-1", Name: "stale", State: model.VMStateStopped},
}, nil, nil) }, nil, nil)
cmd, stdout, _ := newPruneTestCmd("n\n") cmd, stdout, _ := newPruneTestCmd("n\n")
if err := runVMPrune(cmd, "sock", false); err != nil { if err := d.runVMPrune(cmd, "sock", false); err != nil {
t.Fatalf("runVMPrune: %v", err) t.Fatalf("d.runVMPrune: %v", err)
} }
if !strings.Contains(stdout.String(), "aborted") { if !strings.Contains(stdout.String(), "aborted") {
t.Errorf("expected 'aborted' output, got %q", stdout.String()) t.Errorf("expected 'aborted' output, got %q", stdout.String())
@ -120,7 +117,8 @@ func TestRunVMPruneAbortedByUser(t *testing.T) {
} }
func TestRunVMPruneConfirmedDeletesNonRunning(t *testing.T) { func TestRunVMPruneConfirmedDeletesNonRunning(t *testing.T) {
deleted := stubPruneSeams(t, []model.VMRecord{ d := defaultDeps()
deleted := stubPruneSeams(t, d, []model.VMRecord{
{ID: "id-run", Name: "keeper", State: model.VMStateRunning}, {ID: "id-run", Name: "keeper", State: model.VMStateRunning},
{ID: "id-stop", Name: "stale", State: model.VMStateStopped}, {ID: "id-stop", Name: "stale", State: model.VMStateStopped},
{ID: "id-err", Name: "broken", State: model.VMStateError}, {ID: "id-err", Name: "broken", State: model.VMStateError},
@ -128,8 +126,8 @@ func TestRunVMPruneConfirmedDeletesNonRunning(t *testing.T) {
}, nil, nil) }, nil, nil)
cmd, stdout, _ := newPruneTestCmd("y\n") cmd, stdout, _ := newPruneTestCmd("y\n")
if err := runVMPrune(cmd, "sock", false); err != nil { if err := d.runVMPrune(cmd, "sock", false); err != nil {
t.Fatalf("runVMPrune: %v", err) t.Fatalf("d.runVMPrune: %v", err)
} }
// Deleted must be exactly the three non-running IDs, in list order. // Deleted must be exactly the three non-running IDs, in list order.
want := []string{"id-stop", "id-err", "id-created"} want := []string{"id-stop", "id-err", "id-created"}
@ -152,14 +150,15 @@ func TestRunVMPruneConfirmedDeletesNonRunning(t *testing.T) {
} }
func TestRunVMPruneForceSkipsPrompt(t *testing.T) { func TestRunVMPruneForceSkipsPrompt(t *testing.T) {
deleted := stubPruneSeams(t, []model.VMRecord{ d := defaultDeps()
deleted := stubPruneSeams(t, d, []model.VMRecord{
{ID: "id-1", Name: "stale", State: model.VMStateStopped}, {ID: "id-1", Name: "stale", State: model.VMStateStopped},
}, nil, nil) }, nil, nil)
// Empty stdin + force=true: must not block on prompt. // Empty stdin + force=true: must not block on prompt.
cmd, stdout, _ := newPruneTestCmd("") cmd, stdout, _ := newPruneTestCmd("")
if err := runVMPrune(cmd, "sock", true); err != nil { if err := d.runVMPrune(cmd, "sock", true); err != nil {
t.Fatalf("runVMPrune: %v", err) t.Fatalf("d.runVMPrune: %v", err)
} }
if len(*deleted) != 1 || (*deleted)[0] != "id-1" { if len(*deleted) != 1 || (*deleted)[0] != "id-1" {
t.Errorf("deleted = %v, want [id-1]", *deleted) t.Errorf("deleted = %v, want [id-1]", *deleted)
@ -171,7 +170,8 @@ func TestRunVMPruneForceSkipsPrompt(t *testing.T) {
} }
func TestRunVMPruneReportsPartialFailure(t *testing.T) { func TestRunVMPruneReportsPartialFailure(t *testing.T) {
stubPruneSeams(t, d := defaultDeps()
stubPruneSeams(t, d,
[]model.VMRecord{ []model.VMRecord{
{ID: "id-a", Name: "a", State: model.VMStateStopped}, {ID: "id-a", Name: "a", State: model.VMStateStopped},
{ID: "id-b", Name: "b", State: model.VMStateStopped}, {ID: "id-b", Name: "b", State: model.VMStateStopped},
@ -181,7 +181,7 @@ func TestRunVMPruneReportsPartialFailure(t *testing.T) {
) )
cmd, _, stderr := newPruneTestCmd("") cmd, _, stderr := newPruneTestCmd("")
err := runVMPrune(cmd, "sock", true) err := d.runVMPrune(cmd, "sock", true)
if err == nil { if err == nil {
t.Fatal("expected non-zero exit when any delete fails") t.Fatal("expected non-zero exit when any delete fails")
} }
@ -194,10 +194,11 @@ func TestRunVMPruneReportsPartialFailure(t *testing.T) {
} }
func TestRunVMPruneListErrorPropagates(t *testing.T) { func TestRunVMPruneListErrorPropagates(t *testing.T) {
stubPruneSeams(t, nil, fmt.Errorf("rpc failed"), nil) d := defaultDeps()
stubPruneSeams(t, d, nil, fmt.Errorf("rpc failed"), nil)
cmd, _, _ := newPruneTestCmd("") cmd, _, _ := newPruneTestCmd("")
err := runVMPrune(cmd, "sock", true) err := d.runVMPrune(cmd, "sock", true)
if err == nil || !strings.Contains(err.Error(), "rpc failed") { if err == nil || !strings.Contains(err.Error(), "rpc failed") {
t.Fatalf("expected rpc error to propagate, got %v", err) t.Fatalf("expected rpc error to propagate, got %v", err)
} }

View file

@ -20,14 +20,14 @@ import (
// the caller asked (e.g. --rm is about to delete the VM), if the // the caller asked (e.g. --rm is about to delete the VM), if the
// ctx is already done, or if the ssh error isn't the one that // ctx is already done, or if the ssh error isn't the one that
// typically means "user disconnected cleanly". // typically means "user disconnected cleanly".
func runSSHSession(ctx context.Context, socketPath, vmRef string, stdin io.Reader, stdout, stderr io.Writer, sshArgs []string, skipReminder bool) error { func (d *deps) runSSHSession(ctx context.Context, socketPath, vmRef string, stdin io.Reader, stdout, stderr io.Writer, sshArgs []string, skipReminder bool) error {
sshErr := sshExecFunc(ctx, stdin, stdout, stderr, sshArgs) sshErr := d.sshExec(ctx, stdin, stdout, stderr, sshArgs)
if skipReminder || !shouldCheckSSHReminder(sshErr) || ctx.Err() != nil { if skipReminder || !shouldCheckSSHReminder(sshErr) || ctx.Err() != nil {
return sshErr return sshErr
} }
pingCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) pingCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel() defer cancel()
health, err := vmHealthFunc(pingCtx, socketPath, vmRef) health, err := d.vmHealth(pingCtx, socketPath, vmRef)
if err != nil { if err != nil {
_, _ = fmt.Fprintln(stderr, vsockagent.WarningMessage(vmRef, err)) _, _ = fmt.Fprintln(stderr, vsockagent.WarningMessage(vmRef, err))
return sshErr return sshErr

View file

@ -60,9 +60,9 @@ func printVMSpecLine(out io.Writer, params api.VMCreateParams) {
// gets the spec line up front and the progress renderer thereafter. // gets the spec line up front and the progress renderer thereafter.
// On context cancel we cooperate with the daemon to cancel the // On context cancel we cooperate with the daemon to cancel the
// in-flight op so it doesn't leak partially-created VM state. // in-flight op so it doesn't leak partially-created VM state.
func runVMCreate(ctx context.Context, socketPath string, stderr io.Writer, params api.VMCreateParams) (model.VMRecord, error) { func (d *deps) runVMCreate(ctx context.Context, socketPath string, stderr io.Writer, params api.VMCreateParams) (model.VMRecord, error) {
printVMSpecLine(stderr, params) printVMSpecLine(stderr, params)
begin, err := vmCreateBeginFunc(ctx, socketPath, params) begin, err := d.vmCreateBegin(ctx, socketPath, params)
if err != nil { if err != nil {
return model.VMRecord{}, err return model.VMRecord{}, err
} }
@ -86,17 +86,17 @@ func runVMCreate(ctx context.Context, socketPath string, stderr io.Writer, param
case <-ctx.Done(): case <-ctx.Done():
cancelCtx, cancel := context.WithTimeout(context.Background(), time.Second) cancelCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
_ = vmCreateCancelFunc(cancelCtx, socketPath, op.ID) _ = d.vmCreateCancel(cancelCtx, socketPath, op.ID)
return model.VMRecord{}, ctx.Err() return model.VMRecord{}, ctx.Err()
case <-time.After(200 * time.Millisecond): case <-time.After(200 * time.Millisecond):
} }
status, err := vmCreateStatusFunc(ctx, socketPath, op.ID) status, err := d.vmCreateStatus(ctx, socketPath, op.ID)
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
cancelCtx, cancel := context.WithTimeout(context.Background(), time.Second) cancelCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
_ = vmCreateCancelFunc(cancelCtx, socketPath, op.ID) _ = d.vmCreateCancel(cancelCtx, socketPath, op.ID)
return model.VMRecord{}, ctx.Err() return model.VMRecord{}, ctx.Err()
} }
return model.VMRecord{}, err return model.VMRecord{}, err

View file

@ -80,9 +80,9 @@ func (e ExitCodeError) Error() string {
// - it sits inside a non-bare git repository, // - it sits inside a non-bare git repository,
// - the repository has no submodules (unsupported in the shallow // - the repository has no submodules (unsupported in the shallow
// overlay mode vm run uses). // overlay mode vm run uses).
func vmRunPreflightRepo(ctx context.Context, rawPath string) (string, error) { func (d *deps) vmRunPreflightRepo(ctx context.Context, rawPath string) (string, error) {
if strings.TrimSpace(rawPath) == "" { if strings.TrimSpace(rawPath) == "" {
wd, err := cwdFunc() wd, err := d.cwd()
if err != nil { if err != nil {
return "", err return "", err
} }
@ -131,9 +131,9 @@ func splitVMRunArgs(cmd *cobra.Command, args []string) (pathArgs, commandArgs []
// for guest ssh, optionally materialise a workspace and kick off the // for guest ssh, optionally materialise a workspace and kick off the
// tooling bootstrap, then either attach interactively or run the // tooling bootstrap, then either attach interactively or run the
// user's command and propagate its exit status. // user's command and propagate its exit status.
func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, stdin io.Reader, stdout, stderr io.Writer, params api.VMCreateParams, repo *vmRunRepo, command []string, removeOnExit bool) error { func (d *deps) runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, stdin io.Reader, stdout, stderr io.Writer, params api.VMCreateParams, repo *vmRunRepo, command []string, removeOnExit bool) error {
progress := newVMRunProgressRenderer(stderr) progress := newVMRunProgressRenderer(stderr)
vm, err := runVMCreate(ctx, socketPath, stderr, params) vm, err := d.runVMCreate(ctx, socketPath, stderr, params)
if err != nil { if err != nil {
return err return err
} }
@ -155,7 +155,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st
// doesn't abort the delete RPC. // doesn't abort the delete RPC.
cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := vmDeleteFunc(cleanupCtx, socketPath, vmRef); err != nil { if err := d.vmDelete(cleanupCtx, socketPath, vmRef); err != nil {
printVMRunWarning(stderr, fmt.Sprintf("--rm cleanup failed: %v (leaked vm %q; delete manually)", err, vmRef)) printVMRunWarning(stderr, fmt.Sprintf("--rm cleanup failed: %v (leaked vm %q; delete manually)", err, vmRef))
} }
}() }()
@ -163,7 +163,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st
sshAddress := net.JoinHostPort(vm.Runtime.GuestIP, "22") sshAddress := net.JoinHostPort(vm.Runtime.GuestIP, "22")
progress.render("waiting for guest ssh") progress.render("waiting for guest ssh")
sshCtx, cancelSSH := context.WithTimeout(ctx, vmRunSSHTimeout) sshCtx, cancelSSH := context.WithTimeout(ctx, vmRunSSHTimeout)
if err := guestWaitForSSHFunc(sshCtx, sshAddress, cfg.SSHKeyPath, 250*time.Millisecond); err != nil { if err := d.guestWaitForSSH(sshCtx, sshAddress, cfg.SSHKeyPath, 250*time.Millisecond); err != nil {
cancelSSH() cancelSSH()
// Surface parent-context cancellation (Ctrl-C, caller // Surface parent-context cancellation (Ctrl-C, caller
// timeout) as-is. Only the guest-side timeout needs the // timeout) as-is. Only the guest-side timeout needs the
@ -193,7 +193,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st
if strings.TrimSpace(repo.branchName) != "" { if strings.TrimSpace(repo.branchName) != "" {
fromRef = repo.fromRef fromRef = repo.fromRef
} }
prepared, err := vmWorkspacePrepareFunc(ctx, socketPath, api.VMWorkspacePrepareParams{ prepared, err := d.vmWorkspacePrepare(ctx, socketPath, api.VMWorkspacePrepareParams{
IDOrName: vmRef, IDOrName: vmRef,
SourcePath: repo.sourcePath, SourcePath: repo.sourcePath,
GuestPath: vmRunGuestDir(), GuestPath: vmRunGuestDir(),
@ -208,11 +208,11 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st
// daemon side; grab what the tooling harness needs from its // daemon side; grab what the tooling harness needs from its
// result instead of re-inspecting here. // result instead of re-inspecting here.
if len(command) == 0 { if len(command) == 0 {
client, err := guestDialFunc(ctx, sshAddress, cfg.SSHKeyPath) client, err := d.guestDial(ctx, sshAddress, cfg.SSHKeyPath)
if err != nil { if err != nil {
return fmt.Errorf("vm %q is running but guest ssh is unavailable: %w", vmRef, err) return fmt.Errorf("vm %q is running but guest ssh is unavailable: %w", vmRef, err)
} }
if err := startVMRunToolingHarness(ctx, client, prepared.Workspace.RepoRoot, prepared.Workspace.RepoName, progress); err != nil { if err := d.startVMRunToolingHarness(ctx, client, prepared.Workspace.RepoRoot, prepared.Workspace.RepoName, progress); err != nil {
printVMRunWarning(stderr, fmt.Sprintf("guest tooling bootstrap start failed: %v", err)) printVMRunWarning(stderr, fmt.Sprintf("guest tooling bootstrap start failed: %v", err))
} }
_ = client.Close() _ = client.Close()
@ -224,7 +224,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st
} }
if len(command) > 0 { if len(command) > 0 {
progress.render("running command in guest") progress.render("running command in guest")
if err := sshExecFunc(ctx, stdin, stdout, stderr, sshArgs); err != nil { if err := d.sshExec(ctx, stdin, stdout, stderr, sshArgs); err != nil {
var exitErr *exec.ExitError var exitErr *exec.ExitError
if errors.As(err, &exitErr) { if errors.As(err, &exitErr) {
return ExitCodeError{Code: exitErr.ExitCode()} return ExitCodeError{Code: exitErr.ExitCode()}
@ -234,7 +234,7 @@ func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, st
return nil return nil
} }
progress.render("attaching to guest") progress.render("attaching to guest")
return runSSHSession(ctx, socketPath, vmRef, stdin, stdout, stderr, sshArgs, removeOnExit) return d.runSSHSession(ctx, socketPath, vmRef, stdin, stdout, stderr, sshArgs, removeOnExit)
} }
func vmRunGuestDir() string { func vmRunGuestDir() string {
@ -253,11 +253,11 @@ func vmRunToolingHarnessLogPath(repoName string) string {
// script inside the guest. repoRoot / repoName both come from the // script inside the guest. repoRoot / repoName both come from the
// daemon's workspace.prepare RPC response — the CLI no longer does // daemon's workspace.prepare RPC response — the CLI no longer does
// its own git inspection. // its own git inspection.
func startVMRunToolingHarness(ctx context.Context, client vmRunGuestClient, repoRoot, repoName string, progress *vmRunProgressRenderer) error { func (d *deps) startVMRunToolingHarness(ctx context.Context, client vmRunGuestClient, repoRoot, repoName string, progress *vmRunProgressRenderer) error {
if progress != nil { if progress != nil {
progress.render("starting guest tooling bootstrap") progress.render("starting guest tooling bootstrap")
} }
plan := buildVMRunToolingPlanFunc(ctx, repoRoot) plan := d.buildVMRunToolingPlan(ctx, repoRoot)
var uploadLog bytes.Buffer var uploadLog bytes.Buffer
if err := client.UploadFile(ctx, vmRunToolingHarnessPath(repoName), 0o755, []byte(vmRunToolingHarnessScript(plan)), &uploadLog); err != nil { if err := client.UploadFile(ctx, vmRunToolingHarnessPath(repoName), 0o755, []byte(vmRunToolingHarnessScript(plan)), &uploadLog); err != nil {
return formatVMRunStepError("upload guest tooling bootstrap", err, uploadLog.String()) return formatVMRunStepError("upload guest tooling bootstrap", err, uploadLog.String())

View file

@ -18,6 +18,7 @@ import (
"banger/internal/buildinfo" "banger/internal/buildinfo"
"banger/internal/config" "banger/internal/config"
"banger/internal/daemon/opstate" "banger/internal/daemon/opstate"
ws "banger/internal/daemon/workspace"
"banger/internal/imagecat" "banger/internal/imagecat"
"banger/internal/imagepull" "banger/internal/imagepull"
"banger/internal/model" "banger/internal/model"
@ -66,6 +67,8 @@ type Daemon struct {
guestWaitForSSH func(context.Context, string, string, time.Duration) error guestWaitForSSH func(context.Context, string, string, time.Duration) error
guestDial func(context.Context, string, string) (guestSSHClient, error) guestDial func(context.Context, string, string) (guestSSHClient, error)
waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error) waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error)
workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error)
workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error
} }
func Open(ctx context.Context) (d *Daemon, err error) { func Open(ctx context.Context) (d *Daemon, err error) {

View file

@ -15,13 +15,25 @@ import (
"banger/internal/model" "banger/internal/model"
) )
// Test seams. Tests swap these to observe or stall the guest-I/O // workspaceInspectRepoHook + workspaceImportHook dispatch through the
// phase without needing a real git repo or SSH server. Production // per-instance Daemon seams when set, falling back to the real
// callers see the real implementations from the workspace package. // workspace package implementations. Keeping the fallbacks here (as
var ( // opposed to always requiring callers to populate d.workspaceInspectRepo
workspaceInspectRepoFunc = ws.InspectRepo // in a constructor) lets tests selectively override one hook without
workspaceImportFunc = ws.ImportRepoToGuest // having to wire both.
) func (d *Daemon) workspaceInspectRepoHook(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error) {
if d != nil && d.workspaceInspectRepo != nil {
return d.workspaceInspectRepo(ctx, sourcePath, branchName, fromRef)
}
return ws.InspectRepo(ctx, sourcePath, branchName, fromRef)
}
func (d *Daemon) workspaceImportHook(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error {
if d != nil && d.workspaceImport != nil {
return d.workspaceImport(ctx, client, spec, guestPath, mode)
}
return ws.ImportRepoToGuest(ctx, client, spec, guestPath, mode)
}
func (d *Daemon) ExportVMWorkspace(ctx context.Context, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { func (d *Daemon) ExportVMWorkspace(ctx context.Context, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
guestPath := strings.TrimSpace(params.GuestPath) guestPath := strings.TrimSpace(params.GuestPath)
@ -156,7 +168,7 @@ func (d *Daemon) PrepareVMWorkspace(ctx context.Context, params api.VMWorkspaceP
// inspect the local repo, dial SSH, stream the tar, optionally chmod // inspect the local repo, dial SSH, stream the tar, optionally chmod
// readonly. It is called without holding the VM mutex. // readonly. It is called without holding the VM mutex.
func (d *Daemon) prepareVMWorkspaceGuestIO(ctx context.Context, vm model.VMRecord, sourcePath, guestPath, branchName, fromRef string, mode model.WorkspacePrepareMode, readOnly bool) (model.WorkspacePrepareResult, error) { func (d *Daemon) prepareVMWorkspaceGuestIO(ctx context.Context, vm model.VMRecord, sourcePath, guestPath, branchName, fromRef string, mode model.WorkspacePrepareMode, readOnly bool) (model.WorkspacePrepareResult, error) {
spec, err := workspaceInspectRepoFunc(ctx, sourcePath, branchName, fromRef) spec, err := d.workspaceInspectRepoHook(ctx, sourcePath, branchName, fromRef)
if err != nil { if err != nil {
return model.WorkspacePrepareResult{}, err return model.WorkspacePrepareResult{}, err
} }
@ -172,7 +184,7 @@ func (d *Daemon) prepareVMWorkspaceGuestIO(ctx context.Context, vm model.VMRecor
return model.WorkspacePrepareResult{}, fmt.Errorf("dial guest ssh: %w", err) return model.WorkspacePrepareResult{}, fmt.Errorf("dial guest ssh: %w", err)
} }
defer client.Close() defer client.Close()
if err := workspaceImportFunc(ctx, client, spec, guestPath, mode); err != nil { if err := d.workspaceImportHook(ctx, client, spec, guestPath, mode); err != nil {
return model.WorkspacePrepareResult{}, err return model.WorkspacePrepareResult{}, err
} }
if readOnly { if readOnly {

View file

@ -370,9 +370,7 @@ func TestExportVMWorkspace_MultipleChangedFiles(t *testing.T) {
// inside the import step and then asserts the VM mutex is acquirable // inside the import step and then asserts the VM mutex is acquirable
// while the prepare is mid-flight. // while the prepare is mid-flight.
func TestPrepareVMWorkspace_ReleasesVMLockDuringGuestIO(t *testing.T) { func TestPrepareVMWorkspace_ReleasesVMLockDuringGuestIO(t *testing.T) {
// Not parallel: mutates package-level workspaceInspectRepoFunc / t.Parallel()
// workspaceImportFunc seams, which the other prepare-concurrency
// test also swaps.
ctx := context.Background() ctx := context.Background()
apiSock := filepath.Join(t.TempDir(), "fc.sock") apiSock := filepath.Join(t.TempDir(), "fc.sock")
@ -395,21 +393,15 @@ func TestPrepareVMWorkspace_ReleasesVMLockDuringGuestIO(t *testing.T) {
upsertDaemonVM(t, ctx, d.store, vm) upsertDaemonVM(t, ctx, d.store, vm)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
// Replace the seams. InspectRepo returns a trivial spec so the // Install the workspace seams on this daemon instance. InspectRepo
// real filesystem isn't touched; Import blocks until we say go. // returns a trivial spec so the real filesystem isn't touched;
origInspect := workspaceInspectRepoFunc // Import blocks until we say go.
origImport := workspaceImportFunc
t.Cleanup(func() {
workspaceInspectRepoFunc = origInspect
workspaceImportFunc = origImport
})
importStarted := make(chan struct{}) importStarted := make(chan struct{})
releaseImport := make(chan struct{}) releaseImport := make(chan struct{})
workspaceInspectRepoFunc = func(context.Context, string, string, string) (workspace.RepoSpec, error) { d.workspaceInspectRepo = func(context.Context, string, string, string) (workspace.RepoSpec, error) {
return workspace.RepoSpec{RepoName: "fake", RepoRoot: "/tmp/fake"}, nil return workspace.RepoSpec{RepoName: "fake", RepoRoot: "/tmp/fake"}, nil
} }
workspaceImportFunc = func(context.Context, workspace.GuestClient, workspace.RepoSpec, string, model.WorkspacePrepareMode) error { d.workspaceImport = func(context.Context, workspace.GuestClient, workspace.RepoSpec, string, model.WorkspacePrepareMode) error {
close(importStarted) close(importStarted)
<-releaseImport <-releaseImport
return nil return nil
@ -465,7 +457,7 @@ func TestPrepareVMWorkspace_ReleasesVMLockDuringGuestIO(t *testing.T) {
// the workspaceLocks scope: two concurrent prepares on the same VM do // the workspaceLocks scope: two concurrent prepares on the same VM do
// NOT interleave, even though they no longer take the core VM mutex. // NOT interleave, even though they no longer take the core VM mutex.
func TestPrepareVMWorkspace_SerialisesConcurrentPreparesOnSameVM(t *testing.T) { func TestPrepareVMWorkspace_SerialisesConcurrentPreparesOnSameVM(t *testing.T) {
// Not parallel: see note on ReleasesVMLockDuringGuestIO. t.Parallel()
ctx := context.Background() ctx := context.Background()
apiSock := filepath.Join(t.TempDir(), "fc.sock") apiSock := filepath.Join(t.TempDir(), "fc.sock")
@ -488,14 +480,7 @@ func TestPrepareVMWorkspace_SerialisesConcurrentPreparesOnSameVM(t *testing.T) {
upsertDaemonVM(t, ctx, d.store, vm) upsertDaemonVM(t, ctx, d.store, vm)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
origInspect := workspaceInspectRepoFunc d.workspaceInspectRepo = func(context.Context, string, string, string) (workspace.RepoSpec, error) {
origImport := workspaceImportFunc
t.Cleanup(func() {
workspaceInspectRepoFunc = origInspect
workspaceImportFunc = origImport
})
workspaceInspectRepoFunc = func(context.Context, string, string, string) (workspace.RepoSpec, error) {
return workspace.RepoSpec{RepoName: "fake", RepoRoot: "/tmp/fake"}, nil return workspace.RepoSpec{RepoName: "fake", RepoRoot: "/tmp/fake"}, nil
} }
@ -503,7 +488,7 @@ func TestPrepareVMWorkspace_SerialisesConcurrentPreparesOnSameVM(t *testing.T) {
var active int32 var active int32
var maxObserved int32 var maxObserved int32
release := make(chan struct{}) release := make(chan struct{})
workspaceImportFunc = func(context.Context, workspace.GuestClient, workspace.RepoSpec, string, model.WorkspacePrepareMode) error { d.workspaceImport = func(context.Context, workspace.GuestClient, workspace.RepoSpec, string, model.WorkspacePrepareMode) error {
n := atomic.AddInt32(&active, 1) n := atomic.AddInt32(&active, 1)
for { for {
prev := atomic.LoadInt32(&maxObserved) prev := atomic.LoadInt32(&maxObserved)