diff --git a/README.md b/README.md index eea367d..9868e27 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,15 @@ Create and use a VM: `vm create` stays synchronous by default, but on a TTY it now shows live progress until the VM is fully ready. +Start a repo-backed VM session and attach `opencode` automatically: + +```bash +./build/bin/banger vm run +./build/bin/banger vm run ../some-repo --branch feature/alpine --from HEAD +``` + +`vm run` resolves the enclosing git repository, creates a VM, copies a git checkout plus current tracked and untracked non-ignored files into `/root/`, and then runs `opencode attach` from the host against the guest. + ## Web UI `bangerd` serves a local web UI by default at: diff --git a/internal/cli/banger.go b/internal/cli/banger.go index 0743c97..cb595b6 100644 --- a/internal/cli/banger.go +++ b/internal/cli/banger.go @@ -1,11 +1,13 @@ package cli import ( + "bytes" "context" "encoding/json" "errors" "fmt" "io" + "net" "os" "os/exec" "path/filepath" @@ -19,6 +21,7 @@ import ( "banger/internal/api" "banger/internal/config" "banger/internal/daemon" + "banger/internal/guest" "banger/internal/hostnat" "banger/internal/imagepreset" "banger/internal/model" @@ -44,6 +47,26 @@ var ( sshCmd.Stdin = stdin return sshCmd.Run() } + opencodeExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { + opencodeCmd := exec.CommandContext(ctx, "opencode", args...) + opencodeCmd.Stdout = stdout + opencodeCmd.Stderr = stderr + opencodeCmd.Stdin = stdin + return opencodeCmd.Run() + } + hostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, name, args...) + output, err := cmd.CombinedOutput() + if err == nil { + return output, nil + } + command := strings.TrimSpace(strings.Join(append([]string{name}, args...), " ")) + detail := strings.TrimSpace(string(output)) + if detail == "" { + return output, fmt.Errorf("%s: %w", command, err) + } + return output, fmt.Errorf("%s: %w: %s", command, err, detail) + } vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { return rpc.Call[api.VMHealthResult](ctx, socketPath, "vm.health", api.VMRefParams{IDOrName: idOrName}) } @@ -60,8 +83,35 @@ var ( vmPortsFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) { return rpc.Call[api.VMPortsResult](ctx, socketPath, "vm.ports", api.VMRefParams{IDOrName: idOrName}) } + guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { + return guest.WaitForSSH(ctx, address, privateKeyPath, interval) + } + guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { + return guest.Dial(ctx, address, privateKeyPath) + } + cwdFunc = os.Getwd ) +type vmRunGuestClient interface { + Close() error + UploadFile(ctx context.Context, remotePath string, mode os.FileMode, data []byte, logWriter io.Writer) error + RunScript(ctx context.Context, script string, logWriter io.Writer) error + StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error +} + +type vmRunRepoSpec struct { + SourcePath string + RepoRoot string + RepoName string + HeadCommit string + CurrentBranch string + BranchName string + BaseCommit string + OverlayPaths []string +} + +const vmRunGuestBundlePath = "/tmp/banger-vm-run.bundle" + func NewBangerCommand() *cobra.Command { root := &cobra.Command{ Use: "banger", @@ -358,6 +408,7 @@ func newVMCommand() *cobra.Command { } cmd.AddCommand( newVMCreateCommand(), + newVMRunCommand(), newVMListCommand(), newVMShowCommand(), newVMActionCommand("start", "Start a VM", "vm.start"), @@ -374,6 +425,76 @@ func newVMCommand() *cobra.Command { return cmd } +func newVMRunCommand() *cobra.Command { + var ( + name string + imageName string + vcpu = model.DefaultVCPUCount + memory = model.DefaultMemoryMiB + systemOverlaySize = model.FormatSizeBytes(model.DefaultSystemOverlaySize) + workDiskSize = model.FormatSizeBytes(model.DefaultWorkDiskSize) + natEnabled bool + branchName string + fromRef = "HEAD" + ) + cmd := &cobra.Command{ + Use: "run [path]", + Short: "Create a repo-backed VM session and attach opencode", + Args: maxArgsUsage(1, "usage: banger vm run [path]"), + RunE: func(cmd *cobra.Command, args []string) error { + if cmd.Flags().Changed("branch") && strings.TrimSpace(branchName) == "" { + return errors.New("--branch requires a branch name") + } + if cmd.Flags().Changed("from") && strings.TrimSpace(branchName) == "" { + return errors.New("--from requires --branch") + } + + sourcePath := "" + if len(args) == 1 { + sourcePath = args[0] + } + spec, err := inspectVMRunRepo(cmd.Context(), sourcePath, branchName, fromRef) + if err != nil { + return err + } + + layout, err := paths.Resolve() + if err != nil { + return err + } + cfg, err := config.Load(layout) + if err != nil { + return err + } + if err := validateVMRunPrereqs(cfg); err != nil { + return err + } + params, err := vmCreateParamsFromFlags(cmd, name, imageName, vcpu, memory, systemOverlaySize, workDiskSize, natEnabled, false) + if err != nil { + return err + } + if err := system.EnsureSudo(cmd.Context()); err != nil { + return err + } + layout, cfg, err = ensureDaemon(cmd.Context()) + if err != nil { + return err + } + return runVMRun(cmd.Context(), layout.SocketPath, cfg, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), params, spec) + }, + } + cmd.Flags().StringVar(&name, "name", "", "vm name") + cmd.Flags().StringVar(&imageName, "image", "", "image name or id") + cmd.Flags().IntVar(&vcpu, "vcpu", model.DefaultVCPUCount, "vcpu count") + cmd.Flags().IntVar(&memory, "memory", model.DefaultMemoryMiB, "memory in MiB") + cmd.Flags().StringVar(&systemOverlaySize, "system-overlay-size", model.FormatSizeBytes(model.DefaultSystemOverlaySize), "system overlay size") + cmd.Flags().StringVar(&workDiskSize, "disk-size", model.FormatSizeBytes(model.DefaultWorkDiskSize), "work disk size") + cmd.Flags().BoolVar(&natEnabled, "nat", false, "enable NAT") + cmd.Flags().StringVar(&branchName, "branch", "", "create and switch to a new guest branch") + cmd.Flags().StringVar(&fromRef, "from", "HEAD", "base ref for --branch") + return cmd +} + func newVMKillCommand() *cobra.Command { var signal string cmd := &cobra.Command{ @@ -876,6 +997,15 @@ func minArgsUsage(n int, usage string) cobra.PositionalArgs { } } +func maxArgsUsage(n int, usage string) cobra.PositionalArgs { + return func(cmd *cobra.Command, args []string) error { + if len(args) > n { + return errors.New(usage) + } + return nil + } +} + type resolvedVMTarget struct { Index int Ref string @@ -1244,6 +1374,320 @@ func validateSSHPrereqs(cfg model.DaemonConfig) error { return checks.Err("ssh preflight failed") } +func validateVMRunPrereqs(cfg model.DaemonConfig) error { + checks := system.NewPreflight() + checks.RequireCommand("git", "install git") + checks.RequireCommand("opencode", "install opencode") + if strings.TrimSpace(cfg.SSHKeyPath) != "" { + checks.RequireFile(cfg.SSHKeyPath, "ssh private key", `set "ssh_key_path" or let banger create its default key`) + } + return checks.Err("vm run preflight failed") +} + +func inspectVMRunRepo(ctx context.Context, rawPath, branchName, fromRef string) (vmRunRepoSpec, error) { + sourcePath, err := resolveVMRunSourcePath(rawPath) + if err != nil { + return vmRunRepoSpec{}, err + } + + repoRoot, err := gitTrimmedOutput(ctx, sourcePath, "rev-parse", "--show-toplevel") + if err != nil { + return vmRunRepoSpec{}, fmt.Errorf("%s is not inside a git repository", sourcePath) + } + isBare, err := gitTrimmedOutput(ctx, repoRoot, "rev-parse", "--is-bare-repository") + if err != nil { + return vmRunRepoSpec{}, fmt.Errorf("inspect git repository %s: %w", repoRoot, err) + } + if isBare == "true" { + return vmRunRepoSpec{}, fmt.Errorf("vm run requires a non-bare git repository: %s", repoRoot) + } + if err := ensureVMRunRepoHasNoSubmodules(ctx, repoRoot); err != nil { + return vmRunRepoSpec{}, err + } + + headCommit, err := gitTrimmedOutput(ctx, repoRoot, "rev-parse", "HEAD^{commit}") + if err != nil { + return vmRunRepoSpec{}, fmt.Errorf("git repository %s must have at least one commit", repoRoot) + } + currentBranch, err := gitTrimmedOutput(ctx, repoRoot, "branch", "--show-current") + if err != nil { + return vmRunRepoSpec{}, fmt.Errorf("resolve current branch for %s: %w", repoRoot, err) + } + + baseCommit := headCommit + branchName = strings.TrimSpace(branchName) + if branchName != "" { + fromRef = strings.TrimSpace(fromRef) + if fromRef == "" { + return vmRunRepoSpec{}, errors.New("--from cannot be empty") + } + baseCommit, err = gitTrimmedOutput(ctx, repoRoot, "rev-parse", fromRef+"^{commit}") + if err != nil { + return vmRunRepoSpec{}, fmt.Errorf("resolve --from %q: %w", fromRef, err) + } + } + + overlayPaths, err := listVMRunOverlayPaths(ctx, repoRoot) + if err != nil { + return vmRunRepoSpec{}, err + } + + return vmRunRepoSpec{ + SourcePath: sourcePath, + RepoRoot: repoRoot, + RepoName: filepath.Base(repoRoot), + HeadCommit: headCommit, + CurrentBranch: currentBranch, + BranchName: branchName, + BaseCommit: baseCommit, + OverlayPaths: overlayPaths, + }, nil +} + +func resolveVMRunSourcePath(rawPath string) (string, error) { + if strings.TrimSpace(rawPath) == "" { + wd, err := cwdFunc() + if err != nil { + return "", err + } + rawPath = wd + } + absPath, err := filepath.Abs(rawPath) + if err != nil { + return "", err + } + info, err := os.Stat(absPath) + if err != nil { + return "", err + } + if !info.IsDir() { + return "", fmt.Errorf("%s is not a directory", absPath) + } + return absPath, nil +} + +func ensureVMRunRepoHasNoSubmodules(ctx context.Context, repoRoot string) error { + output, err := gitOutput(ctx, repoRoot, "ls-files", "--stage", "-z") + if err != nil { + return fmt.Errorf("inspect git index for %s: %w", repoRoot, err) + } + for _, record := range parseNullSeparatedOutput(output) { + if strings.HasPrefix(record, "160000 ") { + return fmt.Errorf("vm run does not yet support git submodules: %s", repoRoot) + } + } + return nil +} + +func listVMRunOverlayPaths(ctx context.Context, repoRoot string) ([]string, error) { + trackedOutput, err := gitOutput(ctx, repoRoot, "ls-files", "-z") + if err != nil { + return nil, fmt.Errorf("list tracked files for %s: %w", repoRoot, err) + } + untrackedOutput, err := gitOutput(ctx, repoRoot, "ls-files", "--others", "--exclude-standard", "-z") + if err != nil { + return nil, fmt.Errorf("list untracked files for %s: %w", repoRoot, err) + } + + paths := make([]string, 0) + seen := make(map[string]struct{}) + for _, relPath := range parseNullSeparatedOutput(trackedOutput) { + if relPath == "" { + continue + } + if _, err := os.Lstat(filepath.Join(repoRoot, relPath)); err != nil { + if os.IsNotExist(err) { + continue + } + return nil, err + } + seen[relPath] = struct{}{} + paths = append(paths, relPath) + } + for _, relPath := range parseNullSeparatedOutput(untrackedOutput) { + if relPath == "" { + continue + } + if _, ok := seen[relPath]; ok { + continue + } + seen[relPath] = struct{}{} + paths = append(paths, relPath) + } + sort.Strings(paths) + return paths, nil +} + +func gitOutput(ctx context.Context, dir string, args ...string) ([]byte, error) { + fullArgs := make([]string, 0, len(args)+2) + if strings.TrimSpace(dir) != "" { + fullArgs = append(fullArgs, "-C", dir) + } + fullArgs = append(fullArgs, args...) + return hostCommandOutputFunc(ctx, "git", fullArgs...) +} + +func gitTrimmedOutput(ctx context.Context, dir string, args ...string) (string, error) { + output, err := gitOutput(ctx, dir, args...) + if err != nil { + return "", err + } + return strings.TrimSpace(string(output)), nil +} + +func parseNullSeparatedOutput(output []byte) []string { + chunks := bytes.Split(output, []byte{0}) + values := make([]string, 0, len(chunks)) + for _, chunk := range chunks { + value := strings.TrimSpace(string(chunk)) + if value == "" { + continue + } + values = append(values, value) + } + return values +} + +func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, stdin io.Reader, stdout, stderr io.Writer, params api.VMCreateParams, spec vmRunRepoSpec) error { + vm, err := runVMCreate(ctx, socketPath, stderr, params) + if err != nil { + return err + } + vmRef := strings.TrimSpace(vm.Name) + if vmRef == "" { + vmRef = shortID(vm.ID) + } + sshAddress := net.JoinHostPort(vm.Runtime.GuestIP, "22") + if err := guestWaitForSSHFunc(ctx, sshAddress, cfg.SSHKeyPath, 250*time.Millisecond); err != nil { + return fmt.Errorf("vm %q is running but guest ssh is unavailable: %w", vmRef, err) + } + client, err := guestDialFunc(ctx, sshAddress, cfg.SSHKeyPath) + if err != nil { + return fmt.Errorf("vm %q is running but guest ssh is unavailable: %w", vmRef, err) + } + defer client.Close() + if err := importVMRunRepoToGuest(ctx, client, spec); err != nil { + return fmt.Errorf("vm %q is running but repo import failed: %w", vmRef, err) + } + if err := runVMRunAttach(ctx, stdin, stdout, stderr, vm.Runtime.GuestIP, vmRunGuestDir(spec.RepoName)); err != nil { + return fmt.Errorf("vm %q is running but opencode attach failed: %w", vmRef, err) + } + return nil +} + +func importVMRunRepoToGuest(ctx context.Context, client vmRunGuestClient, spec vmRunRepoSpec) error { + bundleData, err := createVMRunBundle(ctx, spec) + if err != nil { + return err + } + var uploadLog bytes.Buffer + if err := client.UploadFile(ctx, vmRunGuestBundlePath, 0o600, bundleData, &uploadLog); err != nil { + return formatVMRunStepError("upload git bundle", err, uploadLog.String()) + } + var scriptLog bytes.Buffer + if err := client.RunScript(ctx, vmRunCloneScript(spec), &scriptLog); err != nil { + return formatVMRunStepError("prepare guest checkout", err, scriptLog.String()) + } + var overlayLog bytes.Buffer + remoteCommand := fmt.Sprintf("tar -C %s --strip-components=1 -xf -", shellQuote(vmRunGuestDir(spec.RepoName))) + if err := client.StreamTarEntries(ctx, spec.RepoRoot, spec.OverlayPaths, remoteCommand, &overlayLog); err != nil { + return formatVMRunStepError("overlay host working tree", err, overlayLog.String()) + } + return nil +} + +func createVMRunBundle(ctx context.Context, spec vmRunRepoSpec) ([]byte, error) { + tempFile, err := os.CreateTemp("", "banger-vm-run-*.bundle") + if err != nil { + return nil, err + } + tempPath := tempFile.Name() + if err := tempFile.Close(); err != nil { + _ = os.Remove(tempPath) + return nil, err + } + defer os.Remove(tempPath) + + args := []string{"-C", spec.RepoRoot, "bundle", "create", tempPath, "--all"} + for _, rev := range uniqueNonEmptyStrings(spec.HeadCommit, spec.BaseCommit) { + args = append(args, rev) + } + if _, err := hostCommandOutputFunc(ctx, "git", args...); err != nil { + return nil, fmt.Errorf("create git bundle: %w", err) + } + data, err := os.ReadFile(tempPath) + if err != nil { + return nil, fmt.Errorf("read git bundle: %w", err) + } + return data, nil +} + +func vmRunCloneScript(spec vmRunRepoSpec) string { + guestDir := vmRunGuestDir(spec.RepoName) + var script strings.Builder + script.WriteString("set -euo pipefail\n") + fmt.Fprintf(&script, "DIR=%s\n", shellQuote(guestDir)) + fmt.Fprintf(&script, "BUNDLE=%s\n", shellQuote(vmRunGuestBundlePath)) + script.WriteString("rm -rf \"$DIR\"\n") + script.WriteString("git clone \"$BUNDLE\" \"$DIR\"\n") + script.WriteString("rm -f \"$BUNDLE\"\n") + switch { + case strings.TrimSpace(spec.BranchName) != "": + fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", shellQuote(spec.BranchName), shellQuote(spec.BaseCommit)) + case strings.TrimSpace(spec.CurrentBranch) != "": + fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", shellQuote(spec.CurrentBranch), shellQuote(spec.HeadCommit)) + default: + fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", shellQuote(spec.HeadCommit)) + } + script.WriteString("find \"$DIR\" -mindepth 1 -maxdepth 1 ! -name .git -exec rm -rf {} +\n") + return script.String() +} + +func vmRunGuestDir(repoName string) string { + return filepath.ToSlash(filepath.Join("/root", repoName)) +} + +func runVMRunAttach(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, guestIP, guestDir string) error { + guestIP = strings.TrimSpace(guestIP) + if guestIP == "" { + return errors.New("vm has no guest IP") + } + return opencodeExecFunc(ctx, stdin, stdout, stderr, []string{ + "attach", + "--dir", guestDir, + "http://" + net.JoinHostPort(guestIP, "4096"), + }) +} + +func formatVMRunStepError(action string, err error, log string) error { + log = strings.TrimSpace(log) + if log == "" { + return fmt.Errorf("%s: %w", action, err) + } + return fmt.Errorf("%s: %w: %s", action, err, log) +} + +func uniqueNonEmptyStrings(values ...string) []string { + unique := make([]string, 0, len(values)) + seen := make(map[string]struct{}, len(values)) + for _, value := range values { + value = strings.TrimSpace(value) + if value == "" { + continue + } + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + unique = append(unique, value) + } + return unique +} + +func shellQuote(value string) string { + return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'" +} + func absolutizeImageBuildPaths(params *api.ImageBuildParams) error { return absolutizePaths(¶ms.KernelPath, ¶ms.InitrdPath, ¶ms.ModulesDir) } diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 3ec9968..5083811 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -146,6 +146,26 @@ func TestVMCreateFlagsExist(t *testing.T) { } } +func TestVMRunFlagsExist(t *testing.T) { + root := NewBangerCommand() + vm, _, err := root.Find([]string{"vm"}) + if err != nil { + t.Fatalf("find vm: %v", err) + } + run, _, err := vm.Find([]string{"run"}) + if err != nil { + t.Fatalf("find run: %v", err) + } + for _, flagName := range []string{"name", "image", "vcpu", "memory", "system-overlay-size", "disk-size", "nat", "branch", "from"} { + if run.Flags().Lookup(flagName) == nil { + t.Fatalf("missing flag %q", flagName) + } + } + if run.Flags().Lookup("no-start") != nil { + t.Fatal("vm run should not expose --no-start") + } +} + func TestVMCreateFlagsShowStaticDefaults(t *testing.T) { root := NewBangerCommand() vm, _, err := root.Find([]string{"vm"}) @@ -171,6 +191,16 @@ func TestVMCreateFlagsShowStaticDefaults(t *testing.T) { } } +func TestVMRunRejectsFromWithoutBranch(t *testing.T) { + cmd := NewBangerCommand() + cmd.SetArgs([]string{"vm", "run", "--from", "HEAD"}) + + err := cmd.Execute() + if err == nil || !strings.Contains(err.Error(), "--from requires --branch") { + t.Fatalf("Execute() error = %v, want --from requires --branch", err) + } +} + func TestImageRegisterFlagsExist(t *testing.T) { root := NewBangerCommand() image, _, err := root.Find([]string{"image"}) @@ -837,6 +867,278 @@ func TestValidateSSHPrereqsFailsForMissingKey(t *testing.T) { } } +func TestResolveVMRunSourcePathDefaultsToCWD(t *testing.T) { + origCWD := cwdFunc + t.Cleanup(func() { + cwdFunc = origCWD + }) + + want := t.TempDir() + cwdFunc = func() (string, error) { + return want, nil + } + + got, err := resolveVMRunSourcePath("") + if err != nil { + t.Fatalf("resolveVMRunSourcePath: %v", err) + } + if got != want { + t.Fatalf("resolveVMRunSourcePath() = %q, want %q", got, want) + } +} + +func TestInspectVMRunRepoUsesRepoRootAndOverlayPaths(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not installed") + } + + repoRoot := t.TempDir() + testRunGit(t, repoRoot, "init") + testRunGit(t, repoRoot, "config", "user.email", "test@example.com") + testRunGit(t, repoRoot, "config", "user.name", "Banger Test") + + if err := os.MkdirAll(filepath.Join(repoRoot, "dir"), 0o755); err != nil { + t.Fatalf("MkdirAll(dir): %v", err) + } + if err := os.WriteFile(filepath.Join(repoRoot, ".gitignore"), []byte("ignored.txt\n"), 0o644); err != nil { + t.Fatalf("WriteFile(.gitignore): %v", err) + } + if err := os.WriteFile(filepath.Join(repoRoot, "tracked.txt"), []byte("tracked\n"), 0o644); err != nil { + t.Fatalf("WriteFile(tracked.txt): %v", err) + } + if err := os.WriteFile(filepath.Join(repoRoot, "dir", "keep.txt"), []byte("keep\n"), 0o644); err != nil { + t.Fatalf("WriteFile(keep.txt): %v", err) + } + testRunGit(t, repoRoot, "add", ".") + testRunGit(t, repoRoot, "commit", "-m", "init") + testRunGit(t, repoRoot, "checkout", "-b", "trunk") + + if err := os.WriteFile(filepath.Join(repoRoot, "tracked.txt"), []byte("tracked local\n"), 0o644); err != nil { + t.Fatalf("WriteFile(tracked.txt local): %v", err) + } + if err := os.WriteFile(filepath.Join(repoRoot, "untracked.txt"), []byte("untracked\n"), 0o644); err != nil { + t.Fatalf("WriteFile(untracked.txt): %v", err) + } + if err := os.WriteFile(filepath.Join(repoRoot, "ignored.txt"), []byte("ignored\n"), 0o644); err != nil { + t.Fatalf("WriteFile(ignored.txt): %v", err) + } + + spec, err := inspectVMRunRepo(context.Background(), filepath.Join(repoRoot, "dir"), "", "HEAD") + if err != nil { + t.Fatalf("inspectVMRunRepo: %v", err) + } + + if spec.RepoRoot != repoRoot { + t.Fatalf("RepoRoot = %q, want %q", spec.RepoRoot, repoRoot) + } + if spec.RepoName != filepath.Base(repoRoot) { + t.Fatalf("RepoName = %q, want %q", spec.RepoName, filepath.Base(repoRoot)) + } + if spec.CurrentBranch != "trunk" { + t.Fatalf("CurrentBranch = %q, want trunk", spec.CurrentBranch) + } + if spec.HeadCommit == "" { + t.Fatal("HeadCommit should not be empty") + } + if spec.BaseCommit != spec.HeadCommit { + t.Fatalf("BaseCommit = %q, want head %q", spec.BaseCommit, spec.HeadCommit) + } + wantOverlay := []string{".gitignore", "dir/keep.txt", "tracked.txt", "untracked.txt"} + if !reflect.DeepEqual(spec.OverlayPaths, wantOverlay) { + t.Fatalf("OverlayPaths = %v, want %v", spec.OverlayPaths, wantOverlay) + } +} + +func TestInspectVMRunRepoRejectsSubmodules(t *testing.T) { + repoRoot := t.TempDir() + + origHostCommandOutput := hostCommandOutputFunc + t.Cleanup(func() { + hostCommandOutputFunc = origHostCommandOutput + }) + + hostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) { + t.Helper() + if name != "git" { + t.Fatalf("command = %q, want git", name) + } + switch { + case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--show-toplevel"}): + return []byte(repoRoot + "\n"), nil + case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--is-bare-repository"}): + return []byte("false\n"), nil + case reflect.DeepEqual(args, []string{"-C", repoRoot, "ls-files", "--stage", "-z"}): + return []byte("160000 deadbeef 0\tvendor/submodule\x00"), nil + default: + t.Fatalf("unexpected git args: %v", args) + return nil, nil + } + } + + _, err := inspectVMRunRepo(context.Background(), repoRoot, "", "HEAD") + if err == nil || !strings.Contains(err.Error(), "submodules") { + t.Fatalf("inspectVMRunRepo() error = %v, want submodule rejection", err) + } +} + +func TestRunVMRunCreatesImportsAndAttaches(t *testing.T) { + repoRoot := t.TempDir() + + origBegin := vmCreateBeginFunc + origStatus := vmCreateStatusFunc + origCancel := vmCreateCancelFunc + origWaitForSSH := guestWaitForSSHFunc + origGuestDial := guestDialFunc + origHostCommandOutput := hostCommandOutputFunc + origOpencodeExec := opencodeExecFunc + t.Cleanup(func() { + vmCreateBeginFunc = origBegin + vmCreateStatusFunc = origStatus + vmCreateCancelFunc = origCancel + guestWaitForSSHFunc = origWaitForSSH + guestDialFunc = origGuestDial + hostCommandOutputFunc = origHostCommandOutput + opencodeExecFunc = origOpencodeExec + }) + + vm := model.VMRecord{ + ID: "vm-id", + Name: "devbox", + Runtime: model.VMRuntime{ + State: model.VMStateRunning, + GuestIP: "172.16.0.2", + }, + } + vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) { + return api.VMCreateBeginResult{ + Operation: api.VMCreateOperation{ + ID: "op-1", + Stage: "ready", + Detail: "vm is ready", + Done: true, + Success: true, + VM: &vm, + }, + }, nil + } + vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) { + t.Fatal("vmCreateStatusFunc should not be called") + return api.VMCreateStatusResult{}, nil + } + vmCreateCancelFunc = func(context.Context, string, string) error { + t.Fatal("vmCreateCancelFunc should not be called") + return nil + } + + fakeClient := &testVMRunGuestClient{} + waitAddress := "" + waitKeyPath := "" + waitInterval := time.Duration(0) + guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { + waitAddress = address + waitKeyPath = privateKeyPath + waitInterval = interval + return nil + } + dialAddress := "" + dialKeyPath := "" + guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) { + dialAddress = address + dialKeyPath = privateKeyPath + return fakeClient, nil + } + hostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) { + if name != "git" { + t.Fatalf("command = %q, want git", name) + } + if len(args) < 7 || args[0] != "-C" || args[1] != repoRoot || args[2] != "bundle" || args[3] != "create" || args[5] != "--all" { + t.Fatalf("unexpected bundle args: %v", args) + } + if !reflect.DeepEqual(args[6:], []string{"deadbeef", "cafebabe"}) { + t.Fatalf("bundle revs = %v, want deadbeef/cafebabe", args[6:]) + } + if err := os.WriteFile(args[4], []byte("bundle-data"), 0o600); err != nil { + t.Fatalf("WriteFile(bundle): %v", err) + } + return nil, nil + } + var attachArgs []string + opencodeExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error { + attachArgs = append([]string(nil), args...) + return nil + } + + spec := vmRunRepoSpec{ + RepoRoot: repoRoot, + RepoName: "repo", + HeadCommit: "deadbeef", + CurrentBranch: "main", + BranchName: "feature", + BaseCommit: "cafebabe", + OverlayPaths: []string{"tracked.txt", "nested/keep.txt"}, + } + err := runVMRun( + context.Background(), + "/tmp/bangerd.sock", + model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"}, + strings.NewReader(""), + &bytes.Buffer{}, + &bytes.Buffer{}, + api.VMCreateParams{Name: "devbox"}, + spec, + ) + if err != nil { + t.Fatalf("runVMRun: %v", err) + } + + if waitAddress != "172.16.0.2:22" { + t.Fatalf("waitAddress = %q, want 172.16.0.2:22", waitAddress) + } + if waitKeyPath != "/tmp/id_ed25519" { + t.Fatalf("waitKeyPath = %q, want /tmp/id_ed25519", waitKeyPath) + } + if waitInterval <= 0 { + t.Fatalf("waitInterval = %s, want positive interval", waitInterval) + } + if dialAddress != waitAddress { + t.Fatalf("dialAddress = %q, want %q", dialAddress, waitAddress) + } + if dialKeyPath != waitKeyPath { + t.Fatalf("dialKeyPath = %q, want %q", dialKeyPath, waitKeyPath) + } + if fakeClient.uploadPath != vmRunGuestBundlePath { + t.Fatalf("uploadPath = %q, want %q", fakeClient.uploadPath, vmRunGuestBundlePath) + } + if fakeClient.uploadMode != 0o600 { + t.Fatalf("uploadMode = %v, want 0600", fakeClient.uploadMode) + } + if string(fakeClient.uploadData) != "bundle-data" { + t.Fatalf("uploadData = %q, want bundle-data", string(fakeClient.uploadData)) + } + if !strings.Contains(fakeClient.script, `git clone "$BUNDLE" "$DIR"`) { + t.Fatalf("script = %q, want clone command", fakeClient.script) + } + if !strings.Contains(fakeClient.script, `git -C "$DIR" checkout -B 'feature' 'cafebabe'`) { + t.Fatalf("script = %q, want guest branch checkout", fakeClient.script) + } + if fakeClient.streamSourceDir != repoRoot { + t.Fatalf("streamSourceDir = %q, want %q", fakeClient.streamSourceDir, repoRoot) + } + if !reflect.DeepEqual(fakeClient.streamEntries, spec.OverlayPaths) { + t.Fatalf("streamEntries = %v, want %v", fakeClient.streamEntries, spec.OverlayPaths) + } + if fakeClient.streamCommand != "tar -C '/root/repo' --strip-components=1 -xf -" { + t.Fatalf("streamCommand = %q", fakeClient.streamCommand) + } + wantAttach := []string{"attach", "--dir", "/root/repo", "http://172.16.0.2:4096"} + if !reflect.DeepEqual(attachArgs, wantAttach) { + t.Fatalf("attachArgs = %v, want %v", attachArgs, wantAttach) + } + if !fakeClient.closed { + t.Fatal("guest client should be closed") + } +} + func TestNewBangerdCommandRejectsArgs(t *testing.T) { cmd := NewBangerdCommand() cmd.SetArgs([]string{"extra"}) @@ -965,3 +1267,48 @@ func TestAbsolutizeImageBuildPaths(t *testing.T) { func testCLIResolvedVM(id, name string) model.VMRecord { return model.VMRecord{ID: id, Name: name} } + +func testRunGit(t *testing.T, dir string, args ...string) string { + t.Helper() + cmd := exec.Command("git", append([]string{"-c", "commit.gpgsign=false", "-C", dir}, args...)...) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("git %v: %v\n%s", args, err, string(output)) + } + return string(output) +} + +type testVMRunGuestClient struct { + closed bool + uploadPath string + uploadMode os.FileMode + uploadData []byte + script string + streamSourceDir string + streamEntries []string + streamCommand string +} + +func (c *testVMRunGuestClient) Close() error { + c.closed = true + return nil +} + +func (c *testVMRunGuestClient) UploadFile(ctx context.Context, remotePath string, mode os.FileMode, data []byte, logWriter io.Writer) error { + c.uploadPath = remotePath + c.uploadMode = mode + c.uploadData = append([]byte(nil), data...) + return nil +} + +func (c *testVMRunGuestClient) RunScript(ctx context.Context, script string, logWriter io.Writer) error { + c.script = script + return nil +} + +func (c *testVMRunGuestClient) StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error { + c.streamSourceDir = sourceDir + c.streamEntries = append([]string(nil), entries...) + c.streamCommand = remoteCommand + return nil +} diff --git a/internal/guest/ssh.go b/internal/guest/ssh.go index 3422da2..2f6af93 100644 --- a/internal/guest/ssh.go +++ b/internal/guest/ssh.go @@ -11,7 +11,9 @@ import ( "io" "net" "os" + "path" "path/filepath" + "sort" "strings" "time" @@ -94,6 +96,19 @@ func (c *Client) StreamTar(ctx context.Context, sourceDir, remoteCommand string, return errors.Join(runErr, tarErr) } +func (c *Client) StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error { + reader, writer := io.Pipe() + writeErr := make(chan error, 1) + go func() { + writeErr <- writeTarEntriesArchive(writer, sourceDir, entries) + _ = writer.Close() + }() + + runErr := c.runSession(ctx, remoteCommand, reader, logWriter) + tarErr := <-writeErr + return errors.Join(runErr, tarErr) +} + func (c *Client) runSession(ctx context.Context, command string, stdin io.Reader, logWriter io.Writer) error { if c == nil || c.client == nil { return fmt.Errorf("ssh client is not connected") @@ -197,3 +212,68 @@ func writeTarArchive(dst io.Writer, sourceDir string) error { return err }) } + +func writeTarEntriesArchive(dst io.Writer, sourceDir string, entries []string) error { + tw := tar.NewWriter(dst) + defer tw.Close() + + sourceDir = filepath.Clean(sourceDir) + rootName := filepath.Base(sourceDir) + + uniqueEntries := make([]string, 0, len(entries)) + seen := make(map[string]struct{}, len(entries)) + for _, entry := range entries { + entry = strings.TrimSpace(entry) + if entry == "" { + continue + } + entry = filepath.Clean(entry) + if entry == "." || entry == ".." || strings.HasPrefix(entry, ".."+string(filepath.Separator)) { + return fmt.Errorf("tar entry %q escapes source dir", entry) + } + if _, ok := seen[entry]; ok { + continue + } + seen[entry] = struct{}{} + uniqueEntries = append(uniqueEntries, entry) + } + sort.Strings(uniqueEntries) + + for _, entry := range uniqueEntries { + fullPath := filepath.Join(sourceDir, entry) + info, err := os.Lstat(fullPath) + if err != nil { + return err + } + linkTarget := "" + if info.Mode()&os.ModeSymlink != 0 { + linkTarget, err = os.Readlink(fullPath) + if err != nil { + return err + } + } + header, err := tar.FileInfoHeader(info, linkTarget) + if err != nil { + return err + } + header.Name = path.Join(rootName, filepath.ToSlash(entry)) + if err := tw.WriteHeader(header); err != nil { + return err + } + if !info.Mode().IsRegular() { + continue + } + file, err := os.Open(fullPath) + if err != nil { + return err + } + if _, err := io.Copy(tw, file); err != nil { + _ = file.Close() + return err + } + if err := file.Close(); err != nil { + return err + } + } + return nil +} diff --git a/internal/guest/ssh_test.go b/internal/guest/ssh_test.go index 3c8411d..fadba7c 100644 --- a/internal/guest/ssh_test.go +++ b/internal/guest/ssh_test.go @@ -91,3 +91,52 @@ func TestAuthorizedPublicKey(t *testing.T) { t.Fatalf("key type = %q, want %q", parsed.Type(), ssh.KeyAlgoRSA) } } + +func TestWriteTarEntriesArchiveIncludesOnlySelectedPaths(t *testing.T) { + t.Parallel() + + sourceDir := filepath.Join(t.TempDir(), "repo") + if err := os.MkdirAll(filepath.Join(sourceDir, "nested"), 0o755); err != nil { + t.Fatalf("MkdirAll(nested): %v", err) + } + if err := os.WriteFile(filepath.Join(sourceDir, "tracked.txt"), []byte("tracked"), 0o644); err != nil { + t.Fatalf("WriteFile(tracked.txt): %v", err) + } + if err := os.WriteFile(filepath.Join(sourceDir, "nested", "keep.txt"), []byte("keep"), 0o644); err != nil { + t.Fatalf("WriteFile(keep.txt): %v", err) + } + if err := os.WriteFile(filepath.Join(sourceDir, "nested", "skip.txt"), []byte("skip"), 0o644); err != nil { + t.Fatalf("WriteFile(skip.txt): %v", err) + } + + var buf bytes.Buffer + if err := writeTarEntriesArchive(&buf, sourceDir, []string{"tracked.txt", "nested/keep.txt"}); err != nil { + t.Fatalf("writeTarEntriesArchive: %v", err) + } + + tr := tar.NewReader(bytes.NewReader(buf.Bytes())) + var names []string + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("tar.Next: %v", err) + } + names = append(names, header.Name) + } + + want := map[string]struct{}{ + "repo/tracked.txt": {}, + "repo/nested/keep.txt": {}, + } + if len(names) != len(want) { + t.Fatalf("archive names = %v, want %d entries", names, len(want)) + } + for _, name := range names { + if _, ok := want[name]; !ok { + t.Fatalf("unexpected archive entry %q in %v", name, names) + } + } +}