diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 231835f..ce52a01 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -1116,27 +1116,28 @@ func TestVMRunPreflightRejectsSubmodules(t *testing.T) { d := defaultDeps() repoRoot := t.TempDir() - origHostCommandOutput := workspace.HostCommandOutputFunc - t.Cleanup(func() { - workspace.HostCommandOutputFunc = origHostCommandOutput - }) - - workspace.HostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) { - t.Helper() - if name != "git" { - t.Fatalf("command = %q, want git", name) - } - switch { - case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--show-toplevel"}): - return []byte(repoRoot + "\n"), nil - case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--is-bare-repository"}): - return []byte("false\n"), nil - case reflect.DeepEqual(args, []string{"-C", repoRoot, "ls-files", "--stage", "-z"}): - return []byte("160000 deadbeef 0\tvendor/submodule\x00"), nil - default: - t.Fatalf("unexpected git args: %v", args) - return nil, nil - } + // Stub the CLI's repo-inspector with a scripted runner. Per-deps + // injection means this test no longer mutates any package global, + // so t.Parallel() is safe to add here in the future without + // worrying about racing another test's fake runner. + d.repoInspector = &workspace.Inspector{ + Runner: func(ctx context.Context, name string, args ...string) ([]byte, error) { + t.Helper() + if name != "git" { + t.Fatalf("command = %q, want git", name) + } + switch { + case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--show-toplevel"}): + return []byte(repoRoot + "\n"), nil + case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--is-bare-repository"}): + return []byte("false\n"), nil + case reflect.DeepEqual(args, []string{"-C", repoRoot, "ls-files", "--stage", "-z"}): + return []byte("160000 deadbeef 0\tvendor/submodule\x00"), nil + default: + t.Fatalf("unexpected git args: %v", args) + return nil, nil + } + }, } _, err := d.vmRunPreflightRepo(context.Background(), repoRoot) diff --git a/internal/cli/commands_vm.go b/internal/cli/commands_vm.go index 524b43b..d4f010a 100644 --- a/internal/cli/commands_vm.go +++ b/internal/cli/commands_vm.go @@ -119,7 +119,7 @@ Three modes: if strings.TrimSpace(repoPtr.branchName) != "" { dryFromRef = repoPtr.fromRef } - return runWorkspaceDryRun(cmd.Context(), cmd.OutOrStdout(), repoPtr.sourcePath, repoPtr.branchName, dryFromRef, repoPtr.includeUntracked) + return d.runWorkspaceDryRun(cmd.Context(), cmd.OutOrStdout(), repoPtr.sourcePath, repoPtr.branchName, dryFromRef, repoPtr.includeUntracked) } layout, err := paths.Resolve() @@ -618,14 +618,14 @@ func (d *deps) newVMWorkspacePrepareCommand() *cobra.Command { prepareFrom = fromRef } if dryRun { - return runWorkspaceDryRun(cmd.Context(), cmd.OutOrStdout(), resolvedPath, branchName, prepareFrom, includeUntracked) + return d.runWorkspaceDryRun(cmd.Context(), cmd.OutOrStdout(), resolvedPath, branchName, prepareFrom, includeUntracked) } layout, _, err := d.ensureDaemon(cmd.Context()) if err != nil { return err } if !includeUntracked { - if err := noteUntrackedSkipped(cmd.Context(), cmd.ErrOrStderr(), resolvedPath); err != nil { + if err := d.noteUntrackedSkipped(cmd.Context(), cmd.ErrOrStderr(), resolvedPath); err != nil { return err } } diff --git a/internal/cli/deps.go b/internal/cli/deps.go index 5940129..e2665ff 100644 --- a/internal/cli/deps.go +++ b/internal/cli/deps.go @@ -12,6 +12,7 @@ import ( "banger/internal/api" "banger/internal/daemon" + "banger/internal/daemon/workspace" "banger/internal/guest" "banger/internal/paths" "banger/internal/rpc" @@ -52,6 +53,12 @@ type deps struct { buildVMRunToolingPlan func(ctx context.Context, repoRoot string) toolingplan.Plan cwd func() (string, error) completionLister func(ctx context.Context, socketPath, method string) ([]string, error) + // repoInspector is the CLI's single workspace-package Inspector. + // Every code path that needs to shell out to git on the host + // (preflight, dry-run, untracked-count note) goes through it, so + // tests inject a stub Runner via this field instead of mutating a + // package global. + repoInspector *workspace.Inspector } func defaultDeps() *deps { @@ -127,5 +134,6 @@ func defaultDeps() *deps { buildVMRunToolingPlan: toolingplan.Build, cwd: os.Getwd, completionLister: defaultCompletionLister, + repoInspector: workspace.NewInspector(), } } diff --git a/internal/cli/vm_run.go b/internal/cli/vm_run.go index e804450..e758b86 100644 --- a/internal/cli/vm_run.go +++ b/internal/cli/vm_run.go @@ -93,18 +93,18 @@ func (d *deps) vmRunPreflightRepo(ctx context.Context, rawPath string) (string, if err != nil { return "", err } - repoRoot, err := workspace.GitTrimmedOutput(ctx, sourcePath, "rev-parse", "--show-toplevel") + repoRoot, err := d.repoInspector.GitTrimmedOutput(ctx, sourcePath, "rev-parse", "--show-toplevel") if err != nil { return "", fmt.Errorf("%s is not inside a git repository", sourcePath) } - isBare, err := workspace.GitTrimmedOutput(ctx, repoRoot, "rev-parse", "--is-bare-repository") + isBare, err := d.repoInspector.GitTrimmedOutput(ctx, repoRoot, "rev-parse", "--is-bare-repository") if err != nil { return "", fmt.Errorf("inspect git repository %s: %w", repoRoot, err) } if isBare == "true" { return "", fmt.Errorf("vm run requires a non-bare git repository: %s", repoRoot) } - submodules, err := workspace.ListSubmodules(ctx, repoRoot) + submodules, err := d.repoInspector.ListSubmodules(ctx, repoRoot) if err != nil { return "", err } @@ -195,7 +195,7 @@ func (d *deps) runVMRun(ctx context.Context, socketPath string, cfg model.Daemon fromRef = repo.fromRef } if !repo.includeUntracked { - if err := noteUntrackedSkipped(ctx, stderr, repo.sourcePath); err != nil { + if err := d.noteUntrackedSkipped(ctx, stderr, repo.sourcePath); err != nil { printVMRunWarning(stderr, fmt.Sprintf("count untracked files failed: %v", err)) } } diff --git a/internal/cli/workspace_preview.go b/internal/cli/workspace_preview.go index b80c1fc..15528c9 100644 --- a/internal/cli/workspace_preview.go +++ b/internal/cli/workspace_preview.go @@ -4,17 +4,16 @@ import ( "context" "fmt" "io" - - "banger/internal/daemon/workspace" ) // runWorkspaceDryRun inspects the local repo at resolvedPath and // prints the file list that `vm run` / `workspace prepare` would ship // into the guest. Runs on the CLI side (no daemon RPC needed) since // the daemon is always local and the workspace inspection is a pure -// git read. -func runWorkspaceDryRun(ctx context.Context, out io.Writer, resolvedPath, branchName, fromRef string, includeUntracked bool) error { - spec, err := workspace.InspectRepo(ctx, resolvedPath, branchName, fromRef, includeUntracked) +// git read. Git calls go through d.repoInspector so tests inject a +// stub Runner via the deps struct instead of touching package globals. +func (d *deps) runWorkspaceDryRun(ctx context.Context, out io.Writer, resolvedPath, branchName, fromRef string, includeUntracked bool) error { + spec, err := d.repoInspector.InspectRepo(ctx, resolvedPath, branchName, fromRef, includeUntracked) if err != nil { return err } @@ -30,7 +29,7 @@ func runWorkspaceDryRun(ctx context.Context, out io.Writer, resolvedPath, branch fmt.Fprintln(out, path) } if !includeUntracked { - if err := noteUntrackedSkipped(ctx, out, spec.RepoRoot); err != nil { + if err := d.noteUntrackedSkipped(ctx, out, spec.RepoRoot); err != nil { return err } } @@ -41,8 +40,8 @@ func runWorkspaceDryRun(ctx context.Context, out io.Writer, resolvedPath, branch // the repo has untracked non-ignored files that will NOT be copied // because --include-untracked was not passed. Silent when there are // no such files, or when the count can't be determined. -func noteUntrackedSkipped(ctx context.Context, out io.Writer, repoRoot string) error { - count, err := workspace.CountUntrackedPaths(ctx, repoRoot) +func (d *deps) noteUntrackedSkipped(ctx context.Context, out io.Writer, repoRoot string) error { + count, err := d.repoInspector.CountUntrackedPaths(ctx, repoRoot) if err != nil { return err } diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 7156473..ed9f715 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -15,6 +15,7 @@ import ( "banger/internal/api" "banger/internal/buildinfo" "banger/internal/config" + ws "banger/internal/daemon/workspace" "banger/internal/model" "banger/internal/paths" "banger/internal/rpc" @@ -616,11 +617,12 @@ func wireServices(d *Daemon) { } if d.ws == nil { d.ws = newWorkspaceService(workspaceServiceDeps{ - runner: d.runner, - logger: d.logger, - config: d.config, - layout: d.layout, - store: d.store, + runner: d.runner, + logger: d.logger, + config: d.config, + layout: d.layout, + store: d.store, + repoInspector: ws.NewInspector(), vmResolver: func(ctx context.Context, idOrName string) (model.VMRecord, error) { return d.vm.FindVM(ctx, idOrName) }, @@ -655,6 +657,7 @@ func wireServices(d *Daemon) { guestDial: d.guestDial, capHooks: d.buildCapabilityHooks(), beginOperation: d.beginOperation, + vsockHostDevice: defaultVsockHostDevice, }) } if len(d.vmCaps) == 0 { diff --git a/internal/daemon/dns_routing.go b/internal/daemon/dns_routing.go index 0160488..92d2c0a 100644 --- a/internal/daemon/dns_routing.go +++ b/internal/daemon/dns_routing.go @@ -3,18 +3,10 @@ package daemon import ( "context" "strings" - - "banger/internal/system" - "banger/internal/vmdns" ) const vmResolverRouteDomain = "~vm" -var ( - lookupExecutableFunc = system.LookupExecutable - vmDNSAddrFunc = func(server *vmdns.Server) string { return server.Addr() } -) - func (n *HostNetwork) syncVMDNSResolverRouting(ctx context.Context) error { if n == nil || n.vmDNS == nil { return nil @@ -22,13 +14,13 @@ func (n *HostNetwork) syncVMDNSResolverRouting(ctx context.Context) error { if strings.TrimSpace(n.config.BridgeName) == "" { return nil } - if _, err := lookupExecutableFunc("resolvectl"); err != nil { + if _, err := n.lookupExecutable("resolvectl"); err != nil { return nil } if _, err := n.runner.Run(ctx, "ip", "link", "show", n.config.BridgeName); err != nil { return nil } - serverAddr := strings.TrimSpace(vmDNSAddrFunc(n.vmDNS)) + serverAddr := strings.TrimSpace(n.vmDNSAddr(n.vmDNS)) if serverAddr == "" { return nil } @@ -46,7 +38,7 @@ func (n *HostNetwork) clearVMDNSResolverRouting(ctx context.Context) error { if n == nil || strings.TrimSpace(n.config.BridgeName) == "" { return nil } - if _, err := lookupExecutableFunc("resolvectl"); err != nil { + if _, err := n.lookupExecutable("resolvectl"); err != nil { return nil } if _, err := n.runner.Run(ctx, "ip", "link", "show", n.config.BridgeName); err != nil { diff --git a/internal/daemon/dns_routing_test.go b/internal/daemon/dns_routing_test.go index bc53945..fb5c056 100644 --- a/internal/daemon/dns_routing_test.go +++ b/internal/daemon/dns_routing_test.go @@ -9,20 +9,6 @@ import ( ) func TestSyncVMDNSResolverRoutingConfiguresResolved(t *testing.T) { - origLookup := lookupExecutableFunc - origAddr := vmDNSAddrFunc - t.Cleanup(func() { - lookupExecutableFunc = origLookup - vmDNSAddrFunc = origAddr - }) - lookupExecutableFunc = func(name string) (string, error) { - if name == "resolvectl" { - return "/usr/bin/resolvectl", nil - } - return "", nil - } - vmDNSAddrFunc = func(*vmdns.Server) string { return "127.0.0.1:42069" } - runner := &scriptedRunner{ t: t, steps: []runnerStep{ @@ -33,7 +19,16 @@ func TestSyncVMDNSResolverRoutingConfiguresResolved(t *testing.T) { }, } cfg := model.DaemonConfig{BridgeName: model.DefaultBridgeName} - n := &HostNetwork{runner: runner, config: cfg, vmDNS: new(vmdns.Server)} + n := &HostNetwork{ + runner: runner, config: cfg, vmDNS: new(vmdns.Server), + lookupExecutable: func(name string) (string, error) { + if name == "resolvectl" { + return "/usr/bin/resolvectl", nil + } + return "", nil + }, + vmDNSAddr: func(*vmdns.Server) string { return "127.0.0.1:42069" }, + } if err := n.syncVMDNSResolverRouting(context.Background()); err != nil { t.Fatalf("syncVMDNSResolverRouting: %v", err) @@ -42,17 +37,6 @@ func TestSyncVMDNSResolverRoutingConfiguresResolved(t *testing.T) { } func TestClearVMDNSResolverRoutingRevertsBridgeConfig(t *testing.T) { - origLookup := lookupExecutableFunc - t.Cleanup(func() { - lookupExecutableFunc = origLookup - }) - lookupExecutableFunc = func(name string) (string, error) { - if name == "resolvectl" { - return "/usr/bin/resolvectl", nil - } - return "", nil - } - runner := &scriptedRunner{ t: t, steps: []runnerStep{ @@ -61,7 +45,15 @@ func TestClearVMDNSResolverRoutingRevertsBridgeConfig(t *testing.T) { }, } cfg := model.DaemonConfig{BridgeName: model.DefaultBridgeName} - n := &HostNetwork{runner: runner, config: cfg} + n := &HostNetwork{ + runner: runner, config: cfg, + lookupExecutable: func(name string) (string, error) { + if name == "resolvectl" { + return "/usr/bin/resolvectl", nil + } + return "", nil + }, + } if err := n.clearVMDNSResolverRouting(context.Background()); err != nil { t.Fatalf("clearVMDNSResolverRouting: %v", err) diff --git a/internal/daemon/doctor.go b/internal/daemon/doctor.go index ff4960f..04bb49f 100644 --- a/internal/daemon/doctor.go +++ b/internal/daemon/doctor.go @@ -184,7 +184,7 @@ func (d *Daemon) vsockChecks() *system.Preflight { } else { checks.Addf("%v", err) } - checks.RequireFile(vsockHostDevicePath, "vsock host device", "load the vhost_vsock kernel module on the host") + checks.RequireFile(d.vm.vsockHostDevice, "vsock host device", "load the vhost_vsock kernel module on the host") return checks } diff --git a/internal/daemon/host_network.go b/internal/daemon/host_network.go index 392fab4..8f04a5b 100644 --- a/internal/daemon/host_network.go +++ b/internal/daemon/host_network.go @@ -41,6 +41,12 @@ type HostNetwork struct { tapPool tapPool vmDNS *vmdns.Server + + // Test seams. Default to real implementations at construction; + // tests build HostNetwork with stubs instead of mutating package + // globals, so parallel tests can't race each other's fake state. + lookupExecutable func(name string) (string, error) + vmDNSAddr func(server *vmdns.Server) string } // hostNetworkDeps is the explicit wiring bag newHostNetwork expects. @@ -56,11 +62,13 @@ type hostNetworkDeps struct { func newHostNetwork(deps hostNetworkDeps) *HostNetwork { return &HostNetwork{ - runner: deps.runner, - logger: deps.logger, - config: deps.config, - layout: deps.layout, - closing: deps.closing, + runner: deps.runner, + logger: deps.logger, + config: deps.config, + layout: deps.layout, + closing: deps.closing, + lookupExecutable: system.LookupExecutable, + vmDNSAddr: func(server *vmdns.Server) string { return server.Addr() }, } } diff --git a/internal/daemon/logger_test.go b/internal/daemon/logger_test.go index 3fe5dde..b9758df 100644 --- a/internal/daemon/logger_test.go +++ b/internal/daemon/logger_test.go @@ -42,11 +42,7 @@ func TestNewDaemonLoggerEmitsJSONAtConfiguredLevel(t *testing.T) { func TestStartVMLockedLogsBridgeFailure(t *testing.T) { ctx := context.Background() - origVsockHostDevicePath := vsockHostDevicePath - vsockHostDevicePath = filepath.Join(t.TempDir(), "vhost-vsock") - t.Cleanup(func() { - vsockHostDevicePath = origVsockHostDevicePath - }) + vsockDevicePath := filepath.Join(t.TempDir(), "vhost-vsock") binDir := t.TempDir() for _, name := range []string{ "sudo", "ip", "dmsetup", "losetup", "blockdev", "truncate", "pgrep", "ps", @@ -62,7 +58,7 @@ func TestStartVMLockedLogsBridgeFailure(t *testing.T) { if err := os.WriteFile(firecrackerBin, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { t.Fatalf("write firecracker: %v", err) } - if err := os.WriteFile(vsockHostDevicePath, []byte{}, 0o644); err != nil { + if err := os.WriteFile(vsockDevicePath, []byte{}, 0o644); err != nil { t.Fatalf("write vsock host device: %v", err) } if err := os.WriteFile(vsockHelper, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { @@ -115,6 +111,7 @@ func TestStartVMLockedLogsBridgeFailure(t *testing.T) { logger: logger, } wireServices(d) + d.vm.vsockHostDevice = vsockDevicePath _, err = d.vm.startVMLocked(ctx, vm, image) if err == nil || !strings.Contains(err.Error(), "bridge up failed") { diff --git a/internal/daemon/preflight.go b/internal/daemon/preflight.go index d2bcec7..ff5d04e 100644 --- a/internal/daemon/preflight.go +++ b/internal/daemon/preflight.go @@ -8,7 +8,11 @@ import ( "banger/internal/system" ) -var vsockHostDevicePath = "/dev/vhost-vsock" +// defaultVsockHostDevice is the vhost-vsock device file every +// Firecracker guest relies on to talk to the host via vsock. Tests +// point at a tempfile by setting VMService.vsockHostDevice; production +// wiring defaults the field to this path in wireServices. +const defaultVsockHostDevice = "/dev/vhost-vsock" func (s *VMService) validateStartPrereqs(ctx context.Context, vm model.VMRecord, image model.Image) error { checks := system.NewPreflight() @@ -33,7 +37,7 @@ func (s *VMService) addBaseStartPrereqs(checks *system.Preflight, image model.Im } else { checks.Addf("%v", err) } - checks.RequireFile(vsockHostDevicePath, "vsock host device", "load the vhost_vsock kernel module on the host") + checks.RequireFile(s.vsockHostDevice, "vsock host device", "load the vhost_vsock kernel module on the host") checks.RequireFile(image.RootfsPath, "rootfs image", "select a valid registered image") checks.RequireFile(image.KernelPath, "kernel image", `re-register or rebuild the image with a valid kernel`) if strings.TrimSpace(image.InitrdPath) != "" { diff --git a/internal/daemon/vm_service.go b/internal/daemon/vm_service.go index 0557d89..30d9d09 100644 --- a/internal/daemon/vm_service.go +++ b/internal/daemon/vm_service.go @@ -66,6 +66,11 @@ type VMService struct { // Test seams. guestWaitForSSH func(context.Context, string, string, time.Duration) error guestDial func(context.Context, string, string) (guestSSHClient, error) + // vsockHostDevice is the path preflight + doctor expect to find for + // the vhost-vsock device. Defaults to defaultVsockHostDevice; tests + // point at a tempfile so RequireFile passes without needing the + // real kernel module loaded. + vsockHostDevice string // Capability hook dispatch. VMService invokes capabilities via // these seams, populated by Daemon.buildCapabilityHooks() at @@ -104,9 +109,14 @@ type vmServiceDeps struct { guestDial func(context.Context, string, string) (guestSSHClient, error) capHooks capabilityHooks beginOperation func(name string, attrs ...any) *operationLog + vsockHostDevice string } func newVMService(deps vmServiceDeps) *VMService { + vsockPath := deps.vsockHostDevice + if vsockPath == "" { + vsockPath = defaultVsockHostDevice + } return &VMService{ runner: deps.runner, logger: deps.logger, @@ -120,6 +130,7 @@ func newVMService(deps vmServiceDeps) *VMService { guestDial: deps.guestDial, capHooks: deps.capHooks, beginOperation: deps.beginOperation, + vsockHostDevice: vsockPath, handles: newHandleCache(), } } diff --git a/internal/daemon/workspace.go b/internal/daemon/workspace.go index 9872f02..c17e622 100644 --- a/internal/daemon/workspace.go +++ b/internal/daemon/workspace.go @@ -24,14 +24,26 @@ func (s *WorkspaceService) workspaceInspectRepoHook(ctx context.Context, sourceP if s != nil && s.workspaceInspectRepo != nil { return s.workspaceInspectRepo(ctx, sourcePath, branchName, fromRef, includeUntracked) } - return ws.InspectRepo(ctx, sourcePath, branchName, fromRef, includeUntracked) + return s.inspector().InspectRepo(ctx, sourcePath, branchName, fromRef, includeUntracked) } func (s *WorkspaceService) workspaceImportHook(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error { if s != nil && s.workspaceImport != nil { return s.workspaceImport(ctx, client, spec, guestPath, mode) } - return ws.ImportRepoToGuest(ctx, client, spec, guestPath, mode) + return s.inspector().ImportRepoToGuest(ctx, client, spec, guestPath, mode) +} + +// inspector returns the service's workspace Inspector, falling back to +// a fresh real-runner Inspector when callers constructed the service +// without wiring one. Keeping the fallback here lets test literals +// that don't care about the Inspector still function without a manual +// NewInspector() call. +func (s *WorkspaceService) inspector() *ws.Inspector { + if s != nil && s.repoInspector != nil { + return s.repoInspector + } + return ws.NewInspector() } func (s *WorkspaceService) ExportVMWorkspace(ctx context.Context, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { diff --git a/internal/daemon/workspace/workspace.go b/internal/daemon/workspace/workspace.go index f9190a7..1f33cf4 100644 --- a/internal/daemon/workspace/workspace.go +++ b/internal/daemon/workspace/workspace.go @@ -2,6 +2,12 @@ // git repo inspection, shallow copy preparation, guest-side tar import, // finalization script generation, and small utilities. // +// Every helper that needs to run a host command (git or otherwise) +// lives as a method on *Inspector rather than a free function that +// routes through a package global. That way two tests running in +// parallel can each build their own Inspector with a stub Runner +// without fighting over shared state. +// // The orchestrator methods (ExportVMWorkspace, PrepareVMWorkspace) stay on // *daemon.Daemon. package workspace @@ -51,9 +57,28 @@ type GuestClient interface { StreamTarEntries(ctx context.Context, dir string, entries []string, command string, log io.Writer) error } -// HostCommandOutputFunc runs a host command and returns its combined output. -// Declared as a package var so tests can substitute a stub runner. -var HostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) { +// RunnerFunc is the single-method surface every Inspector needs: run a +// host command with args, return combined output + error. Tests supply +// a stub that records calls and replays canned responses; production +// uses realHostRunner which wraps system.NewRunner. +type RunnerFunc func(ctx context.Context, name string, args ...string) ([]byte, error) + +// Inspector bundles the host-command seam for all git-using workspace +// helpers. Construct one at the boundary where you're reading the +// filesystem (CLI deps, WorkspaceService) and call its methods directly; +// don't reach into the struct from helper code. +type Inspector struct { + Runner RunnerFunc +} + +// NewInspector returns an Inspector backed by the real host runner. +// Production callers (CLI deps initialisation, daemon WorkspaceService +// wiring) use this; tests construct Inspector{Runner: stub} directly. +func NewInspector() *Inspector { + return &Inspector{Runner: realHostRunner} +} + +func realHostRunner(ctx context.Context, name string, args ...string) ([]byte, error) { runner := system.NewRunner() output, err := runner.Run(ctx, name, args...) if err == nil { @@ -72,55 +97,55 @@ var HostCommandOutputFunc = func(ctx context.Context, name string, args ...strin // submodules, and overlay paths needed for a prepare. Overlay paths // cover tracked files by default; untracked non-ignored files are // included only when includeUntracked is true. -func InspectRepo(ctx context.Context, rawPath, branchName, fromRef string, includeUntracked bool) (RepoSpec, error) { +func (i *Inspector) InspectRepo(ctx context.Context, rawPath, branchName, fromRef string, includeUntracked bool) (RepoSpec, error) { sourcePath, err := ResolveSourcePath(rawPath) if err != nil { return RepoSpec{}, err } - repoRoot, err := GitTrimmedOutput(ctx, sourcePath, "rev-parse", "--show-toplevel") + repoRoot, err := i.GitTrimmedOutput(ctx, sourcePath, "rev-parse", "--show-toplevel") if err != nil { return RepoSpec{}, fmt.Errorf("%s is not inside a git repository", sourcePath) } - isBare, err := GitTrimmedOutput(ctx, repoRoot, "rev-parse", "--is-bare-repository") + isBare, err := i.GitTrimmedOutput(ctx, repoRoot, "rev-parse", "--is-bare-repository") if err != nil { return RepoSpec{}, fmt.Errorf("inspect git repository %s: %w", repoRoot, err) } if isBare == "true" { return RepoSpec{}, fmt.Errorf("workspace prepare requires a non-bare git repository: %s", repoRoot) } - submodules, err := ListSubmodules(ctx, repoRoot) + submodules, err := i.ListSubmodules(ctx, repoRoot) if err != nil { return RepoSpec{}, err } - headCommit, err := GitTrimmedOutput(ctx, repoRoot, "rev-parse", "HEAD^{commit}") + headCommit, err := i.GitTrimmedOutput(ctx, repoRoot, "rev-parse", "HEAD^{commit}") if err != nil { return RepoSpec{}, fmt.Errorf("git repository %s must have at least one commit", repoRoot) } - currentBranch, err := GitTrimmedOutput(ctx, repoRoot, "branch", "--show-current") + currentBranch, err := i.GitTrimmedOutput(ctx, repoRoot, "branch", "--show-current") if err != nil { return RepoSpec{}, fmt.Errorf("resolve current branch for %s: %w", repoRoot, err) } baseCommit := headCommit branchName = strings.TrimSpace(branchName) if branchName != "" { - baseCommit, err = GitTrimmedOutput(ctx, repoRoot, "rev-parse", fromRef+"^{commit}") + baseCommit, err = i.GitTrimmedOutput(ctx, repoRoot, "rev-parse", fromRef+"^{commit}") if err != nil { return RepoSpec{}, fmt.Errorf("resolve workspace from %q: %w", fromRef, err) } } - gitUserName, err := GitResolvedConfigValue(ctx, repoRoot, "user.name") + gitUserName, err := i.GitResolvedConfigValue(ctx, repoRoot, "user.name") if err != nil { return RepoSpec{}, fmt.Errorf("resolve git user.name for %s: %w", repoRoot, err) } - gitUserEmail, err := GitResolvedConfigValue(ctx, repoRoot, "user.email") + gitUserEmail, err := i.GitResolvedConfigValue(ctx, repoRoot, "user.email") if err != nil { return RepoSpec{}, fmt.Errorf("resolve git user.email for %s: %w", repoRoot, err) } - originURL, err := GitResolvedConfigValue(ctx, repoRoot, "remote.origin.url") + originURL, err := i.GitResolvedConfigValue(ctx, repoRoot, "remote.origin.url") if err != nil { return RepoSpec{}, fmt.Errorf("resolve origin url for %s: %w", repoRoot, err) } - overlayPaths, err := ListOverlayPaths(ctx, repoRoot, includeUntracked) + overlayPaths, err := i.ListOverlayPaths(ctx, repoRoot, includeUntracked) if err != nil { return RepoSpec{}, err } @@ -142,7 +167,7 @@ func InspectRepo(ctx context.Context, rawPath, branchName, fromRef string, inclu // ImportRepoToGuest materialises spec inside the guest at guestPath. Mode // selects between full copy, metadata-only, or shallow metadata + overlay. -func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error { +func (i *Inspector) ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error { switch mode { case model.WorkspacePrepareModeFullCopy: var copyLog bytes.Buffer @@ -156,7 +181,7 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g } return nil case model.WorkspacePrepareModeMetadataOnly, model.WorkspacePrepareModeShallowOverlay: - repoCopyDir, cleanup, err := PrepareRepoCopy(ctx, spec) + repoCopyDir, cleanup, err := i.PrepareRepoCopy(ctx, spec) if err != nil { return err } @@ -212,7 +237,7 @@ func FinalizeScript(spec RepoSpec, guestPath string, mode model.WorkspacePrepare // PrepareRepoCopy materialises a shallow clone of spec into a temp dir. The // returned cleanup removes the temp root. -func PrepareRepoCopy(ctx context.Context, spec RepoSpec) (string, func(), error) { +func (i *Inspector) PrepareRepoCopy(ctx context.Context, spec RepoSpec) (string, func(), error) { tempRoot, err := os.MkdirTemp("", "banger-workspace-*") if err != nil { return "", nil, err @@ -224,7 +249,7 @@ func PrepareRepoCopy(ctx context.Context, spec RepoSpec) (string, func(), error) cloneArgs = append(cloneArgs, "--single-branch", "--branch", spec.CurrentBranch) } cloneArgs = append(cloneArgs, GitFileURL(spec.RepoRoot), repoCopyDir) - if err := RunHostCommand(ctx, "git", cloneArgs...); err != nil { + if err := i.RunHostCommand(ctx, "git", cloneArgs...); err != nil { cleanup() return "", nil, fmt.Errorf("clone shallow workspace repo copy: %w", err) } @@ -232,19 +257,19 @@ func PrepareRepoCopy(ctx context.Context, spec RepoSpec) (string, func(), error) if strings.TrimSpace(spec.BranchName) != "" { checkoutCommit = spec.BaseCommit } - if err := RunHostCommand(ctx, "git", "-C", repoCopyDir, "cat-file", "-e", checkoutCommit+"^{commit}"); err != nil { - if err := RunHostCommand(ctx, "git", "-C", repoCopyDir, "fetch", "--depth", fmt.Sprintf("%d", ShallowFetchDepth), GitFileURL(spec.RepoRoot), checkoutCommit); err != nil { + if err := i.RunHostCommand(ctx, "git", "-C", repoCopyDir, "cat-file", "-e", checkoutCommit+"^{commit}"); err != nil { + if err := i.RunHostCommand(ctx, "git", "-C", repoCopyDir, "fetch", "--depth", fmt.Sprintf("%d", ShallowFetchDepth), GitFileURL(spec.RepoRoot), checkoutCommit); err != nil { cleanup() return "", nil, fmt.Errorf("fetch shallow workspace repo commit %s: %w", checkoutCommit, err) } } if strings.TrimSpace(spec.OriginURL) != "" { - if err := RunHostCommand(ctx, "git", "-C", repoCopyDir, "remote", "set-url", "origin", spec.OriginURL); err != nil { + if err := i.RunHostCommand(ctx, "git", "-C", repoCopyDir, "remote", "set-url", "origin", spec.OriginURL); err != nil { cleanup() return "", nil, fmt.Errorf("set workspace origin remote: %w", err) } } else { - if err := RunHostCommand(ctx, "git", "-C", repoCopyDir, "remote", "remove", "origin"); err != nil { + if err := i.RunHostCommand(ctx, "git", "-C", repoCopyDir, "remote", "remove", "origin"); err != nil { cleanup() return "", nil, fmt.Errorf("remove workspace placeholder origin remote: %w", err) } @@ -273,8 +298,8 @@ func ResolveSourcePath(rawPath string) (string, error) { } // ListSubmodules returns the gitlink paths in repoRoot (mode 160000 entries). -func ListSubmodules(ctx context.Context, repoRoot string) ([]string, error) { - output, err := GitOutput(ctx, repoRoot, "ls-files", "--stage", "-z") +func (i *Inspector) ListSubmodules(ctx context.Context, repoRoot string) ([]string, error) { + output, err := i.GitOutput(ctx, repoRoot, "ls-files", "--stage", "-z") if err != nil { return nil, fmt.Errorf("inspect workspace git index for %s: %w", repoRoot, err) } @@ -304,8 +329,8 @@ func ListSubmodules(ctx context.Context, repoRoot string) ([]string, error) { // leave the developer's machine. Callers that genuinely want the // fuller set (scratch repos, vendored binaries the user is iterating // on) opt in explicitly. -func ListOverlayPaths(ctx context.Context, repoRoot string, includeUntracked bool) ([]string, error) { - trackedOutput, err := GitOutput(ctx, repoRoot, "ls-files", "-z") +func (i *Inspector) ListOverlayPaths(ctx context.Context, repoRoot string, includeUntracked bool) ([]string, error) { + trackedOutput, err := i.GitOutput(ctx, repoRoot, "ls-files", "-z") if err != nil { return nil, fmt.Errorf("list tracked files for %s: %w", repoRoot, err) } @@ -325,7 +350,7 @@ func ListOverlayPaths(ctx context.Context, repoRoot string, includeUntracked boo paths = append(paths, relPath) } if includeUntracked { - untrackedOutput, err := GitOutput(ctx, repoRoot, "ls-files", "--others", "--exclude-standard", "-z") + untrackedOutput, err := i.GitOutput(ctx, repoRoot, "ls-files", "--others", "--exclude-standard", "-z") if err != nil { return nil, fmt.Errorf("list untracked files for %s: %w", repoRoot, err) } @@ -348,8 +373,8 @@ func ListOverlayPaths(ctx context.Context, repoRoot string, includeUntracked boo // files in repoRoot. Used by the CLI to warn the user when they are // about to ship a workspace that has local-but-unignored scratch // files which, under the default, will be skipped. -func CountUntrackedPaths(ctx context.Context, repoRoot string) (int, error) { - untrackedOutput, err := GitOutput(ctx, repoRoot, "ls-files", "--others", "--exclude-standard", "-z") +func (i *Inspector) CountUntrackedPaths(ctx context.Context, repoRoot string) (int, error) { + untrackedOutput, err := i.GitOutput(ctx, repoRoot, "ls-files", "--others", "--exclude-standard", "-z") if err != nil { return 0, fmt.Errorf("list untracked files for %s: %w", repoRoot, err) } @@ -377,18 +402,18 @@ func ParsePrepareMode(raw string) (model.WorkspacePrepareMode, error) { } // GitOutput runs `git [-C dir] args...` and returns its raw stdout. -func GitOutput(ctx context.Context, dir string, args ...string) ([]byte, error) { +func (i *Inspector) GitOutput(ctx context.Context, dir string, args ...string) ([]byte, error) { fullArgs := make([]string, 0, len(args)+2) if strings.TrimSpace(dir) != "" { fullArgs = append(fullArgs, "-C", dir) } fullArgs = append(fullArgs, args...) - return HostCommandOutputFunc(ctx, "git", fullArgs...) + return i.Runner(ctx, "git", fullArgs...) } // GitTrimmedOutput returns GitOutput with surrounding whitespace trimmed. -func GitTrimmedOutput(ctx context.Context, dir string, args ...string) (string, error) { - output, err := GitOutput(ctx, dir, args...) +func (i *Inspector) GitTrimmedOutput(ctx context.Context, dir string, args ...string) (string, error) { + output, err := i.GitOutput(ctx, dir, args...) if err != nil { return "", err } @@ -396,8 +421,8 @@ func GitTrimmedOutput(ctx context.Context, dir string, args ...string) (string, } // GitResolvedConfigValue reads git config key with --default "" --get. -func GitResolvedConfigValue(ctx context.Context, dir, key string) (string, error) { - return GitTrimmedOutput(ctx, dir, "config", "--default", "", "--get", key) +func (i *Inspector) GitResolvedConfigValue(ctx context.Context, dir, key string) (string, error) { + return i.GitTrimmedOutput(ctx, dir, "config", "--default", "", "--get", key) } // ParseNullSeparatedOutput splits on NULs and trims, returning non-empty @@ -415,10 +440,10 @@ func ParseNullSeparatedOutput(output []byte) []string { return values } -// RunHostCommand runs a host command via HostCommandOutputFunc, discarding -// its stdout. -func RunHostCommand(ctx context.Context, name string, args ...string) error { - _, err := HostCommandOutputFunc(ctx, name, args...) +// RunHostCommand runs a host command via the Inspector's Runner, +// discarding its stdout. +func (i *Inspector) RunHostCommand(ctx context.Context, name string, args ...string) error { + _, err := i.Runner(ctx, name, args...) return err } diff --git a/internal/daemon/workspace/workspace_test.go b/internal/daemon/workspace/workspace_test.go index 38650f7..6b33205 100644 --- a/internal/daemon/workspace/workspace_test.go +++ b/internal/daemon/workspace/workspace_test.go @@ -59,7 +59,8 @@ func seedRepo(t *testing.T) string { func TestListOverlayPaths_TrackedOnlyByDefault(t *testing.T) { repo := seedRepo(t) - got, err := ListOverlayPaths(context.Background(), repo, false) + i := NewInspector() + got, err := i.ListOverlayPaths(context.Background(), repo, false) if err != nil { t.Fatalf("ListOverlayPaths: %v", err) } @@ -71,7 +72,8 @@ func TestListOverlayPaths_TrackedOnlyByDefault(t *testing.T) { func TestListOverlayPaths_IncludeUntracked(t *testing.T) { repo := seedRepo(t) - got, err := ListOverlayPaths(context.Background(), repo, true) + i := NewInspector() + got, err := i.ListOverlayPaths(context.Background(), repo, true) if err != nil { t.Fatalf("ListOverlayPaths: %v", err) } @@ -89,7 +91,8 @@ func TestListOverlayPaths_IncludeUntracked(t *testing.T) { func TestCountUntrackedPaths(t *testing.T) { repo := seedRepo(t) - count, err := CountUntrackedPaths(context.Background(), repo) + i := NewInspector() + count, err := i.CountUntrackedPaths(context.Background(), repo) if err != nil { t.Fatalf("CountUntrackedPaths: %v", err) } diff --git a/internal/daemon/workspace_service.go b/internal/daemon/workspace_service.go index 5af2e14..386b38b 100644 --- a/internal/daemon/workspace_service.go +++ b/internal/daemon/workspace_service.go @@ -45,6 +45,13 @@ type WorkspaceService struct { beginOperation func(name string, attrs ...any) *operationLog + // repoInspector is the Inspector used by the real InspectRepo / + // ImportRepoToGuest fallbacks when the test seams below aren't + // set. wireServices installs the production one; tests that want + // to intercept only the host-command surface (not the whole + // inspect/import hook) can assign a stub-runner Inspector here. + repoInspector *ws.Inspector + // Test seams. workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string, includeUntracked bool) (ws.RepoSpec, error) workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error @@ -56,6 +63,7 @@ type workspaceServiceDeps struct { config model.DaemonConfig layout paths.Layout store *store.Store + repoInspector *ws.Inspector vmResolver func(ctx context.Context, idOrName string) (model.VMRecord, error) aliveChecker func(vm model.VMRecord) bool waitGuestSSH func(ctx context.Context, address string, interval time.Duration) error @@ -73,6 +81,7 @@ func newWorkspaceService(deps workspaceServiceDeps) *WorkspaceService { config: deps.config, layout: deps.layout, store: deps.store, + repoInspector: deps.repoInspector, vmResolver: deps.vmResolver, aliveChecker: deps.aliveChecker, waitGuestSSH: deps.waitGuestSSH,