Add repo-backed vm run command
Create a CLI-only banger vm run [path] flow that resolves the enclosing git repository, creates a VM, imports a guest checkout, and launches opencode attach automatically from the host. Build the guest checkout by bundling git history plus the resolved base and head commits, cloning that bundle in the guest, and overlaying tracked plus untracked non-ignored files over SSH so local working-tree changes carry over. Support guest-only branch creation with --branch and --from, reject bare repos and submodules, and add selective tar helpers plus CLI seams to keep the workflow testable. Validate with go test ./..., make build, banger vm run --help, and the expected --from requires --branch error path.
This commit is contained in:
parent
8bcc767824
commit
2ebc6f99c6
5 changed files with 929 additions and 0 deletions
|
|
@ -113,6 +113,15 @@ Create and use a VM:
|
|||
|
||||
`vm create` stays synchronous by default, but on a TTY it now shows live progress until the VM is fully ready.
|
||||
|
||||
Start a repo-backed VM session and attach `opencode` automatically:
|
||||
|
||||
```bash
|
||||
./build/bin/banger vm run
|
||||
./build/bin/banger vm run ../some-repo --branch feature/alpine --from HEAD
|
||||
```
|
||||
|
||||
`vm run` resolves the enclosing git repository, creates a VM, copies a git checkout plus current tracked and untracked non-ignored files into `/root/<repo-name>`, and then runs `opencode attach` from the host against the guest.
|
||||
|
||||
## Web UI
|
||||
|
||||
`bangerd` serves a local web UI by default at:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
|
@ -19,6 +21,7 @@ import (
|
|||
"banger/internal/api"
|
||||
"banger/internal/config"
|
||||
"banger/internal/daemon"
|
||||
"banger/internal/guest"
|
||||
"banger/internal/hostnat"
|
||||
"banger/internal/imagepreset"
|
||||
"banger/internal/model"
|
||||
|
|
@ -44,6 +47,26 @@ var (
|
|||
sshCmd.Stdin = stdin
|
||||
return sshCmd.Run()
|
||||
}
|
||||
opencodeExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
|
||||
opencodeCmd := exec.CommandContext(ctx, "opencode", args...)
|
||||
opencodeCmd.Stdout = stdout
|
||||
opencodeCmd.Stderr = stderr
|
||||
opencodeCmd.Stdin = stdin
|
||||
return opencodeCmd.Run()
|
||||
}
|
||||
hostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
cmd := exec.CommandContext(ctx, name, args...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err == nil {
|
||||
return output, nil
|
||||
}
|
||||
command := strings.TrimSpace(strings.Join(append([]string{name}, args...), " "))
|
||||
detail := strings.TrimSpace(string(output))
|
||||
if detail == "" {
|
||||
return output, fmt.Errorf("%s: %w", command, err)
|
||||
}
|
||||
return output, fmt.Errorf("%s: %w: %s", command, err, detail)
|
||||
}
|
||||
vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
|
||||
return rpc.Call[api.VMHealthResult](ctx, socketPath, "vm.health", api.VMRefParams{IDOrName: idOrName})
|
||||
}
|
||||
|
|
@ -60,8 +83,35 @@ var (
|
|||
vmPortsFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) {
|
||||
return rpc.Call[api.VMPortsResult](ctx, socketPath, "vm.ports", api.VMRefParams{IDOrName: idOrName})
|
||||
}
|
||||
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
|
||||
return guest.WaitForSSH(ctx, address, privateKeyPath, interval)
|
||||
}
|
||||
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
|
||||
return guest.Dial(ctx, address, privateKeyPath)
|
||||
}
|
||||
cwdFunc = os.Getwd
|
||||
)
|
||||
|
||||
type vmRunGuestClient interface {
|
||||
Close() error
|
||||
UploadFile(ctx context.Context, remotePath string, mode os.FileMode, data []byte, logWriter io.Writer) error
|
||||
RunScript(ctx context.Context, script string, logWriter io.Writer) error
|
||||
StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error
|
||||
}
|
||||
|
||||
type vmRunRepoSpec struct {
|
||||
SourcePath string
|
||||
RepoRoot string
|
||||
RepoName string
|
||||
HeadCommit string
|
||||
CurrentBranch string
|
||||
BranchName string
|
||||
BaseCommit string
|
||||
OverlayPaths []string
|
||||
}
|
||||
|
||||
const vmRunGuestBundlePath = "/tmp/banger-vm-run.bundle"
|
||||
|
||||
func NewBangerCommand() *cobra.Command {
|
||||
root := &cobra.Command{
|
||||
Use: "banger",
|
||||
|
|
@ -358,6 +408,7 @@ func newVMCommand() *cobra.Command {
|
|||
}
|
||||
cmd.AddCommand(
|
||||
newVMCreateCommand(),
|
||||
newVMRunCommand(),
|
||||
newVMListCommand(),
|
||||
newVMShowCommand(),
|
||||
newVMActionCommand("start", "Start a VM", "vm.start"),
|
||||
|
|
@ -374,6 +425,76 @@ func newVMCommand() *cobra.Command {
|
|||
return cmd
|
||||
}
|
||||
|
||||
func newVMRunCommand() *cobra.Command {
|
||||
var (
|
||||
name string
|
||||
imageName string
|
||||
vcpu = model.DefaultVCPUCount
|
||||
memory = model.DefaultMemoryMiB
|
||||
systemOverlaySize = model.FormatSizeBytes(model.DefaultSystemOverlaySize)
|
||||
workDiskSize = model.FormatSizeBytes(model.DefaultWorkDiskSize)
|
||||
natEnabled bool
|
||||
branchName string
|
||||
fromRef = "HEAD"
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Use: "run [path]",
|
||||
Short: "Create a repo-backed VM session and attach opencode",
|
||||
Args: maxArgsUsage(1, "usage: banger vm run [path]"),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if cmd.Flags().Changed("branch") && strings.TrimSpace(branchName) == "" {
|
||||
return errors.New("--branch requires a branch name")
|
||||
}
|
||||
if cmd.Flags().Changed("from") && strings.TrimSpace(branchName) == "" {
|
||||
return errors.New("--from requires --branch")
|
||||
}
|
||||
|
||||
sourcePath := ""
|
||||
if len(args) == 1 {
|
||||
sourcePath = args[0]
|
||||
}
|
||||
spec, err := inspectVMRunRepo(cmd.Context(), sourcePath, branchName, fromRef)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layout, err := paths.Resolve()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg, err := config.Load(layout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateVMRunPrereqs(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
params, err := vmCreateParamsFromFlags(cmd, name, imageName, vcpu, memory, systemOverlaySize, workDiskSize, natEnabled, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := system.EnsureSudo(cmd.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
layout, cfg, err = ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return runVMRun(cmd.Context(), layout.SocketPath, cfg, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), params, spec)
|
||||
},
|
||||
}
|
||||
cmd.Flags().StringVar(&name, "name", "", "vm name")
|
||||
cmd.Flags().StringVar(&imageName, "image", "", "image name or id")
|
||||
cmd.Flags().IntVar(&vcpu, "vcpu", model.DefaultVCPUCount, "vcpu count")
|
||||
cmd.Flags().IntVar(&memory, "memory", model.DefaultMemoryMiB, "memory in MiB")
|
||||
cmd.Flags().StringVar(&systemOverlaySize, "system-overlay-size", model.FormatSizeBytes(model.DefaultSystemOverlaySize), "system overlay size")
|
||||
cmd.Flags().StringVar(&workDiskSize, "disk-size", model.FormatSizeBytes(model.DefaultWorkDiskSize), "work disk size")
|
||||
cmd.Flags().BoolVar(&natEnabled, "nat", false, "enable NAT")
|
||||
cmd.Flags().StringVar(&branchName, "branch", "", "create and switch to a new guest branch")
|
||||
cmd.Flags().StringVar(&fromRef, "from", "HEAD", "base ref for --branch")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newVMKillCommand() *cobra.Command {
|
||||
var signal string
|
||||
cmd := &cobra.Command{
|
||||
|
|
@ -876,6 +997,15 @@ func minArgsUsage(n int, usage string) cobra.PositionalArgs {
|
|||
}
|
||||
}
|
||||
|
||||
func maxArgsUsage(n int, usage string) cobra.PositionalArgs {
|
||||
return func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) > n {
|
||||
return errors.New(usage)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type resolvedVMTarget struct {
|
||||
Index int
|
||||
Ref string
|
||||
|
|
@ -1244,6 +1374,320 @@ func validateSSHPrereqs(cfg model.DaemonConfig) error {
|
|||
return checks.Err("ssh preflight failed")
|
||||
}
|
||||
|
||||
func validateVMRunPrereqs(cfg model.DaemonConfig) error {
|
||||
checks := system.NewPreflight()
|
||||
checks.RequireCommand("git", "install git")
|
||||
checks.RequireCommand("opencode", "install opencode")
|
||||
if strings.TrimSpace(cfg.SSHKeyPath) != "" {
|
||||
checks.RequireFile(cfg.SSHKeyPath, "ssh private key", `set "ssh_key_path" or let banger create its default key`)
|
||||
}
|
||||
return checks.Err("vm run preflight failed")
|
||||
}
|
||||
|
||||
func inspectVMRunRepo(ctx context.Context, rawPath, branchName, fromRef string) (vmRunRepoSpec, error) {
|
||||
sourcePath, err := resolveVMRunSourcePath(rawPath)
|
||||
if err != nil {
|
||||
return vmRunRepoSpec{}, err
|
||||
}
|
||||
|
||||
repoRoot, err := gitTrimmedOutput(ctx, sourcePath, "rev-parse", "--show-toplevel")
|
||||
if err != nil {
|
||||
return vmRunRepoSpec{}, fmt.Errorf("%s is not inside a git repository", sourcePath)
|
||||
}
|
||||
isBare, err := gitTrimmedOutput(ctx, repoRoot, "rev-parse", "--is-bare-repository")
|
||||
if err != nil {
|
||||
return vmRunRepoSpec{}, fmt.Errorf("inspect git repository %s: %w", repoRoot, err)
|
||||
}
|
||||
if isBare == "true" {
|
||||
return vmRunRepoSpec{}, fmt.Errorf("vm run requires a non-bare git repository: %s", repoRoot)
|
||||
}
|
||||
if err := ensureVMRunRepoHasNoSubmodules(ctx, repoRoot); err != nil {
|
||||
return vmRunRepoSpec{}, err
|
||||
}
|
||||
|
||||
headCommit, err := gitTrimmedOutput(ctx, repoRoot, "rev-parse", "HEAD^{commit}")
|
||||
if err != nil {
|
||||
return vmRunRepoSpec{}, fmt.Errorf("git repository %s must have at least one commit", repoRoot)
|
||||
}
|
||||
currentBranch, err := gitTrimmedOutput(ctx, repoRoot, "branch", "--show-current")
|
||||
if err != nil {
|
||||
return vmRunRepoSpec{}, fmt.Errorf("resolve current branch for %s: %w", repoRoot, err)
|
||||
}
|
||||
|
||||
baseCommit := headCommit
|
||||
branchName = strings.TrimSpace(branchName)
|
||||
if branchName != "" {
|
||||
fromRef = strings.TrimSpace(fromRef)
|
||||
if fromRef == "" {
|
||||
return vmRunRepoSpec{}, errors.New("--from cannot be empty")
|
||||
}
|
||||
baseCommit, err = gitTrimmedOutput(ctx, repoRoot, "rev-parse", fromRef+"^{commit}")
|
||||
if err != nil {
|
||||
return vmRunRepoSpec{}, fmt.Errorf("resolve --from %q: %w", fromRef, err)
|
||||
}
|
||||
}
|
||||
|
||||
overlayPaths, err := listVMRunOverlayPaths(ctx, repoRoot)
|
||||
if err != nil {
|
||||
return vmRunRepoSpec{}, err
|
||||
}
|
||||
|
||||
return vmRunRepoSpec{
|
||||
SourcePath: sourcePath,
|
||||
RepoRoot: repoRoot,
|
||||
RepoName: filepath.Base(repoRoot),
|
||||
HeadCommit: headCommit,
|
||||
CurrentBranch: currentBranch,
|
||||
BranchName: branchName,
|
||||
BaseCommit: baseCommit,
|
||||
OverlayPaths: overlayPaths,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func resolveVMRunSourcePath(rawPath string) (string, error) {
|
||||
if strings.TrimSpace(rawPath) == "" {
|
||||
wd, err := cwdFunc()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rawPath = wd
|
||||
}
|
||||
absPath, err := filepath.Abs(rawPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
info, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return "", fmt.Errorf("%s is not a directory", absPath)
|
||||
}
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
func ensureVMRunRepoHasNoSubmodules(ctx context.Context, repoRoot string) error {
|
||||
output, err := gitOutput(ctx, repoRoot, "ls-files", "--stage", "-z")
|
||||
if err != nil {
|
||||
return fmt.Errorf("inspect git index for %s: %w", repoRoot, err)
|
||||
}
|
||||
for _, record := range parseNullSeparatedOutput(output) {
|
||||
if strings.HasPrefix(record, "160000 ") {
|
||||
return fmt.Errorf("vm run does not yet support git submodules: %s", repoRoot)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func listVMRunOverlayPaths(ctx context.Context, repoRoot string) ([]string, error) {
|
||||
trackedOutput, err := gitOutput(ctx, repoRoot, "ls-files", "-z")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list tracked files for %s: %w", repoRoot, err)
|
||||
}
|
||||
untrackedOutput, err := gitOutput(ctx, repoRoot, "ls-files", "--others", "--exclude-standard", "-z")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list untracked files for %s: %w", repoRoot, err)
|
||||
}
|
||||
|
||||
paths := make([]string, 0)
|
||||
seen := make(map[string]struct{})
|
||||
for _, relPath := range parseNullSeparatedOutput(trackedOutput) {
|
||||
if relPath == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Lstat(filepath.Join(repoRoot, relPath)); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
seen[relPath] = struct{}{}
|
||||
paths = append(paths, relPath)
|
||||
}
|
||||
for _, relPath := range parseNullSeparatedOutput(untrackedOutput) {
|
||||
if relPath == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[relPath]; ok {
|
||||
continue
|
||||
}
|
||||
seen[relPath] = struct{}{}
|
||||
paths = append(paths, relPath)
|
||||
}
|
||||
sort.Strings(paths)
|
||||
return paths, nil
|
||||
}
|
||||
|
||||
func gitOutput(ctx context.Context, dir string, args ...string) ([]byte, error) {
|
||||
fullArgs := make([]string, 0, len(args)+2)
|
||||
if strings.TrimSpace(dir) != "" {
|
||||
fullArgs = append(fullArgs, "-C", dir)
|
||||
}
|
||||
fullArgs = append(fullArgs, args...)
|
||||
return hostCommandOutputFunc(ctx, "git", fullArgs...)
|
||||
}
|
||||
|
||||
func gitTrimmedOutput(ctx context.Context, dir string, args ...string) (string, error) {
|
||||
output, err := gitOutput(ctx, dir, args...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(string(output)), nil
|
||||
}
|
||||
|
||||
func parseNullSeparatedOutput(output []byte) []string {
|
||||
chunks := bytes.Split(output, []byte{0})
|
||||
values := make([]string, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
value := strings.TrimSpace(string(chunk))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func runVMRun(ctx context.Context, socketPath string, cfg model.DaemonConfig, stdin io.Reader, stdout, stderr io.Writer, params api.VMCreateParams, spec vmRunRepoSpec) error {
|
||||
vm, err := runVMCreate(ctx, socketPath, stderr, params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
vmRef := strings.TrimSpace(vm.Name)
|
||||
if vmRef == "" {
|
||||
vmRef = shortID(vm.ID)
|
||||
}
|
||||
sshAddress := net.JoinHostPort(vm.Runtime.GuestIP, "22")
|
||||
if err := guestWaitForSSHFunc(ctx, sshAddress, cfg.SSHKeyPath, 250*time.Millisecond); err != nil {
|
||||
return fmt.Errorf("vm %q is running but guest ssh is unavailable: %w", vmRef, err)
|
||||
}
|
||||
client, err := guestDialFunc(ctx, sshAddress, cfg.SSHKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("vm %q is running but guest ssh is unavailable: %w", vmRef, err)
|
||||
}
|
||||
defer client.Close()
|
||||
if err := importVMRunRepoToGuest(ctx, client, spec); err != nil {
|
||||
return fmt.Errorf("vm %q is running but repo import failed: %w", vmRef, err)
|
||||
}
|
||||
if err := runVMRunAttach(ctx, stdin, stdout, stderr, vm.Runtime.GuestIP, vmRunGuestDir(spec.RepoName)); err != nil {
|
||||
return fmt.Errorf("vm %q is running but opencode attach failed: %w", vmRef, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func importVMRunRepoToGuest(ctx context.Context, client vmRunGuestClient, spec vmRunRepoSpec) error {
|
||||
bundleData, err := createVMRunBundle(ctx, spec)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var uploadLog bytes.Buffer
|
||||
if err := client.UploadFile(ctx, vmRunGuestBundlePath, 0o600, bundleData, &uploadLog); err != nil {
|
||||
return formatVMRunStepError("upload git bundle", err, uploadLog.String())
|
||||
}
|
||||
var scriptLog bytes.Buffer
|
||||
if err := client.RunScript(ctx, vmRunCloneScript(spec), &scriptLog); err != nil {
|
||||
return formatVMRunStepError("prepare guest checkout", err, scriptLog.String())
|
||||
}
|
||||
var overlayLog bytes.Buffer
|
||||
remoteCommand := fmt.Sprintf("tar -C %s --strip-components=1 -xf -", shellQuote(vmRunGuestDir(spec.RepoName)))
|
||||
if err := client.StreamTarEntries(ctx, spec.RepoRoot, spec.OverlayPaths, remoteCommand, &overlayLog); err != nil {
|
||||
return formatVMRunStepError("overlay host working tree", err, overlayLog.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createVMRunBundle(ctx context.Context, spec vmRunRepoSpec) ([]byte, error) {
|
||||
tempFile, err := os.CreateTemp("", "banger-vm-run-*.bundle")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tempPath := tempFile.Name()
|
||||
if err := tempFile.Close(); err != nil {
|
||||
_ = os.Remove(tempPath)
|
||||
return nil, err
|
||||
}
|
||||
defer os.Remove(tempPath)
|
||||
|
||||
args := []string{"-C", spec.RepoRoot, "bundle", "create", tempPath, "--all"}
|
||||
for _, rev := range uniqueNonEmptyStrings(spec.HeadCommit, spec.BaseCommit) {
|
||||
args = append(args, rev)
|
||||
}
|
||||
if _, err := hostCommandOutputFunc(ctx, "git", args...); err != nil {
|
||||
return nil, fmt.Errorf("create git bundle: %w", err)
|
||||
}
|
||||
data, err := os.ReadFile(tempPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read git bundle: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func vmRunCloneScript(spec vmRunRepoSpec) string {
|
||||
guestDir := vmRunGuestDir(spec.RepoName)
|
||||
var script strings.Builder
|
||||
script.WriteString("set -euo pipefail\n")
|
||||
fmt.Fprintf(&script, "DIR=%s\n", shellQuote(guestDir))
|
||||
fmt.Fprintf(&script, "BUNDLE=%s\n", shellQuote(vmRunGuestBundlePath))
|
||||
script.WriteString("rm -rf \"$DIR\"\n")
|
||||
script.WriteString("git clone \"$BUNDLE\" \"$DIR\"\n")
|
||||
script.WriteString("rm -f \"$BUNDLE\"\n")
|
||||
switch {
|
||||
case strings.TrimSpace(spec.BranchName) != "":
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", shellQuote(spec.BranchName), shellQuote(spec.BaseCommit))
|
||||
case strings.TrimSpace(spec.CurrentBranch) != "":
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", shellQuote(spec.CurrentBranch), shellQuote(spec.HeadCommit))
|
||||
default:
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", shellQuote(spec.HeadCommit))
|
||||
}
|
||||
script.WriteString("find \"$DIR\" -mindepth 1 -maxdepth 1 ! -name .git -exec rm -rf {} +\n")
|
||||
return script.String()
|
||||
}
|
||||
|
||||
func vmRunGuestDir(repoName string) string {
|
||||
return filepath.ToSlash(filepath.Join("/root", repoName))
|
||||
}
|
||||
|
||||
func runVMRunAttach(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, guestIP, guestDir string) error {
|
||||
guestIP = strings.TrimSpace(guestIP)
|
||||
if guestIP == "" {
|
||||
return errors.New("vm has no guest IP")
|
||||
}
|
||||
return opencodeExecFunc(ctx, stdin, stdout, stderr, []string{
|
||||
"attach",
|
||||
"--dir", guestDir,
|
||||
"http://" + net.JoinHostPort(guestIP, "4096"),
|
||||
})
|
||||
}
|
||||
|
||||
func formatVMRunStepError(action string, err error, log string) error {
|
||||
log = strings.TrimSpace(log)
|
||||
if log == "" {
|
||||
return fmt.Errorf("%s: %w", action, err)
|
||||
}
|
||||
return fmt.Errorf("%s: %w: %s", action, err, log)
|
||||
}
|
||||
|
||||
func uniqueNonEmptyStrings(values ...string) []string {
|
||||
unique := make([]string, 0, len(values))
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[value]; ok {
|
||||
continue
|
||||
}
|
||||
seen[value] = struct{}{}
|
||||
unique = append(unique, value)
|
||||
}
|
||||
return unique
|
||||
}
|
||||
|
||||
func shellQuote(value string) string {
|
||||
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
|
||||
}
|
||||
|
||||
func absolutizeImageBuildPaths(params *api.ImageBuildParams) error {
|
||||
return absolutizePaths(¶ms.KernelPath, ¶ms.InitrdPath, ¶ms.ModulesDir)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -146,6 +146,26 @@ func TestVMCreateFlagsExist(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestVMRunFlagsExist(t *testing.T) {
|
||||
root := NewBangerCommand()
|
||||
vm, _, err := root.Find([]string{"vm"})
|
||||
if err != nil {
|
||||
t.Fatalf("find vm: %v", err)
|
||||
}
|
||||
run, _, err := vm.Find([]string{"run"})
|
||||
if err != nil {
|
||||
t.Fatalf("find run: %v", err)
|
||||
}
|
||||
for _, flagName := range []string{"name", "image", "vcpu", "memory", "system-overlay-size", "disk-size", "nat", "branch", "from"} {
|
||||
if run.Flags().Lookup(flagName) == nil {
|
||||
t.Fatalf("missing flag %q", flagName)
|
||||
}
|
||||
}
|
||||
if run.Flags().Lookup("no-start") != nil {
|
||||
t.Fatal("vm run should not expose --no-start")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMCreateFlagsShowStaticDefaults(t *testing.T) {
|
||||
root := NewBangerCommand()
|
||||
vm, _, err := root.Find([]string{"vm"})
|
||||
|
|
@ -171,6 +191,16 @@ func TestVMCreateFlagsShowStaticDefaults(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestVMRunRejectsFromWithoutBranch(t *testing.T) {
|
||||
cmd := NewBangerCommand()
|
||||
cmd.SetArgs([]string{"vm", "run", "--from", "HEAD"})
|
||||
|
||||
err := cmd.Execute()
|
||||
if err == nil || !strings.Contains(err.Error(), "--from requires --branch") {
|
||||
t.Fatalf("Execute() error = %v, want --from requires --branch", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageRegisterFlagsExist(t *testing.T) {
|
||||
root := NewBangerCommand()
|
||||
image, _, err := root.Find([]string{"image"})
|
||||
|
|
@ -837,6 +867,278 @@ func TestValidateSSHPrereqsFailsForMissingKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestResolveVMRunSourcePathDefaultsToCWD(t *testing.T) {
|
||||
origCWD := cwdFunc
|
||||
t.Cleanup(func() {
|
||||
cwdFunc = origCWD
|
||||
})
|
||||
|
||||
want := t.TempDir()
|
||||
cwdFunc = func() (string, error) {
|
||||
return want, nil
|
||||
}
|
||||
|
||||
got, err := resolveVMRunSourcePath("")
|
||||
if err != nil {
|
||||
t.Fatalf("resolveVMRunSourcePath: %v", err)
|
||||
}
|
||||
if got != want {
|
||||
t.Fatalf("resolveVMRunSourcePath() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectVMRunRepoUsesRepoRootAndOverlayPaths(t *testing.T) {
|
||||
if _, err := exec.LookPath("git"); err != nil {
|
||||
t.Skip("git not installed")
|
||||
}
|
||||
|
||||
repoRoot := t.TempDir()
|
||||
testRunGit(t, repoRoot, "init")
|
||||
testRunGit(t, repoRoot, "config", "user.email", "test@example.com")
|
||||
testRunGit(t, repoRoot, "config", "user.name", "Banger Test")
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(repoRoot, "dir"), 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(dir): %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(repoRoot, ".gitignore"), []byte("ignored.txt\n"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(.gitignore): %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(repoRoot, "tracked.txt"), []byte("tracked\n"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(tracked.txt): %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(repoRoot, "dir", "keep.txt"), []byte("keep\n"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(keep.txt): %v", err)
|
||||
}
|
||||
testRunGit(t, repoRoot, "add", ".")
|
||||
testRunGit(t, repoRoot, "commit", "-m", "init")
|
||||
testRunGit(t, repoRoot, "checkout", "-b", "trunk")
|
||||
|
||||
if err := os.WriteFile(filepath.Join(repoRoot, "tracked.txt"), []byte("tracked local\n"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(tracked.txt local): %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(repoRoot, "untracked.txt"), []byte("untracked\n"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(untracked.txt): %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(repoRoot, "ignored.txt"), []byte("ignored\n"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(ignored.txt): %v", err)
|
||||
}
|
||||
|
||||
spec, err := inspectVMRunRepo(context.Background(), filepath.Join(repoRoot, "dir"), "", "HEAD")
|
||||
if err != nil {
|
||||
t.Fatalf("inspectVMRunRepo: %v", err)
|
||||
}
|
||||
|
||||
if spec.RepoRoot != repoRoot {
|
||||
t.Fatalf("RepoRoot = %q, want %q", spec.RepoRoot, repoRoot)
|
||||
}
|
||||
if spec.RepoName != filepath.Base(repoRoot) {
|
||||
t.Fatalf("RepoName = %q, want %q", spec.RepoName, filepath.Base(repoRoot))
|
||||
}
|
||||
if spec.CurrentBranch != "trunk" {
|
||||
t.Fatalf("CurrentBranch = %q, want trunk", spec.CurrentBranch)
|
||||
}
|
||||
if spec.HeadCommit == "" {
|
||||
t.Fatal("HeadCommit should not be empty")
|
||||
}
|
||||
if spec.BaseCommit != spec.HeadCommit {
|
||||
t.Fatalf("BaseCommit = %q, want head %q", spec.BaseCommit, spec.HeadCommit)
|
||||
}
|
||||
wantOverlay := []string{".gitignore", "dir/keep.txt", "tracked.txt", "untracked.txt"}
|
||||
if !reflect.DeepEqual(spec.OverlayPaths, wantOverlay) {
|
||||
t.Fatalf("OverlayPaths = %v, want %v", spec.OverlayPaths, wantOverlay)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectVMRunRepoRejectsSubmodules(t *testing.T) {
|
||||
repoRoot := t.TempDir()
|
||||
|
||||
origHostCommandOutput := hostCommandOutputFunc
|
||||
t.Cleanup(func() {
|
||||
hostCommandOutputFunc = origHostCommandOutput
|
||||
})
|
||||
|
||||
hostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
t.Helper()
|
||||
if name != "git" {
|
||||
t.Fatalf("command = %q, want git", name)
|
||||
}
|
||||
switch {
|
||||
case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--show-toplevel"}):
|
||||
return []byte(repoRoot + "\n"), nil
|
||||
case reflect.DeepEqual(args, []string{"-C", repoRoot, "rev-parse", "--is-bare-repository"}):
|
||||
return []byte("false\n"), nil
|
||||
case reflect.DeepEqual(args, []string{"-C", repoRoot, "ls-files", "--stage", "-z"}):
|
||||
return []byte("160000 deadbeef 0\tvendor/submodule\x00"), nil
|
||||
default:
|
||||
t.Fatalf("unexpected git args: %v", args)
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
_, err := inspectVMRunRepo(context.Background(), repoRoot, "", "HEAD")
|
||||
if err == nil || !strings.Contains(err.Error(), "submodules") {
|
||||
t.Fatalf("inspectVMRunRepo() error = %v, want submodule rejection", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunVMRunCreatesImportsAndAttaches(t *testing.T) {
|
||||
repoRoot := t.TempDir()
|
||||
|
||||
origBegin := vmCreateBeginFunc
|
||||
origStatus := vmCreateStatusFunc
|
||||
origCancel := vmCreateCancelFunc
|
||||
origWaitForSSH := guestWaitForSSHFunc
|
||||
origGuestDial := guestDialFunc
|
||||
origHostCommandOutput := hostCommandOutputFunc
|
||||
origOpencodeExec := opencodeExecFunc
|
||||
t.Cleanup(func() {
|
||||
vmCreateBeginFunc = origBegin
|
||||
vmCreateStatusFunc = origStatus
|
||||
vmCreateCancelFunc = origCancel
|
||||
guestWaitForSSHFunc = origWaitForSSH
|
||||
guestDialFunc = origGuestDial
|
||||
hostCommandOutputFunc = origHostCommandOutput
|
||||
opencodeExecFunc = origOpencodeExec
|
||||
})
|
||||
|
||||
vm := model.VMRecord{
|
||||
ID: "vm-id",
|
||||
Name: "devbox",
|
||||
Runtime: model.VMRuntime{
|
||||
State: model.VMStateRunning,
|
||||
GuestIP: "172.16.0.2",
|
||||
},
|
||||
}
|
||||
vmCreateBeginFunc = func(context.Context, string, api.VMCreateParams) (api.VMCreateBeginResult, error) {
|
||||
return api.VMCreateBeginResult{
|
||||
Operation: api.VMCreateOperation{
|
||||
ID: "op-1",
|
||||
Stage: "ready",
|
||||
Detail: "vm is ready",
|
||||
Done: true,
|
||||
Success: true,
|
||||
VM: &vm,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
vmCreateStatusFunc = func(context.Context, string, string) (api.VMCreateStatusResult, error) {
|
||||
t.Fatal("vmCreateStatusFunc should not be called")
|
||||
return api.VMCreateStatusResult{}, nil
|
||||
}
|
||||
vmCreateCancelFunc = func(context.Context, string, string) error {
|
||||
t.Fatal("vmCreateCancelFunc should not be called")
|
||||
return nil
|
||||
}
|
||||
|
||||
fakeClient := &testVMRunGuestClient{}
|
||||
waitAddress := ""
|
||||
waitKeyPath := ""
|
||||
waitInterval := time.Duration(0)
|
||||
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
|
||||
waitAddress = address
|
||||
waitKeyPath = privateKeyPath
|
||||
waitInterval = interval
|
||||
return nil
|
||||
}
|
||||
dialAddress := ""
|
||||
dialKeyPath := ""
|
||||
guestDialFunc = func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) {
|
||||
dialAddress = address
|
||||
dialKeyPath = privateKeyPath
|
||||
return fakeClient, nil
|
||||
}
|
||||
hostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
if name != "git" {
|
||||
t.Fatalf("command = %q, want git", name)
|
||||
}
|
||||
if len(args) < 7 || args[0] != "-C" || args[1] != repoRoot || args[2] != "bundle" || args[3] != "create" || args[5] != "--all" {
|
||||
t.Fatalf("unexpected bundle args: %v", args)
|
||||
}
|
||||
if !reflect.DeepEqual(args[6:], []string{"deadbeef", "cafebabe"}) {
|
||||
t.Fatalf("bundle revs = %v, want deadbeef/cafebabe", args[6:])
|
||||
}
|
||||
if err := os.WriteFile(args[4], []byte("bundle-data"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile(bundle): %v", err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
var attachArgs []string
|
||||
opencodeExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
|
||||
attachArgs = append([]string(nil), args...)
|
||||
return nil
|
||||
}
|
||||
|
||||
spec := vmRunRepoSpec{
|
||||
RepoRoot: repoRoot,
|
||||
RepoName: "repo",
|
||||
HeadCommit: "deadbeef",
|
||||
CurrentBranch: "main",
|
||||
BranchName: "feature",
|
||||
BaseCommit: "cafebabe",
|
||||
OverlayPaths: []string{"tracked.txt", "nested/keep.txt"},
|
||||
}
|
||||
err := runVMRun(
|
||||
context.Background(),
|
||||
"/tmp/bangerd.sock",
|
||||
model.DaemonConfig{SSHKeyPath: "/tmp/id_ed25519"},
|
||||
strings.NewReader(""),
|
||||
&bytes.Buffer{},
|
||||
&bytes.Buffer{},
|
||||
api.VMCreateParams{Name: "devbox"},
|
||||
spec,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("runVMRun: %v", err)
|
||||
}
|
||||
|
||||
if waitAddress != "172.16.0.2:22" {
|
||||
t.Fatalf("waitAddress = %q, want 172.16.0.2:22", waitAddress)
|
||||
}
|
||||
if waitKeyPath != "/tmp/id_ed25519" {
|
||||
t.Fatalf("waitKeyPath = %q, want /tmp/id_ed25519", waitKeyPath)
|
||||
}
|
||||
if waitInterval <= 0 {
|
||||
t.Fatalf("waitInterval = %s, want positive interval", waitInterval)
|
||||
}
|
||||
if dialAddress != waitAddress {
|
||||
t.Fatalf("dialAddress = %q, want %q", dialAddress, waitAddress)
|
||||
}
|
||||
if dialKeyPath != waitKeyPath {
|
||||
t.Fatalf("dialKeyPath = %q, want %q", dialKeyPath, waitKeyPath)
|
||||
}
|
||||
if fakeClient.uploadPath != vmRunGuestBundlePath {
|
||||
t.Fatalf("uploadPath = %q, want %q", fakeClient.uploadPath, vmRunGuestBundlePath)
|
||||
}
|
||||
if fakeClient.uploadMode != 0o600 {
|
||||
t.Fatalf("uploadMode = %v, want 0600", fakeClient.uploadMode)
|
||||
}
|
||||
if string(fakeClient.uploadData) != "bundle-data" {
|
||||
t.Fatalf("uploadData = %q, want bundle-data", string(fakeClient.uploadData))
|
||||
}
|
||||
if !strings.Contains(fakeClient.script, `git clone "$BUNDLE" "$DIR"`) {
|
||||
t.Fatalf("script = %q, want clone command", fakeClient.script)
|
||||
}
|
||||
if !strings.Contains(fakeClient.script, `git -C "$DIR" checkout -B 'feature' 'cafebabe'`) {
|
||||
t.Fatalf("script = %q, want guest branch checkout", fakeClient.script)
|
||||
}
|
||||
if fakeClient.streamSourceDir != repoRoot {
|
||||
t.Fatalf("streamSourceDir = %q, want %q", fakeClient.streamSourceDir, repoRoot)
|
||||
}
|
||||
if !reflect.DeepEqual(fakeClient.streamEntries, spec.OverlayPaths) {
|
||||
t.Fatalf("streamEntries = %v, want %v", fakeClient.streamEntries, spec.OverlayPaths)
|
||||
}
|
||||
if fakeClient.streamCommand != "tar -C '/root/repo' --strip-components=1 -xf -" {
|
||||
t.Fatalf("streamCommand = %q", fakeClient.streamCommand)
|
||||
}
|
||||
wantAttach := []string{"attach", "--dir", "/root/repo", "http://172.16.0.2:4096"}
|
||||
if !reflect.DeepEqual(attachArgs, wantAttach) {
|
||||
t.Fatalf("attachArgs = %v, want %v", attachArgs, wantAttach)
|
||||
}
|
||||
if !fakeClient.closed {
|
||||
t.Fatal("guest client should be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBangerdCommandRejectsArgs(t *testing.T) {
|
||||
cmd := NewBangerdCommand()
|
||||
cmd.SetArgs([]string{"extra"})
|
||||
|
|
@ -965,3 +1267,48 @@ func TestAbsolutizeImageBuildPaths(t *testing.T) {
|
|||
func testCLIResolvedVM(id, name string) model.VMRecord {
|
||||
return model.VMRecord{ID: id, Name: name}
|
||||
}
|
||||
|
||||
func testRunGit(t *testing.T, dir string, args ...string) string {
|
||||
t.Helper()
|
||||
cmd := exec.Command("git", append([]string{"-c", "commit.gpgsign=false", "-C", dir}, args...)...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("git %v: %v\n%s", args, err, string(output))
|
||||
}
|
||||
return string(output)
|
||||
}
|
||||
|
||||
type testVMRunGuestClient struct {
|
||||
closed bool
|
||||
uploadPath string
|
||||
uploadMode os.FileMode
|
||||
uploadData []byte
|
||||
script string
|
||||
streamSourceDir string
|
||||
streamEntries []string
|
||||
streamCommand string
|
||||
}
|
||||
|
||||
func (c *testVMRunGuestClient) Close() error {
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testVMRunGuestClient) UploadFile(ctx context.Context, remotePath string, mode os.FileMode, data []byte, logWriter io.Writer) error {
|
||||
c.uploadPath = remotePath
|
||||
c.uploadMode = mode
|
||||
c.uploadData = append([]byte(nil), data...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testVMRunGuestClient) RunScript(ctx context.Context, script string, logWriter io.Writer) error {
|
||||
c.script = script
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testVMRunGuestClient) StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error {
|
||||
c.streamSourceDir = sourceDir
|
||||
c.streamEntries = append([]string(nil), entries...)
|
||||
c.streamCommand = remoteCommand
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,9 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
|
@ -94,6 +96,19 @@ func (c *Client) StreamTar(ctx context.Context, sourceDir, remoteCommand string,
|
|||
return errors.Join(runErr, tarErr)
|
||||
}
|
||||
|
||||
func (c *Client) StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error {
|
||||
reader, writer := io.Pipe()
|
||||
writeErr := make(chan error, 1)
|
||||
go func() {
|
||||
writeErr <- writeTarEntriesArchive(writer, sourceDir, entries)
|
||||
_ = writer.Close()
|
||||
}()
|
||||
|
||||
runErr := c.runSession(ctx, remoteCommand, reader, logWriter)
|
||||
tarErr := <-writeErr
|
||||
return errors.Join(runErr, tarErr)
|
||||
}
|
||||
|
||||
func (c *Client) runSession(ctx context.Context, command string, stdin io.Reader, logWriter io.Writer) error {
|
||||
if c == nil || c.client == nil {
|
||||
return fmt.Errorf("ssh client is not connected")
|
||||
|
|
@ -197,3 +212,68 @@ func writeTarArchive(dst io.Writer, sourceDir string) error {
|
|||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func writeTarEntriesArchive(dst io.Writer, sourceDir string, entries []string) error {
|
||||
tw := tar.NewWriter(dst)
|
||||
defer tw.Close()
|
||||
|
||||
sourceDir = filepath.Clean(sourceDir)
|
||||
rootName := filepath.Base(sourceDir)
|
||||
|
||||
uniqueEntries := make([]string, 0, len(entries))
|
||||
seen := make(map[string]struct{}, len(entries))
|
||||
for _, entry := range entries {
|
||||
entry = strings.TrimSpace(entry)
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
entry = filepath.Clean(entry)
|
||||
if entry == "." || entry == ".." || strings.HasPrefix(entry, ".."+string(filepath.Separator)) {
|
||||
return fmt.Errorf("tar entry %q escapes source dir", entry)
|
||||
}
|
||||
if _, ok := seen[entry]; ok {
|
||||
continue
|
||||
}
|
||||
seen[entry] = struct{}{}
|
||||
uniqueEntries = append(uniqueEntries, entry)
|
||||
}
|
||||
sort.Strings(uniqueEntries)
|
||||
|
||||
for _, entry := range uniqueEntries {
|
||||
fullPath := filepath.Join(sourceDir, entry)
|
||||
info, err := os.Lstat(fullPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
linkTarget := ""
|
||||
if info.Mode()&os.ModeSymlink != 0 {
|
||||
linkTarget, err = os.Readlink(fullPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
header, err := tar.FileInfoHeader(info, linkTarget)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header.Name = path.Join(rootName, filepath.ToSlash(entry))
|
||||
if err := tw.WriteHeader(header); err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.Mode().IsRegular() {
|
||||
continue
|
||||
}
|
||||
file, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.Copy(tw, file); err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -91,3 +91,52 @@ func TestAuthorizedPublicKey(t *testing.T) {
|
|||
t.Fatalf("key type = %q, want %q", parsed.Type(), ssh.KeyAlgoRSA)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTarEntriesArchiveIncludesOnlySelectedPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sourceDir := filepath.Join(t.TempDir(), "repo")
|
||||
if err := os.MkdirAll(filepath.Join(sourceDir, "nested"), 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(nested): %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(sourceDir, "tracked.txt"), []byte("tracked"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(tracked.txt): %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(sourceDir, "nested", "keep.txt"), []byte("keep"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(keep.txt): %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(sourceDir, "nested", "skip.txt"), []byte("skip"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(skip.txt): %v", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := writeTarEntriesArchive(&buf, sourceDir, []string{"tracked.txt", "nested/keep.txt"}); err != nil {
|
||||
t.Fatalf("writeTarEntriesArchive: %v", err)
|
||||
}
|
||||
|
||||
tr := tar.NewReader(bytes.NewReader(buf.Bytes()))
|
||||
var names []string
|
||||
for {
|
||||
header, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("tar.Next: %v", err)
|
||||
}
|
||||
names = append(names, header.Name)
|
||||
}
|
||||
|
||||
want := map[string]struct{}{
|
||||
"repo/tracked.txt": {},
|
||||
"repo/nested/keep.txt": {},
|
||||
}
|
||||
if len(names) != len(want) {
|
||||
t.Fatalf("archive names = %v, want %d entries", names, len(want))
|
||||
}
|
||||
for _, name := range names {
|
||||
if _, ok := want[name]; !ok {
|
||||
t.Fatalf("unexpected archive entry %q in %v", name, names)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue