diff --git a/internal/api/types.go b/internal/api/types.go index 6756d7e..eccb9c8 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -203,6 +203,29 @@ type GuestSessionAttachBeginResult struct { StreamFormat string `json:"stream_format"` } +type GuestSessionSendParams struct { + VMIDOrName string `json:"vm_id_or_name"` + SessionIDOrName string `json:"session_id_or_name"` + Payload []byte `json:"payload"` +} + +type GuestSessionSendResult struct { + Session model.GuestSession `json:"session"` + BytesWritten int `json:"bytes_written"` +} + +type WorkspaceExportParams struct { + IDOrName string `json:"id_or_name"` + GuestPath string `json:"guest_path,omitempty"` +} + +type WorkspaceExportResult struct { + GuestPath string `json:"guest_path"` + Patch []byte `json:"patch"` + ChangedFiles []string `json:"changed_files"` + HasChanges bool `json:"has_changes"` +} + type VMWorkspacePrepareParams struct { IDOrName string `json:"id_or_name"` SourcePath string `json:"source_path"` diff --git a/internal/cli/banger.go b/internal/cli/banger.go index 458f706..38393b9 100644 --- a/internal/cli/banger.go +++ b/internal/cli/banger.go @@ -89,6 +89,9 @@ var ( vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { return rpc.Call[api.VMWorkspacePrepareResult](ctx, socketPath, "vm.workspace.prepare", params) } + vmWorkspaceExportFunc = func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { + return rpc.Call[api.WorkspaceExportResult](ctx, socketPath, "vm.workspace.export", params) + } guestSessionStartFunc = func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error) { return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.start", params) } @@ -110,6 +113,9 @@ var ( guestSessionAttachBeginFunc = func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) { return rpc.Call[api.GuestSessionAttachBeginResult](ctx, socketPath, "guest.session.attach.begin", params) } + guestSessionSendFunc = func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { + return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params) + } guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { return guest.WaitForSSH(ctx, address, privateKeyPath, interval) } @@ -869,7 +875,10 @@ func newVMWorkspaceCommand() *cobra.Command { Short: "Manage repository workspaces inside a running VM", RunE: helpNoArgs, } - cmd.AddCommand(newVMWorkspacePrepareCommand()) + cmd.AddCommand( + newVMWorkspacePrepareCommand(), + newVMWorkspaceExportCommand(), + ) return cmd } @@ -929,6 +938,52 @@ func newVMWorkspacePrepareCommand() *cobra.Command { return cmd } +func newVMWorkspaceExportCommand() *cobra.Command { + var guestPath string + var outputPath string + cmd := &cobra.Command{ + Use: "export ", + Short: "Pull changes from a guest workspace back to the host as a patch", + Long: "Stage all changes inside the guest workspace (git add -A) and emit a binary-safe unified diff against HEAD. With no --output flag the patch is written to stdout so it can be piped directly to git apply.", + Args: exactArgsUsage(1, "usage: banger vm workspace export "), + Example: strings.TrimSpace(` + banger vm workspace export devbox | git apply + banger vm workspace export devbox --output worker.diff + banger vm workspace export devbox --guest-path /root/project --output changes.diff +`), + RunE: func(cmd *cobra.Command, args []string) error { + layout, _, err := ensureDaemon(cmd.Context()) + if err != nil { + return err + } + result, err := vmWorkspaceExportFunc(cmd.Context(), layout.SocketPath, api.WorkspaceExportParams{ + IDOrName: args[0], + GuestPath: guestPath, + }) + if err != nil { + return err + } + if !result.HasChanges { + _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "no changes") + return nil + } + if outputPath != "" { + if err := os.WriteFile(outputPath, result.Patch, 0o644); err != nil { + return fmt.Errorf("write patch: %w", err) + } + _, err = fmt.Fprintf(cmd.ErrOrStderr(), "patch written to %s (%d bytes, %d files)\n", + outputPath, len(result.Patch), len(result.ChangedFiles)) + return err + } + _, err = cmd.OutOrStdout().Write(result.Patch) + return err + }, + } + cmd.Flags().StringVar(&guestPath, "guest-path", "/root/repo", "guest workspace path") + cmd.Flags().StringVar(&outputPath, "output", "", "write patch to this file instead of stdout") + return cmd +} + func newVMSessionCommand() *cobra.Command { cmd := &cobra.Command{ Use: "session", @@ -944,6 +999,7 @@ func newVMSessionCommand() *cobra.Command { newVMSessionStopCommand(), newVMSessionKillCommand(), newVMSessionAttachCommand(), + newVMSessionSendCommand(), ) return cmd } @@ -1134,6 +1190,51 @@ func newVMSessionAttachCommand() *cobra.Command { } } +func newVMSessionSendCommand() *cobra.Command { + var message string + cmd := &cobra.Command{ + Use: "send ", + Short: "Write bytes to a running guest session's stdin pipe", + Long: "Write a payload to the stdin pipe of a running pipe-mode guest session without holding the exclusive attach. Use --message for an inline JSONL string, or pipe bytes via stdin when --message is omitted. A trailing newline is appended to --message values that lack one.", + Args: exactArgsUsage(2, "usage: banger vm session send [--message '']"), + Example: strings.TrimSpace(` + banger vm session send devbox planner --message '{"type":"abort"}' + banger vm session send devbox planner --message '{"type":"steer","message":"Focus on src/"}' + echo '{"type":"prompt","prompt":"Summarize."}' | banger vm session send devbox planner +`), + RunE: func(cmd *cobra.Command, args []string) error { + layout, _, err := ensureDaemon(cmd.Context()) + if err != nil { + return err + } + var payload []byte + if message != "" { + payload = []byte(message) + if len(payload) > 0 && payload[len(payload)-1] != '\n' { + payload = append(payload, '\n') + } + } else { + payload, err = io.ReadAll(cmd.InOrStdin()) + if err != nil { + return fmt.Errorf("read stdin: %w", err) + } + } + result, err := guestSessionSendFunc(cmd.Context(), layout.SocketPath, api.GuestSessionSendParams{ + VMIDOrName: args[0], + SessionIDOrName: args[1], + Payload: payload, + }) + if err != nil { + return err + } + _, err = fmt.Fprintf(cmd.OutOrStdout(), "sent %d bytes to session %s\n", result.BytesWritten, result.Session.Name) + return err + }, + } + cmd.Flags().StringVar(&message, "message", "", "JSONL message to send; a trailing newline is appended if absent") + return cmd +} + func parseKeyValuePairs(values []string) (map[string]string, error) { if len(values) == 0 { return nil, nil diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 11e2b57..ca5f38b 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -1936,3 +1936,285 @@ func (c *testVMRunGuestClient) StreamTarEntries(ctx context.Context, sourceDir s c.streamCommand = remoteCommand return nil } + +func TestVMSessionSendCommandExists(t *testing.T) { + root := NewBangerCommand() + vm, _, err := root.Find([]string{"vm"}) + if err != nil { + t.Fatalf("find vm: %v", err) + } + session, _, err := vm.Find([]string{"session"}) + if err != nil { + t.Fatalf("find session: %v", err) + } + if _, _, err := session.Find([]string{"send"}); err != nil { + t.Fatalf("find session send: %v", err) + } +} + +func TestVMSessionSendRejectsWrongArgCount(t *testing.T) { + cmd := NewBangerCommand() + cmd.SetArgs([]string{"vm", "session", "send", "only-one-arg"}) + err := cmd.Execute() + if err == nil || !strings.Contains(err.Error(), "usage: banger vm session send") { + t.Fatalf("Execute() error = %v, want send usage error", err) + } +} + +func stubEnsureDaemonForSend(t *testing.T) { + t.Helper() + t.Setenv("XDG_CONFIG_HOME", filepath.Join(t.TempDir(), "config")) + t.Setenv("XDG_STATE_HOME", filepath.Join(t.TempDir(), "state")) + t.Setenv("XDG_RUNTIME_DIR", filepath.Join(t.TempDir(), "run")) + origPing := daemonPingFunc + t.Cleanup(func() { daemonPingFunc = origPing }) + daemonPingFunc = func(context.Context, string) (api.PingResult, error) { + return api.PingResult{Status: "ok", PID: os.Getpid()}, nil + } +} + +func TestVMSessionSendWithMessageFlag(t *testing.T) { + stubEnsureDaemonForSend(t) + + original := guestSessionSendFunc + t.Cleanup(func() { guestSessionSendFunc = original }) + + var capturedParams api.GuestSessionSendParams + guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { + capturedParams = params + return api.GuestSessionSendResult{ + Session: model.GuestSession{ID: "sess-id", Name: "planner"}, + BytesWritten: len(params.Payload), + }, nil + } + + cmd := NewBangerCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner", "--message", `{"type":"abort"}`}) + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute: %v", err) + } + + wantPayload := []byte(`{"type":"abort"}` + "\n") + if string(capturedParams.Payload) != string(wantPayload) { + t.Fatalf("payload = %q, want %q", capturedParams.Payload, wantPayload) + } + if capturedParams.VMIDOrName != "devbox" { + t.Fatalf("VMIDOrName = %q, want %q", capturedParams.VMIDOrName, "devbox") + } + if capturedParams.SessionIDOrName != "planner" { + t.Fatalf("SessionIDOrName = %q, want %q", capturedParams.SessionIDOrName, "planner") + } + if !strings.Contains(out.String(), "17") { + t.Fatalf("output = %q, want bytes_written in output", out.String()) + } +} + +func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) { + stubEnsureDaemonForSend(t) + + original := guestSessionSendFunc + t.Cleanup(func() { guestSessionSendFunc = original }) + + var capturedPayload []byte + guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { + capturedPayload = params.Payload + return api.GuestSessionSendResult{ + Session: model.GuestSession{Name: "s"}, + BytesWritten: len(params.Payload), + }, nil + } + + cmd := NewBangerCommand() + cmd.SetOut(io.Discard) + cmd.SetArgs([]string{"vm", "session", "send", "devbox", "s", "--message", "{\"type\":\"abort\"}\n"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute: %v", err) + } + + // Must not double-append newline. + if capturedPayload[len(capturedPayload)-1] != '\n' { + t.Fatalf("payload missing trailing newline: %q", capturedPayload) + } + if len(capturedPayload) > 0 && capturedPayload[len(capturedPayload)-2] == '\n' { + t.Fatalf("payload has double trailing newline: %q", capturedPayload) + } +} + +func TestVMSessionSendFromStdin(t *testing.T) { + stubEnsureDaemonForSend(t) + + original := guestSessionSendFunc + t.Cleanup(func() { guestSessionSendFunc = original }) + + var capturedPayload []byte + guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { + capturedPayload = params.Payload + return api.GuestSessionSendResult{ + Session: model.GuestSession{Name: "planner"}, + BytesWritten: len(params.Payload), + }, nil + } + + stdinPayload := `{"type":"steer","message":"Focus on src/"}` + "\n" + cmd := NewBangerCommand() + cmd.SetOut(io.Discard) + cmd.SetIn(strings.NewReader(stdinPayload)) + cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute: %v", err) + } + + if string(capturedPayload) != stdinPayload { + t.Fatalf("payload = %q, want %q", capturedPayload, stdinPayload) + } +} + +func TestVMWorkspaceExportCommandExists(t *testing.T) { + root := NewBangerCommand() + vm, _, err := root.Find([]string{"vm"}) + if err != nil { + t.Fatalf("find vm: %v", err) + } + workspace, _, err := vm.Find([]string{"workspace"}) + if err != nil { + t.Fatalf("find workspace: %v", err) + } + if _, _, err := workspace.Find([]string{"export"}); err != nil { + t.Fatalf("find workspace export: %v", err) + } +} + +func TestVMWorkspaceExportRejectsMissingArg(t *testing.T) { + cmd := NewBangerCommand() + cmd.SetArgs([]string{"vm", "workspace", "export"}) + err := cmd.Execute() + if err == nil || !strings.Contains(err.Error(), "usage: banger vm workspace export") { + t.Fatalf("Execute() error = %v, want usage error", err) + } +} + +func TestVMWorkspaceExportWritesToStdout(t *testing.T) { + stubEnsureDaemonForSend(t) + + origExport := vmWorkspaceExportFunc + t.Cleanup(func() { vmWorkspaceExportFunc = origExport }) + + patch := []byte("diff --git a/main.go b/main.go\nindex 0000000..1111111 100644\n") + vmWorkspaceExportFunc = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { + return api.WorkspaceExportResult{ + GuestPath: params.GuestPath, + Patch: patch, + ChangedFiles: []string{"main.go"}, + HasChanges: true, + }, nil + } + + cmd := NewBangerCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(io.Discard) + cmd.SetArgs([]string{"vm", "workspace", "export", "devbox"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute: %v", err) + } + if !bytes.Equal(out.Bytes(), patch) { + t.Fatalf("stdout = %q, want %q", out.Bytes(), patch) + } +} + +func TestVMWorkspaceExportWritesToFile(t *testing.T) { + stubEnsureDaemonForSend(t) + + origExport := vmWorkspaceExportFunc + t.Cleanup(func() { vmWorkspaceExportFunc = origExport }) + + patch := []byte("diff --git a/main.go b/main.go\n") + vmWorkspaceExportFunc = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { + return api.WorkspaceExportResult{ + GuestPath: "/root/repo", + Patch: patch, + ChangedFiles: []string{"main.go"}, + HasChanges: true, + }, nil + } + + outFile := filepath.Join(t.TempDir(), "worker.diff") + cmd := NewBangerCommand() + cmd.SetOut(io.Discard) + var stderr bytes.Buffer + cmd.SetErr(&stderr) + cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--output", outFile}) + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute: %v", err) + } + + got, err := os.ReadFile(outFile) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if !bytes.Equal(got, patch) { + t.Fatalf("file content = %q, want %q", got, patch) + } + if !strings.Contains(stderr.String(), "worker.diff") { + t.Fatalf("stderr = %q, want output path mentioned", stderr.String()) + } +} + +func TestVMWorkspaceExportNoChanges(t *testing.T) { + stubEnsureDaemonForSend(t) + + origExport := vmWorkspaceExportFunc + t.Cleanup(func() { vmWorkspaceExportFunc = origExport }) + + vmWorkspaceExportFunc = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { + return api.WorkspaceExportResult{ + GuestPath: "/root/repo", + HasChanges: false, + }, nil + } + + cmd := NewBangerCommand() + var out bytes.Buffer + var stderr bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&stderr) + cmd.SetArgs([]string{"vm", "workspace", "export", "devbox"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute: %v", err) + } + if out.Len() != 0 { + t.Fatalf("stdout = %q, want empty when no changes", out.String()) + } + if !strings.Contains(stderr.String(), "no changes") { + t.Fatalf("stderr = %q, want 'no changes'", stderr.String()) + } +} + +func TestVMWorkspaceExportGuestPathFlag(t *testing.T) { + stubEnsureDaemonForSend(t) + + origExport := vmWorkspaceExportFunc + t.Cleanup(func() { vmWorkspaceExportFunc = origExport }) + + var capturedParams api.WorkspaceExportParams + vmWorkspaceExportFunc = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { + capturedParams = params + return api.WorkspaceExportResult{HasChanges: false}, nil + } + + cmd := NewBangerCommand() + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--guest-path", "/root/project"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute: %v", err) + } + if capturedParams.GuestPath != "/root/project" { + t.Fatalf("GuestPath = %q, want /root/project", capturedParams.GuestPath) + } + if capturedParams.IDOrName != "devbox" { + t.Fatalf("IDOrName = %q, want devbox", capturedParams.IDOrName) + } +} diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 92eeb19..9a3b84d 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -407,6 +407,13 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response { } workspace, err := d.PrepareVMWorkspace(ctx, params) return marshalResultOrError(api.VMWorkspacePrepareResult{Workspace: workspace}, err) + case "vm.workspace.export": + params, err := rpc.DecodeParams[api.WorkspaceExportParams](req) + if err != nil { + return rpc.NewError("bad_request", err.Error()) + } + result, err := d.ExportVMWorkspace(ctx, params) + return marshalResultOrError(result, err) case "guest.session.start": params, err := rpc.DecodeParams[api.GuestSessionStartParams](req) if err != nil { @@ -456,6 +463,13 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response { } result, err := d.BeginGuestSessionAttach(ctx, params) return marshalResultOrError(result, err) + case "guest.session.send": + params, err := rpc.DecodeParams[api.GuestSessionSendParams](req) + if err != nil { + return rpc.NewError("bad_request", err.Error()) + } + result, err := d.SendToGuestSession(ctx, params) + return marshalResultOrError(result, err) case "image.list": images, err := d.store.ListImages(ctx) return marshalResultOrError(api.ImageListResult{Images: images}, err) diff --git a/internal/daemon/guest_sessions.go b/internal/daemon/guest_sessions.go index e8fa2a0..cf0f9d9 100644 --- a/internal/daemon/guest_sessions.go +++ b/internal/daemon/guest_sessions.go @@ -53,6 +53,7 @@ var guestSessionHostCommandOutputFunc = func(ctx context.Context, name string, a type guestSSHClient interface { Close() error RunScript(context.Context, string, io.Writer) error + RunScriptOutput(context.Context, string) ([]byte, error) UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error StreamTar(context.Context, string, string, io.Writer) error StreamTarEntries(context.Context, string, []string, string, io.Writer) error @@ -400,6 +401,50 @@ func (d *Daemon) GuestSessionLogs(ctx context.Context, params api.GuestSessionLo return api.GuestSessionLogsResult{Session: session, Stream: streamName, Path: path, Content: content}, nil } +func (d *Daemon) SendToGuestSession(ctx context.Context, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { + vm, err := d.FindVM(ctx, params.VMIDOrName) + if err != nil { + return api.GuestSessionSendResult{}, err + } + session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName) + if err != nil { + return api.GuestSessionSendResult{}, err + } + if session.StdinMode != model.GuestSessionStdinPipe { + return api.GuestSessionSendResult{}, errors.New("session does not have a stdin pipe") + } + if session.Status != model.GuestSessionStatusRunning { + return api.GuestSessionSendResult{}, fmt.Errorf("session is not running (status=%s)", session.Status) + } + if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { + return api.GuestSessionSendResult{}, fmt.Errorf("vm %q is not running", vm.Name) + } + if len(params.Payload) == 0 { + return api.GuestSessionSendResult{Session: session}, nil + } + client, err := d.dialGuest(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22")) + if err != nil { + return api.GuestSessionSendResult{}, fmt.Errorf("dial guest: %w", err) + } + defer client.Close() + tmpPath := fmt.Sprintf("/tmp/banger-send-%s.bin", session.ID[:8]) + var uploadLog bytes.Buffer + if err := client.UploadFile(ctx, tmpPath, 0o600, params.Payload, &uploadLog); err != nil { + return api.GuestSessionSendResult{}, fmt.Errorf("upload payload: %w", err) + } + sendScript := fmt.Sprintf( + "set -euo pipefail\ncat %s >> %s\nrm -f %s\n", + guestShellQuote(tmpPath), + guestShellQuote(guestSessionStdinPipePath(session.ID)), + guestShellQuote(tmpPath), + ) + var sendLog bytes.Buffer + if err := client.RunScript(ctx, sendScript, &sendLog); err != nil { + return api.GuestSessionSendResult{}, fmt.Errorf("send to session: %w: %s", err, strings.TrimSpace(sendLog.String())) + } + return api.GuestSessionSendResult{Session: session, BytesWritten: len(params.Payload)}, nil +} + func (d *Daemon) BeginGuestSessionAttach(ctx context.Context, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) { vm, err := d.FindVM(ctx, params.VMIDOrName) if err != nil { diff --git a/internal/daemon/guest_sessions_test.go b/internal/daemon/guest_sessions_test.go index 49e5127..fc75367 100644 --- a/internal/daemon/guest_sessions_test.go +++ b/internal/daemon/guest_sessions_test.go @@ -14,6 +14,7 @@ import ( "banger/internal/api" "banger/internal/model" + "banger/internal/store" ) type fakeGuestSSHClient struct { @@ -57,6 +58,10 @@ func (f *fakeGuestSSHClient) RunScript(_ context.Context, script string, _ io.Wr } } +func (f *fakeGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) { + return nil, nil +} + func (f *fakeGuestSSHClient) UploadFile(_ context.Context, _ string, _ os.FileMode, _ []byte, _ io.Writer) error { return nil } @@ -77,6 +82,276 @@ func (f *fakeGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []s return fmt.Errorf("unexpected StreamTarEntries command: %s", command) } +func TestSendToGuestSession_HappyPath(t *testing.T) { + t.Parallel() + ctx := context.Background() + db := openDaemonStore(t) + + apiSock := filepath.Join(t.TempDir(), "fc.sock") + firecracker := startFakeFirecracker(t, apiSock) + + vm := testVM("sendbox", "image-send", "172.16.0.88") + vm.State = model.VMStateRunning + vm.Runtime.State = model.VMStateRunning + vm.Runtime.PID = firecracker.Process.Pid + vm.Runtime.APISockPath = apiSock + upsertDaemonVM(t, ctx, db, vm) + + session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning) + if err := db.UpsertGuestSession(ctx, session); err != nil { + t.Fatalf("UpsertGuestSession: %v", err) + } + + fake := &recordingGuestSSHClient{} + d := newSendTestDaemon(t, db, fake) + + payload := []byte(`{"type":"abort"}` + "\n") + result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ + VMIDOrName: vm.Name, + SessionIDOrName: session.Name, + Payload: payload, + }) + if err != nil { + t.Fatalf("SendToGuestSession: %v", err) + } + if result.BytesWritten != len(payload) { + t.Fatalf("BytesWritten = %d, want %d", result.BytesWritten, len(payload)) + } + if result.Session.ID != session.ID { + t.Fatalf("Session.ID = %q, want %q", result.Session.ID, session.ID) + } + if len(fake.uploadedFiles) != 1 { + t.Fatalf("UploadFile call count = %d, want 1", len(fake.uploadedFiles)) + } + for path, data := range fake.uploadedFiles { + if !strings.HasPrefix(path, "/tmp/banger-send-") { + t.Fatalf("upload path = %q, want /tmp/banger-send-... prefix", path) + } + if string(data) != string(payload) { + t.Fatalf("upload data = %q, want %q", data, payload) + } + } + if len(fake.ranScripts) != 1 { + t.Fatalf("RunScript call count = %d, want 1", len(fake.ranScripts)) + } + script := fake.ranScripts[0] + pipePath := guestSessionStdinPipePath(session.ID) + if !strings.Contains(script, "cat ") { + t.Fatalf("send script missing cat command: %q", script) + } + if !strings.Contains(script, pipePath) { + t.Fatalf("send script missing pipe path %q: %q", pipePath, script) + } + if !strings.Contains(script, "rm -f ") { + t.Fatalf("send script missing rm cleanup: %q", script) + } +} + +func TestSendToGuestSession_EmptyPayload(t *testing.T) { + t.Parallel() + ctx := context.Background() + db := openDaemonStore(t) + + apiSock := filepath.Join(t.TempDir(), "fc.sock") + firecracker := startFakeFirecracker(t, apiSock) + + vm := testVM("sendbox-empty", "image-send", "172.16.0.89") + vm.State = model.VMStateRunning + vm.Runtime.State = model.VMStateRunning + vm.Runtime.PID = firecracker.Process.Pid + vm.Runtime.APISockPath = apiSock + upsertDaemonVM(t, ctx, db, vm) + + session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning) + if err := db.UpsertGuestSession(ctx, session); err != nil { + t.Fatalf("UpsertGuestSession: %v", err) + } + + fake := &recordingGuestSSHClient{} + d := newSendTestDaemon(t, db, fake) + + result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ + VMIDOrName: vm.Name, + SessionIDOrName: session.Name, + Payload: nil, + }) + if err != nil { + t.Fatalf("SendToGuestSession(empty): %v", err) + } + if result.BytesWritten != 0 { + t.Fatalf("BytesWritten = %d, want 0", result.BytesWritten) + } + if fake.dialCount != 0 { + t.Fatalf("SSH dial count = %d, want 0 for empty payload", fake.dialCount) + } +} + +func TestSendToGuestSession_NotPipeMode(t *testing.T) { + t.Parallel() + ctx := context.Background() + db := openDaemonStore(t) + + vm := testVM("sendbox-closed", "image-send", "172.16.0.90") + vm.State = model.VMStateRunning + upsertDaemonVM(t, ctx, db, vm) + + session := testGuestSession(vm.ID, model.GuestSessionStdinClosed, model.GuestSessionStatusRunning) + if err := db.UpsertGuestSession(ctx, session); err != nil { + t.Fatalf("UpsertGuestSession: %v", err) + } + + d := &Daemon{store: db} + _, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ + VMIDOrName: vm.Name, + SessionIDOrName: session.Name, + Payload: []byte("hello\n"), + }) + if err == nil || !strings.Contains(err.Error(), "stdin pipe") { + t.Fatalf("error = %v, want 'stdin pipe' error", err) + } +} + +func TestSendToGuestSession_SessionNotRunning(t *testing.T) { + t.Parallel() + ctx := context.Background() + db := openDaemonStore(t) + + vm := testVM("sendbox-failed", "image-send", "172.16.0.91") + vm.State = model.VMStateRunning + upsertDaemonVM(t, ctx, db, vm) + + session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusFailed) + if err := db.UpsertGuestSession(ctx, session); err != nil { + t.Fatalf("UpsertGuestSession: %v", err) + } + + d := &Daemon{store: db} + _, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ + VMIDOrName: vm.Name, + SessionIDOrName: session.Name, + Payload: []byte("hello\n"), + }) + if err == nil || !strings.Contains(err.Error(), "not running") { + t.Fatalf("error = %v, want 'not running' error", err) + } +} + +func TestSendToGuestSession_VMNotRunning(t *testing.T) { + t.Parallel() + ctx := context.Background() + db := openDaemonStore(t) + + vm := testVM("sendbox-stopped", "image-send", "172.16.0.92") + vm.State = model.VMStateStopped + upsertDaemonVM(t, ctx, db, vm) + + session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning) + if err := db.UpsertGuestSession(ctx, session); err != nil { + t.Fatalf("UpsertGuestSession: %v", err) + } + + d := &Daemon{store: db} + _, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ + VMIDOrName: vm.Name, + SessionIDOrName: session.Name, + Payload: []byte("hello\n"), + }) + if err == nil || !strings.Contains(err.Error(), "not running") { + t.Fatalf("error = %v, want 'not running' error", err) + } +} + +// recordingGuestSSHClient captures UploadFile and RunScript calls for send tests. +type recordingGuestSSHClient struct { + dialCount int + uploadedFiles map[string][]byte + ranScripts []string +} + +func (r *recordingGuestSSHClient) Close() error { return nil } + +func (r *recordingGuestSSHClient) UploadFile(_ context.Context, path string, _ os.FileMode, data []byte, _ io.Writer) error { + if r.uploadedFiles == nil { + r.uploadedFiles = make(map[string][]byte) + } + copy := make([]byte, len(data)) + _ = copy[:len(data):len(data)] + for i, b := range data { + copy[i] = b + } + r.uploadedFiles[path] = copy + return nil +} + +func (r *recordingGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error { + r.ranScripts = append(r.ranScripts, script) + return nil +} + +func (r *recordingGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) { + return nil, nil +} + +func (r *recordingGuestSSHClient) StreamTar(_ context.Context, _ string, _ string, _ io.Writer) error { + return nil +} + +func (r *recordingGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, _ string, _ io.Writer) error { + return nil +} + +func newSendTestDaemon(t *testing.T, db *store.Store, fake *recordingGuestSSHClient) *Daemon { + t.Helper() + d := &Daemon{ + store: db, + config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")}, + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + } + d.guestDial = func(_ context.Context, _ string, _ string) (guestSSHClient, error) { + fake.dialCount++ + return fake, nil + } + return d +} + +func testGuestSession(vmID string, stdinMode model.GuestSessionStdinMode, status model.GuestSessionStatus) model.GuestSession { + now := model.Now() + id := vmID + "-sess-id" + return model.GuestSession{ + ID: id, + VMID: vmID, + Name: vmID + "-sess", + Backend: guestSessionBackendSSH, + Command: "pi", + Args: []string{"--mode", "rpc"}, + CWD: "/root/repo", + StdinMode: stdinMode, + Status: status, + GuestStateDir: guestSessionStateDir(id), + StdoutLogPath: guestSessionStdoutLogPath(id), + StderrLogPath: guestSessionStderrLogPath(id), + Attachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning, + Reattachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning, + CreatedAt: now, + UpdatedAt: now, + } +} + +func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd { + t.Helper() + cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock)) + if err := cmd.Start(); err != nil { + t.Fatalf("start fake firecracker: %v", err) + } + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + _, _ = cmd.Process.Wait() + } + }) + return cmd +} + func TestGuestSessionPreflightScriptsUseRealNewlines(t *testing.T) { t.Parallel() diff --git a/internal/daemon/workspace.go b/internal/daemon/workspace.go index 33fb1e9..85919dd 100644 --- a/internal/daemon/workspace.go +++ b/internal/daemon/workspace.go @@ -35,6 +35,56 @@ type workspaceRepoSpec struct { Submodules []string } +func (d *Daemon) ExportVMWorkspace(ctx context.Context, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { + guestPath := strings.TrimSpace(params.GuestPath) + if guestPath == "" { + guestPath = "/root/repo" + } + vm, err := d.FindVM(ctx, params.IDOrName) + if err != nil { + return api.WorkspaceExportResult{}, err + } + if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { + return api.WorkspaceExportResult{}, fmt.Errorf("vm %q is not running", vm.Name) + } + client, err := d.dialGuest(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22")) + if err != nil { + return api.WorkspaceExportResult{}, fmt.Errorf("dial guest: %w", err) + } + defer client.Close() + + // Stage all changes then emit a binary-safe unified diff against HEAD. + // --binary ensures binary files are handled correctly by git apply. + patchScript := fmt.Sprintf( + "set -euo pipefail\ncd %s\ngit add -A\ngit diff --cached HEAD --binary\n", + guestShellQuote(guestPath), + ) + patch, err := client.RunScriptOutput(ctx, patchScript) + if err != nil { + return api.WorkspaceExportResult{}, fmt.Errorf("export workspace diff: %w", err) + } + + // Enumerate changed paths (index already staged; this is a cheap read). + namesScript := fmt.Sprintf( + "set -euo pipefail\ncd %s\ngit diff --cached HEAD --name-only\n", + guestShellQuote(guestPath), + ) + namesOut, _ := client.RunScriptOutput(ctx, namesScript) + var changed []string + for _, line := range strings.Split(strings.TrimSpace(string(namesOut)), "\n") { + if line = strings.TrimSpace(line); line != "" { + changed = append(changed, line) + } + } + + return api.WorkspaceExportResult{ + GuestPath: guestPath, + Patch: patch, + ChangedFiles: changed, + HasChanges: len(patch) > 0, + }, nil +} + func (d *Daemon) PrepareVMWorkspace(ctx context.Context, params api.VMWorkspacePrepareParams) (model.WorkspacePrepareResult, error) { mode, err := parseWorkspacePrepareMode(params.Mode) if err != nil { diff --git a/internal/daemon/workspace_test.go b/internal/daemon/workspace_test.go new file mode 100644 index 0000000..df0ea90 --- /dev/null +++ b/internal/daemon/workspace_test.go @@ -0,0 +1,254 @@ +package daemon + +import ( + "context" + "io" + "log/slog" + "os" + "path/filepath" + "strings" + "testing" + + "banger/internal/api" + "banger/internal/model" +) + +// exportGuestClient is a scriptable fake for RunScriptOutput used in export tests. +// Each call to RunScriptOutput returns the next response from the queue. +type exportGuestClient struct { + responses []exportGuestResponse + callIndex int +} + +type exportGuestResponse struct { + output []byte + err error +} + +func (e *exportGuestClient) Close() error { return nil } + +func (e *exportGuestClient) RunScript(_ context.Context, _ string, _ io.Writer) error { + return nil +} + +func (e *exportGuestClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) { + if e.callIndex >= len(e.responses) { + return nil, nil + } + r := e.responses[e.callIndex] + e.callIndex++ + return r.output, r.err +} + +func (e *exportGuestClient) UploadFile(_ context.Context, _ string, _ os.FileMode, _ []byte, _ io.Writer) error { + return nil +} + +func (e *exportGuestClient) StreamTar(_ context.Context, _ string, _ string, _ io.Writer) error { + return nil +} + +func (e *exportGuestClient) StreamTarEntries(_ context.Context, _ string, _ []string, _ string, _ io.Writer) error { + return nil +} + +func newExportTestDaemonStore(t *testing.T, fake *exportGuestClient) *Daemon { + t.Helper() + db := openDaemonStore(t) + d := &Daemon{ + store: db, + config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")}, + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + } + d.guestDial = func(_ context.Context, _ string, _ string) (guestSSHClient, error) { + return fake, nil + } + return d +} + +func TestExportVMWorkspace_HappyPath(t *testing.T) { + t.Parallel() + ctx := context.Background() + + apiSock := filepath.Join(t.TempDir(), "fc.sock") + firecracker := startFakeFirecracker(t, apiSock) + + vm := testVM("exportbox", "image-export", "172.16.0.100") + vm.State = model.VMStateRunning + vm.Runtime.State = model.VMStateRunning + vm.Runtime.PID = firecracker.Process.Pid + vm.Runtime.APISockPath = apiSock + + patch := []byte("diff --git a/file.go b/file.go\nindex 0000000..1111111 100644\n") + names := []byte("file.go\n") + + fake := &exportGuestClient{ + responses: []exportGuestResponse{ + {output: patch}, + {output: names}, + }, + } + d := newExportTestDaemonStore(t, fake) + upsertDaemonVM(t, ctx, d.store, vm) + + result, err := d.ExportVMWorkspace(ctx, api.WorkspaceExportParams{ + IDOrName: vm.Name, + GuestPath: "/root/repo", + }) + if err != nil { + t.Fatalf("ExportVMWorkspace: %v", err) + } + if !result.HasChanges { + t.Fatal("HasChanges = false, want true") + } + if string(result.Patch) != string(patch) { + t.Fatalf("Patch = %q, want %q", result.Patch, patch) + } + if result.GuestPath != "/root/repo" { + t.Fatalf("GuestPath = %q, want /root/repo", result.GuestPath) + } + if len(result.ChangedFiles) != 1 || result.ChangedFiles[0] != "file.go" { + t.Fatalf("ChangedFiles = %v, want [file.go]", result.ChangedFiles) + } + if fake.callIndex != 2 { + t.Fatalf("RunScriptOutput call count = %d, want 2", fake.callIndex) + } +} + +func TestExportVMWorkspace_NoChanges(t *testing.T) { + t.Parallel() + ctx := context.Background() + + apiSock := filepath.Join(t.TempDir(), "fc.sock") + firecracker := startFakeFirecracker(t, apiSock) + + vm := testVM("exportbox-empty", "image-export", "172.16.0.101") + vm.State = model.VMStateRunning + vm.Runtime.State = model.VMStateRunning + vm.Runtime.PID = firecracker.Process.Pid + vm.Runtime.APISockPath = apiSock + + // Both scripts return empty output (no changes). + fake := &exportGuestClient{ + responses: []exportGuestResponse{ + {output: nil}, + {output: nil}, + }, + } + d := newExportTestDaemonStore(t, fake) + upsertDaemonVM(t, ctx, d.store, vm) + + result, err := d.ExportVMWorkspace(ctx, api.WorkspaceExportParams{ + IDOrName: vm.Name, + }) + if err != nil { + t.Fatalf("ExportVMWorkspace: %v", err) + } + if result.HasChanges { + t.Fatal("HasChanges = true, want false") + } + if len(result.Patch) != 0 { + t.Fatalf("Patch = %q, want empty", result.Patch) + } + if len(result.ChangedFiles) != 0 { + t.Fatalf("ChangedFiles = %v, want empty", result.ChangedFiles) + } +} + +func TestExportVMWorkspace_DefaultGuestPath(t *testing.T) { + t.Parallel() + ctx := context.Background() + + apiSock := filepath.Join(t.TempDir(), "fc.sock") + firecracker := startFakeFirecracker(t, apiSock) + + vm := testVM("exportbox-default", "image-export", "172.16.0.102") + vm.State = model.VMStateRunning + vm.Runtime.State = model.VMStateRunning + vm.Runtime.PID = firecracker.Process.Pid + vm.Runtime.APISockPath = apiSock + + fake := &exportGuestClient{ + responses: []exportGuestResponse{ + {output: nil}, + {output: nil}, + }, + } + d := newExportTestDaemonStore(t, fake) + upsertDaemonVM(t, ctx, d.store, vm) + + // GuestPath omitted — should default to /root/repo. + result, err := d.ExportVMWorkspace(ctx, api.WorkspaceExportParams{ + IDOrName: vm.Name, + }) + if err != nil { + t.Fatalf("ExportVMWorkspace: %v", err) + } + if result.GuestPath != "/root/repo" { + t.Fatalf("GuestPath = %q, want /root/repo", result.GuestPath) + } +} + +func TestExportVMWorkspace_VMNotRunning(t *testing.T) { + t.Parallel() + ctx := context.Background() + + vm := testVM("exportbox-stopped", "image-export", "172.16.0.103") + vm.State = model.VMStateStopped + + fake := &exportGuestClient{} + d := newExportTestDaemonStore(t, fake) + upsertDaemonVM(t, ctx, d.store, vm) + + _, err := d.ExportVMWorkspace(ctx, api.WorkspaceExportParams{ + IDOrName: vm.Name, + }) + if err == nil || !strings.Contains(err.Error(), "not running") { + t.Fatalf("error = %v, want 'not running' error", err) + } + if fake.callIndex != 0 { + t.Fatal("RunScriptOutput should not be called when VM is not running") + } +} + +func TestExportVMWorkspace_MultipleChangedFiles(t *testing.T) { + t.Parallel() + ctx := context.Background() + + apiSock := filepath.Join(t.TempDir(), "fc.sock") + firecracker := startFakeFirecracker(t, apiSock) + + vm := testVM("exportbox-multi", "image-export", "172.16.0.104") + vm.State = model.VMStateRunning + vm.Runtime.State = model.VMStateRunning + vm.Runtime.PID = firecracker.Process.Pid + vm.Runtime.APISockPath = apiSock + + patch := []byte("diff --git a/a.go b/a.go\n--- a/a.go\n+++ b/a.go\n") + names := []byte("a.go\nb.go\nnew/file.go\n") + + fake := &exportGuestClient{ + responses: []exportGuestResponse{ + {output: patch}, + {output: names}, + }, + } + d := newExportTestDaemonStore(t, fake) + upsertDaemonVM(t, ctx, d.store, vm) + + result, err := d.ExportVMWorkspace(ctx, api.WorkspaceExportParams{ + IDOrName: vm.Name, + }) + if err != nil { + t.Fatalf("ExportVMWorkspace: %v", err) + } + if len(result.ChangedFiles) != 3 { + t.Fatalf("ChangedFiles = %v, want 3 entries", result.ChangedFiles) + } + want := []string{"a.go", "b.go", "new/file.go"} + for i, f := range want { + if result.ChangedFiles[i] != f { + t.Fatalf("ChangedFiles[%d] = %q, want %q", i, result.ChangedFiles[i], f) + } + } +} diff --git a/internal/guest/ssh.go b/internal/guest/ssh.go index 193e058..6723710 100644 --- a/internal/guest/ssh.go +++ b/internal/guest/ssh.go @@ -89,6 +89,35 @@ func (c *Client) RunScript(ctx context.Context, script string, logWriter io.Writ return c.runSession(ctx, "bash -se", strings.NewReader(script), logWriter) } +// RunScriptOutput runs script on the guest and returns its stdout. +// Stderr is discarded. Use for capturing structured output (patches, JSON, +// file content) where mixing stderr into stdout would corrupt the result. +func (c *Client) RunScriptOutput(ctx context.Context, script string) ([]byte, error) { + if c == nil || c.client == nil { + return nil, fmt.Errorf("ssh client is not connected") + } + session, err := c.client.NewSession() + if err != nil { + return nil, err + } + defer session.Close() + session.Stdin = strings.NewReader(script) + var stdout bytes.Buffer + session.Stdout = &stdout + // session.Stderr left nil: stderr is intentionally discarded. + done := make(chan error, 1) + go func() { + select { + case <-ctx.Done(): + _ = c.client.Close() + case <-done: + } + }() + err = session.Run("bash -se") + done <- nil + return stdout.Bytes(), err +} + func (c *Client) UploadFile(ctx context.Context, remotePath string, mode os.FileMode, data []byte, logWriter io.Writer) error { command := fmt.Sprintf("install -D -m %04o /dev/stdin %s", mode.Perm(), shellQuote(remotePath)) return c.runSession(ctx, command, bytes.NewReader(data), logWriter)