Add repo-backed vm run command

Create a CLI-only banger vm run [path] flow that resolves the enclosing git repository, creates a VM, imports a guest checkout, and launches opencode attach automatically from the host.

Build the guest checkout by bundling git history plus the resolved base and head commits, cloning that bundle in the guest, and overlaying tracked plus untracked non-ignored files over SSH so local working-tree changes carry over. Support guest-only branch creation with --branch and --from, reject bare repos and submodules, and add selective tar helpers plus CLI seams to keep the workflow testable.

Validate with go test ./..., make build, banger vm run --help, and the expected --from requires --branch error path.
This commit is contained in:
Thales Maciel 2026-03-21 23:34:20 -03:00
parent 8bcc767824
commit 2ebc6f99c6
No known key found for this signature in database
GPG key ID: 33112E6833C34679
5 changed files with 929 additions and 0 deletions

View file

@ -113,6 +113,15 @@ Create and use a VM:
`vm create` stays synchronous by default, but on a TTY it now shows live progress until the VM is fully ready.
Start a repo-backed VM session and attach `opencode` automatically:
```bash
./build/bin/banger vm run
./build/bin/banger vm run ../some-repo --branch feature/alpine --from HEAD
```
`vm run` resolves the enclosing git repository, creates a VM, copies a git checkout plus current tracked and untracked non-ignored files into `/root/<repo-name>`, and then runs `opencode attach` from the host against the guest.
## Web UI
`bangerd` serves a local web UI by default at:

View file

@ -1,11 +1,13 @@
package cli
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"path/filepath"
@ -19,6 +21,7 @@ import (
"banger/internal/api"
"banger/internal/config"
"banger/internal/daemon"
"banger/internal/guest"
"banger/internal/hostnat"
"banger/internal/imagepreset"
"banger/internal/model"
@ -44,6 +47,26 @@ var (
sshCmd.Stdin = stdin
return sshCmd.Run()
}
opencodeExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
opencodeCmd := exec.CommandContext(ctx, "opencode", args...)
opencodeCmd.Stdout = stdout
opencodeCmd.Stderr = stderr
opencodeCmd.Stdin = stdin
return opencodeCmd.Run()
}
hostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) {
cmd := exec.CommandContext(ctx, name, args...)
output, err := cmd.CombinedOutput()
if err == nil {
return output, nil
}
command := strings.TrimSpace(strings.Join(append([]string{name}, args...), " "))
detail := strings.TrimSpace(string(output))
if detail == "" {
return output, fmt.Errorf("%s: %w", command, err)
}
return output, fmt.Errorf("%s: %w: %s", command, err, detail)
}
vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
return rpc.Call[api.VMHealthResult](ctx, socketPath, "vm.health", api.VMRefParams{IDOrName: idOrName})
}
@ -60,8 +83,35 @@ var (
vmPortsFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) {
return rpc.Call[api.VMPortsResult](ctx, socketPath, "vm.ports", api.VMRefParams{IDOrName: idOrName})
}
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
return guest.WaitForSSH(ctx, address, privateKeyPath, interval)
}
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
return guest.Dial(ctx, address, privateKeyPath)
}
cwdFunc = os.Getwd
)
type vmRunGuestClient interface {
Close() error
UploadFile(ctx context.Context, remotePath string, mode os.FileMode, data []byte, logWriter io.Writer) error
RunScript(ctx context.Context, script string, logWriter io.Writer) error
StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error
}
type vmRunRepoSpec struct {
SourcePath string
RepoRoot string
RepoName string
HeadCommit string
CurrentBranch string
BranchName string
BaseCommit string
OverlayPaths []string
}
const vmRunGuestBundlePath = "/tmp/banger-vm-run.bundle"
func NewBangerCommand() *cobra.Command {
root := &cobra.Command{
Use: "banger",
@ -358,6 +408,7 @@ func newVMCommand() *cobra.Command {
}
cmd.AddCommand(
newVMCreateCommand(),
newVMRunCommand(),
newVMListCommand(),
newVMShowCommand(),
newVMActionCommand("start", "Start a VM", "vm.start"),
@ -374,6 +425,76 @@ func newVMCommand() *cobra.Command {
return cmd
}
func newVMRunCommand() *cobra.Command {
var (
name string
imageName string
vcpu = model.DefaultVCPUCount
memory = model.DefaultMemoryMiB
systemOverlaySize = model.FormatSizeBytes(model.DefaultSystemOverlaySize)
workDiskSize = model.FormatSizeBytes(model.DefaultWorkDiskSize)
natEnabled bool
branchName string
fromRef = "HEAD"
)
cmd := &cobra.Command{
Use: "run [path]",
Short: "Create a repo-backed VM session and attach opencode",
Args: maxArgsUsage(1, "usage: banger vm run [path]"),
RunE: func(cmd *cobra.Command, args []string) error {
if cmd.Flags().Changed("branch") && strings.TrimSpace(branchName) == "" {
return errors.New("--branch requires a branch name")
}
if cmd.Flags().Changed("from") && strings.TrimSpace(branchName) == "" {
return errors.New("--from requires --branch")
}
sourcePath := ""
if len(args) == 1 {
sourcePath = args[0]
}
spec, err := inspectVMRunRepo(cmd.Context(), sourcePath, branchName, fromRef)
if err != nil {
return err
}
layout, err := paths.Resolve()
if err != nil {
return err
}
cfg, err := config.Load(layout)
if err != nil {
return err
}
if err := validateVMRunPrereqs(cfg); err != nil {
return err
}
params, err := vmCreateParamsFromFlags(cmd, name, imageName, vcpu, memory, systemOverlaySize, workDiskSize, natEnabled, false)
if err != nil {
return err
}
if err := system.EnsureSudo(cmd.Context()); err != nil {
return err
}
layout, cfg, err = ensureDaemon(cmd.Context())
if err != nil {
return err
}
return runVMRun(cmd.Context(), layout.SocketPath, cfg, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), params, spec)
},
}
cmd.Flags().StringVar(&name, "name", "", "vm name")
cmd.Flags().StringVar(&imageName, "image", "", "image name or id")
cmd.Flags().IntVar(&vcpu, "vcpu", model.DefaultVCPUCount, "vcpu count")
cmd.Flags().IntVar(&memory, "memory", model.DefaultMemoryMiB, "memory in MiB")
cmd.Flags().StringVar(&systemOverlaySize, "system-overlay-size", model.FormatSizeBytes(model.DefaultSystemOverlaySize), "system overlay size")
cmd.Flags().StringVar(&workDiskSize, "disk-size", model.FormatSizeBytes(model.DefaultWorkDiskSize), "work disk size")
cmd.Flags().BoolVar(&natEnabled, "nat", false, "enable NAT")
cmd.Flags().StringVar(&branchName, "branch", "", "create and switch to a new guest branch")
cmd.Flags().StringVar(&fromRef, "from", "HEAD", "base ref for --branch")
return cmd
}
func newVMKillCommand() *cobra.Command {
var signal string
cmd := &cobra.Command{
@ -876,6 +997,15 @@ func minArgsUsage(n int, usage string) cobra.PositionalArgs {
}
}
func maxArgsUsage(n int, usage string) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
if len(args) > n {
return errors.New(usage)
}
return nil
}
}
type resolvedVMTarget struct {
Index int
Ref string
@ -1244,6 +1374,320 @@ func validateSSHPrereqs(cfg model.DaemonConfig) error {
return checks.Err("ssh preflight failed")
}
func validateVMRunPrereqs(cfg model.DaemonConfig) error {
checks := system.NewPreflight()
checks.RequireCommand("git", "install git")
checks.RequireCommand("opencode", "install opencode")
if strings.TrimSpace(cfg.SSHKeyPath) != "" {
checks.RequireFile(cfg.SSHKeyPath, "ssh private key", `set "ssh_key_path" or let banger create its default key`)
}
return checks.Err("vm run preflight failed")
}
func inspectVMRunRepo(ctx context.Context, rawPath, branchName, fromRef string) (vmRunRepoSpec, error) {
sourcePath, err := resolveVMRunSourcePath(rawPath)
if err != nil {
return vmRunRepoSpec{}, err
}
repoRoot, err := gitTrimmedOutput(ctx, sourcePath, "rev-parse", "--show-toplevel")
if err != nil {
return vmRunRepoSpec{}, fmt.Errorf("%s is not inside a git repository", sourcePath)
}
isBare, err := gitTrimmedOutput(ctx, repoRoot, "rev-parse", "--is-bare-repository")
if err != nil {
return vmRunRepoSpec{}, fmt.Errorf("inspect git repository %s: %w", repoRoot, err)
}
if isBare == "true" {
return vmRunRepoSpec{}, fmt.Errorf("vm run requires a non-bare git repository: %s", repoRoot)
}
if err := ensureVMRunRepoHasNoSubmodules(ctx, repoRoot); err != nil {
return vmRunRepoSpec{}, err
}
headCommit, err := gitTrimmedOutput(ctx, repoRoot, "rev-parse", "HEAD^{commit}")
if err != nil {
return vmRunRepoSpec{}, fmt.Errorf("git repository %s must have at least one commit", repoRoot)
}
currentBranch, err := gitTrimmedOutput(ctx, repoRoot, "branch", "--show-current")
if err != nil {
return vmRunRepoSpec{}, fmt.Errorf("resolve current branch for %s: %w", repoRoot, err)
}
baseCommit := headCommit
branchName = strings.TrimSpace(branchName)
if branchName != "" {
fromRef = strings.TrimSpace(fromRef)
if fromRef == "" {
return vmRunRepoSpec{}, errors.New("--from cannot be empty")
}
baseCommit, err = gitTrimmedOutput(ctx, repoRoot, "rev-parse", fromRef+"^{commit}")
if err != nil {
return vmRunRepoSpec{}, fmt.Errorf("resolve --from %q: %w", fromRef, err)
}
}
overlayPaths, err := listVMRunOverlayPaths(ctx, repoRoot)
if err != nil {
return vmRunRepoSpec{}, err
}
return vmRunRepoSpec{
SourcePath: sourcePath,
RepoRoot: repoRoot,
RepoName: filepath.Base(repoRoot),
HeadCommit: headCommit,
CurrentBranch: currentBranch,
BranchName: branchName,
BaseCommit: baseCommit,
OverlayPaths: overlayPaths,
}, nil
}
func resolveVMRunSourcePath(rawPath string) (string, error) {
if strings.TrimSpace(rawPath) == "" {
wd, err := cwdFunc()
if err != nil {
return "", err
}
rawPath = wd
}
absPath, err := filepath.Abs(rawPath)
if err != nil {
return "", err
}
info, err := os.Stat(absPath)
if err != nil {
return "", err
}
if !info.IsDir() {
return "", fmt.Errorf("%s is not a directory", absPath)
}
return absPath, nil
}
func ensureVMRunRepoHasNoSubmodules(ctx context.Context, repoRoot string) error {
output, err := gitOutput(ctx, repoRoot, "ls-files", "--stage", "-z")
if err != nil {
return fmt.Errorf("inspect git index for %s: %w", repoRoot, err)
}
for _, record := range parseNullSeparatedOutput(output) {
if strings.HasPrefix(record, "160000 ") {
return fmt.Errorf("vm run does not yet support git submodules: %s", repoRoot)
}
}
return nil
}
func listVMRunOverlayPaths(ctx context.Context, repoRoot string) ([]string, error) {
trackedOutput, err := gitOutput(ctx, repoRoot, "ls-files", "-z")
if err != nil {
return nil, fmt.Errorf("list tracked files for %s: %w", repoRoot, err)
}
untrackedOutput, err := gitOutput(ctx, repoRoot, "ls-files", "--others", "--exclude-standard", "-z")
if err != nil {
return nil, fmt.Errorf("list untracked files for %s: %w", repoRoot, err)
}
paths := make([]string, 0)
seen := make(map[string]struct{})
for _, relPath := range parseNullSeparatedOutput(trackedOutput) {
if relPath == "" {
continue
}
if _, err := os.Lstat(filepath.Join(repoRoot, relPath)); err != nil {
if os.IsNotExist(err) {
continue
}
return nil, err
}
seen[relPath] = struct{}{}
paths = append(paths, relPath)
}
for _, relPath := range parseNullSeparatedOutput(untrackedOutput) {
if relPath == "" {
continue
}
if _, ok := seen[relPath]; ok {
continue
}
seen[relPath] = struct{}{}
paths = append(paths, relPath)
}
sort.Strings(paths)
return paths, nil
}
func gitOutput(ctx context.Context, dir string, args ...string) ([]byte, error) {
fullArgs := make([]string, 0, len(args)+2)
if strings.TrimSpace(dir) != "" {
fullArgs = append(fullArgs, "-C", dir)
}
fullArgs = append(fullArgs, args...)
return hostCommandOutputFunc(ctx, "git", fullArgs...)
}
func gitTrimmedOutput(ctx context.Context, dir string, args ...string) (string, error) {
output, err := gitOutput(ctx, dir, args...)
if err != nil {
return "", err
}
return strings.TrimSpace(string(output)), nil
}
func parseNullSeparatedOutput(output []byte) []string {
chunks := bytes.Split(output, []byte{0})
values := make([]string, 0, len(chunks))
for _, chunk := range chunks {
value := strings.TrimSpace(string(chunk))
if value == "" {
continue
}
values = append(values, value)
}
return values
}
func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, stdin io.Reader, stdout, stderr io.Writer, params api.VMCreateParams, spec vmRunRepoSpec) error {
vm, err := runVMCreate(ctx, socketPath, stderr, params)
if err != nil {
return err
}
vmRef := strings.TrimSpace(vm.Name)
if vmRef == "" {
vmRef = shortID(vm.ID)
}
sshAddress := net.JoinHostPort(vm.Runtime.GuestIP, "22")
if err := guestWaitForSSHFunc(ctx, sshAddress, cfg.SSHKeyPath, 250*time.Millisecond); err != nil {
return fmt.Errorf("vm %q is running but guest ssh is unavailable: %w", vmRef, err)
}
client, err := guestDialFunc(ctx, sshAddress, cfg.SSHKeyPath)
if err != nil {
return fmt.Errorf("vm %q is running but guest ssh is unavailable: %w", vmRef, err)
}
defer client.Close()
if err := importVMRunRepoToGuest(ctx, client, spec); err != nil {
return fmt.Errorf("vm %q is running but repo import failed: %w", vmRef, err)
}
if err := runVMRunAttach(ctx, stdin, stdout, stderr, vm.Runtime.GuestIP, vmRunGuestDir(spec.RepoName)); err != nil {
return fmt.Errorf("vm %q is running but opencode attach failed: %w", vmRef, err)
}
return nil
}
func importVMRunRepoToGuest(ctx context.Context, client vmRunGuestClient, spec vmRunRepoSpec) error {
bundleData, err := createVMRunBundle(ctx, spec)
if err != nil {
return err
}
var uploadLog bytes.Buffer
if err := client.UploadFile(ctx, vmRunGuestBundlePath, 0o600, bundleData, &uploadLog); err != nil {
return formatVMRunStepError("upload git bundle", err, uploadLog.String())
}
var scriptLog bytes.Buffer
if err := client.RunScript(ctx, vmRunCloneScript(spec), &scriptLog); err != nil {
return formatVMRunStepError("prepare guest checkout", err, scriptLog.String())
}
var overlayLog bytes.Buffer
remoteCommand := fmt.Sprintf("tar -C %s --strip-components=1 -xf -", shellQuote(vmRunGuestDir(spec.RepoName)))
if err := client.StreamTarEntries(ctx, spec.RepoRoot, spec.OverlayPaths, remoteCommand, &overlayLog); err != nil {
return formatVMRunStepError("overlay host working tree", err, overlayLog.String())
}
return nil
}
func createVMRunBundle(ctx context.Context, spec vmRunRepoSpec) ([]byte, error) {
tempFile, err := os.CreateTemp("", "banger-vm-run-*.bundle")
if err != nil {
return nil, err
}
tempPath := tempFile.Name()
if err := tempFile.Close(); err != nil {
_ = os.Remove(tempPath)
return nil, err
}
defer os.Remove(tempPath)
args := []string{"-C", spec.RepoRoot, "bundle", "create", tempPath, "--all"}
for _, rev := range uniqueNonEmptyStrings(spec.HeadCommit, spec.BaseCommit) {
args = append(args, rev)
}
if _, err := hostCommandOutputFunc(ctx, "git", args...); err != nil {
return nil, fmt.Errorf("create git bundle: %w", err)
}
data, err := os.ReadFile(tempPath)
if err != nil {
return nil, fmt.Errorf("read git bundle: %w", err)
}
return data, nil
}
func vmRunCloneScript(spec vmRunRepoSpec) string {
guestDir := vmRunGuestDir(spec.RepoName)
var script strings.Builder
script.WriteString("set -euo pipefail\n")
fmt.Fprintf(&script, "DIR=%s\n", shellQuote(guestDir))
fmt.Fprintf(&script, "BUNDLE=%s\n", shellQuote(vmRunGuestBundlePath))
script.WriteString("rm -rf \"$DIR\"\n")
script.WriteString("git clone \"$BUNDLE\" \"$DIR\"\n")
script.WriteString("rm -f \"$BUNDLE\"\n")
switch {
case strings.TrimSpace(spec.BranchName) != "":
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", shellQuote(spec.BranchName), shellQuote(spec.BaseCommit))
case strings.TrimSpace(spec.CurrentBranch) != "":
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", shellQuote(spec.CurrentBranch), shellQuote(spec.HeadCommit))
default:
fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", shellQuote(spec.HeadCommit))
}
script.WriteString("find \"$DIR\" -mindepth 1 -maxdepth 1 ! -name .git -exec rm -rf {} +\n")
return script.String()
}
func vmRunGuestDir(repoName string) string {
return filepath.ToSlash(filepath.Join("/root", repoName))
}
func runVMRunAttach(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, guestIP, guestDir string) error {
guestIP = strings.TrimSpace(guestIP)
if guestIP == "" {
return errors.New("vm has no guest IP")
}
return opencodeExecFunc(ctx, stdin, stdout, stderr, []string{
"attach",
"--dir", guestDir,
"http://" + net.JoinHostPort(guestIP, "4096"),
})
}
func formatVMRunStepError(action string, err error, log string) error {
log = strings.TrimSpace(log)
if log == "" {
return fmt.Errorf("%s: %w", action, err)
}
return fmt.Errorf("%s: %w: %s", action, err, log)
}
func uniqueNonEmptyStrings(values ...string) []string {
unique := make([]string, 0, len(values))
seen := make(map[string]struct{}, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
unique = append(unique, value)
}
return unique
}
func shellQuote(value string) string {
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
}
func absolutizeImageBuildPaths(params *api.ImageBuildParams) error {
return absolutizePaths(&params.KernelPath, &params.InitrdPath, &params.ModulesDir)
}

View file

@ -146,6 +146,26 @@ func TestVMCreateFlagsExist(t *testing.T) {
}
}
func TestVMRunFlagsExist(t *testing.T) {
root := NewBangerCommand()
vm, _, err := root.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
run, _, err := vm.Find([]string{"run"})
if err != nil {
t.Fatalf("find run: %v", err)
}
for _, flagName := range []string{"name", "image", "vcpu", "memory", "system-overlay-size", "disk-size", "nat", "branch", "from"} {
if run.Flags().Lookup(flagName) == nil {
t.Fatalf("missing flag %q", flagName)
}
}
if run.Flags().Lookup("no-start") != nil {
t.Fatal("vm run should not expose --no-start")
}
}
func TestVMCreateFlagsShowStaticDefaults(t *testing.T) {
root := NewBangerCommand()
vm, _, err := root.Find([]string{"vm"})
@ -171,6 +191,16 @@ func TestVMCreateFlagsShowStaticDefaults(t *testing.T) {
}
}
func TestVMRunRejectsFromWithoutBranch(t *testing.T) {
cmd := NewBangerCommand()
cmd.SetArgs([]string{"vm", "run", "--from", "HEAD"})
err := cmd.Execute()
if err == nil || !strings.Contains(err.Error(), "--from requires --branch") {
t.Fatalf("Execute() error = %v, want --from requires --branch", err)
}
}
func TestImageRegisterFlagsExist(t *testing.T) {
root := NewBangerCommand()
image, _, err := root.Find([]string{"image"})
@ -837,6 +867,278 @@ func TestValidateSSHPrereqsFailsForMissingKey(t *testing.T) {
}
}
func TestResolveVMRunSourcePathDefaultsToCWD(t *testing.T) {
origCWD := cwdFunc
t.Cleanup(func() {
cwdFunc = origCWD
})
want := t.TempDir()
cwdFunc = func() (string, error) {
return want, nil
}
got, err := resolveVMRunSourcePath("")
if err != nil {
t.Fatalf("resolveVMRunSourcePath: %v", err)
}
if got != want {
t.Fatalf("resolveVMRunSourcePath() = %q, want %q", got, want)
}
}
func TestInspectVMRunRepoUsesRepoRootAndOverlayPaths(t *testing.T) {
if _, err := exec.LookPath("git"); err != nil {
t.Skip("git not installed")
}
repoRoot := t.TempDir()
testRunGit(t, repoRoot, "init")
testRunGit(t, repoRoot, "config", "user.email", "test@example.com")
testRunGit(t, repoRoot, "config", "user.name", "Banger Test")
if err := os.MkdirAll(filepath.Join(repoRoot, "dir"), 0o755); err != nil {
t.Fatalf("MkdirAll(dir): %v", err)
}
if err := os.WriteFile(filepath.Join(repoRoot, ".gitignore"), []byte("ignored.txt\n"), 0o644); err != nil {
t.Fatalf("WriteFile(.gitignore): %v", err)
}
if err := os.WriteFile(filepath.Join(repoRoot, "tracked.txt"), []byte("tracked\n"), 0o644); err != nil {
t.Fatalf("WriteFile(tracked.txt): %v", err)
}
if err := os.WriteFile(filepath.Join(repoRoot, "dir", "keep.txt"), []byte("keep\n"), 0o644); err != nil {
t.Fatalf("WriteFile(keep.txt): %v", err)
}
testRunGit(t, repoRoot, "add", ".")
testRunGit(t, repoRoot, "commit", "-m", "init")
testRunGit(t, repoRoot, "checkout", "-b", "trunk")
if err := os.WriteFile(filepath.Join(repoRoot, "tracked.txt"), []byte("tracked local\n"), 0o644); err != nil {
t.Fatalf("WriteFile(tracked.txt local): %v", err)
}
if err := os.WriteFile(filepath.Join(repoRoot, "untracked.txt"), []byte("untracked\n"), 0o644); err != nil {
t.Fatalf("WriteFile(untracked.txt): %v", err)
}
if err := os.WriteFile(filepath.Join(repoRoot, "ignored.txt"), []byte("ignored\n"), 0o644); err != nil {
t.Fatalf("WriteFile(ignored.txt): %v", err)
}
spec, err := inspectVMRunRepo(context.Background(), filepath.Join(repoRoot, "dir"), "", "HEAD")
if err != nil {
t.Fatalf("inspectVMRunRepo: %v", err)
}
if spec.RepoRoot != repoRoot {
t.Fatalf("RepoRoot = %q, want %q", spec.RepoRoot, repoRoot)
}
if spec.RepoName != filepath.Base(repoRoot) {
t.Fatalf("RepoName = %q, want %q", spec.RepoName, filepath.Base(repoRoot))
}
if spec.CurrentBranch != "trunk" {
t.Fatalf("CurrentBranch = %q, want trunk", spec.CurrentBranch)
}
if spec.HeadCommit == "" {
t.Fatal("HeadCommit should not be empty")
}
if spec.BaseCommit != spec.HeadCommit {
t.Fatalf("BaseCommit = %q, want head %q", spec.BaseCommit, spec.HeadCommit)
}
wantOverlay := []string{".gitignore", "dir/keep.txt", "tracked.txt", "untracked.txt"}
if !reflect.DeepEqual(spec.OverlayPaths, wantOverlay) {
t.Fatalf("OverlayPaths = %v, want %v", spec.OverlayPaths, wantOverlay)
}
}
func TestInspectVMRunRepoRejectsSubmodules(t *testing.T) {
repoRoot := t.TempDir()
origHostCommandOutput := hostCommandOutputFunc
t.Cleanup(func() {
hostCommandOutputFunc = origHostCommandOutput
})
hostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) {
t.Helper()
if name != "git" {
t.Fatalf("command = %q, want git", name)
}
switch {
case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--show-toplevel"}):
return []byte(repoRoot + "\n"), nil
case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--is-bare-repository"}):
return []byte("false\n"), nil
case reflect.DeepEqual(args, []string{"-C", repoRoot, "ls-files", "--stage", "-z"}):
return []byte("160000 deadbeef 0\tvendor/submodule\x00"), nil
default:
t.Fatalf("unexpected git args: %v", args)
return nil, nil
}
}
_, err := inspectVMRunRepo(context.Background(), repoRoot, "", "HEAD")
if err == nil || !strings.Contains(err.Error(), "submodules") {
t.Fatalf("inspectVMRunRepo() error = %v, want submodule rejection", err)
}
}
func TestRunVMRunCreatesImportsAndAttaches(t *testing.T) {
repoRoot := t.TempDir()
origBegin := vmCreateBeginFunc
origStatus := vmCreateStatusFunc
origCancel := vmCreateCancelFunc
origWaitForSSH := guestWaitForSSHFunc
origGuestDial := guestDialFunc
origHostCommandOutput := hostCommandOutputFunc
origOpencodeExec := opencodeExecFunc
t.Cleanup(func() {
vmCreateBeginFunc = origBegin
vmCreateStatusFunc = origStatus
vmCreateCancelFunc = origCancel
guestWaitForSSHFunc = origWaitForSSH
guestDialFunc = origGuestDial
hostCommandOutputFunc = origHostCommandOutput
opencodeExecFunc = origOpencodeExec
})
vm := model.VMRecord{
ID: "vm-id",
Name: "devbox",
Runtime: model.VMRuntime{
State: model.VMStateRunning,
GuestIP: "172.16.0.2",
},
}
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
return api.VMCreateBeginResult{
Operation: api.VMCreateOperation{
ID: "op-1",
Stage: "ready",
Detail: "vm is ready",
Done: true,
Success: true,
VM: &vm,
},
}, nil
}
vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
t.Fatal("vmCreateStatusFunc should not be called")
return api.VMCreateStatusResult{}, nil
}
vmCreateCancelFunc = func(context.Context, string, string) error {
t.Fatal("vmCreateCancelFunc should not be called")
return nil
}
fakeClient := &testVMRunGuestClient{}
waitAddress := ""
waitKeyPath := ""
waitInterval := time.Duration(0)
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
waitAddress = address
waitKeyPath = privateKeyPath
waitInterval = interval
return nil
}
dialAddress := ""
dialKeyPath := ""
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
dialAddress = address
dialKeyPath = privateKeyPath
return fakeClient, nil
}
hostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) {
if name != "git" {
t.Fatalf("command = %q, want git", name)
}
if len(args) < 7 || args[0] != "-C" || args[1] != repoRoot || args[2] != "bundle" || args[3] != "create" || args[5] != "--all" {
t.Fatalf("unexpected bundle args: %v", args)
}
if !reflect.DeepEqual(args[6:], []string{"deadbeef", "cafebabe"}) {
t.Fatalf("bundle revs = %v, want deadbeef/cafebabe", args[6:])
}
if err := os.WriteFile(args[4], []byte("bundle-data"), 0o600); err != nil {
t.Fatalf("WriteFile(bundle): %v", err)
}
return nil, nil
}
var attachArgs []string
opencodeExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
attachArgs = append([]string(nil), args...)
return nil
}
spec := vmRunRepoSpec{
RepoRoot: repoRoot,
RepoName: "repo",
HeadCommit: "deadbeef",
CurrentBranch: "main",
BranchName: "feature",
BaseCommit: "cafebabe",
OverlayPaths: []string{"tracked.txt", "nested/keep.txt"},
}
err := runVMRun(
context.Background(),
"/tmp/bangerd.sock",
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
strings.NewReader(""),
&bytes.Buffer{},
&bytes.Buffer{},
api.VMCreateParams{Name: "devbox"},
spec,
)
if err != nil {
t.Fatalf("runVMRun: %v", err)
}
if waitAddress != "172.16.0.2:22" {
t.Fatalf("waitAddress = %q, want 172.16.0.2:22", waitAddress)
}
if waitKeyPath != "/tmp/id_ed25519" {
t.Fatalf("waitKeyPath = %q, want /tmp/id_ed25519", waitKeyPath)
}
if waitInterval <= 0 {
t.Fatalf("waitInterval = %s, want positive interval", waitInterval)
}
if dialAddress != waitAddress {
t.Fatalf("dialAddress = %q, want %q", dialAddress, waitAddress)
}
if dialKeyPath != waitKeyPath {
t.Fatalf("dialKeyPath = %q, want %q", dialKeyPath, waitKeyPath)
}
if fakeClient.uploadPath != vmRunGuestBundlePath {
t.Fatalf("uploadPath = %q, want %q", fakeClient.uploadPath, vmRunGuestBundlePath)
}
if fakeClient.uploadMode != 0o600 {
t.Fatalf("uploadMode = %v, want 0600", fakeClient.uploadMode)
}
if string(fakeClient.uploadData) != "bundle-data" {
t.Fatalf("uploadData = %q, want bundle-data", string(fakeClient.uploadData))
}
if !strings.Contains(fakeClient.script, `git clone "$BUNDLE" "$DIR"`) {
t.Fatalf("script = %q, want clone command", fakeClient.script)
}
if !strings.Contains(fakeClient.script, `git -C "$DIR" checkout -B 'feature' 'cafebabe'`) {
t.Fatalf("script = %q, want guest branch checkout", fakeClient.script)
}
if fakeClient.streamSourceDir != repoRoot {
t.Fatalf("streamSourceDir = %q, want %q", fakeClient.streamSourceDir, repoRoot)
}
if !reflect.DeepEqual(fakeClient.streamEntries, spec.OverlayPaths) {
t.Fatalf("streamEntries = %v, want %v", fakeClient.streamEntries, spec.OverlayPaths)
}
if fakeClient.streamCommand != "tar -C '/root/repo' --strip-components=1 -xf -" {
t.Fatalf("streamCommand = %q", fakeClient.streamCommand)
}
wantAttach := []string{"attach", "--dir", "/root/repo", "http://172.16.0.2:4096"}
if !reflect.DeepEqual(attachArgs, wantAttach) {
t.Fatalf("attachArgs = %v, want %v", attachArgs, wantAttach)
}
if !fakeClient.closed {
t.Fatal("guest client should be closed")
}
}
func TestNewBangerdCommandRejectsArgs(t *testing.T) {
cmd := NewBangerdCommand()
cmd.SetArgs([]string{"extra"})
@ -965,3 +1267,48 @@ func TestAbsolutizeImageBuildPaths(t *testing.T) {
func testCLIResolvedVM(id, name string) model.VMRecord {
return model.VMRecord{ID: id, Name: name}
}
func testRunGit(t *testing.T, dir string, args ...string) string {
t.Helper()
cmd := exec.Command("git", append([]string{"-c", "commit.gpgsign=false", "-C", dir}, args...)...)
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("git %v: %v\n%s", args, err, string(output))
}
return string(output)
}
type testVMRunGuestClient struct {
closed bool
uploadPath string
uploadMode os.FileMode
uploadData []byte
script string
streamSourceDir string
streamEntries []string
streamCommand string
}
func (c *testVMRunGuestClient) Close() error {
c.closed = true
return nil
}
func (c *testVMRunGuestClient) UploadFile(ctx context.Context, remotePath string, mode os.FileMode, data []byte, logWriter io.Writer) error {
c.uploadPath = remotePath
c.uploadMode = mode
c.uploadData = append([]byte(nil), data...)
return nil
}
func (c *testVMRunGuestClient) RunScript(ctx context.Context, script string, logWriter io.Writer) error {
c.script = script
return nil
}
func (c *testVMRunGuestClient) StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error {
c.streamSourceDir = sourceDir
c.streamEntries = append([]string(nil), entries...)
c.streamCommand = remoteCommand
return nil
}

View file

@ -11,7 +11,9 @@ import (
"io"
"net"
"os"
"path"
"path/filepath"
"sort"
"strings"
"time"
@ -94,6 +96,19 @@ func (c *Client) StreamTar(ctx context.Context, sourceDir, remoteCommand string,
return errors.Join(runErr, tarErr)
}
func (c *Client) StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error {
reader, writer := io.Pipe()
writeErr := make(chan error, 1)
go func() {
writeErr <- writeTarEntriesArchive(writer, sourceDir, entries)
_ = writer.Close()
}()
runErr := c.runSession(ctx, remoteCommand, reader, logWriter)
tarErr := <-writeErr
return errors.Join(runErr, tarErr)
}
func (c *Client) runSession(ctx context.Context, command string, stdin io.Reader, logWriter io.Writer) error {
if c == nil || c.client == nil {
return fmt.Errorf("ssh client is not connected")
@ -197,3 +212,68 @@ func writeTarArchive(dst io.Writer, sourceDir string) error {
return err
})
}
func writeTarEntriesArchive(dst io.Writer, sourceDir string, entries []string) error {
tw := tar.NewWriter(dst)
defer tw.Close()
sourceDir = filepath.Clean(sourceDir)
rootName := filepath.Base(sourceDir)
uniqueEntries := make([]string, 0, len(entries))
seen := make(map[string]struct{}, len(entries))
for _, entry := range entries {
entry = strings.TrimSpace(entry)
if entry == "" {
continue
}
entry = filepath.Clean(entry)
if entry == "." || entry == ".." || strings.HasPrefix(entry, ".."+string(filepath.Separator)) {
return fmt.Errorf("tar entry %q escapes source dir", entry)
}
if _, ok := seen[entry]; ok {
continue
}
seen[entry] = struct{}{}
uniqueEntries = append(uniqueEntries, entry)
}
sort.Strings(uniqueEntries)
for _, entry := range uniqueEntries {
fullPath := filepath.Join(sourceDir, entry)
info, err := os.Lstat(fullPath)
if err != nil {
return err
}
linkTarget := ""
if info.Mode()&os.ModeSymlink != 0 {
linkTarget, err = os.Readlink(fullPath)
if err != nil {
return err
}
}
header, err := tar.FileInfoHeader(info, linkTarget)
if err != nil {
return err
}
header.Name = path.Join(rootName, filepath.ToSlash(entry))
if err := tw.WriteHeader(header); err != nil {
return err
}
if !info.Mode().IsRegular() {
continue
}
file, err := os.Open(fullPath)
if err != nil {
return err
}
if _, err := io.Copy(tw, file); err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
}
return nil
}

View file

@ -91,3 +91,52 @@ func TestAuthorizedPublicKey(t *testing.T) {
t.Fatalf("key type = %q, want %q", parsed.Type(), ssh.KeyAlgoRSA)
}
}
func TestWriteTarEntriesArchiveIncludesOnlySelectedPaths(t *testing.T) {
t.Parallel()
sourceDir := filepath.Join(t.TempDir(), "repo")
if err := os.MkdirAll(filepath.Join(sourceDir, "nested"), 0o755); err != nil {
t.Fatalf("MkdirAll(nested): %v", err)
}
if err := os.WriteFile(filepath.Join(sourceDir, "tracked.txt"), []byte("tracked"), 0o644); err != nil {
t.Fatalf("WriteFile(tracked.txt): %v", err)
}
if err := os.WriteFile(filepath.Join(sourceDir, "nested", "keep.txt"), []byte("keep"), 0o644); err != nil {
t.Fatalf("WriteFile(keep.txt): %v", err)
}
if err := os.WriteFile(filepath.Join(sourceDir, "nested", "skip.txt"), []byte("skip"), 0o644); err != nil {
t.Fatalf("WriteFile(skip.txt): %v", err)
}
var buf bytes.Buffer
if err := writeTarEntriesArchive(&buf, sourceDir, []string{"tracked.txt", "nested/keep.txt"}); err != nil {
t.Fatalf("writeTarEntriesArchive: %v", err)
}
tr := tar.NewReader(bytes.NewReader(buf.Bytes()))
var names []string
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("tar.Next: %v", err)
}
names = append(names, header.Name)
}
want := map[string]struct{}{
"repo/tracked.txt": {},
"repo/nested/keep.txt": {},
}
if len(names) != len(want) {
t.Fatalf("archive names = %v, want %d entries", names, len(want))
}
for _, name := range names {
if _, ok := want[name]; !ok {
t.Fatalf("unexpected archive entry %q in %v", name, names)
}
}
}