Add guest.session.send and vm.workspace.export RPCs
guest.session.send — write to a pipe-mode session's stdin without holding the exclusive attach. The daemon dials a fresh SSH connection, uploads the payload to a temp file, and cats it into the session's named FIFO. Linux atomicity for writes ≤ PIPE_BUF covers all pi RPC JSONL lines. Attach exclusivity is unchanged. vm.workspace.export — pull changes from guest back to host. Runs `git add -A && git diff --cached HEAD --binary` inside the guest via a new RunScriptOutput helper on guest.Client (stdout-only capture, distinct from RunScript which merges stderr). Returns a binary-safe patch and a list of changed files. CLI writes the patch to stdout for `| git apply` or to a file via --output. RunScriptOutput is implemented as a direct SSH session (same pattern as runSession) rather than going through StartCommand/StreamSession to avoid closing the underlying Client, which is required since ExportVMWorkspace calls it twice on the same connection. New files: internal/daemon/workspace_test.go
This commit is contained in:
parent
797a9de1ce
commit
94c353f317
9 changed files with 1074 additions and 1 deletions
|
|
@ -203,6 +203,29 @@ type GuestSessionAttachBeginResult struct {
|
|||
StreamFormat string `json:"stream_format"`
|
||||
}
|
||||
|
||||
type GuestSessionSendParams struct {
|
||||
VMIDOrName string `json:"vm_id_or_name"`
|
||||
SessionIDOrName string `json:"session_id_or_name"`
|
||||
Payload []byte `json:"payload"`
|
||||
}
|
||||
|
||||
type GuestSessionSendResult struct {
|
||||
Session model.GuestSession `json:"session"`
|
||||
BytesWritten int `json:"bytes_written"`
|
||||
}
|
||||
|
||||
type WorkspaceExportParams struct {
|
||||
IDOrName string `json:"id_or_name"`
|
||||
GuestPath string `json:"guest_path,omitempty"`
|
||||
}
|
||||
|
||||
type WorkspaceExportResult struct {
|
||||
GuestPath string `json:"guest_path"`
|
||||
Patch []byte `json:"patch"`
|
||||
ChangedFiles []string `json:"changed_files"`
|
||||
HasChanges bool `json:"has_changes"`
|
||||
}
|
||||
|
||||
type VMWorkspacePrepareParams struct {
|
||||
IDOrName string `json:"id_or_name"`
|
||||
SourcePath string `json:"source_path"`
|
||||
|
|
|
|||
|
|
@ -89,6 +89,9 @@ var (
|
|||
vmWorkspacePrepareFunc = func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) {
|
||||
return rpc.Call[api.VMWorkspacePrepareResult](ctx, socketPath, "vm.workspace.prepare", params)
|
||||
}
|
||||
vmWorkspaceExportFunc = func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
|
||||
return rpc.Call[api.WorkspaceExportResult](ctx, socketPath, "vm.workspace.export", params)
|
||||
}
|
||||
guestSessionStartFunc = func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error) {
|
||||
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.start", params)
|
||||
}
|
||||
|
|
@ -110,6 +113,9 @@ var (
|
|||
guestSessionAttachBeginFunc = func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) {
|
||||
return rpc.Call[api.GuestSessionAttachBeginResult](ctx, socketPath, "guest.session.attach.begin", params)
|
||||
}
|
||||
guestSessionSendFunc = func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
|
||||
return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params)
|
||||
}
|
||||
guestWaitForSSHFunc = func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
|
||||
return guest.WaitForSSH(ctx, address, privateKeyPath, interval)
|
||||
}
|
||||
|
|
@ -869,7 +875,10 @@ func newVMWorkspaceCommand() *cobra.Command {
|
|||
Short: "Manage repository workspaces inside a running VM",
|
||||
RunE: helpNoArgs,
|
||||
}
|
||||
cmd.AddCommand(newVMWorkspacePrepareCommand())
|
||||
cmd.AddCommand(
|
||||
newVMWorkspacePrepareCommand(),
|
||||
newVMWorkspaceExportCommand(),
|
||||
)
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
|
@ -929,6 +938,52 @@ func newVMWorkspacePrepareCommand() *cobra.Command {
|
|||
return cmd
|
||||
}
|
||||
|
||||
func newVMWorkspaceExportCommand() *cobra.Command {
|
||||
var guestPath string
|
||||
var outputPath string
|
||||
cmd := &cobra.Command{
|
||||
Use: "export <id-or-name>",
|
||||
Short: "Pull changes from a guest workspace back to the host as a patch",
|
||||
Long: "Stage all changes inside the guest workspace (git add -A) and emit a binary-safe unified diff against HEAD. With no --output flag the patch is written to stdout so it can be piped directly to git apply.",
|
||||
Args: exactArgsUsage(1, "usage: banger vm workspace export <id-or-name>"),
|
||||
Example: strings.TrimSpace(`
|
||||
banger vm workspace export devbox | git apply
|
||||
banger vm workspace export devbox --output worker.diff
|
||||
banger vm workspace export devbox --guest-path /root/project --output changes.diff
|
||||
`),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
layout, _, err := ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := vmWorkspaceExportFunc(cmd.Context(), layout.SocketPath, api.WorkspaceExportParams{
|
||||
IDOrName: args[0],
|
||||
GuestPath: guestPath,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !result.HasChanges {
|
||||
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "no changes")
|
||||
return nil
|
||||
}
|
||||
if outputPath != "" {
|
||||
if err := os.WriteFile(outputPath, result.Patch, 0o644); err != nil {
|
||||
return fmt.Errorf("write patch: %w", err)
|
||||
}
|
||||
_, err = fmt.Fprintf(cmd.ErrOrStderr(), "patch written to %s (%d bytes, %d files)\n",
|
||||
outputPath, len(result.Patch), len(result.ChangedFiles))
|
||||
return err
|
||||
}
|
||||
_, err = cmd.OutOrStdout().Write(result.Patch)
|
||||
return err
|
||||
},
|
||||
}
|
||||
cmd.Flags().StringVar(&guestPath, "guest-path", "/root/repo", "guest workspace path")
|
||||
cmd.Flags().StringVar(&outputPath, "output", "", "write patch to this file instead of stdout")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newVMSessionCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "session",
|
||||
|
|
@ -944,6 +999,7 @@ func newVMSessionCommand() *cobra.Command {
|
|||
newVMSessionStopCommand(),
|
||||
newVMSessionKillCommand(),
|
||||
newVMSessionAttachCommand(),
|
||||
newVMSessionSendCommand(),
|
||||
)
|
||||
return cmd
|
||||
}
|
||||
|
|
@ -1134,6 +1190,51 @@ func newVMSessionAttachCommand() *cobra.Command {
|
|||
}
|
||||
}
|
||||
|
||||
func newVMSessionSendCommand() *cobra.Command {
|
||||
var message string
|
||||
cmd := &cobra.Command{
|
||||
Use: "send <id-or-name> <session>",
|
||||
Short: "Write bytes to a running guest session's stdin pipe",
|
||||
Long: "Write a payload to the stdin pipe of a running pipe-mode guest session without holding the exclusive attach. Use --message for an inline JSONL string, or pipe bytes via stdin when --message is omitted. A trailing newline is appended to --message values that lack one.",
|
||||
Args: exactArgsUsage(2, "usage: banger vm session send <id-or-name> <session> [--message '<json>']"),
|
||||
Example: strings.TrimSpace(`
|
||||
banger vm session send devbox planner --message '{"type":"abort"}'
|
||||
banger vm session send devbox planner --message '{"type":"steer","message":"Focus on src/"}'
|
||||
echo '{"type":"prompt","prompt":"Summarize."}' | banger vm session send devbox planner
|
||||
`),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
layout, _, err := ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var payload []byte
|
||||
if message != "" {
|
||||
payload = []byte(message)
|
||||
if len(payload) > 0 && payload[len(payload)-1] != '\n' {
|
||||
payload = append(payload, '\n')
|
||||
}
|
||||
} else {
|
||||
payload, err = io.ReadAll(cmd.InOrStdin())
|
||||
if err != nil {
|
||||
return fmt.Errorf("read stdin: %w", err)
|
||||
}
|
||||
}
|
||||
result, err := guestSessionSendFunc(cmd.Context(), layout.SocketPath, api.GuestSessionSendParams{
|
||||
VMIDOrName: args[0],
|
||||
SessionIDOrName: args[1],
|
||||
Payload: payload,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintf(cmd.OutOrStdout(), "sent %d bytes to session %s\n", result.BytesWritten, result.Session.Name)
|
||||
return err
|
||||
},
|
||||
}
|
||||
cmd.Flags().StringVar(&message, "message", "", "JSONL message to send; a trailing newline is appended if absent")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func parseKeyValuePairs(values []string) (map[string]string, error) {
|
||||
if len(values) == 0 {
|
||||
return nil, nil
|
||||
|
|
|
|||
|
|
@ -1936,3 +1936,285 @@ func (c *testVMRunGuestClient) StreamTarEntries(ctx context.Context, sourceDir s
|
|||
c.streamCommand = remoteCommand
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestVMSessionSendCommandExists(t *testing.T) {
|
||||
root := NewBangerCommand()
|
||||
vm, _, err := root.Find([]string{"vm"})
|
||||
if err != nil {
|
||||
t.Fatalf("find vm: %v", err)
|
||||
}
|
||||
session, _, err := vm.Find([]string{"session"})
|
||||
if err != nil {
|
||||
t.Fatalf("find session: %v", err)
|
||||
}
|
||||
if _, _, err := session.Find([]string{"send"}); err != nil {
|
||||
t.Fatalf("find session send: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMSessionSendRejectsWrongArgCount(t *testing.T) {
|
||||
cmd := NewBangerCommand()
|
||||
cmd.SetArgs([]string{"vm", "session", "send", "only-one-arg"})
|
||||
err := cmd.Execute()
|
||||
if err == nil || !strings.Contains(err.Error(), "usage: banger vm session send") {
|
||||
t.Fatalf("Execute() error = %v, want send usage error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func stubEnsureDaemonForSend(t *testing.T) {
|
||||
t.Helper()
|
||||
t.Setenv("XDG_CONFIG_HOME", filepath.Join(t.TempDir(), "config"))
|
||||
t.Setenv("XDG_STATE_HOME", filepath.Join(t.TempDir(), "state"))
|
||||
t.Setenv("XDG_RUNTIME_DIR", filepath.Join(t.TempDir(), "run"))
|
||||
origPing := daemonPingFunc
|
||||
t.Cleanup(func() { daemonPingFunc = origPing })
|
||||
daemonPingFunc = func(context.Context, string) (api.PingResult, error) {
|
||||
return api.PingResult{Status: "ok", PID: os.Getpid()}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMSessionSendWithMessageFlag(t *testing.T) {
|
||||
stubEnsureDaemonForSend(t)
|
||||
|
||||
original := guestSessionSendFunc
|
||||
t.Cleanup(func() { guestSessionSendFunc = original })
|
||||
|
||||
var capturedParams api.GuestSessionSendParams
|
||||
guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
|
||||
capturedParams = params
|
||||
return api.GuestSessionSendResult{
|
||||
Session: model.GuestSession{ID: "sess-id", Name: "planner"},
|
||||
BytesWritten: len(params.Payload),
|
||||
}, nil
|
||||
}
|
||||
|
||||
cmd := NewBangerCommand()
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner", "--message", `{"type":"abort"}`})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
wantPayload := []byte(`{"type":"abort"}` + "\n")
|
||||
if string(capturedParams.Payload) != string(wantPayload) {
|
||||
t.Fatalf("payload = %q, want %q", capturedParams.Payload, wantPayload)
|
||||
}
|
||||
if capturedParams.VMIDOrName != "devbox" {
|
||||
t.Fatalf("VMIDOrName = %q, want %q", capturedParams.VMIDOrName, "devbox")
|
||||
}
|
||||
if capturedParams.SessionIDOrName != "planner" {
|
||||
t.Fatalf("SessionIDOrName = %q, want %q", capturedParams.SessionIDOrName, "planner")
|
||||
}
|
||||
if !strings.Contains(out.String(), "17") {
|
||||
t.Fatalf("output = %q, want bytes_written in output", out.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) {
|
||||
stubEnsureDaemonForSend(t)
|
||||
|
||||
original := guestSessionSendFunc
|
||||
t.Cleanup(func() { guestSessionSendFunc = original })
|
||||
|
||||
var capturedPayload []byte
|
||||
guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
|
||||
capturedPayload = params.Payload
|
||||
return api.GuestSessionSendResult{
|
||||
Session: model.GuestSession{Name: "s"},
|
||||
BytesWritten: len(params.Payload),
|
||||
}, nil
|
||||
}
|
||||
|
||||
cmd := NewBangerCommand()
|
||||
cmd.SetOut(io.Discard)
|
||||
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "s", "--message", "{\"type\":\"abort\"}\n"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
// Must not double-append newline.
|
||||
if capturedPayload[len(capturedPayload)-1] != '\n' {
|
||||
t.Fatalf("payload missing trailing newline: %q", capturedPayload)
|
||||
}
|
||||
if len(capturedPayload) > 0 && capturedPayload[len(capturedPayload)-2] == '\n' {
|
||||
t.Fatalf("payload has double trailing newline: %q", capturedPayload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMSessionSendFromStdin(t *testing.T) {
|
||||
stubEnsureDaemonForSend(t)
|
||||
|
||||
original := guestSessionSendFunc
|
||||
t.Cleanup(func() { guestSessionSendFunc = original })
|
||||
|
||||
var capturedPayload []byte
|
||||
guestSessionSendFunc = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
|
||||
capturedPayload = params.Payload
|
||||
return api.GuestSessionSendResult{
|
||||
Session: model.GuestSession{Name: "planner"},
|
||||
BytesWritten: len(params.Payload),
|
||||
}, nil
|
||||
}
|
||||
|
||||
stdinPayload := `{"type":"steer","message":"Focus on src/"}` + "\n"
|
||||
cmd := NewBangerCommand()
|
||||
cmd.SetOut(io.Discard)
|
||||
cmd.SetIn(strings.NewReader(stdinPayload))
|
||||
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
if string(capturedPayload) != stdinPayload {
|
||||
t.Fatalf("payload = %q, want %q", capturedPayload, stdinPayload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMWorkspaceExportCommandExists(t *testing.T) {
|
||||
root := NewBangerCommand()
|
||||
vm, _, err := root.Find([]string{"vm"})
|
||||
if err != nil {
|
||||
t.Fatalf("find vm: %v", err)
|
||||
}
|
||||
workspace, _, err := vm.Find([]string{"workspace"})
|
||||
if err != nil {
|
||||
t.Fatalf("find workspace: %v", err)
|
||||
}
|
||||
if _, _, err := workspace.Find([]string{"export"}); err != nil {
|
||||
t.Fatalf("find workspace export: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMWorkspaceExportRejectsMissingArg(t *testing.T) {
|
||||
cmd := NewBangerCommand()
|
||||
cmd.SetArgs([]string{"vm", "workspace", "export"})
|
||||
err := cmd.Execute()
|
||||
if err == nil || !strings.Contains(err.Error(), "usage: banger vm workspace export") {
|
||||
t.Fatalf("Execute() error = %v, want usage error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMWorkspaceExportWritesToStdout(t *testing.T) {
|
||||
stubEnsureDaemonForSend(t)
|
||||
|
||||
origExport := vmWorkspaceExportFunc
|
||||
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
|
||||
|
||||
patch := []byte("diff --git a/main.go b/main.go\nindex 0000000..1111111 100644\n")
|
||||
vmWorkspaceExportFunc = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
|
||||
return api.WorkspaceExportResult{
|
||||
GuestPath: params.GuestPath,
|
||||
Patch: patch,
|
||||
ChangedFiles: []string{"main.go"},
|
||||
HasChanges: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
cmd := NewBangerCommand()
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
cmd.SetErr(io.Discard)
|
||||
cmd.SetArgs([]string{"vm", "workspace", "export", "devbox"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !bytes.Equal(out.Bytes(), patch) {
|
||||
t.Fatalf("stdout = %q, want %q", out.Bytes(), patch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMWorkspaceExportWritesToFile(t *testing.T) {
|
||||
stubEnsureDaemonForSend(t)
|
||||
|
||||
origExport := vmWorkspaceExportFunc
|
||||
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
|
||||
|
||||
patch := []byte("diff --git a/main.go b/main.go\n")
|
||||
vmWorkspaceExportFunc = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
|
||||
return api.WorkspaceExportResult{
|
||||
GuestPath: "/root/repo",
|
||||
Patch: patch,
|
||||
ChangedFiles: []string{"main.go"},
|
||||
HasChanges: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
outFile := filepath.Join(t.TempDir(), "worker.diff")
|
||||
cmd := NewBangerCommand()
|
||||
cmd.SetOut(io.Discard)
|
||||
var stderr bytes.Buffer
|
||||
cmd.SetErr(&stderr)
|
||||
cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--output", outFile})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
got, err := os.ReadFile(outFile)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile: %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, patch) {
|
||||
t.Fatalf("file content = %q, want %q", got, patch)
|
||||
}
|
||||
if !strings.Contains(stderr.String(), "worker.diff") {
|
||||
t.Fatalf("stderr = %q, want output path mentioned", stderr.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMWorkspaceExportNoChanges(t *testing.T) {
|
||||
stubEnsureDaemonForSend(t)
|
||||
|
||||
origExport := vmWorkspaceExportFunc
|
||||
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
|
||||
|
||||
vmWorkspaceExportFunc = func(_ context.Context, _ string, _ api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
|
||||
return api.WorkspaceExportResult{
|
||||
GuestPath: "/root/repo",
|
||||
HasChanges: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
cmd := NewBangerCommand()
|
||||
var out bytes.Buffer
|
||||
var stderr bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
cmd.SetErr(&stderr)
|
||||
cmd.SetArgs([]string{"vm", "workspace", "export", "devbox"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if out.Len() != 0 {
|
||||
t.Fatalf("stdout = %q, want empty when no changes", out.String())
|
||||
}
|
||||
if !strings.Contains(stderr.String(), "no changes") {
|
||||
t.Fatalf("stderr = %q, want 'no changes'", stderr.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMWorkspaceExportGuestPathFlag(t *testing.T) {
|
||||
stubEnsureDaemonForSend(t)
|
||||
|
||||
origExport := vmWorkspaceExportFunc
|
||||
t.Cleanup(func() { vmWorkspaceExportFunc = origExport })
|
||||
|
||||
var capturedParams api.WorkspaceExportParams
|
||||
vmWorkspaceExportFunc = func(_ context.Context, _ string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
|
||||
capturedParams = params
|
||||
return api.WorkspaceExportResult{HasChanges: false}, nil
|
||||
}
|
||||
|
||||
cmd := NewBangerCommand()
|
||||
cmd.SetOut(io.Discard)
|
||||
cmd.SetErr(io.Discard)
|
||||
cmd.SetArgs([]string{"vm", "workspace", "export", "devbox", "--guest-path", "/root/project"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if capturedParams.GuestPath != "/root/project" {
|
||||
t.Fatalf("GuestPath = %q, want /root/project", capturedParams.GuestPath)
|
||||
}
|
||||
if capturedParams.IDOrName != "devbox" {
|
||||
t.Fatalf("IDOrName = %q, want devbox", capturedParams.IDOrName)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -407,6 +407,13 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
|
|||
}
|
||||
workspace, err := d.PrepareVMWorkspace(ctx, params)
|
||||
return marshalResultOrError(api.VMWorkspacePrepareResult{Workspace: workspace}, err)
|
||||
case "vm.workspace.export":
|
||||
params, err := rpc.DecodeParams[api.WorkspaceExportParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
result, err := d.ExportVMWorkspace(ctx, params)
|
||||
return marshalResultOrError(result, err)
|
||||
case "guest.session.start":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionStartParams](req)
|
||||
if err != nil {
|
||||
|
|
@ -456,6 +463,13 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
|
|||
}
|
||||
result, err := d.BeginGuestSessionAttach(ctx, params)
|
||||
return marshalResultOrError(result, err)
|
||||
case "guest.session.send":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionSendParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
result, err := d.SendToGuestSession(ctx, params)
|
||||
return marshalResultOrError(result, err)
|
||||
case "image.list":
|
||||
images, err := d.store.ListImages(ctx)
|
||||
return marshalResultOrError(api.ImageListResult{Images: images}, err)
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ var guestSessionHostCommandOutputFunc = func(ctx context.Context, name string, a
|
|||
type guestSSHClient interface {
|
||||
Close() error
|
||||
RunScript(context.Context, string, io.Writer) error
|
||||
RunScriptOutput(context.Context, string) ([]byte, error)
|
||||
UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error
|
||||
StreamTar(context.Context, string, string, io.Writer) error
|
||||
StreamTarEntries(context.Context, string, []string, string, io.Writer) error
|
||||
|
|
@ -400,6 +401,50 @@ func (d *Daemon) GuestSessionLogs(ctx context.Context, params api.GuestSessionLo
|
|||
return api.GuestSessionLogsResult{Session: session, Stream: streamName, Path: path, Content: content}, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) SendToGuestSession(ctx context.Context, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
|
||||
vm, err := d.FindVM(ctx, params.VMIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionSendResult{}, err
|
||||
}
|
||||
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionSendResult{}, err
|
||||
}
|
||||
if session.StdinMode != model.GuestSessionStdinPipe {
|
||||
return api.GuestSessionSendResult{}, errors.New("session does not have a stdin pipe")
|
||||
}
|
||||
if session.Status != model.GuestSessionStatusRunning {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("session is not running (status=%s)", session.Status)
|
||||
}
|
||||
if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("vm %q is not running", vm.Name)
|
||||
}
|
||||
if len(params.Payload) == 0 {
|
||||
return api.GuestSessionSendResult{Session: session}, nil
|
||||
}
|
||||
client, err := d.dialGuest(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"))
|
||||
if err != nil {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("dial guest: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
tmpPath := fmt.Sprintf("/tmp/banger-send-%s.bin", session.ID[:8])
|
||||
var uploadLog bytes.Buffer
|
||||
if err := client.UploadFile(ctx, tmpPath, 0o600, params.Payload, &uploadLog); err != nil {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("upload payload: %w", err)
|
||||
}
|
||||
sendScript := fmt.Sprintf(
|
||||
"set -euo pipefail\ncat %s >> %s\nrm -f %s\n",
|
||||
guestShellQuote(tmpPath),
|
||||
guestShellQuote(guestSessionStdinPipePath(session.ID)),
|
||||
guestShellQuote(tmpPath),
|
||||
)
|
||||
var sendLog bytes.Buffer
|
||||
if err := client.RunScript(ctx, sendScript, &sendLog); err != nil {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("send to session: %w: %s", err, strings.TrimSpace(sendLog.String()))
|
||||
}
|
||||
return api.GuestSessionSendResult{Session: session, BytesWritten: len(params.Payload)}, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) BeginGuestSessionAttach(ctx context.Context, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) {
|
||||
vm, err := d.FindVM(ctx, params.VMIDOrName)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"banger/internal/api"
|
||||
"banger/internal/model"
|
||||
"banger/internal/store"
|
||||
)
|
||||
|
||||
type fakeGuestSSHClient struct {
|
||||
|
|
@ -57,6 +58,10 @@ func (f *fakeGuestSSHClient) RunScript(_ context.Context, script string, _ io.Wr
|
|||
}
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) UploadFile(_ context.Context, _ string, _ os.FileMode, _ []byte, _ io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -77,6 +82,276 @@ func (f *fakeGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []s
|
|||
return fmt.Errorf("unexpected StreamTarEntries command: %s", command)
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_HappyPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
apiSock := filepath.Join(t.TempDir(), "fc.sock")
|
||||
firecracker := startFakeFirecracker(t, apiSock)
|
||||
|
||||
vm := testVM("sendbox", "image-send", "172.16.0.88")
|
||||
vm.State = model.VMStateRunning
|
||||
vm.Runtime.State = model.VMStateRunning
|
||||
vm.Runtime.PID = firecracker.Process.Pid
|
||||
vm.Runtime.APISockPath = apiSock
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
fake := &recordingGuestSSHClient{}
|
||||
d := newSendTestDaemon(t, db, fake)
|
||||
|
||||
payload := []byte(`{"type":"abort"}` + "\n")
|
||||
result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: payload,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendToGuestSession: %v", err)
|
||||
}
|
||||
if result.BytesWritten != len(payload) {
|
||||
t.Fatalf("BytesWritten = %d, want %d", result.BytesWritten, len(payload))
|
||||
}
|
||||
if result.Session.ID != session.ID {
|
||||
t.Fatalf("Session.ID = %q, want %q", result.Session.ID, session.ID)
|
||||
}
|
||||
if len(fake.uploadedFiles) != 1 {
|
||||
t.Fatalf("UploadFile call count = %d, want 1", len(fake.uploadedFiles))
|
||||
}
|
||||
for path, data := range fake.uploadedFiles {
|
||||
if !strings.HasPrefix(path, "/tmp/banger-send-") {
|
||||
t.Fatalf("upload path = %q, want /tmp/banger-send-... prefix", path)
|
||||
}
|
||||
if string(data) != string(payload) {
|
||||
t.Fatalf("upload data = %q, want %q", data, payload)
|
||||
}
|
||||
}
|
||||
if len(fake.ranScripts) != 1 {
|
||||
t.Fatalf("RunScript call count = %d, want 1", len(fake.ranScripts))
|
||||
}
|
||||
script := fake.ranScripts[0]
|
||||
pipePath := guestSessionStdinPipePath(session.ID)
|
||||
if !strings.Contains(script, "cat ") {
|
||||
t.Fatalf("send script missing cat command: %q", script)
|
||||
}
|
||||
if !strings.Contains(script, pipePath) {
|
||||
t.Fatalf("send script missing pipe path %q: %q", pipePath, script)
|
||||
}
|
||||
if !strings.Contains(script, "rm -f ") {
|
||||
t.Fatalf("send script missing rm cleanup: %q", script)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_EmptyPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
apiSock := filepath.Join(t.TempDir(), "fc.sock")
|
||||
firecracker := startFakeFirecracker(t, apiSock)
|
||||
|
||||
vm := testVM("sendbox-empty", "image-send", "172.16.0.89")
|
||||
vm.State = model.VMStateRunning
|
||||
vm.Runtime.State = model.VMStateRunning
|
||||
vm.Runtime.PID = firecracker.Process.Pid
|
||||
vm.Runtime.APISockPath = apiSock
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
fake := &recordingGuestSSHClient{}
|
||||
d := newSendTestDaemon(t, db, fake)
|
||||
|
||||
result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: nil,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendToGuestSession(empty): %v", err)
|
||||
}
|
||||
if result.BytesWritten != 0 {
|
||||
t.Fatalf("BytesWritten = %d, want 0", result.BytesWritten)
|
||||
}
|
||||
if fake.dialCount != 0 {
|
||||
t.Fatalf("SSH dial count = %d, want 0 for empty payload", fake.dialCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_NotPipeMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
vm := testVM("sendbox-closed", "image-send", "172.16.0.90")
|
||||
vm.State = model.VMStateRunning
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinClosed, model.GuestSessionStatusRunning)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
d := &Daemon{store: db}
|
||||
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: []byte("hello\n"),
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "stdin pipe") {
|
||||
t.Fatalf("error = %v, want 'stdin pipe' error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_SessionNotRunning(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
vm := testVM("sendbox-failed", "image-send", "172.16.0.91")
|
||||
vm.State = model.VMStateRunning
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusFailed)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
d := &Daemon{store: db}
|
||||
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: []byte("hello\n"),
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "not running") {
|
||||
t.Fatalf("error = %v, want 'not running' error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_VMNotRunning(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
vm := testVM("sendbox-stopped", "image-send", "172.16.0.92")
|
||||
vm.State = model.VMStateStopped
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
d := &Daemon{store: db}
|
||||
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: []byte("hello\n"),
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "not running") {
|
||||
t.Fatalf("error = %v, want 'not running' error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// recordingGuestSSHClient captures UploadFile and RunScript calls for send tests.
|
||||
type recordingGuestSSHClient struct {
|
||||
dialCount int
|
||||
uploadedFiles map[string][]byte
|
||||
ranScripts []string
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) Close() error { return nil }
|
||||
|
||||
func (r *recordingGuestSSHClient) UploadFile(_ context.Context, path string, _ os.FileMode, data []byte, _ io.Writer) error {
|
||||
if r.uploadedFiles == nil {
|
||||
r.uploadedFiles = make(map[string][]byte)
|
||||
}
|
||||
copy := make([]byte, len(data))
|
||||
_ = copy[:len(data):len(data)]
|
||||
for i, b := range data {
|
||||
copy[i] = b
|
||||
}
|
||||
r.uploadedFiles[path] = copy
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error {
|
||||
r.ranScripts = append(r.ranScripts, script)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) StreamTar(_ context.Context, _ string, _ string, _ io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, _ string, _ io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newSendTestDaemon(t *testing.T, db *store.Store, fake *recordingGuestSSHClient) *Daemon {
|
||||
t.Helper()
|
||||
d := &Daemon{
|
||||
store: db,
|
||||
config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")},
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
}
|
||||
d.guestDial = func(_ context.Context, _ string, _ string) (guestSSHClient, error) {
|
||||
fake.dialCount++
|
||||
return fake, nil
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func testGuestSession(vmID string, stdinMode model.GuestSessionStdinMode, status model.GuestSessionStatus) model.GuestSession {
|
||||
now := model.Now()
|
||||
id := vmID + "-sess-id"
|
||||
return model.GuestSession{
|
||||
ID: id,
|
||||
VMID: vmID,
|
||||
Name: vmID + "-sess",
|
||||
Backend: guestSessionBackendSSH,
|
||||
Command: "pi",
|
||||
Args: []string{"--mode", "rpc"},
|
||||
CWD: "/root/repo",
|
||||
StdinMode: stdinMode,
|
||||
Status: status,
|
||||
GuestStateDir: guestSessionStateDir(id),
|
||||
StdoutLogPath: guestSessionStdoutLogPath(id),
|
||||
StderrLogPath: guestSessionStderrLogPath(id),
|
||||
Attachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning,
|
||||
Reattachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd {
|
||||
t.Helper()
|
||||
cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("start fake firecracker: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
_, _ = cmd.Process.Wait()
|
||||
}
|
||||
})
|
||||
return cmd
|
||||
}
|
||||
|
||||
func TestGuestSessionPreflightScriptsUseRealNewlines(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
|||
|
|
@ -35,6 +35,56 @@ type workspaceRepoSpec struct {
|
|||
Submodules []string
|
||||
}
|
||||
|
||||
func (d *Daemon) ExportVMWorkspace(ctx context.Context, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
|
||||
guestPath := strings.TrimSpace(params.GuestPath)
|
||||
if guestPath == "" {
|
||||
guestPath = "/root/repo"
|
||||
}
|
||||
vm, err := d.FindVM(ctx, params.IDOrName)
|
||||
if err != nil {
|
||||
return api.WorkspaceExportResult{}, err
|
||||
}
|
||||
if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
|
||||
return api.WorkspaceExportResult{}, fmt.Errorf("vm %q is not running", vm.Name)
|
||||
}
|
||||
client, err := d.dialGuest(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"))
|
||||
if err != nil {
|
||||
return api.WorkspaceExportResult{}, fmt.Errorf("dial guest: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Stage all changes then emit a binary-safe unified diff against HEAD.
|
||||
// --binary ensures binary files are handled correctly by git apply.
|
||||
patchScript := fmt.Sprintf(
|
||||
"set -euo pipefail\ncd %s\ngit add -A\ngit diff --cached HEAD --binary\n",
|
||||
guestShellQuote(guestPath),
|
||||
)
|
||||
patch, err := client.RunScriptOutput(ctx, patchScript)
|
||||
if err != nil {
|
||||
return api.WorkspaceExportResult{}, fmt.Errorf("export workspace diff: %w", err)
|
||||
}
|
||||
|
||||
// Enumerate changed paths (index already staged; this is a cheap read).
|
||||
namesScript := fmt.Sprintf(
|
||||
"set -euo pipefail\ncd %s\ngit diff --cached HEAD --name-only\n",
|
||||
guestShellQuote(guestPath),
|
||||
)
|
||||
namesOut, _ := client.RunScriptOutput(ctx, namesScript)
|
||||
var changed []string
|
||||
for _, line := range strings.Split(strings.TrimSpace(string(namesOut)), "\n") {
|
||||
if line = strings.TrimSpace(line); line != "" {
|
||||
changed = append(changed, line)
|
||||
}
|
||||
}
|
||||
|
||||
return api.WorkspaceExportResult{
|
||||
GuestPath: guestPath,
|
||||
Patch: patch,
|
||||
ChangedFiles: changed,
|
||||
HasChanges: len(patch) > 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) PrepareVMWorkspace(ctx context.Context, params api.VMWorkspacePrepareParams) (model.WorkspacePrepareResult, error) {
|
||||
mode, err := parseWorkspacePrepareMode(params.Mode)
|
||||
if err != nil {
|
||||
|
|
|
|||
254
internal/daemon/workspace_test.go
Normal file
254
internal/daemon/workspace_test.go
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"banger/internal/api"
|
||||
"banger/internal/model"
|
||||
)
|
||||
|
||||
// exportGuestClient is a scriptable fake for RunScriptOutput used in export tests.
|
||||
// Each call to RunScriptOutput returns the next response from the queue.
|
||||
type exportGuestClient struct {
|
||||
responses []exportGuestResponse
|
||||
callIndex int
|
||||
}
|
||||
|
||||
type exportGuestResponse struct {
|
||||
output []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *exportGuestClient) Close() error { return nil }
|
||||
|
||||
func (e *exportGuestClient) RunScript(_ context.Context, _ string, _ io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *exportGuestClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) {
|
||||
if e.callIndex >= len(e.responses) {
|
||||
return nil, nil
|
||||
}
|
||||
r := e.responses[e.callIndex]
|
||||
e.callIndex++
|
||||
return r.output, r.err
|
||||
}
|
||||
|
||||
func (e *exportGuestClient) UploadFile(_ context.Context, _ string, _ os.FileMode, _ []byte, _ io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *exportGuestClient) StreamTar(_ context.Context, _ string, _ string, _ io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *exportGuestClient) StreamTarEntries(_ context.Context, _ string, _ []string, _ string, _ io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newExportTestDaemonStore(t *testing.T, fake *exportGuestClient) *Daemon {
|
||||
t.Helper()
|
||||
db := openDaemonStore(t)
|
||||
d := &Daemon{
|
||||
store: db,
|
||||
config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")},
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
}
|
||||
d.guestDial = func(_ context.Context, _ string, _ string) (guestSSHClient, error) {
|
||||
return fake, nil
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func TestExportVMWorkspace_HappyPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
apiSock := filepath.Join(t.TempDir(), "fc.sock")
|
||||
firecracker := startFakeFirecracker(t, apiSock)
|
||||
|
||||
vm := testVM("exportbox", "image-export", "172.16.0.100")
|
||||
vm.State = model.VMStateRunning
|
||||
vm.Runtime.State = model.VMStateRunning
|
||||
vm.Runtime.PID = firecracker.Process.Pid
|
||||
vm.Runtime.APISockPath = apiSock
|
||||
|
||||
patch := []byte("diff --git a/file.go b/file.go\nindex 0000000..1111111 100644\n")
|
||||
names := []byte("file.go\n")
|
||||
|
||||
fake := &exportGuestClient{
|
||||
responses: []exportGuestResponse{
|
||||
{output: patch},
|
||||
{output: names},
|
||||
},
|
||||
}
|
||||
d := newExportTestDaemonStore(t, fake)
|
||||
upsertDaemonVM(t, ctx, d.store, vm)
|
||||
|
||||
result, err := d.ExportVMWorkspace(ctx, api.WorkspaceExportParams{
|
||||
IDOrName: vm.Name,
|
||||
GuestPath: "/root/repo",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExportVMWorkspace: %v", err)
|
||||
}
|
||||
if !result.HasChanges {
|
||||
t.Fatal("HasChanges = false, want true")
|
||||
}
|
||||
if string(result.Patch) != string(patch) {
|
||||
t.Fatalf("Patch = %q, want %q", result.Patch, patch)
|
||||
}
|
||||
if result.GuestPath != "/root/repo" {
|
||||
t.Fatalf("GuestPath = %q, want /root/repo", result.GuestPath)
|
||||
}
|
||||
if len(result.ChangedFiles) != 1 || result.ChangedFiles[0] != "file.go" {
|
||||
t.Fatalf("ChangedFiles = %v, want [file.go]", result.ChangedFiles)
|
||||
}
|
||||
if fake.callIndex != 2 {
|
||||
t.Fatalf("RunScriptOutput call count = %d, want 2", fake.callIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportVMWorkspace_NoChanges(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
apiSock := filepath.Join(t.TempDir(), "fc.sock")
|
||||
firecracker := startFakeFirecracker(t, apiSock)
|
||||
|
||||
vm := testVM("exportbox-empty", "image-export", "172.16.0.101")
|
||||
vm.State = model.VMStateRunning
|
||||
vm.Runtime.State = model.VMStateRunning
|
||||
vm.Runtime.PID = firecracker.Process.Pid
|
||||
vm.Runtime.APISockPath = apiSock
|
||||
|
||||
// Both scripts return empty output (no changes).
|
||||
fake := &exportGuestClient{
|
||||
responses: []exportGuestResponse{
|
||||
{output: nil},
|
||||
{output: nil},
|
||||
},
|
||||
}
|
||||
d := newExportTestDaemonStore(t, fake)
|
||||
upsertDaemonVM(t, ctx, d.store, vm)
|
||||
|
||||
result, err := d.ExportVMWorkspace(ctx, api.WorkspaceExportParams{
|
||||
IDOrName: vm.Name,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExportVMWorkspace: %v", err)
|
||||
}
|
||||
if result.HasChanges {
|
||||
t.Fatal("HasChanges = true, want false")
|
||||
}
|
||||
if len(result.Patch) != 0 {
|
||||
t.Fatalf("Patch = %q, want empty", result.Patch)
|
||||
}
|
||||
if len(result.ChangedFiles) != 0 {
|
||||
t.Fatalf("ChangedFiles = %v, want empty", result.ChangedFiles)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportVMWorkspace_DefaultGuestPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
apiSock := filepath.Join(t.TempDir(), "fc.sock")
|
||||
firecracker := startFakeFirecracker(t, apiSock)
|
||||
|
||||
vm := testVM("exportbox-default", "image-export", "172.16.0.102")
|
||||
vm.State = model.VMStateRunning
|
||||
vm.Runtime.State = model.VMStateRunning
|
||||
vm.Runtime.PID = firecracker.Process.Pid
|
||||
vm.Runtime.APISockPath = apiSock
|
||||
|
||||
fake := &exportGuestClient{
|
||||
responses: []exportGuestResponse{
|
||||
{output: nil},
|
||||
{output: nil},
|
||||
},
|
||||
}
|
||||
d := newExportTestDaemonStore(t, fake)
|
||||
upsertDaemonVM(t, ctx, d.store, vm)
|
||||
|
||||
// GuestPath omitted — should default to /root/repo.
|
||||
result, err := d.ExportVMWorkspace(ctx, api.WorkspaceExportParams{
|
||||
IDOrName: vm.Name,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExportVMWorkspace: %v", err)
|
||||
}
|
||||
if result.GuestPath != "/root/repo" {
|
||||
t.Fatalf("GuestPath = %q, want /root/repo", result.GuestPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportVMWorkspace_VMNotRunning(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
vm := testVM("exportbox-stopped", "image-export", "172.16.0.103")
|
||||
vm.State = model.VMStateStopped
|
||||
|
||||
fake := &exportGuestClient{}
|
||||
d := newExportTestDaemonStore(t, fake)
|
||||
upsertDaemonVM(t, ctx, d.store, vm)
|
||||
|
||||
_, err := d.ExportVMWorkspace(ctx, api.WorkspaceExportParams{
|
||||
IDOrName: vm.Name,
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "not running") {
|
||||
t.Fatalf("error = %v, want 'not running' error", err)
|
||||
}
|
||||
if fake.callIndex != 0 {
|
||||
t.Fatal("RunScriptOutput should not be called when VM is not running")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportVMWorkspace_MultipleChangedFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
apiSock := filepath.Join(t.TempDir(), "fc.sock")
|
||||
firecracker := startFakeFirecracker(t, apiSock)
|
||||
|
||||
vm := testVM("exportbox-multi", "image-export", "172.16.0.104")
|
||||
vm.State = model.VMStateRunning
|
||||
vm.Runtime.State = model.VMStateRunning
|
||||
vm.Runtime.PID = firecracker.Process.Pid
|
||||
vm.Runtime.APISockPath = apiSock
|
||||
|
||||
patch := []byte("diff --git a/a.go b/a.go\n--- a/a.go\n+++ b/a.go\n")
|
||||
names := []byte("a.go\nb.go\nnew/file.go\n")
|
||||
|
||||
fake := &exportGuestClient{
|
||||
responses: []exportGuestResponse{
|
||||
{output: patch},
|
||||
{output: names},
|
||||
},
|
||||
}
|
||||
d := newExportTestDaemonStore(t, fake)
|
||||
upsertDaemonVM(t, ctx, d.store, vm)
|
||||
|
||||
result, err := d.ExportVMWorkspace(ctx, api.WorkspaceExportParams{
|
||||
IDOrName: vm.Name,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExportVMWorkspace: %v", err)
|
||||
}
|
||||
if len(result.ChangedFiles) != 3 {
|
||||
t.Fatalf("ChangedFiles = %v, want 3 entries", result.ChangedFiles)
|
||||
}
|
||||
want := []string{"a.go", "b.go", "new/file.go"}
|
||||
for i, f := range want {
|
||||
if result.ChangedFiles[i] != f {
|
||||
t.Fatalf("ChangedFiles[%d] = %q, want %q", i, result.ChangedFiles[i], f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -89,6 +89,35 @@ func (c *Client) RunScript(ctx context.Context, script string, logWriter io.Writ
|
|||
return c.runSession(ctx, "bash -se", strings.NewReader(script), logWriter)
|
||||
}
|
||||
|
||||
// RunScriptOutput runs script on the guest and returns its stdout.
|
||||
// Stderr is discarded. Use for capturing structured output (patches, JSON,
|
||||
// file content) where mixing stderr into stdout would corrupt the result.
|
||||
func (c *Client) RunScriptOutput(ctx context.Context, script string) ([]byte, error) {
|
||||
if c == nil || c.client == nil {
|
||||
return nil, fmt.Errorf("ssh client is not connected")
|
||||
}
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer session.Close()
|
||||
session.Stdin = strings.NewReader(script)
|
||||
var stdout bytes.Buffer
|
||||
session.Stdout = &stdout
|
||||
// session.Stderr left nil: stderr is intentionally discarded.
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = c.client.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
err = session.Run("bash -se")
|
||||
done <- nil
|
||||
return stdout.Bytes(), err
|
||||
}
|
||||
|
||||
func (c *Client) UploadFile(ctx context.Context, remotePath string, mode os.FileMode, data []byte, logWriter io.Writer) error {
|
||||
command := fmt.Sprintf("install -D -m %04o /dev/stdin %s", mode.Perm(), shellQuote(remotePath))
|
||||
return c.runSession(ctx, command, bytes.NewReader(data), logWriter)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue