package daemon import ( "context" "fmt" "io" "log/slog" "os" "os/exec" "path/filepath" "strings" "testing" "time" "banger/internal/api" sess "banger/internal/daemon/session" "banger/internal/model" "banger/internal/store" ) type fakeGuestSSHClient struct { t *testing.T existingDirs map[string]bool closed bool } func (f *fakeGuestSSHClient) Close() error { f.closed = true return nil } func (f *fakeGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error { f.t.Helper() switch { case strings.Contains(script, `\n`): return fmt.Errorf("script still contains escaped newline literals: %q", script) case strings.Contains(script, `echo "missing cwd: $DIR"`): if strings.Contains(script, "DIR='/root/repo'\n") && f.existingDirs["/root/repo"] { return nil } return fmt.Errorf("missing cwd") case strings.Contains(script, "check_command() {"): return nil case strings.Contains(script, `git config --global --add safe.directory "$DIR"`): if strings.Contains(script, "DIR='/root/repo'\n") { f.existingDirs["/root/repo"] = true return nil } return fmt.Errorf("workspace finalize used unexpected guest path") case strings.Contains(script, "chmod -R a-w"): if f.existingDirs["/root/repo"] { return nil } return fmt.Errorf("workspace path missing during readonly chmod") case strings.Contains(script, "nohup bash "): return nil default: return nil } } func (f *fakeGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) { return nil, nil } func (f *fakeGuestSSHClient) UploadFile(_ context.Context, _ string, _ os.FileMode, _ []byte, _ io.Writer) error { return nil } func (f *fakeGuestSSHClient) StreamTar(_ context.Context, _ string, command string, _ io.Writer) error { if strings.Contains(command, "/root/repo") { f.existingDirs["/root/repo"] = true return nil } return fmt.Errorf("unexpected StreamTar command: %s", command) } func (f *fakeGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, command string, _ io.Writer) error { if strings.Contains(command, "/root/repo") { f.existingDirs["/root/repo"] = true return nil } return fmt.Errorf("unexpected StreamTarEntries command: %s", command) } func TestSendToGuestSession_HappyPath(t *testing.T) { t.Parallel() ctx := context.Background() db := openDaemonStore(t) apiSock := filepath.Join(t.TempDir(), "fc.sock") firecracker := startFakeFirecracker(t, apiSock) vm := testVM("sendbox", "image-send", "172.16.0.88") vm.State = model.VMStateRunning vm.Runtime.State = model.VMStateRunning vm.Runtime.PID = firecracker.Process.Pid vm.Runtime.APISockPath = apiSock upsertDaemonVM(t, ctx, db, vm) session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning) if err := db.UpsertGuestSession(ctx, session); err != nil { t.Fatalf("UpsertGuestSession: %v", err) } fake := &recordingGuestSSHClient{} d := newSendTestDaemon(t, db, fake) payload := []byte(`{"type":"abort"}` + "\n") result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ VMIDOrName: vm.Name, SessionIDOrName: session.Name, Payload: payload, }) if err != nil { t.Fatalf("SendToGuestSession: %v", err) } if result.BytesWritten != len(payload) { t.Fatalf("BytesWritten = %d, want %d", result.BytesWritten, len(payload)) } if result.Session.ID != session.ID { t.Fatalf("Session.ID = %q, want %q", result.Session.ID, session.ID) } if len(fake.uploadedFiles) != 1 { t.Fatalf("UploadFile call count = %d, want 1", len(fake.uploadedFiles)) } for path, data := range fake.uploadedFiles { if !strings.HasPrefix(path, "/tmp/banger-send-") { t.Fatalf("upload path = %q, want /tmp/banger-send-... prefix", path) } if string(data) != string(payload) { t.Fatalf("upload data = %q, want %q", data, payload) } } if len(fake.ranScripts) != 1 { t.Fatalf("RunScript call count = %d, want 1", len(fake.ranScripts)) } script := fake.ranScripts[0] pipePath := sess.StdinPipePath(session.ID) if !strings.Contains(script, "cat ") { t.Fatalf("send script missing cat command: %q", script) } if !strings.Contains(script, pipePath) { t.Fatalf("send script missing pipe path %q: %q", pipePath, script) } if !strings.Contains(script, "rm -f ") { t.Fatalf("send script missing rm cleanup: %q", script) } } func TestSendToGuestSession_EmptyPayload(t *testing.T) { t.Parallel() ctx := context.Background() db := openDaemonStore(t) apiSock := filepath.Join(t.TempDir(), "fc.sock") firecracker := startFakeFirecracker(t, apiSock) vm := testVM("sendbox-empty", "image-send", "172.16.0.89") vm.State = model.VMStateRunning vm.Runtime.State = model.VMStateRunning vm.Runtime.PID = firecracker.Process.Pid vm.Runtime.APISockPath = apiSock upsertDaemonVM(t, ctx, db, vm) session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning) if err := db.UpsertGuestSession(ctx, session); err != nil { t.Fatalf("UpsertGuestSession: %v", err) } fake := &recordingGuestSSHClient{} d := newSendTestDaemon(t, db, fake) result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ VMIDOrName: vm.Name, SessionIDOrName: session.Name, Payload: nil, }) if err != nil { t.Fatalf("SendToGuestSession(empty): %v", err) } if result.BytesWritten != 0 { t.Fatalf("BytesWritten = %d, want 0", result.BytesWritten) } if fake.dialCount != 0 { t.Fatalf("SSH dial count = %d, want 0 for empty payload", fake.dialCount) } } func TestSendToGuestSession_NotPipeMode(t *testing.T) { t.Parallel() ctx := context.Background() db := openDaemonStore(t) vm := testVM("sendbox-closed", "image-send", "172.16.0.90") vm.State = model.VMStateRunning upsertDaemonVM(t, ctx, db, vm) session := testGuestSession(vm.ID, model.GuestSessionStdinClosed, model.GuestSessionStatusRunning) if err := db.UpsertGuestSession(ctx, session); err != nil { t.Fatalf("UpsertGuestSession: %v", err) } d := &Daemon{store: db} _, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ VMIDOrName: vm.Name, SessionIDOrName: session.Name, Payload: []byte("hello\n"), }) if err == nil || !strings.Contains(err.Error(), "stdin pipe") { t.Fatalf("error = %v, want 'stdin pipe' error", err) } } func TestSendToGuestSession_SessionNotRunning(t *testing.T) { t.Parallel() ctx := context.Background() db := openDaemonStore(t) vm := testVM("sendbox-failed", "image-send", "172.16.0.91") vm.State = model.VMStateRunning upsertDaemonVM(t, ctx, db, vm) session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusFailed) if err := db.UpsertGuestSession(ctx, session); err != nil { t.Fatalf("UpsertGuestSession: %v", err) } d := &Daemon{store: db} _, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ VMIDOrName: vm.Name, SessionIDOrName: session.Name, Payload: []byte("hello\n"), }) if err == nil || !strings.Contains(err.Error(), "not running") { t.Fatalf("error = %v, want 'not running' error", err) } } func TestSendToGuestSession_VMNotRunning(t *testing.T) { t.Parallel() ctx := context.Background() db := openDaemonStore(t) vm := testVM("sendbox-stopped", "image-send", "172.16.0.92") vm.State = model.VMStateStopped upsertDaemonVM(t, ctx, db, vm) session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning) if err := db.UpsertGuestSession(ctx, session); err != nil { t.Fatalf("UpsertGuestSession: %v", err) } d := &Daemon{store: db} _, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ VMIDOrName: vm.Name, SessionIDOrName: session.Name, Payload: []byte("hello\n"), }) if err == nil || !strings.Contains(err.Error(), "not running") { t.Fatalf("error = %v, want 'not running' error", err) } } // recordingGuestSSHClient captures UploadFile and RunScript calls for send tests. type recordingGuestSSHClient struct { dialCount int uploadedFiles map[string][]byte ranScripts []string } func (r *recordingGuestSSHClient) Close() error { return nil } func (r *recordingGuestSSHClient) UploadFile(_ context.Context, path string, _ os.FileMode, data []byte, _ io.Writer) error { if r.uploadedFiles == nil { r.uploadedFiles = make(map[string][]byte) } copy := make([]byte, len(data)) _ = copy[:len(data):len(data)] for i, b := range data { copy[i] = b } r.uploadedFiles[path] = copy return nil } func (r *recordingGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error { r.ranScripts = append(r.ranScripts, script) return nil } func (r *recordingGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) { return nil, nil } func (r *recordingGuestSSHClient) StreamTar(_ context.Context, _ string, _ string, _ io.Writer) error { return nil } func (r *recordingGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, _ string, _ io.Writer) error { return nil } func newSendTestDaemon(t *testing.T, db *store.Store, fake *recordingGuestSSHClient) *Daemon { t.Helper() d := &Daemon{ store: db, config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")}, logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } d.guestDial = func(_ context.Context, _ string, _ string) (guestSSHClient, error) { fake.dialCount++ return fake, nil } return d } func testGuestSession(vmID string, stdinMode model.GuestSessionStdinMode, status model.GuestSessionStatus) model.GuestSession { now := model.Now() id := vmID + "-sess-id" return model.GuestSession{ ID: id, VMID: vmID, Name: vmID + "-sess", Backend: sess.BackendSSH, Command: "pi", Args: []string{"--mode", "rpc"}, CWD: "/root/repo", StdinMode: stdinMode, Status: status, GuestStateDir: sess.StateDir(id), StdoutLogPath: sess.StdoutLogPath(id), StderrLogPath: sess.StderrLogPath(id), Attachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning, Reattachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning, CreatedAt: now, UpdatedAt: now, } } func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd { t.Helper() cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock)) if err := cmd.Start(); err != nil { t.Fatalf("start fake firecracker: %v", err) } t.Cleanup(func() { if cmd.Process != nil { _ = cmd.Process.Kill() _, _ = cmd.Process.Wait() } }) return cmd } func TestGuestSessionPreflightScriptsUseRealNewlines(t *testing.T) { t.Parallel() cwdScript := sess.CWDPreflightScript("/root/repo") if strings.Contains(cwdScript, `\n`) { t.Fatalf("cwd preflight script still contains escaped newline literals: %q", cwdScript) } if !strings.Contains(cwdScript, "\n") { t.Fatalf("cwd preflight script should contain real newlines: %q", cwdScript) } commandScript := sess.CommandPreflightScript([]string{"git", "pi"}) if strings.Contains(commandScript, `\n`) { t.Fatalf("command preflight script still contains escaped newline literals: %q", commandScript) } if !strings.Contains(commandScript, "\n") { t.Fatalf("command preflight script should contain real newlines: %q", commandScript) } attachInput := sess.AttachInputCommand("session-id") if strings.Contains(attachInput, `\n`) { t.Fatalf("attach input command still contains escaped newline literals: %q", attachInput) } attachTail := sess.AttachTailCommand("/tmp/stdout.log") if strings.Contains(attachTail, `\n`) { t.Fatalf("attach tail command still contains escaped newline literals: %q", attachTail) } } func TestPrepareWorkspaceThenStartGuestSessionPassesCWDPreflight(t *testing.T) { ctx := context.Background() db := openDaemonStore(t) repoRoot := filepath.Join(t.TempDir(), "repo") if err := os.MkdirAll(repoRoot, 0o755); err != nil { t.Fatalf("MkdirAll(repoRoot): %v", err) } if err := os.WriteFile(filepath.Join(repoRoot, "README.md"), []byte("hello\n"), 0o644); err != nil { t.Fatalf("WriteFile(README.md): %v", err) } runGit := func(args ...string) { t.Helper() cmd := exec.Command("git", append([]string{"-C", repoRoot}, args...)...) output, err := cmd.CombinedOutput() if err != nil { t.Fatalf("git %s: %v\n%s", strings.Join(args, " "), err, output) } } runGit("init", "-b", "main") runGit("config", "user.name", "Test User") runGit("config", "user.email", "test@example.com") runGit("add", ".") runGit("commit", "-m", "initial") apiSock := filepath.Join(t.TempDir(), "fc.sock") firecracker := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock)) if err := firecracker.Start(); err != nil { t.Fatalf("start fake firecracker: %v", err) } t.Cleanup(func() { if firecracker.Process != nil { _ = firecracker.Process.Kill() _, _ = firecracker.Process.Wait() } }) vm := testVM("pi-devbox", "image-pi", "172.16.0.77") vm.State = model.VMStateRunning vm.Runtime.State = model.VMStateRunning vm.Runtime.PID = firecracker.Process.Pid vm.Runtime.APISockPath = apiSock upsertDaemonVM(t, ctx, db, vm) fakeClient := &fakeGuestSSHClient{t: t, existingDirs: map[string]bool{}} d := &Daemon{ store: db, config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")}, logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil } d.guestDial = func(context.Context, string, string) (guestSSHClient, error) { return fakeClient, nil } d.waitForGuestSessionReady = func(_ context.Context, _ model.VMRecord, session model.GuestSession) (model.GuestSession, error) { now := model.Now() session.Status = model.GuestSessionStatusRunning session.GuestPID = 4242 session.StartedAt = now session.UpdatedAt = now session.Attachable = session.StdinMode == model.GuestSessionStdinPipe session.Reattachable = session.StdinMode == model.GuestSessionStdinPipe return session, nil } workspace, err := d.PrepareVMWorkspace(ctx, api.VMWorkspacePrepareParams{ IDOrName: vm.Name, SourcePath: repoRoot, GuestPath: "/root/repo", ReadOnly: true, }) if err != nil { t.Fatalf("PrepareVMWorkspace: %v", err) } if workspace.GuestPath != "/root/repo" { t.Fatalf("PrepareVMWorkspace guest path = %q, want /root/repo", workspace.GuestPath) } if !fakeClient.existingDirs["/root/repo"] { t.Fatalf("workspace prepare did not mark /root/repo as present in fake guest") } session, err := d.StartGuestSession(ctx, api.GuestSessionStartParams{ VMIDOrName: vm.Name, Name: "testpi", Command: "pi", Args: []string{"--mode", "rpc", "--no-session"}, CWD: "/root/repo", StdinMode: string(model.GuestSessionStdinPipe), RequiredCommands: []string{"git"}, }) if err != nil { t.Fatalf("StartGuestSession: %v", err) } if session.Status != model.GuestSessionStatusRunning { t.Fatalf("session status = %q, want %q", session.Status, model.GuestSessionStatusRunning) } if session.LaunchStage != "" { t.Fatalf("session launch stage = %q, want empty", session.LaunchStage) } if session.LaunchMessage != "" { t.Fatalf("session launch message = %q, want empty", session.LaunchMessage) } if session.GuestPID == 0 { t.Fatalf("session guest pid = 0, want non-zero") } if !session.Attachable { t.Fatalf("session should be attachable for pipe stdin mode") } }