seams: move the last four package globals onto instance fields

Three test seams were still package-level mutable vars, which tests
had to swap before use. That's the classic path to flaky parallel
tests — two goroutines fighting over the same global fake. Push each
down to the struct that owns the behaviour.

internal/daemon/dns_routing.go
  lookupExecutableFunc + vmDNSAddrFunc → fields on *HostNetwork,
  defaulted at newHostNetwork time. dns_routing_test builds
  HostNetwork{..., lookupExecutable: stub, vmDNSAddr: stub} inline,
  no more t.Cleanup dance around package-level vars.

internal/daemon/preflight.go + doctor.go
  vsockHostDevicePath (mutable string) → vsockHostDevice field on
  *VMService, defaulted via defaultVsockHostDevice constant in
  newVMService. Preflight reads s.vsockHostDevice; doctor reads
  d.vm.vsockHostDevice. Logger test sets d.vm.vsockHostDevice = tmp
  after wireServices.

internal/daemon/workspace/workspace.go
  HostCommandOutputFunc → *Inspector struct with a Runner field.
  Every git-using helper (GitOutput, GitTrimmedOutput,
  GitResolvedConfigValue, RunHostCommand, ListSubmodules,
  ListOverlayPaths, CountUntrackedPaths, InspectRepo,
  ImportRepoToGuest, PrepareRepoCopy) is now a method on *Inspector.
  NewInspector() wraps the real host runner for production;
  WorkspaceService holds one via repoInspector, CLI deps holds one
  too. cli_test.go's submodule-rejection test builds its own
  Inspector with a scripted Runner instead of patching a global.
  Pure helpers (FinalizeScript, ResolveSourcePath, ParsePrepareMode,
  ShellQuote, FormatStepError, GitFileURL, ParseNullSeparatedOutput)
  stay free functions since they don't touch the host.

Sentinel: grep for HostCommandOutputFunc, lookupExecutableFunc,
vmDNSAddrFunc, vsockHostDevicePath is now empty across internal/.
make lint test green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Thales Maciel 2026-04-22 12:07:14 -03:00
parent 2685bc73f8
commit ecb18ce6ca
No known key found for this signature in database
GPG key ID: 33112E6833C34679
17 changed files with 201 additions and 137 deletions

View file

@ -1116,27 +1116,28 @@ func TestVMRunPreflightRejectsSubmodules(t *testing.T) {
d := defaultDeps() d := defaultDeps()
repoRoot := t.TempDir() repoRoot := t.TempDir()
origHostCommandOutput := workspace.HostCommandOutputFunc // Stub the CLI's repo-inspector with a scripted runner. Per-deps
t.Cleanup(func() { // injection means this test no longer mutates any package global,
workspace.HostCommandOutputFunc = origHostCommandOutput // so t.Parallel() is safe to add here in the future without
}) // worrying about racing another test's fake runner.
d.repoInspector = &workspace.Inspector{
workspace.HostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) { Runner: func(ctx context.Context, name string, args ...string) ([]byte, error) {
t.Helper() t.Helper()
if name != "git" { if name != "git" {
t.Fatalf("command = %q, want git", name) t.Fatalf("command = %q, want git", name)
} }
switch { switch {
case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--show-toplevel"}): case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--show-toplevel"}):
return []byte(repoRoot + "\n"), nil return []byte(repoRoot + "\n"), nil
case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--is-bare-repository"}): case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--is-bare-repository"}):
return []byte("false\n"), nil return []byte("false\n"), nil
case reflect.DeepEqual(args, []string{"-C", repoRoot, "ls-files", "--stage", "-z"}): case reflect.DeepEqual(args, []string{"-C", repoRoot, "ls-files", "--stage", "-z"}):
return []byte("160000 deadbeef 0\tvendor/submodule\x00"), nil return []byte("160000 deadbeef 0\tvendor/submodule\x00"), nil
default: default:
t.Fatalf("unexpected git args: %v", args) t.Fatalf("unexpected git args: %v", args)
return nil, nil return nil, nil
} }
},
} }
_, err := d.vmRunPreflightRepo(context.Background(), repoRoot) _, err := d.vmRunPreflightRepo(context.Background(), repoRoot)

View file

@ -119,7 +119,7 @@ Three modes:
if strings.TrimSpace(repoPtr.branchName) != "" { if strings.TrimSpace(repoPtr.branchName) != "" {
dryFromRef = repoPtr.fromRef dryFromRef = repoPtr.fromRef
} }
return runWorkspaceDryRun(cmd.Context(), cmd.OutOrStdout(), repoPtr.sourcePath, repoPtr.branchName, dryFromRef, repoPtr.includeUntracked) return d.runWorkspaceDryRun(cmd.Context(), cmd.OutOrStdout(), repoPtr.sourcePath, repoPtr.branchName, dryFromRef, repoPtr.includeUntracked)
} }
layout, err := paths.Resolve() layout, err := paths.Resolve()
@ -618,14 +618,14 @@ func (d *deps) newVMWorkspacePrepareCommand() *cobra.Command {
prepareFrom = fromRef prepareFrom = fromRef
} }
if dryRun { if dryRun {
return runWorkspaceDryRun(cmd.Context(), cmd.OutOrStdout(), resolvedPath, branchName, prepareFrom, includeUntracked) return d.runWorkspaceDryRun(cmd.Context(), cmd.OutOrStdout(), resolvedPath, branchName, prepareFrom, includeUntracked)
} }
layout, _, err := d.ensureDaemon(cmd.Context()) layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
if !includeUntracked { if !includeUntracked {
if err := noteUntrackedSkipped(cmd.Context(), cmd.ErrOrStderr(), resolvedPath); err != nil { if err := d.noteUntrackedSkipped(cmd.Context(), cmd.ErrOrStderr(), resolvedPath); err != nil {
return err return err
} }
} }

View file

@ -12,6 +12,7 @@ import (
"banger/internal/api" "banger/internal/api"
"banger/internal/daemon" "banger/internal/daemon"
"banger/internal/daemon/workspace"
"banger/internal/guest" "banger/internal/guest"
"banger/internal/paths" "banger/internal/paths"
"banger/internal/rpc" "banger/internal/rpc"
@ -52,6 +53,12 @@ type deps struct {
buildVMRunToolingPlan func(ctx context.Context, repoRoot string) toolingplan.Plan buildVMRunToolingPlan func(ctx context.Context, repoRoot string) toolingplan.Plan
cwd func() (string, error) cwd func() (string, error)
completionLister func(ctx context.Context, socketPath, method string) ([]string, error) completionLister func(ctx context.Context, socketPath, method string) ([]string, error)
// repoInspector is the CLI's single workspace-package Inspector.
// Every code path that needs to shell out to git on the host
// (preflight, dry-run, untracked-count note) goes through it, so
// tests inject a stub Runner via this field instead of mutating a
// package global.
repoInspector *workspace.Inspector
} }
func defaultDeps() *deps { func defaultDeps() *deps {
@ -127,5 +134,6 @@ func defaultDeps() *deps {
buildVMRunToolingPlan: toolingplan.Build, buildVMRunToolingPlan: toolingplan.Build,
cwd: os.Getwd, cwd: os.Getwd,
completionLister: defaultCompletionLister, completionLister: defaultCompletionLister,
repoInspector: workspace.NewInspector(),
} }
} }

View file

@ -93,18 +93,18 @@ func (d *deps) vmRunPreflightRepo(ctx context.Context, rawPath string) (string,
if err != nil { if err != nil {
return "", err return "", err
} }
repoRoot, err := workspace.GitTrimmedOutput(ctx, sourcePath, "rev-parse", "--show-toplevel") repoRoot, err := d.repoInspector.GitTrimmedOutput(ctx, sourcePath, "rev-parse", "--show-toplevel")
if err != nil { if err != nil {
return "", fmt.Errorf("%s is not inside a git repository", sourcePath) return "", fmt.Errorf("%s is not inside a git repository", sourcePath)
} }
isBare, err := workspace.GitTrimmedOutput(ctx, repoRoot, "rev-parse", "--is-bare-repository") isBare, err := d.repoInspector.GitTrimmedOutput(ctx, repoRoot, "rev-parse", "--is-bare-repository")
if err != nil { if err != nil {
return "", fmt.Errorf("inspect git repository %s: %w", repoRoot, err) return "", fmt.Errorf("inspect git repository %s: %w", repoRoot, err)
} }
if isBare == "true" { if isBare == "true" {
return "", fmt.Errorf("vm run requires a non-bare git repository: %s", repoRoot) return "", fmt.Errorf("vm run requires a non-bare git repository: %s", repoRoot)
} }
submodules, err := workspace.ListSubmodules(ctx, repoRoot) submodules, err := d.repoInspector.ListSubmodules(ctx, repoRoot)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -195,7 +195,7 @@ func (d *deps) runVMRun(ctx context.Context, socketPath string, cfg model.Daemon
fromRef = repo.fromRef fromRef = repo.fromRef
} }
if !repo.includeUntracked { if !repo.includeUntracked {
if err := noteUntrackedSkipped(ctx, stderr, repo.sourcePath); err != nil { if err := d.noteUntrackedSkipped(ctx, stderr, repo.sourcePath); err != nil {
printVMRunWarning(stderr, fmt.Sprintf("count untracked files failed: %v", err)) printVMRunWarning(stderr, fmt.Sprintf("count untracked files failed: %v", err))
} }
} }

View file

@ -4,17 +4,16 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"banger/internal/daemon/workspace"
) )
// runWorkspaceDryRun inspects the local repo at resolvedPath and // runWorkspaceDryRun inspects the local repo at resolvedPath and
// prints the file list that `vm run` / `workspace prepare` would ship // prints the file list that `vm run` / `workspace prepare` would ship
// into the guest. Runs on the CLI side (no daemon RPC needed) since // into the guest. Runs on the CLI side (no daemon RPC needed) since
// the daemon is always local and the workspace inspection is a pure // the daemon is always local and the workspace inspection is a pure
// git read. // git read. Git calls go through d.repoInspector so tests inject a
func runWorkspaceDryRun(ctx context.Context, out io.Writer, resolvedPath, branchName, fromRef string, includeUntracked bool) error { // stub Runner via the deps struct instead of touching package globals.
spec, err := workspace.InspectRepo(ctx, resolvedPath, branchName, fromRef, includeUntracked) func (d *deps) runWorkspaceDryRun(ctx context.Context, out io.Writer, resolvedPath, branchName, fromRef string, includeUntracked bool) error {
spec, err := d.repoInspector.InspectRepo(ctx, resolvedPath, branchName, fromRef, includeUntracked)
if err != nil { if err != nil {
return err return err
} }
@ -30,7 +29,7 @@ func runWorkspaceDryRun(ctx context.Context, out io.Writer, resolvedPath, branch
fmt.Fprintln(out, path) fmt.Fprintln(out, path)
} }
if !includeUntracked { if !includeUntracked {
if err := noteUntrackedSkipped(ctx, out, spec.RepoRoot); err != nil { if err := d.noteUntrackedSkipped(ctx, out, spec.RepoRoot); err != nil {
return err return err
} }
} }
@ -41,8 +40,8 @@ func runWorkspaceDryRun(ctx context.Context, out io.Writer, resolvedPath, branch
// the repo has untracked non-ignored files that will NOT be copied // the repo has untracked non-ignored files that will NOT be copied
// because --include-untracked was not passed. Silent when there are // because --include-untracked was not passed. Silent when there are
// no such files, or when the count can't be determined. // no such files, or when the count can't be determined.
func noteUntrackedSkipped(ctx context.Context, out io.Writer, repoRoot string) error { func (d *deps) noteUntrackedSkipped(ctx context.Context, out io.Writer, repoRoot string) error {
count, err := workspace.CountUntrackedPaths(ctx, repoRoot) count, err := d.repoInspector.CountUntrackedPaths(ctx, repoRoot)
if err != nil { if err != nil {
return err return err
} }

View file

@ -15,6 +15,7 @@ import (
"banger/internal/api" "banger/internal/api"
"banger/internal/buildinfo" "banger/internal/buildinfo"
"banger/internal/config" "banger/internal/config"
ws "banger/internal/daemon/workspace"
"banger/internal/model" "banger/internal/model"
"banger/internal/paths" "banger/internal/paths"
"banger/internal/rpc" "banger/internal/rpc"
@ -616,11 +617,12 @@ func wireServices(d *Daemon) {
} }
if d.ws == nil { if d.ws == nil {
d.ws = newWorkspaceService(workspaceServiceDeps{ d.ws = newWorkspaceService(workspaceServiceDeps{
runner: d.runner, runner: d.runner,
logger: d.logger, logger: d.logger,
config: d.config, config: d.config,
layout: d.layout, layout: d.layout,
store: d.store, store: d.store,
repoInspector: ws.NewInspector(),
vmResolver: func(ctx context.Context, idOrName string) (model.VMRecord, error) { vmResolver: func(ctx context.Context, idOrName string) (model.VMRecord, error) {
return d.vm.FindVM(ctx, idOrName) return d.vm.FindVM(ctx, idOrName)
}, },
@ -655,6 +657,7 @@ func wireServices(d *Daemon) {
guestDial: d.guestDial, guestDial: d.guestDial,
capHooks: d.buildCapabilityHooks(), capHooks: d.buildCapabilityHooks(),
beginOperation: d.beginOperation, beginOperation: d.beginOperation,
vsockHostDevice: defaultVsockHostDevice,
}) })
} }
if len(d.vmCaps) == 0 { if len(d.vmCaps) == 0 {

View file

@ -3,18 +3,10 @@ package daemon
import ( import (
"context" "context"
"strings" "strings"
"banger/internal/system"
"banger/internal/vmdns"
) )
const vmResolverRouteDomain = "~vm" const vmResolverRouteDomain = "~vm"
var (
lookupExecutableFunc = system.LookupExecutable
vmDNSAddrFunc = func(server *vmdns.Server) string { return server.Addr() }
)
func (n *HostNetwork) syncVMDNSResolverRouting(ctx context.Context) error { func (n *HostNetwork) syncVMDNSResolverRouting(ctx context.Context) error {
if n == nil || n.vmDNS == nil { if n == nil || n.vmDNS == nil {
return nil return nil
@ -22,13 +14,13 @@ func (n *HostNetwork) syncVMDNSResolverRouting(ctx context.Context) error {
if strings.TrimSpace(n.config.BridgeName) == "" { if strings.TrimSpace(n.config.BridgeName) == "" {
return nil return nil
} }
if _, err := lookupExecutableFunc("resolvectl"); err != nil { if _, err := n.lookupExecutable("resolvectl"); err != nil {
return nil return nil
} }
if _, err := n.runner.Run(ctx, "ip", "link", "show", n.config.BridgeName); err != nil { if _, err := n.runner.Run(ctx, "ip", "link", "show", n.config.BridgeName); err != nil {
return nil return nil
} }
serverAddr := strings.TrimSpace(vmDNSAddrFunc(n.vmDNS)) serverAddr := strings.TrimSpace(n.vmDNSAddr(n.vmDNS))
if serverAddr == "" { if serverAddr == "" {
return nil return nil
} }
@ -46,7 +38,7 @@ func (n *HostNetwork) clearVMDNSResolverRouting(ctx context.Context) error {
if n == nil || strings.TrimSpace(n.config.BridgeName) == "" { if n == nil || strings.TrimSpace(n.config.BridgeName) == "" {
return nil return nil
} }
if _, err := lookupExecutableFunc("resolvectl"); err != nil { if _, err := n.lookupExecutable("resolvectl"); err != nil {
return nil return nil
} }
if _, err := n.runner.Run(ctx, "ip", "link", "show", n.config.BridgeName); err != nil { if _, err := n.runner.Run(ctx, "ip", "link", "show", n.config.BridgeName); err != nil {

View file

@ -9,20 +9,6 @@ import (
) )
func TestSyncVMDNSResolverRoutingConfiguresResolved(t *testing.T) { func TestSyncVMDNSResolverRoutingConfiguresResolved(t *testing.T) {
origLookup := lookupExecutableFunc
origAddr := vmDNSAddrFunc
t.Cleanup(func() {
lookupExecutableFunc = origLookup
vmDNSAddrFunc = origAddr
})
lookupExecutableFunc = func(name string) (string, error) {
if name == "resolvectl" {
return "/usr/bin/resolvectl", nil
}
return "", nil
}
vmDNSAddrFunc = func(*vmdns.Server) string { return "127.0.0.1:42069" }
runner := &scriptedRunner{ runner := &scriptedRunner{
t: t, t: t,
steps: []runnerStep{ steps: []runnerStep{
@ -33,7 +19,16 @@ func TestSyncVMDNSResolverRoutingConfiguresResolved(t *testing.T) {
}, },
} }
cfg := model.DaemonConfig{BridgeName: model.DefaultBridgeName} cfg := model.DaemonConfig{BridgeName: model.DefaultBridgeName}
n := &HostNetwork{runner: runner, config: cfg, vmDNS: new(vmdns.Server)} n := &HostNetwork{
runner: runner, config: cfg, vmDNS: new(vmdns.Server),
lookupExecutable: func(name string) (string, error) {
if name == "resolvectl" {
return "/usr/bin/resolvectl", nil
}
return "", nil
},
vmDNSAddr: func(*vmdns.Server) string { return "127.0.0.1:42069" },
}
if err := n.syncVMDNSResolverRouting(context.Background()); err != nil { if err := n.syncVMDNSResolverRouting(context.Background()); err != nil {
t.Fatalf("syncVMDNSResolverRouting: %v", err) t.Fatalf("syncVMDNSResolverRouting: %v", err)
@ -42,17 +37,6 @@ func TestSyncVMDNSResolverRoutingConfiguresResolved(t *testing.T) {
} }
func TestClearVMDNSResolverRoutingRevertsBridgeConfig(t *testing.T) { func TestClearVMDNSResolverRoutingRevertsBridgeConfig(t *testing.T) {
origLookup := lookupExecutableFunc
t.Cleanup(func() {
lookupExecutableFunc = origLookup
})
lookupExecutableFunc = func(name string) (string, error) {
if name == "resolvectl" {
return "/usr/bin/resolvectl", nil
}
return "", nil
}
runner := &scriptedRunner{ runner := &scriptedRunner{
t: t, t: t,
steps: []runnerStep{ steps: []runnerStep{
@ -61,7 +45,15 @@ func TestClearVMDNSResolverRoutingRevertsBridgeConfig(t *testing.T) {
}, },
} }
cfg := model.DaemonConfig{BridgeName: model.DefaultBridgeName} cfg := model.DaemonConfig{BridgeName: model.DefaultBridgeName}
n := &HostNetwork{runner: runner, config: cfg} n := &HostNetwork{
runner: runner, config: cfg,
lookupExecutable: func(name string) (string, error) {
if name == "resolvectl" {
return "/usr/bin/resolvectl", nil
}
return "", nil
},
}
if err := n.clearVMDNSResolverRouting(context.Background()); err != nil { if err := n.clearVMDNSResolverRouting(context.Background()); err != nil {
t.Fatalf("clearVMDNSResolverRouting: %v", err) t.Fatalf("clearVMDNSResolverRouting: %v", err)

View file

@ -184,7 +184,7 @@ func (d *Daemon) vsockChecks() *system.Preflight {
} else { } else {
checks.Addf("%v", err) checks.Addf("%v", err)
} }
checks.RequireFile(vsockHostDevicePath, "vsock host device", "load the vhost_vsock kernel module on the host") checks.RequireFile(d.vm.vsockHostDevice, "vsock host device", "load the vhost_vsock kernel module on the host")
return checks return checks
} }

View file

@ -41,6 +41,12 @@ type HostNetwork struct {
tapPool tapPool tapPool tapPool
vmDNS *vmdns.Server vmDNS *vmdns.Server
// Test seams. Default to real implementations at construction;
// tests build HostNetwork with stubs instead of mutating package
// globals, so parallel tests can't race each other's fake state.
lookupExecutable func(name string) (string, error)
vmDNSAddr func(server *vmdns.Server) string
} }
// hostNetworkDeps is the explicit wiring bag newHostNetwork expects. // hostNetworkDeps is the explicit wiring bag newHostNetwork expects.
@ -56,11 +62,13 @@ type hostNetworkDeps struct {
func newHostNetwork(deps hostNetworkDeps) *HostNetwork { func newHostNetwork(deps hostNetworkDeps) *HostNetwork {
return &HostNetwork{ return &HostNetwork{
runner: deps.runner, runner: deps.runner,
logger: deps.logger, logger: deps.logger,
config: deps.config, config: deps.config,
layout: deps.layout, layout: deps.layout,
closing: deps.closing, closing: deps.closing,
lookupExecutable: system.LookupExecutable,
vmDNSAddr: func(server *vmdns.Server) string { return server.Addr() },
} }
} }

View file

@ -42,11 +42,7 @@ func TestNewDaemonLoggerEmitsJSONAtConfiguredLevel(t *testing.T) {
func TestStartVMLockedLogsBridgeFailure(t *testing.T) { func TestStartVMLockedLogsBridgeFailure(t *testing.T) {
ctx := context.Background() ctx := context.Background()
origVsockHostDevicePath := vsockHostDevicePath vsockDevicePath := filepath.Join(t.TempDir(), "vhost-vsock")
vsockHostDevicePath = filepath.Join(t.TempDir(), "vhost-vsock")
t.Cleanup(func() {
vsockHostDevicePath = origVsockHostDevicePath
})
binDir := t.TempDir() binDir := t.TempDir()
for _, name := range []string{ for _, name := range []string{
"sudo", "ip", "dmsetup", "losetup", "blockdev", "truncate", "pgrep", "ps", "sudo", "ip", "dmsetup", "losetup", "blockdev", "truncate", "pgrep", "ps",
@ -62,7 +58,7 @@ func TestStartVMLockedLogsBridgeFailure(t *testing.T) {
if err := os.WriteFile(firecrackerBin, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { if err := os.WriteFile(firecrackerBin, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("write firecracker: %v", err) t.Fatalf("write firecracker: %v", err)
} }
if err := os.WriteFile(vsockHostDevicePath, []byte{}, 0o644); err != nil { if err := os.WriteFile(vsockDevicePath, []byte{}, 0o644); err != nil {
t.Fatalf("write vsock host device: %v", err) t.Fatalf("write vsock host device: %v", err)
} }
if err := os.WriteFile(vsockHelper, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { if err := os.WriteFile(vsockHelper, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
@ -115,6 +111,7 @@ func TestStartVMLockedLogsBridgeFailure(t *testing.T) {
logger: logger, logger: logger,
} }
wireServices(d) wireServices(d)
d.vm.vsockHostDevice = vsockDevicePath
_, err = d.vm.startVMLocked(ctx, vm, image) _, err = d.vm.startVMLocked(ctx, vm, image)
if err == nil || !strings.Contains(err.Error(), "bridge up failed") { if err == nil || !strings.Contains(err.Error(), "bridge up failed") {

View file

@ -8,7 +8,11 @@ import (
"banger/internal/system" "banger/internal/system"
) )
var vsockHostDevicePath = "/dev/vhost-vsock" // defaultVsockHostDevice is the vhost-vsock device file every
// Firecracker guest relies on to talk to the host via vsock. Tests
// point at a tempfile by setting VMService.vsockHostDevice; production
// wiring defaults the field to this path in wireServices.
const defaultVsockHostDevice = "/dev/vhost-vsock"
func (s *VMService) validateStartPrereqs(ctx context.Context, vm model.VMRecord, image model.Image) error { func (s *VMService) validateStartPrereqs(ctx context.Context, vm model.VMRecord, image model.Image) error {
checks := system.NewPreflight() checks := system.NewPreflight()
@ -33,7 +37,7 @@ func (s *VMService) addBaseStartPrereqs(checks *system.Preflight, image model.Im
} else { } else {
checks.Addf("%v", err) checks.Addf("%v", err)
} }
checks.RequireFile(vsockHostDevicePath, "vsock host device", "load the vhost_vsock kernel module on the host") checks.RequireFile(s.vsockHostDevice, "vsock host device", "load the vhost_vsock kernel module on the host")
checks.RequireFile(image.RootfsPath, "rootfs image", "select a valid registered image") checks.RequireFile(image.RootfsPath, "rootfs image", "select a valid registered image")
checks.RequireFile(image.KernelPath, "kernel image", `re-register or rebuild the image with a valid kernel`) checks.RequireFile(image.KernelPath, "kernel image", `re-register or rebuild the image with a valid kernel`)
if strings.TrimSpace(image.InitrdPath) != "" { if strings.TrimSpace(image.InitrdPath) != "" {

View file

@ -66,6 +66,11 @@ type VMService struct {
// Test seams. // Test seams.
guestWaitForSSH func(context.Context, string, string, time.Duration) error guestWaitForSSH func(context.Context, string, string, time.Duration) error
guestDial func(context.Context, string, string) (guestSSHClient, error) guestDial func(context.Context, string, string) (guestSSHClient, error)
// vsockHostDevice is the path preflight + doctor expect to find for
// the vhost-vsock device. Defaults to defaultVsockHostDevice; tests
// point at a tempfile so RequireFile passes without needing the
// real kernel module loaded.
vsockHostDevice string
// Capability hook dispatch. VMService invokes capabilities via // Capability hook dispatch. VMService invokes capabilities via
// these seams, populated by Daemon.buildCapabilityHooks() at // these seams, populated by Daemon.buildCapabilityHooks() at
@ -104,9 +109,14 @@ type vmServiceDeps struct {
guestDial func(context.Context, string, string) (guestSSHClient, error) guestDial func(context.Context, string, string) (guestSSHClient, error)
capHooks capabilityHooks capHooks capabilityHooks
beginOperation func(name string, attrs ...any) *operationLog beginOperation func(name string, attrs ...any) *operationLog
vsockHostDevice string
} }
func newVMService(deps vmServiceDeps) *VMService { func newVMService(deps vmServiceDeps) *VMService {
vsockPath := deps.vsockHostDevice
if vsockPath == "" {
vsockPath = defaultVsockHostDevice
}
return &VMService{ return &VMService{
runner: deps.runner, runner: deps.runner,
logger: deps.logger, logger: deps.logger,
@ -120,6 +130,7 @@ func newVMService(deps vmServiceDeps) *VMService {
guestDial: deps.guestDial, guestDial: deps.guestDial,
capHooks: deps.capHooks, capHooks: deps.capHooks,
beginOperation: deps.beginOperation, beginOperation: deps.beginOperation,
vsockHostDevice: vsockPath,
handles: newHandleCache(), handles: newHandleCache(),
} }
} }

View file

@ -24,14 +24,26 @@ func (s *WorkspaceService) workspaceInspectRepoHook(ctx context.Context, sourceP
if s != nil && s.workspaceInspectRepo != nil { if s != nil && s.workspaceInspectRepo != nil {
return s.workspaceInspectRepo(ctx, sourcePath, branchName, fromRef, includeUntracked) return s.workspaceInspectRepo(ctx, sourcePath, branchName, fromRef, includeUntracked)
} }
return ws.InspectRepo(ctx, sourcePath, branchName, fromRef, includeUntracked) return s.inspector().InspectRepo(ctx, sourcePath, branchName, fromRef, includeUntracked)
} }
func (s *WorkspaceService) workspaceImportHook(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error { func (s *WorkspaceService) workspaceImportHook(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error {
if s != nil && s.workspaceImport != nil { if s != nil && s.workspaceImport != nil {
return s.workspaceImport(ctx, client, spec, guestPath, mode) return s.workspaceImport(ctx, client, spec, guestPath, mode)
} }
return ws.ImportRepoToGuest(ctx, client, spec, guestPath, mode) return s.inspector().ImportRepoToGuest(ctx, client, spec, guestPath, mode)
}
// inspector returns the service's workspace Inspector, falling back to
// a fresh real-runner Inspector when callers constructed the service
// without wiring one. Keeping the fallback here lets test literals
// that don't care about the Inspector still function without a manual
// NewInspector() call.
func (s *WorkspaceService) inspector() *ws.Inspector {
if s != nil && s.repoInspector != nil {
return s.repoInspector
}
return ws.NewInspector()
} }
func (s *WorkspaceService) ExportVMWorkspace(ctx context.Context, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { func (s *WorkspaceService) ExportVMWorkspace(ctx context.Context, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {

View file

@ -2,6 +2,12 @@
// git repo inspection, shallow copy preparation, guest-side tar import, // git repo inspection, shallow copy preparation, guest-side tar import,
// finalization script generation, and small utilities. // finalization script generation, and small utilities.
// //
// Every helper that needs to run a host command (git or otherwise)
// lives as a method on *Inspector rather than a free function that
// routes through a package global. That way two tests running in
// parallel can each build their own Inspector with a stub Runner
// without fighting over shared state.
//
// The orchestrator methods (ExportVMWorkspace, PrepareVMWorkspace) stay on // The orchestrator methods (ExportVMWorkspace, PrepareVMWorkspace) stay on
// *daemon.Daemon. // *daemon.Daemon.
package workspace package workspace
@ -51,9 +57,28 @@ type GuestClient interface {
StreamTarEntries(ctx context.Context, dir string, entries []string, command string, log io.Writer) error StreamTarEntries(ctx context.Context, dir string, entries []string, command string, log io.Writer) error
} }
// HostCommandOutputFunc runs a host command and returns its combined output. // RunnerFunc is the single-method surface every Inspector needs: run a
// Declared as a package var so tests can substitute a stub runner. // host command with args, return combined output + error. Tests supply
var HostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) { // a stub that records calls and replays canned responses; production
// uses realHostRunner which wraps system.NewRunner.
type RunnerFunc func(ctx context.Context, name string, args ...string) ([]byte, error)
// Inspector bundles the host-command seam for all git-using workspace
// helpers. Construct one at the boundary where you're reading the
// filesystem (CLI deps, WorkspaceService) and call its methods directly;
// don't reach into the struct from helper code.
type Inspector struct {
Runner RunnerFunc
}
// NewInspector returns an Inspector backed by the real host runner.
// Production callers (CLI deps initialisation, daemon WorkspaceService
// wiring) use this; tests construct Inspector{Runner: stub} directly.
func NewInspector() *Inspector {
return &Inspector{Runner: realHostRunner}
}
func realHostRunner(ctx context.Context, name string, args ...string) ([]byte, error) {
runner := system.NewRunner() runner := system.NewRunner()
output, err := runner.Run(ctx, name, args...) output, err := runner.Run(ctx, name, args...)
if err == nil { if err == nil {
@ -72,55 +97,55 @@ var HostCommandOutputFunc = func(ctx context.Context, name string, args ...strin
// submodules, and overlay paths needed for a prepare. Overlay paths // submodules, and overlay paths needed for a prepare. Overlay paths
// cover tracked files by default; untracked non-ignored files are // cover tracked files by default; untracked non-ignored files are
// included only when includeUntracked is true. // included only when includeUntracked is true.
func InspectRepo(ctx context.Context, rawPath, branchName, fromRef string, includeUntracked bool) (RepoSpec, error) { func (i *Inspector) InspectRepo(ctx context.Context, rawPath, branchName, fromRef string, includeUntracked bool) (RepoSpec, error) {
sourcePath, err := ResolveSourcePath(rawPath) sourcePath, err := ResolveSourcePath(rawPath)
if err != nil { if err != nil {
return RepoSpec{}, err return RepoSpec{}, err
} }
repoRoot, err := GitTrimmedOutput(ctx, sourcePath, "rev-parse", "--show-toplevel") repoRoot, err := i.GitTrimmedOutput(ctx, sourcePath, "rev-parse", "--show-toplevel")
if err != nil { if err != nil {
return RepoSpec{}, fmt.Errorf("%s is not inside a git repository", sourcePath) return RepoSpec{}, fmt.Errorf("%s is not inside a git repository", sourcePath)
} }
isBare, err := GitTrimmedOutput(ctx, repoRoot, "rev-parse", "--is-bare-repository") isBare, err := i.GitTrimmedOutput(ctx, repoRoot, "rev-parse", "--is-bare-repository")
if err != nil { if err != nil {
return RepoSpec{}, fmt.Errorf("inspect git repository %s: %w", repoRoot, err) return RepoSpec{}, fmt.Errorf("inspect git repository %s: %w", repoRoot, err)
} }
if isBare == "true" { if isBare == "true" {
return RepoSpec{}, fmt.Errorf("workspace prepare requires a non-bare git repository: %s", repoRoot) return RepoSpec{}, fmt.Errorf("workspace prepare requires a non-bare git repository: %s", repoRoot)
} }
submodules, err := ListSubmodules(ctx, repoRoot) submodules, err := i.ListSubmodules(ctx, repoRoot)
if err != nil { if err != nil {
return RepoSpec{}, err return RepoSpec{}, err
} }
headCommit, err := GitTrimmedOutput(ctx, repoRoot, "rev-parse", "HEAD^{commit}") headCommit, err := i.GitTrimmedOutput(ctx, repoRoot, "rev-parse", "HEAD^{commit}")
if err != nil { if err != nil {
return RepoSpec{}, fmt.Errorf("git repository %s must have at least one commit", repoRoot) return RepoSpec{}, fmt.Errorf("git repository %s must have at least one commit", repoRoot)
} }
currentBranch, err := GitTrimmedOutput(ctx, repoRoot, "branch", "--show-current") currentBranch, err := i.GitTrimmedOutput(ctx, repoRoot, "branch", "--show-current")
if err != nil { if err != nil {
return RepoSpec{}, fmt.Errorf("resolve current branch for %s: %w", repoRoot, err) return RepoSpec{}, fmt.Errorf("resolve current branch for %s: %w", repoRoot, err)
} }
baseCommit := headCommit baseCommit := headCommit
branchName = strings.TrimSpace(branchName) branchName = strings.TrimSpace(branchName)
if branchName != "" { if branchName != "" {
baseCommit, err = GitTrimmedOutput(ctx, repoRoot, "rev-parse", fromRef+"^{commit}") baseCommit, err = i.GitTrimmedOutput(ctx, repoRoot, "rev-parse", fromRef+"^{commit}")
if err != nil { if err != nil {
return RepoSpec{}, fmt.Errorf("resolve workspace from %q: %w", fromRef, err) return RepoSpec{}, fmt.Errorf("resolve workspace from %q: %w", fromRef, err)
} }
} }
gitUserName, err := GitResolvedConfigValue(ctx, repoRoot, "user.name") gitUserName, err := i.GitResolvedConfigValue(ctx, repoRoot, "user.name")
if err != nil { if err != nil {
return RepoSpec{}, fmt.Errorf("resolve git user.name for %s: %w", repoRoot, err) return RepoSpec{}, fmt.Errorf("resolve git user.name for %s: %w", repoRoot, err)
} }
gitUserEmail, err := GitResolvedConfigValue(ctx, repoRoot, "user.email") gitUserEmail, err := i.GitResolvedConfigValue(ctx, repoRoot, "user.email")
if err != nil { if err != nil {
return RepoSpec{}, fmt.Errorf("resolve git user.email for %s: %w", repoRoot, err) return RepoSpec{}, fmt.Errorf("resolve git user.email for %s: %w", repoRoot, err)
} }
originURL, err := GitResolvedConfigValue(ctx, repoRoot, "remote.origin.url") originURL, err := i.GitResolvedConfigValue(ctx, repoRoot, "remote.origin.url")
if err != nil { if err != nil {
return RepoSpec{}, fmt.Errorf("resolve origin url for %s: %w", repoRoot, err) return RepoSpec{}, fmt.Errorf("resolve origin url for %s: %w", repoRoot, err)
} }
overlayPaths, err := ListOverlayPaths(ctx, repoRoot, includeUntracked) overlayPaths, err := i.ListOverlayPaths(ctx, repoRoot, includeUntracked)
if err != nil { if err != nil {
return RepoSpec{}, err return RepoSpec{}, err
} }
@ -142,7 +167,7 @@ func InspectRepo(ctx context.Context, rawPath, branchName, fromRef string, inclu
// ImportRepoToGuest materialises spec inside the guest at guestPath. Mode // ImportRepoToGuest materialises spec inside the guest at guestPath. Mode
// selects between full copy, metadata-only, or shallow metadata + overlay. // selects between full copy, metadata-only, or shallow metadata + overlay.
func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error { func (i *Inspector) ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error {
switch mode { switch mode {
case model.WorkspacePrepareModeFullCopy: case model.WorkspacePrepareModeFullCopy:
var copyLog bytes.Buffer var copyLog bytes.Buffer
@ -156,7 +181,7 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g
} }
return nil return nil
case model.WorkspacePrepareModeMetadataOnly, model.WorkspacePrepareModeShallowOverlay: case model.WorkspacePrepareModeMetadataOnly, model.WorkspacePrepareModeShallowOverlay:
repoCopyDir, cleanup, err := PrepareRepoCopy(ctx, spec) repoCopyDir, cleanup, err := i.PrepareRepoCopy(ctx, spec)
if err != nil { if err != nil {
return err return err
} }
@ -212,7 +237,7 @@ func FinalizeScript(spec RepoSpec, guestPath string, mode model.WorkspacePrepare
// PrepareRepoCopy materialises a shallow clone of spec into a temp dir. The // PrepareRepoCopy materialises a shallow clone of spec into a temp dir. The
// returned cleanup removes the temp root. // returned cleanup removes the temp root.
func PrepareRepoCopy(ctx context.Context, spec RepoSpec) (string, func(), error) { func (i *Inspector) PrepareRepoCopy(ctx context.Context, spec RepoSpec) (string, func(), error) {
tempRoot, err := os.MkdirTemp("", "banger-workspace-*") tempRoot, err := os.MkdirTemp("", "banger-workspace-*")
if err != nil { if err != nil {
return "", nil, err return "", nil, err
@ -224,7 +249,7 @@ func PrepareRepoCopy(ctx context.Context, spec RepoSpec) (string, func(), error)
cloneArgs = append(cloneArgs, "--single-branch", "--branch", spec.CurrentBranch) cloneArgs = append(cloneArgs, "--single-branch", "--branch", spec.CurrentBranch)
} }
cloneArgs = append(cloneArgs, GitFileURL(spec.RepoRoot), repoCopyDir) cloneArgs = append(cloneArgs, GitFileURL(spec.RepoRoot), repoCopyDir)
if err := RunHostCommand(ctx, "git", cloneArgs...); err != nil { if err := i.RunHostCommand(ctx, "git", cloneArgs...); err != nil {
cleanup() cleanup()
return "", nil, fmt.Errorf("clone shallow workspace repo copy: %w", err) return "", nil, fmt.Errorf("clone shallow workspace repo copy: %w", err)
} }
@ -232,19 +257,19 @@ func PrepareRepoCopy(ctx context.Context, spec RepoSpec) (string, func(), error)
if strings.TrimSpace(spec.BranchName) != "" { if strings.TrimSpace(spec.BranchName) != "" {
checkoutCommit = spec.BaseCommit checkoutCommit = spec.BaseCommit
} }
if err := RunHostCommand(ctx, "git", "-C", repoCopyDir, "cat-file", "-e", checkoutCommit+"^{commit}"); err != nil { if err := i.RunHostCommand(ctx, "git", "-C", repoCopyDir, "cat-file", "-e", checkoutCommit+"^{commit}"); err != nil {
if err := RunHostCommand(ctx, "git", "-C", repoCopyDir, "fetch", "--depth", fmt.Sprintf("%d", ShallowFetchDepth), GitFileURL(spec.RepoRoot), checkoutCommit); err != nil { if err := i.RunHostCommand(ctx, "git", "-C", repoCopyDir, "fetch", "--depth", fmt.Sprintf("%d", ShallowFetchDepth), GitFileURL(spec.RepoRoot), checkoutCommit); err != nil {
cleanup() cleanup()
return "", nil, fmt.Errorf("fetch shallow workspace repo commit %s: %w", checkoutCommit, err) return "", nil, fmt.Errorf("fetch shallow workspace repo commit %s: %w", checkoutCommit, err)
} }
} }
if strings.TrimSpace(spec.OriginURL) != "" { if strings.TrimSpace(spec.OriginURL) != "" {
if err := RunHostCommand(ctx, "git", "-C", repoCopyDir, "remote", "set-url", "origin", spec.OriginURL); err != nil { if err := i.RunHostCommand(ctx, "git", "-C", repoCopyDir, "remote", "set-url", "origin", spec.OriginURL); err != nil {
cleanup() cleanup()
return "", nil, fmt.Errorf("set workspace origin remote: %w", err) return "", nil, fmt.Errorf("set workspace origin remote: %w", err)
} }
} else { } else {
if err := RunHostCommand(ctx, "git", "-C", repoCopyDir, "remote", "remove", "origin"); err != nil { if err := i.RunHostCommand(ctx, "git", "-C", repoCopyDir, "remote", "remove", "origin"); err != nil {
cleanup() cleanup()
return "", nil, fmt.Errorf("remove workspace placeholder origin remote: %w", err) return "", nil, fmt.Errorf("remove workspace placeholder origin remote: %w", err)
} }
@ -273,8 +298,8 @@ func ResolveSourcePath(rawPath string) (string, error) {
} }
// ListSubmodules returns the gitlink paths in repoRoot (mode 160000 entries). // ListSubmodules returns the gitlink paths in repoRoot (mode 160000 entries).
func ListSubmodules(ctx context.Context, repoRoot string) ([]string, error) { func (i *Inspector) ListSubmodules(ctx context.Context, repoRoot string) ([]string, error) {
output, err := GitOutput(ctx, repoRoot, "ls-files", "--stage", "-z") output, err := i.GitOutput(ctx, repoRoot, "ls-files", "--stage", "-z")
if err != nil { if err != nil {
return nil, fmt.Errorf("inspect workspace git index for %s: %w", repoRoot, err) return nil, fmt.Errorf("inspect workspace git index for %s: %w", repoRoot, err)
} }
@ -304,8 +329,8 @@ func ListSubmodules(ctx context.Context, repoRoot string) ([]string, error) {
// leave the developer's machine. Callers that genuinely want the // leave the developer's machine. Callers that genuinely want the
// fuller set (scratch repos, vendored binaries the user is iterating // fuller set (scratch repos, vendored binaries the user is iterating
// on) opt in explicitly. // on) opt in explicitly.
func ListOverlayPaths(ctx context.Context, repoRoot string, includeUntracked bool) ([]string, error) { func (i *Inspector) ListOverlayPaths(ctx context.Context, repoRoot string, includeUntracked bool) ([]string, error) {
trackedOutput, err := GitOutput(ctx, repoRoot, "ls-files", "-z") trackedOutput, err := i.GitOutput(ctx, repoRoot, "ls-files", "-z")
if err != nil { if err != nil {
return nil, fmt.Errorf("list tracked files for %s: %w", repoRoot, err) return nil, fmt.Errorf("list tracked files for %s: %w", repoRoot, err)
} }
@ -325,7 +350,7 @@ func ListOverlayPaths(ctx context.Context, repoRoot string, includeUntracked boo
paths = append(paths, relPath) paths = append(paths, relPath)
} }
if includeUntracked { if includeUntracked {
untrackedOutput, err := GitOutput(ctx, repoRoot, "ls-files", "--others", "--exclude-standard", "-z") untrackedOutput, err := i.GitOutput(ctx, repoRoot, "ls-files", "--others", "--exclude-standard", "-z")
if err != nil { if err != nil {
return nil, fmt.Errorf("list untracked files for %s: %w", repoRoot, err) return nil, fmt.Errorf("list untracked files for %s: %w", repoRoot, err)
} }
@ -348,8 +373,8 @@ func ListOverlayPaths(ctx context.Context, repoRoot string, includeUntracked boo
// files in repoRoot. Used by the CLI to warn the user when they are // files in repoRoot. Used by the CLI to warn the user when they are
// about to ship a workspace that has local-but-unignored scratch // about to ship a workspace that has local-but-unignored scratch
// files which, under the default, will be skipped. // files which, under the default, will be skipped.
func CountUntrackedPaths(ctx context.Context, repoRoot string) (int, error) { func (i *Inspector) CountUntrackedPaths(ctx context.Context, repoRoot string) (int, error) {
untrackedOutput, err := GitOutput(ctx, repoRoot, "ls-files", "--others", "--exclude-standard", "-z") untrackedOutput, err := i.GitOutput(ctx, repoRoot, "ls-files", "--others", "--exclude-standard", "-z")
if err != nil { if err != nil {
return 0, fmt.Errorf("list untracked files for %s: %w", repoRoot, err) return 0, fmt.Errorf("list untracked files for %s: %w", repoRoot, err)
} }
@ -377,18 +402,18 @@ func ParsePrepareMode(raw string) (model.WorkspacePrepareMode, error) {
} }
// GitOutput runs `git [-C dir] args...` and returns its raw stdout. // GitOutput runs `git [-C dir] args...` and returns its raw stdout.
func GitOutput(ctx context.Context, dir string, args ...string) ([]byte, error) { func (i *Inspector) GitOutput(ctx context.Context, dir string, args ...string) ([]byte, error) {
fullArgs := make([]string, 0, len(args)+2) fullArgs := make([]string, 0, len(args)+2)
if strings.TrimSpace(dir) != "" { if strings.TrimSpace(dir) != "" {
fullArgs = append(fullArgs, "-C", dir) fullArgs = append(fullArgs, "-C", dir)
} }
fullArgs = append(fullArgs, args...) fullArgs = append(fullArgs, args...)
return HostCommandOutputFunc(ctx, "git", fullArgs...) return i.Runner(ctx, "git", fullArgs...)
} }
// GitTrimmedOutput returns GitOutput with surrounding whitespace trimmed. // GitTrimmedOutput returns GitOutput with surrounding whitespace trimmed.
func GitTrimmedOutput(ctx context.Context, dir string, args ...string) (string, error) { func (i *Inspector) GitTrimmedOutput(ctx context.Context, dir string, args ...string) (string, error) {
output, err := GitOutput(ctx, dir, args...) output, err := i.GitOutput(ctx, dir, args...)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -396,8 +421,8 @@ func GitTrimmedOutput(ctx context.Context, dir string, args ...string) (string,
} }
// GitResolvedConfigValue reads git config key with --default "" --get. // GitResolvedConfigValue reads git config key with --default "" --get.
func GitResolvedConfigValue(ctx context.Context, dir, key string) (string, error) { func (i *Inspector) GitResolvedConfigValue(ctx context.Context, dir, key string) (string, error) {
return GitTrimmedOutput(ctx, dir, "config", "--default", "", "--get", key) return i.GitTrimmedOutput(ctx, dir, "config", "--default", "", "--get", key)
} }
// ParseNullSeparatedOutput splits on NULs and trims, returning non-empty // ParseNullSeparatedOutput splits on NULs and trims, returning non-empty
@ -415,10 +440,10 @@ func ParseNullSeparatedOutput(output []byte) []string {
return values return values
} }
// RunHostCommand runs a host command via HostCommandOutputFunc, discarding // RunHostCommand runs a host command via the Inspector's Runner,
// its stdout. // discarding its stdout.
func RunHostCommand(ctx context.Context, name string, args ...string) error { func (i *Inspector) RunHostCommand(ctx context.Context, name string, args ...string) error {
_, err := HostCommandOutputFunc(ctx, name, args...) _, err := i.Runner(ctx, name, args...)
return err return err
} }

View file

@ -59,7 +59,8 @@ func seedRepo(t *testing.T) string {
func TestListOverlayPaths_TrackedOnlyByDefault(t *testing.T) { func TestListOverlayPaths_TrackedOnlyByDefault(t *testing.T) {
repo := seedRepo(t) repo := seedRepo(t)
got, err := ListOverlayPaths(context.Background(), repo, false) i := NewInspector()
got, err := i.ListOverlayPaths(context.Background(), repo, false)
if err != nil { if err != nil {
t.Fatalf("ListOverlayPaths: %v", err) t.Fatalf("ListOverlayPaths: %v", err)
} }
@ -71,7 +72,8 @@ func TestListOverlayPaths_TrackedOnlyByDefault(t *testing.T) {
func TestListOverlayPaths_IncludeUntracked(t *testing.T) { func TestListOverlayPaths_IncludeUntracked(t *testing.T) {
repo := seedRepo(t) repo := seedRepo(t)
got, err := ListOverlayPaths(context.Background(), repo, true) i := NewInspector()
got, err := i.ListOverlayPaths(context.Background(), repo, true)
if err != nil { if err != nil {
t.Fatalf("ListOverlayPaths: %v", err) t.Fatalf("ListOverlayPaths: %v", err)
} }
@ -89,7 +91,8 @@ func TestListOverlayPaths_IncludeUntracked(t *testing.T) {
func TestCountUntrackedPaths(t *testing.T) { func TestCountUntrackedPaths(t *testing.T) {
repo := seedRepo(t) repo := seedRepo(t)
count, err := CountUntrackedPaths(context.Background(), repo) i := NewInspector()
count, err := i.CountUntrackedPaths(context.Background(), repo)
if err != nil { if err != nil {
t.Fatalf("CountUntrackedPaths: %v", err) t.Fatalf("CountUntrackedPaths: %v", err)
} }

View file

@ -45,6 +45,13 @@ type WorkspaceService struct {
beginOperation func(name string, attrs ...any) *operationLog beginOperation func(name string, attrs ...any) *operationLog
// repoInspector is the Inspector used by the real InspectRepo /
// ImportRepoToGuest fallbacks when the test seams below aren't
// set. wireServices installs the production one; tests that want
// to intercept only the host-command surface (not the whole
// inspect/import hook) can assign a stub-runner Inspector here.
repoInspector *ws.Inspector
// Test seams. // Test seams.
workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string, includeUntracked bool) (ws.RepoSpec, error) workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string, includeUntracked bool) (ws.RepoSpec, error)
workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error
@ -56,6 +63,7 @@ type workspaceServiceDeps struct {
config model.DaemonConfig config model.DaemonConfig
layout paths.Layout layout paths.Layout
store *store.Store store *store.Store
repoInspector *ws.Inspector
vmResolver func(ctx context.Context, idOrName string) (model.VMRecord, error) vmResolver func(ctx context.Context, idOrName string) (model.VMRecord, error)
aliveChecker func(vm model.VMRecord) bool aliveChecker func(vm model.VMRecord) bool
waitGuestSSH func(ctx context.Context, address string, interval time.Duration) error waitGuestSSH func(ctx context.Context, address string, interval time.Duration) error
@ -73,6 +81,7 @@ func newWorkspaceService(deps workspaceServiceDeps) *WorkspaceService {
config: deps.config, config: deps.config,
layout: deps.layout, layout: deps.layout,
store: deps.store, store: deps.store,
repoInspector: deps.repoInspector,
vmResolver: deps.vmResolver, vmResolver: deps.vmResolver,
aliveChecker: deps.aliveChecker, aliveChecker: deps.aliveChecker,
waitGuestSSH: deps.waitGuestSSH, waitGuestSSH: deps.waitGuestSSH,