Fix guest session cwd preflight scripts

Guest session cwd and command preflight helpers were emitting literal
`\\n` separators, so the guest shell saw malformed one-line scripts and
could fail `preflight_cwd` even when `/root/repo` already existed.

Replace those builders with real newlines, and fix the nearby attach
helper commands that were making the same mistake.

Add a small daemon guest-SSH seam so workspace preparation and session
start can share a fake backend in tests, then cover the regression with
an end-to-end daemon test for `PrepareVMWorkspace` followed by
`StartGuestSession` on `/root/repo`.

Validation: `GOCACHE=/tmp/banger-gocache go test ./internal/daemon` and
`GOCACHE=/tmp/banger-gocache go test ./...`.
This commit is contained in:
Thales Maciel 2026-04-13 18:26:19 -03:00
parent 37c4c091ec
commit 5e26fd7544
No known key found for this signature in database
GPG key ID: 33112E6833C34679
4 changed files with 296 additions and 49 deletions

View file

@ -27,33 +27,36 @@ import (
) )
type Daemon struct { type Daemon struct {
layout paths.Layout layout paths.Layout
config model.DaemonConfig config model.DaemonConfig
store *store.Store store *store.Store
runner system.CommandRunner runner system.CommandRunner
logger *slog.Logger logger *slog.Logger
mu sync.Mutex mu sync.Mutex
createOpsMu sync.Mutex createOpsMu sync.Mutex
createOps map[string]*vmCreateOperationState createOps map[string]*vmCreateOperationState
imageBuildOpsMu sync.Mutex imageBuildOpsMu sync.Mutex
imageBuildOps map[string]*imageBuildOperationState imageBuildOps map[string]*imageBuildOperationState
vmLocksMu sync.Mutex vmLocksMu sync.Mutex
vmLocks map[string]*sync.Mutex vmLocks map[string]*sync.Mutex
sessionControllers map[string]*guestSessionController sessionControllers map[string]*guestSessionController
tapPoolMu sync.Mutex tapPoolMu sync.Mutex
tapPool []string tapPool []string
tapPoolNext int tapPoolNext int
closing chan struct{} closing chan struct{}
once sync.Once once sync.Once
pid int pid int
listener net.Listener listener net.Listener
webListener net.Listener webListener net.Listener
webServer *http.Server webServer *http.Server
webURL string webURL string
vmDNS *vmdns.Server vmDNS *vmdns.Server
vmCaps []vmCapability vmCaps []vmCapability
imageBuild func(context.Context, imageBuildSpec) error imageBuild func(context.Context, imageBuildSpec) error
requestHandler func(context.Context, rpc.Request) rpc.Response requestHandler func(context.Context, rpc.Request) rpc.Response
guestWaitForSSH func(context.Context, string, string, time.Duration) error
guestDial func(context.Context, string, string) (guestSSHClient, error)
waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error)
} }
func Open(ctx context.Context) (d *Daemon, err error) { func Open(ctx context.Context) (d *Daemon, err error) {

View file

@ -50,6 +50,35 @@ var guestSessionHostCommandOutputFunc = func(ctx context.Context, name string, a
return output, fmt.Errorf("%s: %w: %s", command, err, detail) return output, fmt.Errorf("%s: %w: %s", command, err, detail)
} }
type guestSSHClient interface {
Close() error
RunScript(context.Context, string, io.Writer) error
UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error
StreamTar(context.Context, string, string, io.Writer) error
StreamTarEntries(context.Context, string, []string, string, io.Writer) error
}
func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval time.Duration) error {
if d != nil && d.guestWaitForSSH != nil {
return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
}
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
}
func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) {
if d != nil && d.guestDial != nil {
return d.guestDial(ctx, address, d.config.SSHKeyPath)
}
return guest.Dial(ctx, address, d.config.SSHKeyPath)
}
func (d *Daemon) waitForGuestSessionReadyHook(ctx context.Context, vm model.VMRecord, session model.GuestSession) (model.GuestSession, error) {
if d != nil && d.waitForGuestSessionReady != nil {
return d.waitForGuestSessionReady(ctx, vm, session)
}
return d.waitForGuestSessionReadyDefault(ctx, vm, session)
}
type guestSessionController struct { type guestSessionController struct {
stream *guest.StreamSession stream *guest.StreamSession
streams []*guest.StreamSession streams []*guest.StreamSession
@ -215,10 +244,10 @@ func (d *Daemon) startGuestSessionLocked(ctx context.Context, vm model.VMRecord,
return session, nil return session, nil
} }
address := net.JoinHostPort(vm.Runtime.GuestIP, "22") address := net.JoinHostPort(vm.Runtime.GuestIP, "22")
if err := guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, 250*time.Millisecond); err != nil { if err := d.waitForGuestSSH(ctx, address, 250*time.Millisecond); err != nil {
return fail("ssh_unavailable", fmt.Sprintf("guest ssh unavailable: %v", err), "") return fail("ssh_unavailable", fmt.Sprintf("guest ssh unavailable: %v", err), "")
} }
client, err := guest.Dial(ctx, address, d.config.SSHKeyPath) client, err := d.dialGuest(ctx, address)
if err != nil { if err != nil {
return fail("dial_guest", fmt.Sprintf("dial guest ssh: %v", err), "") return fail("dial_guest", fmt.Sprintf("dial guest ssh: %v", err), "")
} }
@ -243,7 +272,7 @@ func (d *Daemon) startGuestSessionLocked(ctx context.Context, vm model.VMRecord,
} }
readyCtx, cancel := context.WithTimeout(ctx, 5*time.Second) readyCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() defer cancel()
updated, err := d.waitForGuestSessionReady(readyCtx, vm, session) updated, err := d.waitForGuestSessionReadyHook(readyCtx, vm, session)
if err != nil { if err != nil {
return fail("ready_wait", "guest session did not report ready state", err.Error()) return fail("ready_wait", "guest session did not report ready state", err.Error())
} }
@ -628,7 +657,7 @@ func (d *Daemon) watchGuestSessionAttach(id string, controller *guestSessionCont
} }
} }
func (d *Daemon) waitForGuestSessionReady(ctx context.Context, vm model.VMRecord, session model.GuestSession) (model.GuestSession, error) { func (d *Daemon) waitForGuestSessionReadyDefault(ctx context.Context, vm model.VMRecord, session model.GuestSession) (model.GuestSession, error) {
for { for {
updated, err := d.refreshGuestSession(ctx, vm, session) updated, err := d.refreshGuestSession(ctx, vm, session)
if err == nil { if err == nil {
@ -1037,35 +1066,35 @@ func normalizeGuestSessionRequiredCommands(command string, extras []string) []st
func guestSessionCWDPreflightScript(cwd string) string { func guestSessionCWDPreflightScript(cwd string) string {
var script strings.Builder var script strings.Builder
script.WriteString("set -euo pipefail\\n") script.WriteString("set -euo pipefail\n")
fmt.Fprintf(&script, "DIR=%s\\n", guestShellQuote(defaultGuestSessionCWD(cwd))) fmt.Fprintf(&script, "DIR=%s\n", guestShellQuote(defaultGuestSessionCWD(cwd)))
script.WriteString("if [ ! -d \"$DIR\" ]; then echo \"missing cwd: $DIR\"; exit 1; fi\\n") script.WriteString("if [ ! -d \"$DIR\" ]; then echo \"missing cwd: $DIR\"; exit 1; fi\n")
return script.String() return script.String()
} }
func guestSessionCommandPreflightScript(commands []string) string { func guestSessionCommandPreflightScript(commands []string) string {
var script strings.Builder var script strings.Builder
script.WriteString("set -euo pipefail\\n") script.WriteString("set -euo pipefail\n")
script.WriteString("check_command() {\\n") script.WriteString("check_command() {\n")
script.WriteString(" cmd=\\\"$1\\\"\\n") script.WriteString(" cmd=\"$1\"\n")
script.WriteString(" case \\\"$cmd\\\" in\\n") script.WriteString(" case \"$cmd\" in\n")
script.WriteString(" */*) [ -x \\\"$cmd\\\" ] || { echo \\\"missing command: $cmd\\\"; exit 1; } ;;\\n") script.WriteString(" */*) [ -x \"$cmd\" ] || { echo \"missing command: $cmd\"; exit 1; } ;;\n")
script.WriteString(" *) command -v \\\"$cmd\\\" >/dev/null 2>&1 || { echo \\\"missing command: $cmd\\\"; exit 1; } ;;\\n") script.WriteString(" *) command -v \"$cmd\" >/dev/null 2>&1 || { echo \"missing command: $cmd\"; exit 1; } ;;\n")
script.WriteString(" esac\\n") script.WriteString(" esac\n")
script.WriteString("}\\n") script.WriteString("}\n")
for _, command := range commands { for _, command := range commands {
fmt.Fprintf(&script, "check_command %s\\n", guestShellQuote(command)) fmt.Fprintf(&script, "check_command %s\n", guestShellQuote(command))
} }
return script.String() return script.String()
} }
func guestSessionAttachInputCommand(sessionID string) string { func guestSessionAttachInputCommand(sessionID string) string {
path := guestSessionStdinPipePath(sessionID) path := guestSessionStdinPipePath(sessionID)
return "bash -lc " + guestShellQuote(fmt.Sprintf("set -euo pipefail\\n[ -p %s ] || mkfifo -m 600 %s\\nexec cat > %s\\n", guestShellQuote(path), guestShellQuote(path), guestShellQuote(path))) return "bash -lc " + guestShellQuote(fmt.Sprintf("set -euo pipefail\n[ -p %s ] || mkfifo -m 600 %s\nexec cat > %s\n", guestShellQuote(path), guestShellQuote(path), guestShellQuote(path)))
} }
func guestSessionAttachTailCommand(path string) string { func guestSessionAttachTailCommand(path string) string {
return "bash -lc " + guestShellQuote(fmt.Sprintf("set -euo pipefail\\ntouch %s\\nexec tail -n 0 -F %s 2>/dev/null\\n", guestShellQuote(path), guestShellQuote(path))) return "bash -lc " + guestShellQuote(fmt.Sprintf("set -euo pipefail\ntouch %s\nexec tail -n 0 -F %s 2>/dev/null\n", guestShellQuote(path), guestShellQuote(path)))
} }
func guestSessionEnvLines(values map[string]string) []string { func guestSessionEnvLines(values map[string]string) []string {

View file

@ -0,0 +1,216 @@
package daemon
import (
"context"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
"banger/internal/api"
"banger/internal/model"
)
type fakeGuestSSHClient struct {
t *testing.T
existingDirs map[string]bool
closed bool
}
func (f *fakeGuestSSHClient) Close() error {
f.closed = true
return nil
}
func (f *fakeGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error {
f.t.Helper()
switch {
case strings.Contains(script, `\n`):
return fmt.Errorf("script still contains escaped newline literals: %q", script)
case strings.Contains(script, `echo "missing cwd: $DIR"`):
if strings.Contains(script, "DIR='/root/repo'\n") && f.existingDirs["/root/repo"] {
return nil
}
return fmt.Errorf("missing cwd")
case strings.Contains(script, "check_command() {"):
return nil
case strings.Contains(script, `git config --global --add safe.directory "$DIR"`):
if strings.Contains(script, "DIR='/root/repo'\n") {
f.existingDirs["/root/repo"] = true
return nil
}
return fmt.Errorf("workspace finalize used unexpected guest path")
case strings.Contains(script, "chmod -R a-w"):
if f.existingDirs["/root/repo"] {
return nil
}
return fmt.Errorf("workspace path missing during readonly chmod")
case strings.Contains(script, "nohup bash "):
return nil
default:
return nil
}
}
func (f *fakeGuestSSHClient) UploadFile(_ context.Context, _ string, _ os.FileMode, _ []byte, _ io.Writer) error {
return nil
}
func (f *fakeGuestSSHClient) StreamTar(_ context.Context, _ string, command string, _ io.Writer) error {
if strings.Contains(command, "/root/repo") {
f.existingDirs["/root/repo"] = true
return nil
}
return fmt.Errorf("unexpected StreamTar command: %s", command)
}
func (f *fakeGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, command string, _ io.Writer) error {
if strings.Contains(command, "/root/repo") {
f.existingDirs["/root/repo"] = true
return nil
}
return fmt.Errorf("unexpected StreamTarEntries command: %s", command)
}
func TestGuestSessionPreflightScriptsUseRealNewlines(t *testing.T) {
t.Parallel()
cwdScript := guestSessionCWDPreflightScript("/root/repo")
if strings.Contains(cwdScript, `\n`) {
t.Fatalf("cwd preflight script still contains escaped newline literals: %q", cwdScript)
}
if !strings.Contains(cwdScript, "\n") {
t.Fatalf("cwd preflight script should contain real newlines: %q", cwdScript)
}
commandScript := guestSessionCommandPreflightScript([]string{"git", "pi"})
if strings.Contains(commandScript, `\n`) {
t.Fatalf("command preflight script still contains escaped newline literals: %q", commandScript)
}
if !strings.Contains(commandScript, "\n") {
t.Fatalf("command preflight script should contain real newlines: %q", commandScript)
}
attachInput := guestSessionAttachInputCommand("session-id")
if strings.Contains(attachInput, `\n`) {
t.Fatalf("attach input command still contains escaped newline literals: %q", attachInput)
}
attachTail := guestSessionAttachTailCommand("/tmp/stdout.log")
if strings.Contains(attachTail, `\n`) {
t.Fatalf("attach tail command still contains escaped newline literals: %q", attachTail)
}
}
func TestPrepareWorkspaceThenStartGuestSessionPassesCWDPreflight(t *testing.T) {
ctx := context.Background()
db := openDaemonStore(t)
repoRoot := filepath.Join(t.TempDir(), "repo")
if err := os.MkdirAll(repoRoot, 0o755); err != nil {
t.Fatalf("MkdirAll(repoRoot): %v", err)
}
if err := os.WriteFile(filepath.Join(repoRoot, "README.md"), []byte("hello\n"), 0o644); err != nil {
t.Fatalf("WriteFile(README.md): %v", err)
}
runGit := func(args ...string) {
t.Helper()
cmd := exec.Command("git", append([]string{"-C", repoRoot}, args...)...)
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("git %s: %v\n%s", strings.Join(args, " "), err, output)
}
}
runGit("init", "-b", "main")
runGit("config", "user.name", "Test User")
runGit("config", "user.email", "test@example.com")
runGit("add", ".")
runGit("commit", "-m", "initial")
apiSock := filepath.Join(t.TempDir(), "fc.sock")
firecracker := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
if err := firecracker.Start(); err != nil {
t.Fatalf("start fake firecracker: %v", err)
}
t.Cleanup(func() {
if firecracker.Process != nil {
_ = firecracker.Process.Kill()
_, _ = firecracker.Process.Wait()
}
})
vm := testVM("pi-devbox", "image-pi", "172.16.0.77")
vm.State = model.VMStateRunning
vm.Runtime.State = model.VMStateRunning
vm.Runtime.PID = firecracker.Process.Pid
vm.Runtime.APISockPath = apiSock
upsertDaemonVM(t, ctx, db, vm)
fakeClient := &fakeGuestSSHClient{t: t, existingDirs: map[string]bool{}}
d := &Daemon{
store: db,
config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")},
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
}
d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
d.guestDial = func(context.Context, string, string) (guestSSHClient, error) { return fakeClient, nil }
d.waitForGuestSessionReady = func(_ context.Context, _ model.VMRecord, session model.GuestSession) (model.GuestSession, error) {
now := model.Now()
session.Status = model.GuestSessionStatusRunning
session.GuestPID = 4242
session.StartedAt = now
session.UpdatedAt = now
session.Attachable = session.StdinMode == model.GuestSessionStdinPipe
session.Reattachable = session.StdinMode == model.GuestSessionStdinPipe
return session, nil
}
workspace, err := d.PrepareVMWorkspace(ctx, api.VMWorkspacePrepareParams{
IDOrName: vm.Name,
SourcePath: repoRoot,
GuestPath: "/root/repo",
ReadOnly: true,
})
if err != nil {
t.Fatalf("PrepareVMWorkspace: %v", err)
}
if workspace.GuestPath != "/root/repo" {
t.Fatalf("PrepareVMWorkspace guest path = %q, want /root/repo", workspace.GuestPath)
}
if !fakeClient.existingDirs["/root/repo"] {
t.Fatalf("workspace prepare did not mark /root/repo as present in fake guest")
}
session, err := d.StartGuestSession(ctx, api.GuestSessionStartParams{
VMIDOrName: vm.Name,
Name: "testpi",
Command: "pi",
Args: []string{"--mode", "rpc", "--no-session"},
CWD: "/root/repo",
StdinMode: string(model.GuestSessionStdinPipe),
RequiredCommands: []string{"git"},
})
if err != nil {
t.Fatalf("StartGuestSession: %v", err)
}
if session.Status != model.GuestSessionStatusRunning {
t.Fatalf("session status = %q, want %q", session.Status, model.GuestSessionStatusRunning)
}
if session.LaunchStage != "" {
t.Fatalf("session launch stage = %q, want empty", session.LaunchStage)
}
if session.LaunchMessage != "" {
t.Fatalf("session launch message = %q, want empty", session.LaunchMessage)
}
if session.GuestPID == 0 {
t.Fatalf("session guest pid = 0, want non-zero")
}
if !session.Attachable {
t.Fatalf("session should be attachable for pipe stdin mode")
}
}

View file

@ -14,7 +14,6 @@ import (
"time" "time"
"banger/internal/api" "banger/internal/api"
"banger/internal/guest"
"banger/internal/model" "banger/internal/model"
"banger/internal/system" "banger/internal/system"
) )
@ -77,10 +76,10 @@ func (d *Daemon) prepareVMWorkspaceLocked(ctx context.Context, vm model.VMRecord
return model.WorkspacePrepareResult{}, fmt.Errorf("workspace mode %q does not support git submodules in %s (%s); use --mode full_copy", mode, spec.RepoRoot, strings.Join(spec.Submodules, ", ")) return model.WorkspacePrepareResult{}, fmt.Errorf("workspace mode %q does not support git submodules in %s (%s); use --mode full_copy", mode, spec.RepoRoot, strings.Join(spec.Submodules, ", "))
} }
address := net.JoinHostPort(vm.Runtime.GuestIP, "22") address := net.JoinHostPort(vm.Runtime.GuestIP, "22")
if err := guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, 250*time.Millisecond); err != nil { if err := d.waitForGuestSSH(ctx, address, 250*time.Millisecond); err != nil {
return model.WorkspacePrepareResult{}, fmt.Errorf("guest ssh unavailable: %w", err) return model.WorkspacePrepareResult{}, fmt.Errorf("guest ssh unavailable: %w", err)
} }
client, err := guest.Dial(ctx, address, d.config.SSHKeyPath) client, err := d.dialGuest(ctx, address)
if err != nil { if err != nil {
return model.WorkspacePrepareResult{}, fmt.Errorf("dial guest ssh: %w", err) return model.WorkspacePrepareResult{}, fmt.Errorf("dial guest ssh: %w", err)
} }
@ -179,7 +178,7 @@ func inspectWorkspaceRepo(ctx context.Context, rawPath, branchName, fromRef stri
}, nil }, nil
} }
func importWorkspaceRepoToGuest(ctx context.Context, client *guest.Client, spec workspaceRepoSpec, guestPath string, mode model.WorkspacePrepareMode) error { func importWorkspaceRepoToGuest(ctx context.Context, client guestSSHClient, spec workspaceRepoSpec, guestPath string, mode model.WorkspacePrepareMode) error {
switch mode { switch mode {
case model.WorkspacePrepareModeFullCopy: case model.WorkspacePrepareModeFullCopy:
var copyLog bytes.Buffer var copyLog bytes.Buffer