diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index de747a2..92eeb19 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -27,33 +27,36 @@ import ( ) type Daemon struct { - layout paths.Layout - config model.DaemonConfig - store *store.Store - runner system.CommandRunner - logger *slog.Logger - mu sync.Mutex - createOpsMu sync.Mutex - createOps map[string]*vmCreateOperationState - imageBuildOpsMu sync.Mutex - imageBuildOps map[string]*imageBuildOperationState - vmLocksMu sync.Mutex - vmLocks map[string]*sync.Mutex - sessionControllers map[string]*guestSessionController - tapPoolMu sync.Mutex - tapPool []string - tapPoolNext int - closing chan struct{} - once sync.Once - pid int - listener net.Listener - webListener net.Listener - webServer *http.Server - webURL string - vmDNS *vmdns.Server - vmCaps []vmCapability - imageBuild func(context.Context, imageBuildSpec) error - requestHandler func(context.Context, rpc.Request) rpc.Response + layout paths.Layout + config model.DaemonConfig + store *store.Store + runner system.CommandRunner + logger *slog.Logger + mu sync.Mutex + createOpsMu sync.Mutex + createOps map[string]*vmCreateOperationState + imageBuildOpsMu sync.Mutex + imageBuildOps map[string]*imageBuildOperationState + vmLocksMu sync.Mutex + vmLocks map[string]*sync.Mutex + sessionControllers map[string]*guestSessionController + tapPoolMu sync.Mutex + tapPool []string + tapPoolNext int + closing chan struct{} + once sync.Once + pid int + listener net.Listener + webListener net.Listener + webServer *http.Server + webURL string + vmDNS *vmdns.Server + vmCaps []vmCapability + imageBuild func(context.Context, imageBuildSpec) error + requestHandler func(context.Context, rpc.Request) rpc.Response + guestWaitForSSH func(context.Context, string, string, time.Duration) error + guestDial func(context.Context, string, string) (guestSSHClient, error) + waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error) } func Open(ctx context.Context) (d *Daemon, err error) { diff --git a/internal/daemon/guest_sessions.go b/internal/daemon/guest_sessions.go index b0f9dcd..e8fa2a0 100644 --- a/internal/daemon/guest_sessions.go +++ b/internal/daemon/guest_sessions.go @@ -50,6 +50,35 @@ var guestSessionHostCommandOutputFunc = func(ctx context.Context, name string, a return output, fmt.Errorf("%s: %w: %s", command, err, detail) } +type guestSSHClient interface { + Close() error + RunScript(context.Context, string, io.Writer) error + UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error + StreamTar(context.Context, string, string, io.Writer) error + StreamTarEntries(context.Context, string, []string, string, io.Writer) error +} + +func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval time.Duration) error { + if d != nil && d.guestWaitForSSH != nil { + return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval) + } + return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, interval) +} + +func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) { + if d != nil && d.guestDial != nil { + return d.guestDial(ctx, address, d.config.SSHKeyPath) + } + return guest.Dial(ctx, address, d.config.SSHKeyPath) +} + +func (d *Daemon) waitForGuestSessionReadyHook(ctx context.Context, vm model.VMRecord, session model.GuestSession) (model.GuestSession, error) { + if d != nil && d.waitForGuestSessionReady != nil { + return d.waitForGuestSessionReady(ctx, vm, session) + } + return d.waitForGuestSessionReadyDefault(ctx, vm, session) +} + type guestSessionController struct { stream *guest.StreamSession streams []*guest.StreamSession @@ -215,10 +244,10 @@ func (d *Daemon) startGuestSessionLocked(ctx context.Context, vm model.VMRecord, return session, nil } address := net.JoinHostPort(vm.Runtime.GuestIP, "22") - if err := guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, 250*time.Millisecond); err != nil { + if err := d.waitForGuestSSH(ctx, address, 250*time.Millisecond); err != nil { return fail("ssh_unavailable", fmt.Sprintf("guest ssh unavailable: %v", err), "") } - client, err := guest.Dial(ctx, address, d.config.SSHKeyPath) + client, err := d.dialGuest(ctx, address) if err != nil { return fail("dial_guest", fmt.Sprintf("dial guest ssh: %v", err), "") } @@ -243,7 +272,7 @@ func (d *Daemon) startGuestSessionLocked(ctx context.Context, vm model.VMRecord, } readyCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - updated, err := d.waitForGuestSessionReady(readyCtx, vm, session) + updated, err := d.waitForGuestSessionReadyHook(readyCtx, vm, session) if err != nil { return fail("ready_wait", "guest session did not report ready state", err.Error()) } @@ -628,7 +657,7 @@ func (d *Daemon) watchGuestSessionAttach(id string, controller *guestSessionCont } } -func (d *Daemon) waitForGuestSessionReady(ctx context.Context, vm model.VMRecord, session model.GuestSession) (model.GuestSession, error) { +func (d *Daemon) waitForGuestSessionReadyDefault(ctx context.Context, vm model.VMRecord, session model.GuestSession) (model.GuestSession, error) { for { updated, err := d.refreshGuestSession(ctx, vm, session) if err == nil { @@ -1037,35 +1066,35 @@ func normalizeGuestSessionRequiredCommands(command string, extras []string) []st func guestSessionCWDPreflightScript(cwd string) string { var script strings.Builder - script.WriteString("set -euo pipefail\\n") - fmt.Fprintf(&script, "DIR=%s\\n", guestShellQuote(defaultGuestSessionCWD(cwd))) - script.WriteString("if [ ! -d \"$DIR\" ]; then echo \"missing cwd: $DIR\"; exit 1; fi\\n") + script.WriteString("set -euo pipefail\n") + fmt.Fprintf(&script, "DIR=%s\n", guestShellQuote(defaultGuestSessionCWD(cwd))) + script.WriteString("if [ ! -d \"$DIR\" ]; then echo \"missing cwd: $DIR\"; exit 1; fi\n") return script.String() } func guestSessionCommandPreflightScript(commands []string) string { var script strings.Builder - script.WriteString("set -euo pipefail\\n") - script.WriteString("check_command() {\\n") - script.WriteString(" cmd=\\\"$1\\\"\\n") - script.WriteString(" case \\\"$cmd\\\" in\\n") - script.WriteString(" */*) [ -x \\\"$cmd\\\" ] || { echo \\\"missing command: $cmd\\\"; exit 1; } ;;\\n") - script.WriteString(" *) command -v \\\"$cmd\\\" >/dev/null 2>&1 || { echo \\\"missing command: $cmd\\\"; exit 1; } ;;\\n") - script.WriteString(" esac\\n") - script.WriteString("}\\n") + script.WriteString("set -euo pipefail\n") + script.WriteString("check_command() {\n") + script.WriteString(" cmd=\"$1\"\n") + script.WriteString(" case \"$cmd\" in\n") + script.WriteString(" */*) [ -x \"$cmd\" ] || { echo \"missing command: $cmd\"; exit 1; } ;;\n") + script.WriteString(" *) command -v \"$cmd\" >/dev/null 2>&1 || { echo \"missing command: $cmd\"; exit 1; } ;;\n") + script.WriteString(" esac\n") + script.WriteString("}\n") for _, command := range commands { - fmt.Fprintf(&script, "check_command %s\\n", guestShellQuote(command)) + fmt.Fprintf(&script, "check_command %s\n", guestShellQuote(command)) } return script.String() } func guestSessionAttachInputCommand(sessionID string) string { path := guestSessionStdinPipePath(sessionID) - return "bash -lc " + guestShellQuote(fmt.Sprintf("set -euo pipefail\\n[ -p %s ] || mkfifo -m 600 %s\\nexec cat > %s\\n", guestShellQuote(path), guestShellQuote(path), guestShellQuote(path))) + return "bash -lc " + guestShellQuote(fmt.Sprintf("set -euo pipefail\n[ -p %s ] || mkfifo -m 600 %s\nexec cat > %s\n", guestShellQuote(path), guestShellQuote(path), guestShellQuote(path))) } func guestSessionAttachTailCommand(path string) string { - return "bash -lc " + guestShellQuote(fmt.Sprintf("set -euo pipefail\\ntouch %s\\nexec tail -n 0 -F %s 2>/dev/null\\n", guestShellQuote(path), guestShellQuote(path))) + return "bash -lc " + guestShellQuote(fmt.Sprintf("set -euo pipefail\ntouch %s\nexec tail -n 0 -F %s 2>/dev/null\n", guestShellQuote(path), guestShellQuote(path))) } func guestSessionEnvLines(values map[string]string) []string { diff --git a/internal/daemon/guest_sessions_test.go b/internal/daemon/guest_sessions_test.go new file mode 100644 index 0000000..49e5127 --- /dev/null +++ b/internal/daemon/guest_sessions_test.go @@ -0,0 +1,216 @@ +package daemon + +import ( + "context" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "banger/internal/api" + "banger/internal/model" +) + +type fakeGuestSSHClient struct { + t *testing.T + existingDirs map[string]bool + closed bool +} + +func (f *fakeGuestSSHClient) Close() error { + f.closed = true + return nil +} + +func (f *fakeGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error { + f.t.Helper() + switch { + case strings.Contains(script, `\n`): + return fmt.Errorf("script still contains escaped newline literals: %q", script) + case strings.Contains(script, `echo "missing cwd: $DIR"`): + if strings.Contains(script, "DIR='/root/repo'\n") && f.existingDirs["/root/repo"] { + return nil + } + return fmt.Errorf("missing cwd") + case strings.Contains(script, "check_command() {"): + return nil + case strings.Contains(script, `git config --global --add safe.directory "$DIR"`): + if strings.Contains(script, "DIR='/root/repo'\n") { + f.existingDirs["/root/repo"] = true + return nil + } + return fmt.Errorf("workspace finalize used unexpected guest path") + case strings.Contains(script, "chmod -R a-w"): + if f.existingDirs["/root/repo"] { + return nil + } + return fmt.Errorf("workspace path missing during readonly chmod") + case strings.Contains(script, "nohup bash "): + return nil + default: + return nil + } +} + +func (f *fakeGuestSSHClient) UploadFile(_ context.Context, _ string, _ os.FileMode, _ []byte, _ io.Writer) error { + return nil +} + +func (f *fakeGuestSSHClient) StreamTar(_ context.Context, _ string, command string, _ io.Writer) error { + if strings.Contains(command, "/root/repo") { + f.existingDirs["/root/repo"] = true + return nil + } + return fmt.Errorf("unexpected StreamTar command: %s", command) +} + +func (f *fakeGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, command string, _ io.Writer) error { + if strings.Contains(command, "/root/repo") { + f.existingDirs["/root/repo"] = true + return nil + } + return fmt.Errorf("unexpected StreamTarEntries command: %s", command) +} + +func TestGuestSessionPreflightScriptsUseRealNewlines(t *testing.T) { + t.Parallel() + + cwdScript := guestSessionCWDPreflightScript("/root/repo") + if strings.Contains(cwdScript, `\n`) { + t.Fatalf("cwd preflight script still contains escaped newline literals: %q", cwdScript) + } + if !strings.Contains(cwdScript, "\n") { + t.Fatalf("cwd preflight script should contain real newlines: %q", cwdScript) + } + + commandScript := guestSessionCommandPreflightScript([]string{"git", "pi"}) + if strings.Contains(commandScript, `\n`) { + t.Fatalf("command preflight script still contains escaped newline literals: %q", commandScript) + } + if !strings.Contains(commandScript, "\n") { + t.Fatalf("command preflight script should contain real newlines: %q", commandScript) + } + + attachInput := guestSessionAttachInputCommand("session-id") + if strings.Contains(attachInput, `\n`) { + t.Fatalf("attach input command still contains escaped newline literals: %q", attachInput) + } + + attachTail := guestSessionAttachTailCommand("/tmp/stdout.log") + if strings.Contains(attachTail, `\n`) { + t.Fatalf("attach tail command still contains escaped newline literals: %q", attachTail) + } +} + +func TestPrepareWorkspaceThenStartGuestSessionPassesCWDPreflight(t *testing.T) { + ctx := context.Background() + db := openDaemonStore(t) + + repoRoot := filepath.Join(t.TempDir(), "repo") + if err := os.MkdirAll(repoRoot, 0o755); err != nil { + t.Fatalf("MkdirAll(repoRoot): %v", err) + } + if err := os.WriteFile(filepath.Join(repoRoot, "README.md"), []byte("hello\n"), 0o644); err != nil { + t.Fatalf("WriteFile(README.md): %v", err) + } + runGit := func(args ...string) { + t.Helper() + cmd := exec.Command("git", append([]string{"-C", repoRoot}, args...)...) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("git %s: %v\n%s", strings.Join(args, " "), err, output) + } + } + runGit("init", "-b", "main") + runGit("config", "user.name", "Test User") + runGit("config", "user.email", "test@example.com") + runGit("add", ".") + runGit("commit", "-m", "initial") + + apiSock := filepath.Join(t.TempDir(), "fc.sock") + firecracker := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock)) + if err := firecracker.Start(); err != nil { + t.Fatalf("start fake firecracker: %v", err) + } + t.Cleanup(func() { + if firecracker.Process != nil { + _ = firecracker.Process.Kill() + _, _ = firecracker.Process.Wait() + } + }) + + vm := testVM("pi-devbox", "image-pi", "172.16.0.77") + vm.State = model.VMStateRunning + vm.Runtime.State = model.VMStateRunning + vm.Runtime.PID = firecracker.Process.Pid + vm.Runtime.APISockPath = apiSock + upsertDaemonVM(t, ctx, db, vm) + + fakeClient := &fakeGuestSSHClient{t: t, existingDirs: map[string]bool{}} + d := &Daemon{ + store: db, + config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")}, + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + } + d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil } + d.guestDial = func(context.Context, string, string) (guestSSHClient, error) { return fakeClient, nil } + d.waitForGuestSessionReady = func(_ context.Context, _ model.VMRecord, session model.GuestSession) (model.GuestSession, error) { + now := model.Now() + session.Status = model.GuestSessionStatusRunning + session.GuestPID = 4242 + session.StartedAt = now + session.UpdatedAt = now + session.Attachable = session.StdinMode == model.GuestSessionStdinPipe + session.Reattachable = session.StdinMode == model.GuestSessionStdinPipe + return session, nil + } + + workspace, err := d.PrepareVMWorkspace(ctx, api.VMWorkspacePrepareParams{ + IDOrName: vm.Name, + SourcePath: repoRoot, + GuestPath: "/root/repo", + ReadOnly: true, + }) + if err != nil { + t.Fatalf("PrepareVMWorkspace: %v", err) + } + if workspace.GuestPath != "/root/repo" { + t.Fatalf("PrepareVMWorkspace guest path = %q, want /root/repo", workspace.GuestPath) + } + if !fakeClient.existingDirs["/root/repo"] { + t.Fatalf("workspace prepare did not mark /root/repo as present in fake guest") + } + + session, err := d.StartGuestSession(ctx, api.GuestSessionStartParams{ + VMIDOrName: vm.Name, + Name: "testpi", + Command: "pi", + Args: []string{"--mode", "rpc", "--no-session"}, + CWD: "/root/repo", + StdinMode: string(model.GuestSessionStdinPipe), + RequiredCommands: []string{"git"}, + }) + if err != nil { + t.Fatalf("StartGuestSession: %v", err) + } + if session.Status != model.GuestSessionStatusRunning { + t.Fatalf("session status = %q, want %q", session.Status, model.GuestSessionStatusRunning) + } + if session.LaunchStage != "" { + t.Fatalf("session launch stage = %q, want empty", session.LaunchStage) + } + if session.LaunchMessage != "" { + t.Fatalf("session launch message = %q, want empty", session.LaunchMessage) + } + if session.GuestPID == 0 { + t.Fatalf("session guest pid = 0, want non-zero") + } + if !session.Attachable { + t.Fatalf("session should be attachable for pipe stdin mode") + } +} diff --git a/internal/daemon/workspace.go b/internal/daemon/workspace.go index 1bc396a..33fb1e9 100644 --- a/internal/daemon/workspace.go +++ b/internal/daemon/workspace.go @@ -14,7 +14,6 @@ import ( "time" "banger/internal/api" - "banger/internal/guest" "banger/internal/model" "banger/internal/system" ) @@ -77,10 +76,10 @@ func (d *Daemon) prepareVMWorkspaceLocked(ctx context.Context, vm model.VMRecord return model.WorkspacePrepareResult{}, fmt.Errorf("workspace mode %q does not support git submodules in %s (%s); use --mode full_copy", mode, spec.RepoRoot, strings.Join(spec.Submodules, ", ")) } address := net.JoinHostPort(vm.Runtime.GuestIP, "22") - if err := guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, 250*time.Millisecond); err != nil { + if err := d.waitForGuestSSH(ctx, address, 250*time.Millisecond); err != nil { return model.WorkspacePrepareResult{}, fmt.Errorf("guest ssh unavailable: %w", err) } - client, err := guest.Dial(ctx, address, d.config.SSHKeyPath) + client, err := d.dialGuest(ctx, address) if err != nil { return model.WorkspacePrepareResult{}, fmt.Errorf("dial guest ssh: %w", err) } @@ -179,7 +178,7 @@ func inspectWorkspaceRepo(ctx context.Context, rawPath, branchName, fromRef stri }, nil } -func importWorkspaceRepoToGuest(ctx context.Context, client *guest.Client, spec workspaceRepoSpec, guestPath string, mode model.WorkspacePrepareMode) error { +func importWorkspaceRepoToGuest(ctx context.Context, client guestSSHClient, spec workspaceRepoSpec, guestPath string, mode model.WorkspacePrepareMode) error { switch mode { case model.WorkspacePrepareModeFullCopy: var copyLog bytes.Buffer