diff --git a/internal/daemon/vm.go b/internal/daemon/vm.go index aae5e80..d3ba689 100644 --- a/internal/daemon/vm.go +++ b/internal/daemon/vm.go @@ -703,10 +703,20 @@ func (d *Daemon) flattenNestedWorkHome(ctx context.Context, workMount string) er if !exists(nestedHome) { return nil } - if err := system.CopyDirContents(ctx, d.runner, nestedHome, workMount, true); err != nil { + if _, err := d.runner.RunSudo(ctx, "chmod", "755", nestedHome); err != nil { return err } - _, err := d.runner.RunSudo(ctx, "rm", "-rf", nestedHome) + entries, err := os.ReadDir(nestedHome) + if err != nil { + return err + } + for _, entry := range entries { + sourcePath := filepath.Join(nestedHome, entry.Name()) + if _, err := d.runner.RunSudo(ctx, "cp", "-a", sourcePath, workMount+"/"); err != nil { + return err + } + } + _, err = d.runner.RunSudo(ctx, "rm", "-rf", nestedHome) return err } diff --git a/internal/daemon/vm_test.go b/internal/daemon/vm_test.go index ffed6d1..e5f9442 100644 --- a/internal/daemon/vm_test.go +++ b/internal/daemon/vm_test.go @@ -279,6 +279,35 @@ func TestSetVMDiskResizeFailsPreflightWhenToolsMissing(t *testing.T) { } } +func TestFlattenNestedWorkHomeCopiesEntriesIndividually(t *testing.T) { + t.Parallel() + + workMount := t.TempDir() + nestedHome := filepath.Join(workMount, "root") + if err := os.MkdirAll(filepath.Join(nestedHome, ".ssh"), 0o755); err != nil { + t.Fatalf("MkdirAll(.ssh): %v", err) + } + if err := os.WriteFile(filepath.Join(nestedHome, "notes.txt"), []byte("seed"), 0o644); err != nil { + t.Fatalf("WriteFile(notes.txt): %v", err) + } + + runner := &scriptedRunner{ + t: t, + steps: []runnerStep{ + sudoStep("", nil, "chmod", "755", nestedHome), + sudoStep("", nil, "cp", "-a", filepath.Join(nestedHome, ".ssh"), workMount+"/"), + sudoStep("", nil, "cp", "-a", filepath.Join(nestedHome, "notes.txt"), workMount+"/"), + sudoStep("", nil, "rm", "-rf", nestedHome), + }, + } + d := &Daemon{runner: runner} + + if err := d.flattenNestedWorkHome(context.Background(), workMount); err != nil { + t.Fatalf("flattenNestedWorkHome: %v", err) + } + runner.assertExhausted() +} + func TestCreateVMRejectsNonPositiveCPUAndMemory(t *testing.T) { d := &Daemon{} if _, err := d.CreateVM(context.Background(), api.VMCreateParams{VCPUCount: ptr(0)}); err == nil || !strings.Contains(err.Error(), "vcpu must be a positive integer") { diff --git a/internal/firecracker/client.go b/internal/firecracker/client.go index ad56e8f..2b72d4a 100644 --- a/internal/firecracker/client.go +++ b/internal/firecracker/client.go @@ -132,7 +132,10 @@ func buildConfig(cfg MachineConfig) sdk.Config { } func buildProcessRunner(cfg MachineConfig, logFile *os.File) *exec.Cmd { - cmd := exec.Command("sudo", "-n", cfg.BinaryPath, "--api-sock", cfg.SocketPath, "--id", cfg.VMID) + script := "umask 000 && exec " + shellQuote(cfg.BinaryPath) + + " --api-sock " + shellQuote(cfg.SocketPath) + + " --id " + shellQuote(cfg.VMID) + cmd := exec.Command("sudo", "-n", "sh", "-c", script) cmd.Stdin = nil if logFile != nil { cmd.Stdout = logFile @@ -141,6 +144,10 @@ func buildProcessRunner(cfg MachineConfig, logFile *os.File) *exec.Cmd { return cmd } +func shellQuote(value string) string { + return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'" +} + func newLogger(base *slog.Logger) *logrus.Entry { logger := logrus.New() logger.SetOutput(io.Discard) diff --git a/internal/firecracker/client_test.go b/internal/firecracker/client_test.go index 89fd184..1db4549 100644 --- a/internal/firecracker/client_test.go +++ b/internal/firecracker/client_test.go @@ -58,7 +58,7 @@ func TestBuildConfig(t *testing.T) { } } -func TestBuildProcessRunnerUsesDirectSudoCommand(t *testing.T) { +func TestBuildProcessRunnerUsesSudoShellWrapper(t *testing.T) { cmd := buildProcessRunner(MachineConfig{ BinaryPath: "/repo/firecracker", SocketPath: "/tmp/fc.sock", @@ -68,14 +68,15 @@ func TestBuildProcessRunnerUsesDirectSudoCommand(t *testing.T) { if cmd.Path != "/usr/bin/sudo" && cmd.Path != "sudo" { t.Fatalf("command path = %q", cmd.Path) } - if len(cmd.Args) != 7 { + if len(cmd.Args) != 5 { t.Fatalf("args = %v", cmd.Args) } - want := []string{"sudo", "-n", "/repo/firecracker", "--api-sock", "/tmp/fc.sock", "--id", "vm-1"} - for i, arg := range want { - if cmd.Args[i] != arg { - t.Fatalf("args[%d] = %q, want %q (all args: %v)", i, cmd.Args[i], arg, cmd.Args) - } + if cmd.Args[1] != "-n" || cmd.Args[2] != "sh" || cmd.Args[3] != "-c" { + t.Fatalf("args = %v", cmd.Args) + } + want := "umask 000 && exec '/repo/firecracker' --api-sock '/tmp/fc.sock' --id 'vm-1'" + if cmd.Args[4] != want { + t.Fatalf("script = %q, want %q", cmd.Args[4], want) } if cmd.Cancel != nil { t.Fatal("process runner should not be tied to a request context")