From 2b6437d1b46ededb56d6e56542d048fb495b1703 Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Mon, 20 Apr 2026 12:47:58 -0300 Subject: [PATCH] remove vm session feature MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cuts the daemon-managed guest-session machinery (start/list/show/ logs/stop/kill/attach/send). The feature shipped aimed at agent- orchestration workflows (programmatic stdin piping into a long-lived guest process) that aren't driving any concrete user today, and the ~2.3K LOC of daemon surface area — attach bridge, FIFO keepalive, controller registry, sessionstream framing, SQLite persistence — was locking in an API we'd have to keep through v0.1.0. Anything session-flavoured that people actually need today can be done with `vm ssh + tmux` or `vm run -- cmd`. Deleted: - internal/cli/commands_vm_session.go - internal/daemon/{guest_sessions,session_lifecycle,session_attach,session_stream,session_controller}.go - internal/daemon/session/ (guest-session helpers package) - internal/sessionstream/ (framing package) - internal/daemon/guest_sessions_test.go - internal/store/guest_session_test.go - GuestSession* types from internal/{api,model} - Store UpsertGuestSession/GetGuestSession/ListGuestSessionsByVM/DeleteGuestSession + scanner helpers - guest.session.* RPC dispatch entries - 5 CLI session tests, 2 completion tests, 2 printer tests Extracted: - ShellQuote + FormatStepError lifted to internal/daemon/workspace/util.go (only non-session consumer); workspace package now self-contained - internal/daemon/guest_ssh.go keeps guestSSHClient + dialGuest + waitForGuestSSH — still used by workspace prepare/export - internal/daemon/fake_firecracker_test.go preserves the test helper that used to live in guest_sessions_test.go Store schema: CREATE TABLE guest_sessions and its column migrations removed. Existing dev DBs keep an orphan table (harmless, pre-v0.1.0). Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 4 +- docs/advanced.md | 28 +- internal/api/types.go | 64 --- internal/cli/aliases_test.go | 1 - internal/cli/cli_test.go | 119 +---- internal/cli/commands_vm.go | 1 - internal/cli/commands_vm_session.go | 370 ------------- internal/cli/completion.go | 42 +- internal/cli/completion_test.go | 77 +-- internal/cli/deps.go | 82 +-- internal/cli/formatters_test.go | 68 --- internal/cli/printers.go | 25 - internal/daemon/ARCHITECTURE.md | 13 +- internal/daemon/daemon.go | 109 +--- internal/daemon/doc.go | 24 +- internal/daemon/fake_firecracker_test.go | 26 + internal/daemon/guest_sessions.go | 142 ----- internal/daemon/guest_sessions_test.go | 492 ----------------- internal/daemon/guest_ssh.go | 35 ++ internal/daemon/open_close_test.go | 25 +- internal/daemon/session/session.go | 521 ------------------- internal/daemon/session/session_test.go | 440 ---------------- internal/daemon/session_attach.go | 224 -------- internal/daemon/session_controller.go | 184 ------- internal/daemon/session_lifecycle.go | 213 -------- internal/daemon/session_stream.go | 120 ----- internal/daemon/workspace.go | 11 +- internal/daemon/workspace/util.go | 20 + internal/daemon/workspace/workspace.go | 29 +- internal/model/types.go | 48 -- internal/sessionstream/sessionstream.go | 76 --- internal/sessionstream/sessionstream_test.go | 117 ----- internal/store/guest_session_test.go | 214 -------- internal/store/store.go | 261 ---------- 34 files changed, 194 insertions(+), 4031 deletions(-) delete mode 100644 internal/cli/commands_vm_session.go create mode 100644 internal/daemon/fake_firecracker_test.go delete mode 100644 internal/daemon/guest_sessions.go delete mode 100644 internal/daemon/guest_sessions_test.go create mode 100644 internal/daemon/guest_ssh.go delete mode 100644 internal/daemon/session/session.go delete mode 100644 internal/daemon/session/session_test.go delete mode 100644 internal/daemon/session_attach.go delete mode 100644 internal/daemon/session_controller.go delete mode 100644 internal/daemon/session_lifecycle.go delete mode 100644 internal/daemon/session_stream.go create mode 100644 internal/daemon/workspace/util.go delete mode 100644 internal/sessionstream/sessionstream.go delete mode 100644 internal/sessionstream/sessionstream_test.go delete mode 100644 internal/store/guest_session_test.go diff --git a/README.md b/README.md index 89a4c4e..7cc850b 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ User data stays in place — the target prints the paths so you can `banger` ships completion scripts for bash, zsh, fish, and powershell. Tab-completion covers subcommands, flags, and live -resource names (VM, image, kernel, session) looked up from the +resource names (VM, image, kernel) looked up from the daemon. With the daemon down, resource completion silently returns nothing — no file-completion fallback. @@ -179,7 +179,7 @@ ones you want. ## Advanced The common path is `vm run`. Power-user flows (`vm create`, OCI pull -for arbitrary images, `image register`, long-lived sessions) are +for arbitrary images, `image register`, manual workspace prepare) are documented in [`docs/advanced.md`](docs/advanced.md). ## Security diff --git a/docs/advanced.md b/docs/advanced.md index 8863739..191086a 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -60,33 +60,19 @@ disk, or pass `--kernel /abs/path/vmlinux` for a one-off kernel. For reproducible custom images, write a Dockerfile and publish it to an image catalog. See [`docs/image-catalog.md`](image-catalog.md). -## Workspace + session primitives +## Workspace primitive -Long-lived guest commands managed by the daemon, attachable over a -local Unix socket bridge. Useful for agent/background processes that -need to survive SSH disconnects. +`vm run ./repo` (see README) handles the common case. For a manual +flow against an already-running VM, `vm workspace prepare` +materialises a local git checkout into the guest: ```bash banger vm workspace prepare ./other-repo --guest-path /root/repo -banger vm session start --name planner --cwd /root/repo \ - --stdin-mode pipe -- pi --mode rpc -banger vm session attach planner -banger vm session logs planner --stream stderr -banger vm session stop planner ``` -Details: - -- `vm workspace prepare` materialises a local git checkout into a - running VM. Default guest path `/root/repo`; default mode is a - shallow metadata copy plus tracked and untracked non-ignored - overlay. -- `vm session start` launches a daemon-managed long-lived guest - command. The daemon preflights that the guest `cwd` exists and the - command is on guest `PATH` before launch. Use `--stdin-mode pipe` - when you need live `attach`. -- `vm session attach` is exclusive and same-host only. Pipe-mode - sessions survive daemon restarts. +Default guest path is `/root/repo`; default mode is a shallow metadata +copy plus tracked and untracked non-ignored overlay. For repositories +with submodules, pass `--mode full_copy`. ## Inspecting boot failures diff --git a/internal/api/types.go b/internal/api/types.go index 5ae0e32..8a3ff99 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -122,70 +122,6 @@ type VMPortsResult struct { Ports []VMPort `json:"ports"` } -type GuestSessionStartParams struct { - VMIDOrName string `json:"vm_id_or_name"` - Name string `json:"name,omitempty"` - Command string `json:"command"` - Args []string `json:"args,omitempty"` - CWD string `json:"cwd,omitempty"` - Env map[string]string `json:"env,omitempty"` - StdinMode string `json:"stdin_mode,omitempty"` - Tags map[string]string `json:"tags,omitempty"` - RequiredCommands []string `json:"required_commands,omitempty"` -} - -type GuestSessionRefParams struct { - VMIDOrName string `json:"vm_id_or_name"` - SessionIDOrName string `json:"session_id_or_name"` -} - -type GuestSessionLogsParams struct { - VMIDOrName string `json:"vm_id_or_name"` - SessionIDOrName string `json:"session_id_or_name"` - Stream string `json:"stream,omitempty"` - TailLines int `json:"tail_lines,omitempty"` -} - -type GuestSessionAttachBeginParams struct { - VMIDOrName string `json:"vm_id_or_name"` - SessionIDOrName string `json:"session_id_or_name"` -} - -type GuestSessionListResult struct { - Sessions []model.GuestSession `json:"sessions"` -} - -type GuestSessionShowResult struct { - Session model.GuestSession `json:"session"` -} - -type GuestSessionLogsResult struct { - Session model.GuestSession `json:"session"` - Stream string `json:"stream"` - Path string `json:"path,omitempty"` - Content string `json:"content,omitempty"` -} - -type GuestSessionAttachBeginResult struct { - Session model.GuestSession `json:"session"` - AttachID string `json:"attach_id"` - TransportKind string `json:"transport_kind"` - TransportTarget string `json:"transport_target"` - SocketPath string `json:"socket_path,omitempty"` - StreamFormat string `json:"stream_format"` -} - -type GuestSessionSendParams struct { - VMIDOrName string `json:"vm_id_or_name"` - SessionIDOrName string `json:"session_id_or_name"` - Payload []byte `json:"payload"` -} - -type GuestSessionSendResult struct { - Session model.GuestSession `json:"session"` - BytesWritten int `json:"bytes_written"` -} - type WorkspaceExportParams struct { IDOrName string `json:"id_or_name"` GuestPath string `json:"guest_path,omitempty"` diff --git a/internal/cli/aliases_test.go b/internal/cli/aliases_test.go index 12853e4..ed1cbe3 100644 --- a/internal/cli/aliases_test.go +++ b/internal/cli/aliases_test.go @@ -46,7 +46,6 @@ func TestListCommandsHaveLsAlias(t *testing.T) { {"vm", "list"}, {"image", "list"}, {"kernel", "list"}, - {"vm", "session", "list"}, } for _, path := range cases { t.Run(path[len(path)-1], func(t *testing.T) { diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index d9030ca..8c9c26a 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -1918,34 +1918,9 @@ func (c *testVMRunGuestClient) StreamTarEntries(ctx context.Context, sourceDir s return nil } -func TestVMSessionSendCommandExists(t *testing.T) { - root := NewBangerCommand() - vm, _, err := root.Find([]string{"vm"}) - if err != nil { - t.Fatalf("find vm: %v", err) - } - session, _, err := vm.Find([]string{"session"}) - if err != nil { - t.Fatalf("find session: %v", err) - } - if _, _, err := session.Find([]string{"send"}); err != nil { - t.Fatalf("find session send: %v", err) - } -} - -func TestVMSessionSendRejectsWrongArgCount(t *testing.T) { - cmd := NewBangerCommand() - cmd.SetArgs([]string{"vm", "session", "send", "only-one-arg"}) - err := cmd.Execute() - if err == nil || !strings.Contains(err.Error(), "usage: banger vm session send") { - t.Fatalf("Execute() error = %v, want send usage error", err) - } -} - // stubEnsureDaemonForSend isolates XDG dirs and installs a daemon-ping // fake onto the caller's *deps so `ensureDaemon` short-circuits without -// trying to spawn bangerd. `vm session send` uses this to avoid needing -// a built binary on disk. +// trying to spawn bangerd. func stubEnsureDaemonForSend(t *testing.T, d *deps) { t.Helper() t.Setenv("XDG_CONFIG_HOME", filepath.Join(t.TempDir(), "config")) @@ -1956,98 +1931,6 @@ func stubEnsureDaemonForSend(t *testing.T, d *deps) { } } -func TestVMSessionSendWithMessageFlag(t *testing.T) { - d := defaultDeps() - stubEnsureDaemonForSend(t, d) - - var capturedParams api.GuestSessionSendParams - d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { - capturedParams = params - return api.GuestSessionSendResult{ - Session: model.GuestSession{ID: "sess-id", Name: "planner"}, - BytesWritten: len(params.Payload), - }, nil - } - - cmd := d.newRootCommand() - var out bytes.Buffer - cmd.SetOut(&out) - cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner", "--message", `{"type":"abort"}`}) - if err := cmd.Execute(); err != nil { - t.Fatalf("Execute: %v", err) - } - - wantPayload := []byte(`{"type":"abort"}` + "\n") - if string(capturedParams.Payload) != string(wantPayload) { - t.Fatalf("payload = %q, want %q", capturedParams.Payload, wantPayload) - } - if capturedParams.VMIDOrName != "devbox" { - t.Fatalf("VMIDOrName = %q, want %q", capturedParams.VMIDOrName, "devbox") - } - if capturedParams.SessionIDOrName != "planner" { - t.Fatalf("SessionIDOrName = %q, want %q", capturedParams.SessionIDOrName, "planner") - } - if !strings.Contains(out.String(), "17") { - t.Fatalf("output = %q, want bytes_written in output", out.String()) - } -} - -func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) { - d := defaultDeps() - stubEnsureDaemonForSend(t, d) - - var capturedPayload []byte - d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { - capturedPayload = params.Payload - return api.GuestSessionSendResult{ - Session: model.GuestSession{Name: "s"}, - BytesWritten: len(params.Payload), - }, nil - } - - cmd := d.newRootCommand() - cmd.SetOut(io.Discard) - cmd.SetArgs([]string{"vm", "session", "send", "devbox", "s", "--message", "{\"type\":\"abort\"}\n"}) - if err := cmd.Execute(); err != nil { - t.Fatalf("Execute: %v", err) - } - - // Must not double-append newline. - if capturedPayload[len(capturedPayload)-1] != '\n' { - t.Fatalf("payload missing trailing newline: %q", capturedPayload) - } - if len(capturedPayload) > 0 && capturedPayload[len(capturedPayload)-2] == '\n' { - t.Fatalf("payload has double trailing newline: %q", capturedPayload) - } -} - -func TestVMSessionSendFromStdin(t *testing.T) { - d := defaultDeps() - stubEnsureDaemonForSend(t, d) - - var capturedPayload []byte - d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { - capturedPayload = params.Payload - return api.GuestSessionSendResult{ - Session: model.GuestSession{Name: "planner"}, - BytesWritten: len(params.Payload), - }, nil - } - - stdinPayload := `{"type":"steer","message":"Focus on src/"}` + "\n" - cmd := d.newRootCommand() - cmd.SetOut(io.Discard) - cmd.SetIn(strings.NewReader(stdinPayload)) - cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner"}) - if err := cmd.Execute(); err != nil { - t.Fatalf("Execute: %v", err) - } - - if string(capturedPayload) != stdinPayload { - t.Fatalf("payload = %q, want %q", capturedPayload, stdinPayload) - } -} - func TestVMWorkspaceExportCommandExists(t *testing.T) { root := NewBangerCommand() vm, _, err := root.Find([]string{"vm"}) diff --git a/internal/cli/commands_vm.go b/internal/cli/commands_vm.go index f85198b..c674d1e 100644 --- a/internal/cli/commands_vm.go +++ b/internal/cli/commands_vm.go @@ -42,7 +42,6 @@ func (d *deps) newVMCommand() *cobra.Command { d.newVMSetCommand(), d.newVMSSHCommand(), d.newVMWorkspaceCommand(), - d.newVMSessionCommand(), d.newVMLogsCommand(), d.newVMStatsCommand(), d.newVMPortsCommand(), diff --git a/internal/cli/commands_vm_session.go b/internal/cli/commands_vm_session.go deleted file mode 100644 index d539445..0000000 --- a/internal/cli/commands_vm_session.go +++ /dev/null @@ -1,370 +0,0 @@ -package cli - -import ( - "context" - "errors" - "fmt" - "io" - "net" - "strings" - - "banger/internal/api" - "banger/internal/model" - "banger/internal/sessionstream" - - "github.com/spf13/cobra" -) - -func (d *deps) newVMSessionCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "session", - Short: "Manage long-lived guest commands inside a VM", - Long: "Start, inspect, stop, and attach to daemon-managed guest commands. Pipe-mode sessions expose live stdio for interactive protocols. Attach is exclusive and currently uses a same-host local bridge.", - RunE: helpNoArgs, - } - cmd.AddCommand( - d.newVMSessionStartCommand(), - d.newVMSessionListCommand(), - d.newVMSessionShowCommand(), - d.newVMSessionLogsCommand(), - d.newVMSessionStopCommand(), - d.newVMSessionKillCommand(), - d.newVMSessionAttachCommand(), - d.newVMSessionSendCommand(), - ) - return cmd -} - -func (d *deps) newVMSessionStartCommand() *cobra.Command { - var name string - var cwd string - var stdinMode string - var envPairs []string - var tagPairs []string - var requiredCommands []string - cmd := &cobra.Command{ - Use: "start [args...]", - Short: "Start a managed guest command", - Long: "Start a daemon-managed guest command. The daemon verifies that the guest working directory exists and that the requested command is present in guest PATH before launch. Use --stdin-mode pipe when you need live attach.", - Args: minArgsUsage(2, "usage: banger vm session start [flags] -- [args...]"), - ValidArgsFunction: d.completeVMNameOnlyAtPos0, - Example: strings.TrimSpace(` - banger vm session start devbox --name planner --cwd /root/repo --stdin-mode pipe --require-command git -- pi --mode rpc --no-session - banger vm session start devbox --name shell --stdin-mode pipe -- bash -lc 'exec bash' -`), - RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := d.ensureDaemon(cmd.Context()) - if err != nil { - return err - } - env, err := parseKeyValuePairs(envPairs) - if err != nil { - return err - } - tags, err := parseKeyValuePairs(tagPairs) - if err != nil { - return err - } - result, err := d.guestSessionStart(cmd.Context(), layout.SocketPath, api.GuestSessionStartParams{ - VMIDOrName: args[0], - Name: name, - Command: args[1], - Args: append([]string(nil), args[2:]...), - CWD: cwd, - Env: env, - StdinMode: stdinMode, - Tags: tags, - RequiredCommands: append([]string(nil), requiredCommands...), - }) - if err != nil { - return err - } - if err := printGuestSessionSummary(cmd.OutOrStdout(), result.Session); err != nil { - return err - } - if result.Session.Status == model.GuestSessionStatusFailed && strings.TrimSpace(result.Session.LaunchMessage) != "" { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "warning: session failed at %s: %s\n", result.Session.LaunchStage, result.Session.LaunchMessage) - } - return nil - }, - } - cmd.Flags().StringVar(&name, "name", "", "session name") - cmd.Flags().StringVar(&cwd, "cwd", "", "guest working directory; must already exist") - cmd.Flags().StringVar(&stdinMode, "stdin-mode", string(model.GuestSessionStdinClosed), "stdin mode: closed or pipe (pipe enables attach)") - cmd.Flags().StringArrayVar(&envPairs, "env", nil, "environment entry in KEY=VALUE form") - cmd.Flags().StringArrayVar(&tagPairs, "tag", nil, "session tag in KEY=VALUE form") - cmd.Flags().StringArrayVar(&requiredCommands, "require-command", nil, "extra guest command that must exist in PATH before launch; repeatable") - return cmd -} - -func (d *deps) newVMSessionListCommand() *cobra.Command { - return &cobra.Command{ - Use: "list ", - Aliases: []string{"ls"}, - Short: "List managed guest commands for a VM", - Args: exactArgsUsage(1, "usage: banger vm session list "), - ValidArgsFunction: d.completeVMNameOnlyAtPos0, - RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := d.ensureDaemon(cmd.Context()) - if err != nil { - return err - } - result, err := d.guestSessionList(cmd.Context(), layout.SocketPath, args[0]) - if err != nil { - return err - } - return printGuestSessionTable(cmd.OutOrStdout(), result.Sessions) - }, - } -} - -func (d *deps) newVMSessionShowCommand() *cobra.Command { - return &cobra.Command{ - Use: "show ", - Short: "Show managed guest command details", - Args: exactArgsUsage(2, "usage: banger vm session show "), - ValidArgsFunction: d.completeSessionNames, - RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := d.ensureDaemon(cmd.Context()) - if err != nil { - return err - } - result, err := d.guestSessionGet(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) - if err != nil { - return err - } - return printJSON(cmd.OutOrStdout(), result.Session) - }, - } -} - -func (d *deps) newVMSessionLogsCommand() *cobra.Command { - var stream string - var tailLines int - cmd := &cobra.Command{ - Use: "logs ", - Short: "Show stdout or stderr for a guest session", - Args: exactArgsUsage(2, "usage: banger vm session logs [--stream stdout|stderr] [-n LINES] "), - ValidArgsFunction: d.completeSessionNames, - RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := d.ensureDaemon(cmd.Context()) - if err != nil { - return err - } - result, err := d.guestSessionLogs(cmd.Context(), layout.SocketPath, api.GuestSessionLogsParams{VMIDOrName: args[0], SessionIDOrName: args[1], Stream: stream, TailLines: tailLines}) - if err != nil { - return err - } - _, err = fmt.Fprint(cmd.OutOrStdout(), result.Content) - return err - }, - } - cmd.Flags().StringVar(&stream, "stream", "stdout", "log stream to read") - cmd.Flags().IntVarP(&tailLines, "lines", "n", 200, "number of lines to tail") - return cmd -} - -func (d *deps) newVMSessionStopCommand() *cobra.Command { - return &cobra.Command{ - Use: "stop ", - Short: "Send SIGTERM to a guest session", - Args: exactArgsUsage(2, "usage: banger vm session stop "), - ValidArgsFunction: d.completeSessionNames, - RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := d.ensureDaemon(cmd.Context()) - if err != nil { - return err - } - result, err := d.guestSessionStop(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) - if err != nil { - return err - } - return printGuestSessionSummary(cmd.OutOrStdout(), result.Session) - }, - } -} - -func (d *deps) newVMSessionKillCommand() *cobra.Command { - return &cobra.Command{ - Use: "kill ", - Short: "Send SIGKILL to a guest session", - Args: exactArgsUsage(2, "usage: banger vm session kill "), - ValidArgsFunction: d.completeSessionNames, - RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := d.ensureDaemon(cmd.Context()) - if err != nil { - return err - } - result, err := d.guestSessionKill(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) - if err != nil { - return err - } - return printGuestSessionSummary(cmd.OutOrStdout(), result.Session) - }, - } -} - -func (d *deps) newVMSessionAttachCommand() *cobra.Command { - return &cobra.Command{ - Use: "attach ", - Short: "Attach local stdio to an attachable guest session", - Long: "Attach local stdio to a pipe-mode session through a daemon-created local Unix socket bridge. Only one active attach is allowed at a time, and the client must run on the same host as the daemon.", - Args: exactArgsUsage(2, "usage: banger vm session attach "), - ValidArgsFunction: d.completeSessionNames, - RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := d.ensureDaemon(cmd.Context()) - if err != nil { - return err - } - result, err := d.guestSessionAttachBegin(cmd.Context(), layout.SocketPath, api.GuestSessionAttachBeginParams{VMIDOrName: args[0], SessionIDOrName: args[1]}) - if err != nil { - return err - } - socketPath := strings.TrimSpace(result.SocketPath) - if socketPath == "" && result.TransportKind == "unix_socket" { - socketPath = strings.TrimSpace(result.TransportTarget) - } - return runGuestSessionAttach(cmd.Context(), cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), socketPath) - }, - } -} - -func (d *deps) newVMSessionSendCommand() *cobra.Command { - var message string - cmd := &cobra.Command{ - Use: "send ", - Short: "Write bytes to a running guest session's stdin pipe", - Long: "Write a payload to the stdin pipe of a running pipe-mode guest session without holding the exclusive attach. Use --message for an inline JSONL string, or pipe bytes via stdin when --message is omitted. A trailing newline is appended to --message values that lack one.", - Args: exactArgsUsage(2, "usage: banger vm session send [--message '']"), - ValidArgsFunction: d.completeSessionNames, - Example: strings.TrimSpace(` - banger vm session send devbox planner --message '{"type":"abort"}' - banger vm session send devbox planner --message '{"type":"steer","message":"Focus on src/"}' - echo '{"type":"prompt","prompt":"Summarize."}' | banger vm session send devbox planner -`), - RunE: func(cmd *cobra.Command, args []string) error { - layout, _, err := d.ensureDaemon(cmd.Context()) - if err != nil { - return err - } - var payload []byte - if message != "" { - payload = []byte(message) - if len(payload) > 0 && payload[len(payload)-1] != '\n' { - payload = append(payload, '\n') - } - } else { - payload, err = io.ReadAll(cmd.InOrStdin()) - if err != nil { - return fmt.Errorf("read stdin: %w", err) - } - } - result, err := d.guestSessionSend(cmd.Context(), layout.SocketPath, api.GuestSessionSendParams{ - VMIDOrName: args[0], - SessionIDOrName: args[1], - Payload: payload, - }) - if err != nil { - return err - } - _, err = fmt.Fprintf(cmd.OutOrStdout(), "sent %d bytes to session %s\n", result.BytesWritten, result.Session.Name) - return err - }, - } - cmd.Flags().StringVar(&message, "message", "", "JSONL message to send; a trailing newline is appended if absent") - return cmd -} - -func parseKeyValuePairs(values []string) (map[string]string, error) { - if len(values) == 0 { - return nil, nil - } - result := make(map[string]string, len(values)) - for _, value := range values { - key, raw, ok := strings.Cut(value, "=") - if !ok || strings.TrimSpace(key) == "" { - return nil, fmt.Errorf("invalid key=value entry %q", value) - } - result[strings.TrimSpace(key)] = raw - } - return result, nil -} - -func runGuestSessionAttach(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, socketPath string) error { - conn, err := (&net.Dialer{}).DialContext(ctx, "unix", socketPath) - if err != nil { - return err - } - defer conn.Close() - writeErrCh := make(chan error, 1) - go func() { - writeErrCh <- streamGuestSessionAttachInput(conn, stdin) - }() - for { - channel, payload, err := sessionstream.ReadFrame(conn) - if err != nil { - if ctx.Err() != nil { - return ctx.Err() - } - if errors.Is(err, io.EOF) { - return nil - } - return err - } - switch channel { - case sessionstream.ChannelStdout: - if _, err := stdout.Write(payload); err != nil { - return err - } - case sessionstream.ChannelStderr: - if _, err := stderr.Write(payload); err != nil { - return err - } - case sessionstream.ChannelControl: - message, err := sessionstream.ReadControl(payload) - if err != nil { - return err - } - switch message.Type { - case "exit": - if message.ExitCode != nil && *message.ExitCode != 0 { - return fmt.Errorf("guest session exited with code %d", *message.ExitCode) - } - return nil - case "error": - if strings.TrimSpace(message.Error) == "" { - return errors.New("guest session attach failed") - } - return errors.New(message.Error) - } - } - select { - case err := <-writeErrCh: - if err != nil { - return err - } - default: - } - } -} - -func streamGuestSessionAttachInput(conn net.Conn, stdin io.Reader) error { - if stdin == nil { - return sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "eof"}) - } - buffer := make([]byte, 32*1024) - for { - n, err := stdin.Read(buffer) - if n > 0 { - if writeErr := sessionstream.WriteFrame(conn, sessionstream.ChannelStdin, buffer[:n]); writeErr != nil { - return writeErr - } - } - if err != nil { - if errors.Is(err, io.EOF) { - return sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "eof"}) - } - return err - } - } -} diff --git a/internal/cli/completion.go b/internal/cli/completion.go index db42627..8032efd 100644 --- a/internal/cli/completion.go +++ b/internal/cli/completion.go @@ -21,9 +21,9 @@ import ( // - Fail silently. Completion is advisory; any error path returns an // empty suggestion list rather than propagating to the user. -// defaultCompletionLister + defaultCompletionSessionLister back the -// corresponding *deps fields; tests inject their own fakes via the -// struct instead of mutating package-level vars. +// defaultCompletionLister backs the *deps.completionLister field; +// tests inject their own fake via the struct instead of mutating +// package-level vars. func defaultCompletionLister(ctx context.Context, socketPath, method string) ([]string, error) { switch method { case "vm.list": @@ -66,20 +66,6 @@ func defaultCompletionLister(ctx context.Context, socketPath, method string) ([] return nil, nil } -func defaultCompletionSessionLister(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) { - result, err := rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: vmIDOrName}) - if err != nil { - return nil, err - } - names := make([]string, 0, len(result.Sessions)) - for _, session := range result.Sessions { - if session.Name != "" { - names = append(names, session.Name) - } - } - return names, nil -} - // daemonSocketForCompletion returns the socket path IFF the daemon is // already running. Returns "", false when no daemon is up — completion // callers use this as the bail signal. @@ -177,25 +163,3 @@ func (d *deps) completeKernelNames(cmd *cobra.Command, args []string, toComplete } return filterPrefix(names, args, toComplete), cobra.ShellCompDirectiveNoFileComp } - -// completeSessionNames handles `... ` commands: pos 0 -// completes VMs, pos 1 completes sessions owned by args[0], pos 2+ is -// silent. -func (d *deps) completeSessionNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - switch len(args) { - case 0: - return d.completeVMNames(cmd, args, toComplete) - case 1: - socket, ok := d.daemonSocketForCompletion(cmd.Context()) - if !ok { - return nil, cobra.ShellCompDirectiveNoFileComp - } - names, err := d.completionSessionLister(cmd.Context(), socket, args[0]) - if err != nil { - return nil, cobra.ShellCompDirectiveNoFileComp - } - return filterPrefix(names, nil, toComplete), cobra.ShellCompDirectiveNoFileComp - default: - return nil, cobra.ShellCompDirectiveNoFileComp - } -} diff --git a/internal/cli/completion_test.go b/internal/cli/completion_test.go index e552732..4c542c4 100644 --- a/internal/cli/completion_test.go +++ b/internal/cli/completion_test.go @@ -19,10 +19,7 @@ func stubCompletionSeams( d *deps, pingErr error, names map[string][]string, - listErr error, - sessions map[string][]string, - sessionErr error, -) { + listErr error) { t.Helper() d.daemonPing = func(ctx context.Context, socketPath string) (api.PingResult, error) { @@ -37,12 +34,6 @@ func stubCompletionSeams( } return names[method], nil } - d.completionSessionLister = func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) { - if sessionErr != nil { - return nil, sessionErr - } - return sessions[vmIDOrName], nil - } } func TestFilterPrefix(t *testing.T) { @@ -82,7 +73,7 @@ func testCmdWithCtx() *cobra.Command { func TestCompleteVMNamesHappyPath(t *testing.T) { d := defaultDeps() - stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil) + stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil) got, directive := d.completeVMNames(testCmdWithCtx(), nil, "") if directive != cobra.ShellCompDirectiveNoFileComp { @@ -95,7 +86,7 @@ func TestCompleteVMNamesHappyPath(t *testing.T) { func TestCompleteVMNamesDaemonDown(t *testing.T) { d := defaultDeps() - stubCompletionSeams(t, d, errors.New("connection refused"), nil, nil, nil, nil) + stubCompletionSeams(t, d, errors.New("connection refused"), nil, nil) got, directive := d.completeVMNames(testCmdWithCtx(), nil, "") if len(got) != 0 { @@ -108,7 +99,7 @@ func TestCompleteVMNamesDaemonDown(t *testing.T) { func TestCompleteVMNamesRPCError(t *testing.T) { d := defaultDeps() - stubCompletionSeams(t, d, nil, nil, errors.New("rpc failed"), nil, nil) + stubCompletionSeams(t, d, nil, nil, errors.New("rpc failed")) got, _ := d.completeVMNames(testCmdWithCtx(), nil, "") if len(got) != 0 { @@ -118,7 +109,7 @@ func TestCompleteVMNamesRPCError(t *testing.T) { func TestCompleteVMNamesExcludesAlreadyEntered(t *testing.T) { d := defaultDeps() - stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil) + stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil) got, _ := d.completeVMNames(testCmdWithCtx(), []string{"alpha"}, "") want := []string{"beta", "gamma"} @@ -129,7 +120,7 @@ func TestCompleteVMNamesExcludesAlreadyEntered(t *testing.T) { func TestCompleteVMNamesPrefixFilter(t *testing.T) { d := defaultDeps() - stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "alphabet"}}, nil, nil, nil) + stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "alphabet"}}, nil) got, _ := d.completeVMNames(testCmdWithCtx(), nil, "alp") want := []string{"alpha", "alphabet"} @@ -140,7 +131,7 @@ func TestCompleteVMNamesPrefixFilter(t *testing.T) { func TestCompleteVMNameOnlyAtPos0(t *testing.T) { d := defaultDeps() - stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha"}}, nil, nil, nil) + stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha"}}, nil) atPos0, _ := d.completeVMNameOnlyAtPos0(testCmdWithCtx(), nil, "") if len(atPos0) != 1 || atPos0[0] != "alpha" { @@ -155,7 +146,7 @@ func TestCompleteVMNameOnlyAtPos0(t *testing.T) { func TestCompleteImageNames(t *testing.T) { d := defaultDeps() - stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"debian-bookworm", "alpine"}}, nil, nil, nil) + stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"debian-bookworm", "alpine"}}, nil) got, _ := d.completeImageNames(testCmdWithCtx(), nil, "") if !reflect.DeepEqual(got, []string{"debian-bookworm", "alpine"}) { @@ -165,7 +156,7 @@ func TestCompleteImageNames(t *testing.T) { func TestCompleteKernelNames(t *testing.T) { d := defaultDeps() - stubCompletionSeams(t, d, nil, map[string][]string{"kernel.list": {"generic-6.12"}}, nil, nil, nil) + stubCompletionSeams(t, d, nil, map[string][]string{"kernel.list": {"generic-6.12"}}, nil) got, _ := d.completeKernelNames(testCmdWithCtx(), nil, "") if len(got) != 1 || got[0] != "generic-6.12" { @@ -175,58 +166,10 @@ func TestCompleteKernelNames(t *testing.T) { func TestCompleteImageNameOnlyAtPos0SilentAfterFirst(t *testing.T) { d := defaultDeps() - stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"alpine"}}, nil, nil, nil) + stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"alpine"}}, nil) after, _ := d.completeImageNameOnlyAtPos0(testCmdWithCtx(), []string{"alpine"}, "") if len(after) != 0 { t.Errorf("expected silence at pos 1+, got %v", after) } } - -func TestCompleteSessionNames(t *testing.T) { - d := defaultDeps() - stubCompletionSeams(t, d, - nil, - map[string][]string{"vm.list": {"devbox"}}, - nil, - map[string][]string{"devbox": {"planner", "worker"}}, - nil, - ) - - // Position 0 → VMs. - vms, _ := d.completeSessionNames(testCmdWithCtx(), nil, "") - if len(vms) != 1 || vms[0] != "devbox" { - t.Errorf("pos 0: got %v", vms) - } - - // Position 1 → sessions scoped to args[0]. - sessions, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "") - if !reflect.DeepEqual(sessions, []string{"planner", "worker"}) { - t.Errorf("pos 1: got %v", sessions) - } - - // Position 1 with prefix filter. - filtered, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "wor") - if len(filtered) != 1 || filtered[0] != "worker" { - t.Errorf("pos 1 prefix: got %v", filtered) - } - - // Position 2+ silent. - past, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox", "planner"}, "") - if len(past) != 0 { - t.Errorf("pos 2+: got %v", past) - } -} - -func TestCompleteSessionNamesDaemonDown(t *testing.T) { - d := defaultDeps() - stubCompletionSeams(t, d, errors.New("down"), nil, nil, nil, nil) - - got, directive := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "") - if len(got) != 0 { - t.Errorf("expected no suggestions when daemon down, got %v", got) - } - if directive != cobra.ShellCompDirectiveNoFileComp { - t.Errorf("directive = %d, want NoFileComp", directive) - } -} diff --git a/internal/cli/deps.go b/internal/cli/deps.go index e18bff3..5940129 100644 --- a/internal/cli/deps.go +++ b/internal/cli/deps.go @@ -31,36 +31,27 @@ import ( // validators) stay package-level because they hold no references to // external systems. type deps struct { - bangerdPath func() (string, error) - daemonExePath func(pid int) string - doctor func(ctx context.Context) (system.Report, error) - sshExec func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error - hostCommandOutput func(ctx context.Context, name string, args ...string) ([]byte, error) - vmHealth func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) - vmSSH func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error) - vmDelete func(ctx context.Context, socketPath, idOrName string) error - vmList func(ctx context.Context, socketPath string) (api.VMListResult, error) - daemonPing func(ctx context.Context, socketPath string) (api.PingResult, error) - vmCreateBegin func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error) - vmCreateStatus func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error) - vmCreateCancel func(ctx context.Context, socketPath, operationID string) error - vmPorts func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) - vmWorkspacePrepare func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) - vmWorkspaceExport func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) - guestSessionStart func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error) - guestSessionGet func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) - guestSessionList func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error) - guestSessionStop func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) - guestSessionKill func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) - guestSessionLogs func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) - guestSessionAttachBegin func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) - guestSessionSend func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) - guestWaitForSSH func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error - guestDial func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) - buildVMRunToolingPlan func(ctx context.Context, repoRoot string) toolingplan.Plan - cwd func() (string, error) - completionLister func(ctx context.Context, socketPath, method string) ([]string, error) - completionSessionLister func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) + bangerdPath func() (string, error) + daemonExePath func(pid int) string + doctor func(ctx context.Context) (system.Report, error) + sshExec func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error + hostCommandOutput func(ctx context.Context, name string, args ...string) ([]byte, error) + vmHealth func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) + vmSSH func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error) + vmDelete func(ctx context.Context, socketPath, idOrName string) error + vmList func(ctx context.Context, socketPath string) (api.VMListResult, error) + daemonPing func(ctx context.Context, socketPath string) (api.PingResult, error) + vmCreateBegin func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error) + vmCreateStatus func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error) + vmCreateCancel func(ctx context.Context, socketPath, operationID string) error + vmPorts func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) + vmWorkspacePrepare func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) + vmWorkspaceExport func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) + guestWaitForSSH func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error + guestDial func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error) + buildVMRunToolingPlan func(ctx context.Context, repoRoot string) toolingplan.Plan + cwd func() (string, error) + completionLister func(ctx context.Context, socketPath, method string) ([]string, error) } func defaultDeps() *deps { @@ -125,30 +116,6 @@ func defaultDeps() *deps { vmWorkspaceExport: func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) { return rpc.Call[api.WorkspaceExportResult](ctx, socketPath, "vm.workspace.export", params) }, - guestSessionStart: func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error) { - return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.start", params) - }, - guestSessionGet: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) { - return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.get", params) - }, - guestSessionList: func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error) { - return rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: idOrName}) - }, - guestSessionStop: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) { - return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.stop", params) - }, - guestSessionKill: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) { - return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.kill", params) - }, - guestSessionLogs: func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) { - return rpc.Call[api.GuestSessionLogsResult](ctx, socketPath, "guest.session.logs", params) - }, - guestSessionAttachBegin: func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) { - return rpc.Call[api.GuestSessionAttachBeginResult](ctx, socketPath, "guest.session.attach.begin", params) - }, - guestSessionSend: func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) { - return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params) - }, guestWaitForSSH: func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { knownHosts, _ := bangerKnownHostsPath() return guest.WaitForSSH(ctx, address, privateKeyPath, knownHosts, interval) @@ -157,9 +124,8 @@ func defaultDeps() *deps { knownHosts, _ := bangerKnownHostsPath() return guest.Dial(ctx, address, privateKeyPath, knownHosts) }, - buildVMRunToolingPlan: toolingplan.Build, - cwd: os.Getwd, - completionLister: defaultCompletionLister, - completionSessionLister: defaultCompletionSessionLister, + buildVMRunToolingPlan: toolingplan.Build, + cwd: os.Getwd, + completionLister: defaultCompletionLister, } } diff --git a/internal/cli/formatters_test.go b/internal/cli/formatters_test.go index c5833d2..65e2ba0 100644 --- a/internal/cli/formatters_test.go +++ b/internal/cli/formatters_test.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "fmt" - "reflect" "strings" "testing" @@ -52,37 +51,6 @@ func TestDashIfEmpty(t *testing.T) { } } -func TestParseKeyValuePairs(t *testing.T) { - t.Run("nil when empty", func(t *testing.T) { - got, err := parseKeyValuePairs(nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != nil { - t.Fatalf("got %v, want nil", got) - } - }) - - t.Run("parses entries", func(t *testing.T) { - got, err := parseKeyValuePairs([]string{"a=1", " b = two", "c=x=y"}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - want := map[string]string{"a": "1", "b": " two", "c": "x=y"} - if !reflect.DeepEqual(got, want) { - t.Fatalf("got %v, want %v", got, want) - } - }) - - t.Run("rejects malformed entries", func(t *testing.T) { - for _, bad := range []string{"noequals", "=noKey", " =v"} { - if _, err := parseKeyValuePairs([]string{bad}); err == nil { - t.Errorf("expected error for %q", bad) - } - } - }) -} - func TestExitCodeErrorError(t *testing.T) { e := ExitCodeError{Code: 42} got := e.Error() @@ -234,38 +202,6 @@ func TestPrintKernelCatalogTable(t *testing.T) { } } -func TestPrintGuestSessionTable(t *testing.T) { - var buf bytes.Buffer - sessions := []model.GuestSession{ - {ID: "abcdef0123456789", Name: "planner", Status: "running", Command: "pi", CWD: "/root/repo", Attachable: true}, - {ID: "short", Name: "once", Status: "exited", Command: "true", CWD: "/tmp", Attachable: false}, - } - if err := printGuestSessionTable(&buf, sessions); err != nil { - t.Fatalf("printGuestSessionTable: %v", err) - } - got := buf.String() - for _, want := range []string{"ID", "NAME", "planner", "once", "yes", "no", "pi"} { - if !strings.Contains(got, want) { - t.Errorf("output missing %q:\n%s", want, got) - } - } -} - -func TestPrintGuestSessionSummary(t *testing.T) { - var buf bytes.Buffer - session := model.GuestSession{ - ID: "id1", Name: "s", Status: "exited", Command: "true", CWD: "/root", - } - if err := printGuestSessionSummary(&buf, session); err != nil { - t.Fatalf("printGuestSessionSummary: %v", err) - } - got := buf.String() - fields := strings.Split(strings.TrimRight(got, "\n"), "\t") - if len(fields) != 5 { - t.Fatalf("expected 5 tab-separated fields, got %d: %q", len(fields), got) - } -} - func TestPrintJSON(t *testing.T) { var buf bytes.Buffer if err := printJSON(&buf, map[string]int{"a": 1, "b": 2}); err != nil { @@ -340,10 +276,6 @@ type failWriter struct{} func (failWriter) Write([]byte) (int, error) { return 0, fmt.Errorf("boom") } func TestPrintersPropagateWriteErrors(t *testing.T) { - sessions := []model.GuestSession{{ID: "id", Name: "n"}} - if err := printGuestSessionTable(failWriter{}, sessions); err == nil { - t.Error("expected write error from printGuestSessionTable") - } kernels := []api.KernelEntry{{Name: "k"}} if err := printKernelListTable(failWriter{}, kernels); err == nil { t.Error("expected write error from printKernelListTable") diff --git a/internal/cli/printers.go b/internal/cli/printers.go index 54c593b..e370c9b 100644 --- a/internal/cli/printers.go +++ b/internal/cli/printers.go @@ -3,7 +3,6 @@ package cli import ( "encoding/json" "fmt" - "io" "os" "sort" "strings" @@ -276,30 +275,6 @@ func printKernelCatalogTable(out anyWriter, entries []api.KernelCatalogEntry) er return w.Flush() } -// -- guest session printers ----------------------------------------- - -func printGuestSessionSummary(out anyWriter, session model.GuestSession) error { - _, err := fmt.Fprintf(out, "%s\t%s\t%s\t%s\t%s\n", session.ID, session.Name, session.Status, session.Command, session.CWD) - return err -} - -func printGuestSessionTable(out io.Writer, sessions []model.GuestSession) error { - tw := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0) - if _, err := fmt.Fprintln(tw, "ID\tNAME\tSTATUS\tATTACH\tCOMMAND\tCWD"); err != nil { - return err - } - for _, session := range sessions { - attach := "no" - if session.Attachable { - attach = "yes" - } - if _, err := fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n", shortID(session.ID), session.Name, session.Status, attach, session.Command, session.CWD); err != nil { - return err - } - } - return tw.Flush() -} - // -- doctor printer ------------------------------------------------- func printDoctorReport(out anyWriter, report system.Report) error { diff --git a/internal/daemon/ARCHITECTURE.md b/internal/daemon/ARCHITECTURE.md index 2a3b63e..d1943aa 100644 --- a/internal/daemon/ARCHITECTURE.md +++ b/internal/daemon/ARCHITECTURE.md @@ -32,13 +32,11 @@ owning types: - `createOps opstate.Registry[*vmCreateOperationState]` — in-flight VM create operations; owns its own lock. - `tapPool tapPool` — TAP interface pool; owns its own lock. -- `sessions sessionRegistry` — active guest session controllers; owns - its own lock. - `listener`, `vmDNS` — networking. - `vmCaps` — registered VM capability hooks. - `pullAndFlatten`, `finalizePulledRootfs`, `bundleFetch`, `requestHandler`, `guestWaitForSSH`, `guestDial`, - `waitForGuestSessionReady` — injectable seams used by tests. + `workspaceInspectRepo`, `workspaceImport` — injectable seams used by tests. ## Subpackages @@ -53,11 +51,9 @@ state beyond small test seams. | `internal/daemon/dmsnap` | Device-mapper COW snapshot create/cleanup/remove. | | `internal/daemon/fcproc` | Firecracker process primitives (bridge, tap, binary, PID, kill, wait). | | `internal/daemon/imagemgr` | Image subsystem pure helpers: validators, staging, build script gen. | -| `internal/daemon/session` | Guest-session helpers: state paths, scripts, parsing, utilities. | | `internal/daemon/workspace` | Workspace helpers: git inspection, copy prep, guest import script. | -`workspace` imports `session` for `ShellQuote` and `FormatStepError`; all -other subpackages are leaves (no other intra-daemon subpackage imports). +All subpackages are leaves — no intra-daemon subpackage imports another. ## Lock ordering @@ -73,9 +69,8 @@ time. `workspace.prepare` acquires `vmLocks[id]` just long enough to validate VM state, releases it, then acquires `workspaceLocks[id]` for the guest I/O phase. -Subsystem-local locks (`tapPool.mu`, `sessionRegistry.mu`, -`opstate.Registry` mu, `guestSessionController.attachMu` / -`writeMu`) are leaves. They do not contend with each other. +Subsystem-local locks (`tapPool.mu`, `opstate.Registry` mu) are leaves. +They do not contend with each other. Notes: diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index b06ee80..c582826 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -51,24 +51,22 @@ type Daemon struct { // See internal/daemon/vm_handles.go — persistent durable state // lives in the store, this is rebuildable from a per-VM // handles.json scratch file and OS inspection. - handles *handleCache - sessions sessionRegistry - tapPool tapPool - closing chan struct{} - once sync.Once - pid int - listener net.Listener - vmDNS *vmdns.Server - vmCaps []vmCapability - pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error) - finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error - bundleFetch func(ctx context.Context, destDir string, entry imagecat.CatEntry) (imagecat.Manifest, error) - requestHandler func(context.Context, rpc.Request) rpc.Response - guestWaitForSSH func(context.Context, string, string, time.Duration) error - guestDial func(context.Context, string, string) (guestSSHClient, error) - waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error) - workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error) - workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error + handles *handleCache + tapPool tapPool + closing chan struct{} + once sync.Once + pid int + listener net.Listener + vmDNS *vmdns.Server + vmCaps []vmCapability + pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error) + finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error + bundleFetch func(ctx context.Context, destDir string, entry imagecat.CatEntry) (imagecat.Manifest, error) + requestHandler func(context.Context, rpc.Request) rpc.Response + guestWaitForSSH func(context.Context, string, string, time.Duration) error + guestDial func(context.Context, string, string) (guestSSHClient, error) + workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error) + workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error } func Open(ctx context.Context) (d *Daemon, err error) { @@ -93,15 +91,14 @@ func Open(ctx context.Context) (d *Daemon, err error) { return nil, err } d = &Daemon{ - layout: layout, - config: cfg, - store: db, - runner: system.NewRunner(), - logger: logger, - closing: make(chan struct{}), - pid: os.Getpid(), - handles: newHandleCache(), - sessions: newSessionRegistry(), + layout: layout, + config: cfg, + store: db, + runner: system.NewRunner(), + logger: logger, + closing: make(chan struct{}), + pid: os.Getpid(), + handles: newHandleCache(), } // From here on, every failure path must run Close() so the host // state we touched (DNS listener goroutine, resolvectl routing, @@ -144,7 +141,7 @@ func (d *Daemon) Close() error { if d.listener != nil { _ = d.listener.Close() } - err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.closeGuestSessionControllers(), d.store.Close()) + err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.store.Close()) }) return err } @@ -424,62 +421,6 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response { } result, err := d.ExportVMWorkspace(ctx, params) return marshalResultOrError(result, err) - case "guest.session.start": - params, err := rpc.DecodeParams[api.GuestSessionStartParams](req) - if err != nil { - return rpc.NewError("bad_request", err.Error()) - } - session, err := d.StartGuestSession(ctx, params) - return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err) - case "guest.session.get": - params, err := rpc.DecodeParams[api.GuestSessionRefParams](req) - if err != nil { - return rpc.NewError("bad_request", err.Error()) - } - session, err := d.GetGuestSession(ctx, params) - return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err) - case "guest.session.list": - params, err := rpc.DecodeParams[api.VMRefParams](req) - if err != nil { - return rpc.NewError("bad_request", err.Error()) - } - sessions, err := d.ListGuestSessions(ctx, params) - return marshalResultOrError(api.GuestSessionListResult{Sessions: sessions}, err) - case "guest.session.stop": - params, err := rpc.DecodeParams[api.GuestSessionRefParams](req) - if err != nil { - return rpc.NewError("bad_request", err.Error()) - } - session, err := d.StopGuestSession(ctx, params) - return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err) - case "guest.session.kill": - params, err := rpc.DecodeParams[api.GuestSessionRefParams](req) - if err != nil { - return rpc.NewError("bad_request", err.Error()) - } - session, err := d.KillGuestSession(ctx, params) - return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err) - case "guest.session.logs": - params, err := rpc.DecodeParams[api.GuestSessionLogsParams](req) - if err != nil { - return rpc.NewError("bad_request", err.Error()) - } - result, err := d.GuestSessionLogs(ctx, params) - return marshalResultOrError(result, err) - case "guest.session.attach.begin": - params, err := rpc.DecodeParams[api.GuestSessionAttachBeginParams](req) - if err != nil { - return rpc.NewError("bad_request", err.Error()) - } - result, err := d.BeginGuestSessionAttach(ctx, params) - return marshalResultOrError(result, err) - case "guest.session.send": - params, err := rpc.DecodeParams[api.GuestSessionSendParams](req) - if err != nil { - return rpc.NewError("bad_request", err.Error()) - } - result, err := d.SendToGuestSession(ctx, params) - return marshalResultOrError(result, err) case "image.list": images, err := d.store.ListImages(ctx) return marshalResultOrError(api.ImageListResult{Images: images}, err) diff --git a/internal/daemon/doc.go b/internal/daemon/doc.go index 0e696c2..f83aeab 100644 --- a/internal/daemon/doc.go +++ b/internal/daemon/doc.go @@ -1,8 +1,8 @@ // Package daemon hosts the Banger daemon process. // // The daemon exposes a JSON-RPC endpoint over a Unix socket. It owns VM -// lifecycle, image management, guest sessions, host networking bootstrap, -// and state persistence via internal/store. +// lifecycle, image management, host networking bootstrap, and state +// persistence via internal/store. // // The package is organised into cohesive groups. Pure stateless helpers for // each group have been lifted into subpackages; orchestrator methods @@ -18,13 +18,9 @@ // internal/daemon/imagemgr Image subsystem helpers: path validation, // artifact staging, guest provisioning script // generator, metadata. -// internal/daemon/session Guest-session helpers: state paths, runner -// / inspect / signal scripts, state snapshot -// parsing, launch helpers, ShellQuote, -// FormatStepError. // internal/daemon/workspace Workspace helpers: git repo inspection, // shallow copy prep, guest-side import, -// finalize script generation. +// finalize script generation, shell quoting. // // VM lifecycle (in this package): // @@ -50,11 +46,7 @@ // // Guest interaction (in this package): // -// guest_sessions.go dialGuest, waitForGuestSSH, refresh/inspect -// session_lifecycle.go Start/Stop/Kill/Get/List/signal orchestrators -// session_attach.go BeginGuestSessionAttach + bridge/forward/watch -// session_stream.go GuestSessionLogs, SendToGuestSession -// session_controller.go guestSessionController, sessionRegistry +// guest_ssh.go guestSSHClient, dialGuest, waitForGuestSSH // ssh_client_config.go daemon-managed SSH client key material // workspace.go ExportVMWorkspace, PrepareVMWorkspace // @@ -73,10 +65,8 @@ // // Lock ordering: // -// vmLocks[id] → {createVMMu, imageOpsMu} → subsystem-local locks +// vmLocks[id] → workspaceLocks[id] → {createVMMu, imageOpsMu} → subsystem-local locks // -// Subsystem-local locks live on their owning type (tapPool.mu, -// sessionRegistry.mu, opstate.Registry mu, guestSessionController.attachMu / -// writeMu) and do not contend with each other. See ARCHITECTURE.md for -// details. +// Subsystem-local locks (tapPool.mu, opstate.Registry mu) are leaves and +// do not contend with each other. See ARCHITECTURE.md for details. package daemon diff --git a/internal/daemon/fake_firecracker_test.go b/internal/daemon/fake_firecracker_test.go new file mode 100644 index 0000000..2ad1555 --- /dev/null +++ b/internal/daemon/fake_firecracker_test.go @@ -0,0 +1,26 @@ +package daemon + +import ( + "fmt" + "os/exec" + "testing" +) + +// startFakeFirecracker launches a bash sleep-loop rewritten to match +// the firecracker command line a real process would expose, so +// reconcile / handle-cache paths that grep /proc//cmdline accept +// it as a firecracker process. Killed on test cleanup. +func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd { + t.Helper() + cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock)) + if err := cmd.Start(); err != nil { + t.Fatalf("start fake firecracker: %v", err) + } + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + _, _ = cmd.Process.Wait() + } + }) + return cmd +} diff --git a/internal/daemon/guest_sessions.go b/internal/daemon/guest_sessions.go deleted file mode 100644 index bc59742..0000000 --- a/internal/daemon/guest_sessions.go +++ /dev/null @@ -1,142 +0,0 @@ -package daemon - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "net" - "os" - "path/filepath" - "strings" - "time" - - "banger/internal/daemon/session" - "banger/internal/guest" - "banger/internal/model" - "banger/internal/system" -) - -type guestSSHClient interface { - Close() error - RunScript(context.Context, string, io.Writer) error - RunScriptOutput(context.Context, string) ([]byte, error) - UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error - StreamTar(context.Context, string, string, io.Writer) error - StreamTarEntries(context.Context, string, []string, string, io.Writer) error -} - -func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval time.Duration) error { - if d != nil && d.guestWaitForSSH != nil { - return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval) - } - return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval) -} - -func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) { - if d != nil && d.guestDial != nil { - return d.guestDial(ctx, address, d.config.SSHKeyPath) - } - return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath) -} - -func (d *Daemon) waitForGuestSessionReadyHook(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) { - if d != nil && d.waitForGuestSessionReady != nil { - return d.waitForGuestSessionReady(ctx, vm, s) - } - return d.waitForGuestSessionReadyDefault(ctx, vm, s) -} - -func (d *Daemon) waitForGuestSessionReadyDefault(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) { - for { - updated, err := d.refreshGuestSession(ctx, vm, s) - if err == nil { - s = updated - if s.GuestPID != 0 || s.ExitCode != nil || s.Status == model.GuestSessionStatusRunning || s.Status == model.GuestSessionStatusFailed || s.Status == model.GuestSessionStatusExited { - return s, nil - } - } - select { - case <-ctx.Done(): - return s, ctx.Err() - case <-time.After(100 * time.Millisecond): - } - } -} - -func (d *Daemon) refreshGuestSession(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) { - if s.Status != model.GuestSessionStatusStarting && s.Status != model.GuestSessionStatusRunning && s.Status != model.GuestSessionStatusStopping { - return s, nil - } - snapshot, err := d.inspectGuestSessionState(ctx, vm, s) - if err != nil { - return s, err - } - original := s - session.ApplyStateSnapshot(&s, snapshot, d.vmAlive(vm)) - if session.StateChanged(original, s) { - s.UpdatedAt = model.Now() - if err := d.store.UpsertGuestSession(ctx, s); err != nil { - return s, err - } - } - return s, nil -} - -func (d *Daemon) inspectGuestSessionState(ctx context.Context, vm model.VMRecord, s model.GuestSession) (session.StateSnapshot, error) { - if d.vmAlive(vm) { - client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath) - if err != nil { - return session.StateSnapshot{}, err - } - defer client.Close() - var output bytes.Buffer - if err := client.RunScript(ctx, session.InspectScript(s.ID), &output); err != nil { - return session.StateSnapshot{}, session.FormatStepError("inspect guest session state", err, output.String()) - } - return session.ParseState(output.String()) - } - return d.inspectGuestSessionStateFromWorkDisk(ctx, vm, s.ID) -} - -func (d *Daemon) inspectGuestSessionStateFromWorkDisk(ctx context.Context, vm model.VMRecord, sessionID string) (session.StateSnapshot, error) { - runner := d.runner - if runner == nil { - runner = system.NewRunner() - } - workMount, cleanup, err := system.MountTempDir(ctx, runner, vm.Runtime.WorkDiskPath, false) - if err != nil { - return session.StateSnapshot{}, err - } - defer cleanup() - stateDir := filepath.Join(workMount, session.RelativeStateDir(sessionID)) - return session.InspectStateFromDir(stateDir) -} - -func (d *Daemon) findGuestSession(ctx context.Context, vmID, idOrName string) (model.GuestSession, error) { - if strings.TrimSpace(idOrName) == "" { - return model.GuestSession{}, errors.New("session id or name is required") - } - if s, err := d.store.GetGuestSession(ctx, vmID, idOrName); err == nil { - return s, nil - } - sessions, err := d.store.ListGuestSessionsByVM(ctx, vmID) - if err != nil { - return model.GuestSession{}, err - } - matches := make([]model.GuestSession, 0, 1) - for _, s := range sessions { - if strings.HasPrefix(s.ID, idOrName) || strings.HasPrefix(s.Name, idOrName) { - matches = append(matches, s) - } - } - switch len(matches) { - case 0: - return model.GuestSession{}, fmt.Errorf("session %q not found", idOrName) - case 1: - return matches[0], nil - default: - return model.GuestSession{}, fmt.Errorf("multiple sessions match %q", idOrName) - } -} diff --git a/internal/daemon/guest_sessions_test.go b/internal/daemon/guest_sessions_test.go deleted file mode 100644 index bbe5f13..0000000 --- a/internal/daemon/guest_sessions_test.go +++ /dev/null @@ -1,492 +0,0 @@ -package daemon - -import ( - "context" - "fmt" - "io" - "log/slog" - "os" - "os/exec" - "path/filepath" - "strings" - "testing" - "time" - - "banger/internal/api" - sess "banger/internal/daemon/session" - "banger/internal/model" - "banger/internal/store" -) - -type fakeGuestSSHClient struct { - t *testing.T - existingDirs map[string]bool - closed bool -} - -func (f *fakeGuestSSHClient) Close() error { - f.closed = true - return nil -} - -func (f *fakeGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error { - f.t.Helper() - switch { - case strings.Contains(script, `\n`): - return fmt.Errorf("script still contains escaped newline literals: %q", script) - case strings.Contains(script, `echo "missing cwd: $DIR"`): - if strings.Contains(script, "DIR='/root/repo'\n") && f.existingDirs["/root/repo"] { - return nil - } - return fmt.Errorf("missing cwd") - case strings.Contains(script, "check_command() {"): - return nil - case strings.Contains(script, `git config --global --add safe.directory "$DIR"`): - if strings.Contains(script, "DIR='/root/repo'\n") { - f.existingDirs["/root/repo"] = true - return nil - } - return fmt.Errorf("workspace finalize used unexpected guest path") - case strings.Contains(script, "chmod -R a-w"): - if f.existingDirs["/root/repo"] { - return nil - } - return fmt.Errorf("workspace path missing during readonly chmod") - case strings.Contains(script, "nohup bash "): - return nil - default: - return nil - } -} - -func (f *fakeGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) { - return nil, nil -} - -func (f *fakeGuestSSHClient) UploadFile(_ context.Context, _ string, _ os.FileMode, _ []byte, _ io.Writer) error { - return nil -} - -func (f *fakeGuestSSHClient) StreamTar(_ context.Context, _ string, command string, _ io.Writer) error { - if strings.Contains(command, "/root/repo") { - f.existingDirs["/root/repo"] = true - return nil - } - return fmt.Errorf("unexpected StreamTar command: %s", command) -} - -func (f *fakeGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, command string, _ io.Writer) error { - if strings.Contains(command, "/root/repo") { - f.existingDirs["/root/repo"] = true - return nil - } - return fmt.Errorf("unexpected StreamTarEntries command: %s", command) -} - -func TestSendToGuestSession_HappyPath(t *testing.T) { - t.Parallel() - ctx := context.Background() - db := openDaemonStore(t) - - apiSock := filepath.Join(t.TempDir(), "fc.sock") - firecracker := startFakeFirecracker(t, apiSock) - - vm := testVM("sendbox", "image-send", "172.16.0.88") - vm.State = model.VMStateRunning - vm.Runtime.State = model.VMStateRunning - vm.Runtime.APISockPath = apiSock - upsertDaemonVM(t, ctx, db, vm) - - session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning) - if err := db.UpsertGuestSession(ctx, session); err != nil { - t.Fatalf("UpsertGuestSession: %v", err) - } - - fake := &recordingGuestSSHClient{} - d := newSendTestDaemon(t, db, fake) - d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) - - payload := []byte(`{"type":"abort"}` + "\n") - result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ - VMIDOrName: vm.Name, - SessionIDOrName: session.Name, - Payload: payload, - }) - if err != nil { - t.Fatalf("SendToGuestSession: %v", err) - } - if result.BytesWritten != len(payload) { - t.Fatalf("BytesWritten = %d, want %d", result.BytesWritten, len(payload)) - } - if result.Session.ID != session.ID { - t.Fatalf("Session.ID = %q, want %q", result.Session.ID, session.ID) - } - if len(fake.uploadedFiles) != 1 { - t.Fatalf("UploadFile call count = %d, want 1", len(fake.uploadedFiles)) - } - for path, data := range fake.uploadedFiles { - if !strings.HasPrefix(path, "/tmp/banger-send-") { - t.Fatalf("upload path = %q, want /tmp/banger-send-... prefix", path) - } - if string(data) != string(payload) { - t.Fatalf("upload data = %q, want %q", data, payload) - } - } - if len(fake.ranScripts) != 1 { - t.Fatalf("RunScript call count = %d, want 1", len(fake.ranScripts)) - } - script := fake.ranScripts[0] - pipePath := sess.StdinPipePath(session.ID) - if !strings.Contains(script, "cat ") { - t.Fatalf("send script missing cat command: %q", script) - } - if !strings.Contains(script, pipePath) { - t.Fatalf("send script missing pipe path %q: %q", pipePath, script) - } - if !strings.Contains(script, "rm -f ") { - t.Fatalf("send script missing rm cleanup: %q", script) - } -} - -func TestSendToGuestSession_EmptyPayload(t *testing.T) { - t.Parallel() - ctx := context.Background() - db := openDaemonStore(t) - - apiSock := filepath.Join(t.TempDir(), "fc.sock") - firecracker := startFakeFirecracker(t, apiSock) - - vm := testVM("sendbox-empty", "image-send", "172.16.0.89") - vm.State = model.VMStateRunning - vm.Runtime.State = model.VMStateRunning - vm.Runtime.APISockPath = apiSock - upsertDaemonVM(t, ctx, db, vm) - - session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning) - if err := db.UpsertGuestSession(ctx, session); err != nil { - t.Fatalf("UpsertGuestSession: %v", err) - } - - fake := &recordingGuestSSHClient{} - d := newSendTestDaemon(t, db, fake) - d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) - - result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ - VMIDOrName: vm.Name, - SessionIDOrName: session.Name, - Payload: nil, - }) - if err != nil { - t.Fatalf("SendToGuestSession(empty): %v", err) - } - if result.BytesWritten != 0 { - t.Fatalf("BytesWritten = %d, want 0", result.BytesWritten) - } - if fake.dialCount != 0 { - t.Fatalf("SSH dial count = %d, want 0 for empty payload", fake.dialCount) - } -} - -func TestSendToGuestSession_NotPipeMode(t *testing.T) { - t.Parallel() - ctx := context.Background() - db := openDaemonStore(t) - - vm := testVM("sendbox-closed", "image-send", "172.16.0.90") - vm.State = model.VMStateRunning - upsertDaemonVM(t, ctx, db, vm) - - session := testGuestSession(vm.ID, model.GuestSessionStdinClosed, model.GuestSessionStatusRunning) - if err := db.UpsertGuestSession(ctx, session); err != nil { - t.Fatalf("UpsertGuestSession: %v", err) - } - - d := &Daemon{store: db} - _, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ - VMIDOrName: vm.Name, - SessionIDOrName: session.Name, - Payload: []byte("hello\n"), - }) - if err == nil || !strings.Contains(err.Error(), "stdin pipe") { - t.Fatalf("error = %v, want 'stdin pipe' error", err) - } -} - -func TestSendToGuestSession_SessionNotRunning(t *testing.T) { - t.Parallel() - ctx := context.Background() - db := openDaemonStore(t) - - vm := testVM("sendbox-failed", "image-send", "172.16.0.91") - vm.State = model.VMStateRunning - upsertDaemonVM(t, ctx, db, vm) - - session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusFailed) - if err := db.UpsertGuestSession(ctx, session); err != nil { - t.Fatalf("UpsertGuestSession: %v", err) - } - - d := &Daemon{store: db} - _, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ - VMIDOrName: vm.Name, - SessionIDOrName: session.Name, - Payload: []byte("hello\n"), - }) - if err == nil || !strings.Contains(err.Error(), "not running") { - t.Fatalf("error = %v, want 'not running' error", err) - } -} - -func TestSendToGuestSession_VMNotRunning(t *testing.T) { - t.Parallel() - ctx := context.Background() - db := openDaemonStore(t) - - vm := testVM("sendbox-stopped", "image-send", "172.16.0.92") - vm.State = model.VMStateStopped - upsertDaemonVM(t, ctx, db, vm) - - session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning) - if err := db.UpsertGuestSession(ctx, session); err != nil { - t.Fatalf("UpsertGuestSession: %v", err) - } - - d := &Daemon{store: db} - _, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{ - VMIDOrName: vm.Name, - SessionIDOrName: session.Name, - Payload: []byte("hello\n"), - }) - if err == nil || !strings.Contains(err.Error(), "not running") { - t.Fatalf("error = %v, want 'not running' error", err) - } -} - -// recordingGuestSSHClient captures UploadFile and RunScript calls for send tests. -type recordingGuestSSHClient struct { - dialCount int - uploadedFiles map[string][]byte - ranScripts []string -} - -func (r *recordingGuestSSHClient) Close() error { return nil } - -func (r *recordingGuestSSHClient) UploadFile(_ context.Context, path string, _ os.FileMode, data []byte, _ io.Writer) error { - if r.uploadedFiles == nil { - r.uploadedFiles = make(map[string][]byte) - } - copy := make([]byte, len(data)) - _ = copy[:len(data):len(data)] - for i, b := range data { - copy[i] = b - } - r.uploadedFiles[path] = copy - return nil -} - -func (r *recordingGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error { - r.ranScripts = append(r.ranScripts, script) - return nil -} - -func (r *recordingGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) { - return nil, nil -} - -func (r *recordingGuestSSHClient) StreamTar(_ context.Context, _ string, _ string, _ io.Writer) error { - return nil -} - -func (r *recordingGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, _ string, _ io.Writer) error { - return nil -} - -func newSendTestDaemon(t *testing.T, db *store.Store, fake *recordingGuestSSHClient) *Daemon { - t.Helper() - d := &Daemon{ - store: db, - config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")}, - logger: slog.New(slog.NewTextHandler(io.Discard, nil)), - } - d.guestDial = func(_ context.Context, _ string, _ string) (guestSSHClient, error) { - fake.dialCount++ - return fake, nil - } - return d -} - -func testGuestSession(vmID string, stdinMode model.GuestSessionStdinMode, status model.GuestSessionStatus) model.GuestSession { - now := model.Now() - id := vmID + "-sess-id" - return model.GuestSession{ - ID: id, - VMID: vmID, - Name: vmID + "-sess", - Backend: sess.BackendSSH, - Command: "pi", - Args: []string{"--mode", "rpc"}, - CWD: "/root/repo", - StdinMode: stdinMode, - Status: status, - GuestStateDir: sess.StateDir(id), - StdoutLogPath: sess.StdoutLogPath(id), - StderrLogPath: sess.StderrLogPath(id), - Attachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning, - Reattachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning, - CreatedAt: now, - UpdatedAt: now, - } -} - -func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd { - t.Helper() - cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock)) - if err := cmd.Start(); err != nil { - t.Fatalf("start fake firecracker: %v", err) - } - t.Cleanup(func() { - if cmd.Process != nil { - _ = cmd.Process.Kill() - _, _ = cmd.Process.Wait() - } - }) - return cmd -} - -func TestGuestSessionPreflightScriptsUseRealNewlines(t *testing.T) { - t.Parallel() - - cwdScript := sess.CWDPreflightScript("/root/repo") - if strings.Contains(cwdScript, `\n`) { - t.Fatalf("cwd preflight script still contains escaped newline literals: %q", cwdScript) - } - if !strings.Contains(cwdScript, "\n") { - t.Fatalf("cwd preflight script should contain real newlines: %q", cwdScript) - } - - commandScript := sess.CommandPreflightScript([]string{"git", "pi"}) - if strings.Contains(commandScript, `\n`) { - t.Fatalf("command preflight script still contains escaped newline literals: %q", commandScript) - } - if !strings.Contains(commandScript, "\n") { - t.Fatalf("command preflight script should contain real newlines: %q", commandScript) - } - - attachInput := sess.AttachInputCommand("session-id") - if strings.Contains(attachInput, `\n`) { - t.Fatalf("attach input command still contains escaped newline literals: %q", attachInput) - } - - attachTail := sess.AttachTailCommand("/tmp/stdout.log") - if strings.Contains(attachTail, `\n`) { - t.Fatalf("attach tail command still contains escaped newline literals: %q", attachTail) - } -} - -func TestPrepareWorkspaceThenStartGuestSessionPassesCWDPreflight(t *testing.T) { - ctx := context.Background() - db := openDaemonStore(t) - - repoRoot := filepath.Join(t.TempDir(), "repo") - if err := os.MkdirAll(repoRoot, 0o755); err != nil { - t.Fatalf("MkdirAll(repoRoot): %v", err) - } - if err := os.WriteFile(filepath.Join(repoRoot, "README.md"), []byte("hello\n"), 0o644); err != nil { - t.Fatalf("WriteFile(README.md): %v", err) - } - runGit := func(args ...string) { - t.Helper() - cmd := exec.Command("git", append([]string{"-C", repoRoot}, args...)...) - output, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("git %s: %v\n%s", strings.Join(args, " "), err, output) - } - } - runGit("init", "-b", "main") - runGit("config", "user.name", "Test User") - runGit("config", "user.email", "test@example.com") - runGit("add", ".") - runGit("commit", "-m", "initial") - - apiSock := filepath.Join(t.TempDir(), "fc.sock") - firecracker := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock)) - if err := firecracker.Start(); err != nil { - t.Fatalf("start fake firecracker: %v", err) - } - t.Cleanup(func() { - if firecracker.Process != nil { - _ = firecracker.Process.Kill() - _, _ = firecracker.Process.Wait() - } - }) - - vm := testVM("pi-devbox", "image-pi", "172.16.0.77") - vm.State = model.VMStateRunning - vm.Runtime.State = model.VMStateRunning - vm.Runtime.APISockPath = apiSock - upsertDaemonVM(t, ctx, db, vm) - - fakeClient := &fakeGuestSSHClient{t: t, existingDirs: map[string]bool{}} - d := &Daemon{ - store: db, - config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")}, - logger: slog.New(slog.NewTextHandler(io.Discard, nil)), - } - d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid}) - d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil } - d.guestDial = func(context.Context, string, string) (guestSSHClient, error) { return fakeClient, nil } - d.waitForGuestSessionReady = func(_ context.Context, _ model.VMRecord, session model.GuestSession) (model.GuestSession, error) { - now := model.Now() - session.Status = model.GuestSessionStatusRunning - session.GuestPID = 4242 - session.StartedAt = now - session.UpdatedAt = now - session.Attachable = session.StdinMode == model.GuestSessionStdinPipe - session.Reattachable = session.StdinMode == model.GuestSessionStdinPipe - return session, nil - } - - workspace, err := d.PrepareVMWorkspace(ctx, api.VMWorkspacePrepareParams{ - IDOrName: vm.Name, - SourcePath: repoRoot, - GuestPath: "/root/repo", - ReadOnly: true, - }) - if err != nil { - t.Fatalf("PrepareVMWorkspace: %v", err) - } - if workspace.GuestPath != "/root/repo" { - t.Fatalf("PrepareVMWorkspace guest path = %q, want /root/repo", workspace.GuestPath) - } - if !fakeClient.existingDirs["/root/repo"] { - t.Fatalf("workspace prepare did not mark /root/repo as present in fake guest") - } - - session, err := d.StartGuestSession(ctx, api.GuestSessionStartParams{ - VMIDOrName: vm.Name, - Name: "testpi", - Command: "pi", - Args: []string{"--mode", "rpc", "--no-session"}, - CWD: "/root/repo", - StdinMode: string(model.GuestSessionStdinPipe), - RequiredCommands: []string{"git"}, - }) - if err != nil { - t.Fatalf("StartGuestSession: %v", err) - } - if session.Status != model.GuestSessionStatusRunning { - t.Fatalf("session status = %q, want %q", session.Status, model.GuestSessionStatusRunning) - } - if session.LaunchStage != "" { - t.Fatalf("session launch stage = %q, want empty", session.LaunchStage) - } - if session.LaunchMessage != "" { - t.Fatalf("session launch message = %q, want empty", session.LaunchMessage) - } - if session.GuestPID == 0 { - t.Fatalf("session guest pid = 0, want non-zero") - } - if !session.Attachable { - t.Fatalf("session should be attachable for pipe stdin mode") - } -} diff --git a/internal/daemon/guest_ssh.go b/internal/daemon/guest_ssh.go new file mode 100644 index 0000000..de05991 --- /dev/null +++ b/internal/daemon/guest_ssh.go @@ -0,0 +1,35 @@ +package daemon + +import ( + "context" + "io" + "os" + "time" + + "banger/internal/guest" +) + +// guestSSHClient is the narrow guest-SSH surface the daemon uses for +// workspace prepare / export and ad-hoc guest interactions. +type guestSSHClient interface { + Close() error + RunScript(context.Context, string, io.Writer) error + RunScriptOutput(context.Context, string) ([]byte, error) + UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error + StreamTar(context.Context, string, string, io.Writer) error + StreamTarEntries(context.Context, string, []string, string, io.Writer) error +} + +func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval time.Duration) error { + if d != nil && d.guestWaitForSSH != nil { + return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval) + } + return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval) +} + +func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) { + if d != nil && d.guestDial != nil { + return d.guestDial(ctx, address, d.config.SSHKeyPath) + } + return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath) +} diff --git a/internal/daemon/open_close_test.go b/internal/daemon/open_close_test.go index 57d70e4..1fb4d3a 100644 --- a/internal/daemon/open_close_test.go +++ b/internal/daemon/open_close_test.go @@ -26,10 +26,9 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) { name: "only store + closing channel (early failure)", build: func(t *testing.T) *Daemon { return &Daemon{ - store: openDaemonStore(t), - closing: make(chan struct{}), - sessions: newSessionRegistry(), - logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + store: openDaemonStore(t), + closing: make(chan struct{}), + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } }, verify: func(t *testing.T, d *Daemon) { @@ -49,11 +48,10 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) { t.Fatalf("vmdns.New: %v", err) } return &Daemon{ - store: openDaemonStore(t), - closing: make(chan struct{}), - sessions: newSessionRegistry(), - vmDNS: server, - logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + store: openDaemonStore(t), + closing: make(chan struct{}), + vmDNS: server, + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } }, verify: func(t *testing.T, d *Daemon) { @@ -86,11 +84,10 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) { // returns and also calls Close afterwards, both paths must survive. func TestCloseIdempotentUnderConcurrency(t *testing.T) { d := &Daemon{ - store: openDaemonStore(t), - closing: make(chan struct{}), - sessions: newSessionRegistry(), - logger: slog.New(slog.NewTextHandler(io.Discard, nil)), - config: model.DaemonConfig{BridgeName: ""}, + store: openDaemonStore(t), + closing: make(chan struct{}), + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + config: model.DaemonConfig{BridgeName: ""}, } var count atomic.Int32 diff --git a/internal/daemon/session/session.go b/internal/daemon/session/session.go deleted file mode 100644 index bb42743..0000000 --- a/internal/daemon/session/session.go +++ /dev/null @@ -1,521 +0,0 @@ -// Package session contains the pure helpers of the guest-session subsystem: -// bash script generators, on-guest state path helpers, state snapshot -// parsing, and small utilities like ShellQuote and FormatStepError. -// -// The orchestrator methods (StartGuestSession, BeginGuestSessionAttach, -// etc.) stay on *daemon.Daemon and compose these helpers. -package session - -import ( - "bufio" - "errors" - "fmt" - "os" - "path/filepath" - "sort" - "strconv" - "strings" - "syscall" - - "banger/internal/model" - "banger/internal/system" - - "golang.org/x/crypto/ssh" -) - -// Constants shared between orchestration and helpers. -const ( - BackendSSH = "ssh" - AttachBackendNone = "none" - AttachBackendSSHBridge = "ssh_rehydratable" - AttachModeExclusive = "exclusive" - TransportUnixSocket = "unix_socket" - StateRoot = "/root/.local/state/banger/sessions" - LogTailLineDefault = 200 -) - -// StateSnapshot is the decoded per-session state as read from the guest. -type StateSnapshot struct { - Status string - GuestPID int - ExitCode *int - Alive bool - LastError string -} - -// -- Guest filesystem paths ------------------------------------------------- - -func StateDir(id string) string { - return filepath.ToSlash(filepath.Join(StateRoot, id)) -} - -func RelativeStateDir(id string) string { - return strings.TrimPrefix(StateDir(id), "/root/") -} - -func ScriptPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "run.sh")) } -func PIDPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "pid")) } -func MonitorPIDPath(id string) string { - return filepath.ToSlash(filepath.Join(StateDir(id), "monitor_pid")) -} -func ExitCodePath(id string) string { - return filepath.ToSlash(filepath.Join(StateDir(id), "exit_code")) -} -func StdinPipePath(id string) string { - return filepath.ToSlash(filepath.Join(StateDir(id), "stdin.pipe")) -} -func StdinKeepalivePIDPath(id string) string { - return filepath.ToSlash(filepath.Join(StateDir(id), "stdin_keepalive.pid")) -} -func StatusPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "status")) } -func ErrorPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "error")) } -func StdoutLogPath(id string) string { - return filepath.ToSlash(filepath.Join(StateDir(id), "stdout.log")) -} -func StderrLogPath(id string) string { - return filepath.ToSlash(filepath.Join(StateDir(id), "stderr.log")) -} - -// -- Script generators ------------------------------------------------------ - -// Script returns the bash runner installed into the guest for session. It -// sets up state/log paths, optional stdin fifo, and wait-loop around the -// user command. -func Script(sess model.GuestSession) string { - var script strings.Builder - script.WriteString("set -euo pipefail\n") - fmt.Fprintf(&script, "STATE_DIR=%s\n", ShellQuote(sess.GuestStateDir)) - fmt.Fprintf(&script, "STDOUT_LOG=%s\n", ShellQuote(sess.StdoutLogPath)) - fmt.Fprintf(&script, "STDERR_LOG=%s\n", ShellQuote(sess.StderrLogPath)) - fmt.Fprintf(&script, "PID_FILE=%s\n", ShellQuote(PIDPath(sess.ID))) - fmt.Fprintf(&script, "MONITOR_PID_FILE=%s\n", ShellQuote(MonitorPIDPath(sess.ID))) - fmt.Fprintf(&script, "EXIT_FILE=%s\n", ShellQuote(ExitCodePath(sess.ID))) - fmt.Fprintf(&script, "STATUS_FILE=%s\n", ShellQuote(StatusPath(sess.ID))) - fmt.Fprintf(&script, "ERROR_FILE=%s\n", ShellQuote(ErrorPath(sess.ID))) - fmt.Fprintf(&script, "STDIN_PIPE=%s\n", ShellQuote(StdinPipePath(sess.ID))) - fmt.Fprintf(&script, "STDIN_KEEPALIVE_PID_FILE=%s\n", ShellQuote(StdinKeepalivePIDPath(sess.ID))) - fmt.Fprintf(&script, "SESSION_CWD=%s\n", ShellQuote(DefaultCWD(sess.CWD))) - script.WriteString("mkdir -p \"$STATE_DIR\"\n") - script.WriteString(": >\"$STDOUT_LOG\"\n") - script.WriteString(": >\"$STDERR_LOG\"\n") - script.WriteString("rm -f \"$EXIT_FILE\" \"$ERROR_FILE\" \"$STDIN_KEEPALIVE_PID_FILE\"\n") - if sess.StdinMode == model.GuestSessionStdinPipe { - script.WriteString("rm -f \"$STDIN_PIPE\"\n") - script.WriteString("mkfifo -m 600 \"$STDIN_PIPE\"\n") - } - script.WriteString("printf '%s\\n' \"${BASHPID:-$$}\" >\"$MONITOR_PID_FILE\"\n") - script.WriteString("printf 'starting\\n' >\"$STATUS_FILE\"\n") - script.WriteString("cd \"$SESSION_CWD\"\n") - script.WriteString("exec > >(tee -a \"$STDOUT_LOG\") 2> >(tee -a \"$STDERR_LOG\" >&2)\n") - for _, line := range EnvLines(sess.Env) { - script.WriteString(line) - script.WriteByte('\n') - } - script.WriteString("COMMAND=(") - for _, value := range append([]string{sess.Command}, sess.Args...) { - script.WriteByte(' ') - script.WriteString(ShellQuote(value)) - } - script.WriteString(" )\n") - if sess.StdinMode == model.GuestSessionStdinPipe { - script.WriteString("( while :; do sleep 3600; done ) >\"$STDIN_PIPE\" &\n") - script.WriteString("keepalive=$!\n") - script.WriteString("printf '%s\\n' \"$keepalive\" >\"$STDIN_KEEPALIVE_PID_FILE\"\n") - script.WriteString("\"${COMMAND[@]}\" <\"$STDIN_PIPE\" &\n") - } else { - script.WriteString("\"${COMMAND[@]}\" &\n") - } - script.WriteString("child=$!\n") - script.WriteString("printf '%s\\n' \"$child\" >\"$PID_FILE\"\n") - script.WriteString("printf 'running\\n' >\"$STATUS_FILE\"\n") - script.WriteString("wait \"$child\"\n") - script.WriteString("rc=$?\n") - if sess.StdinMode == model.GuestSessionStdinPipe { - script.WriteString("if [ -f \"$STDIN_KEEPALIVE_PID_FILE\" ]; then kill \"$(cat \"$STDIN_KEEPALIVE_PID_FILE\")\" 2>/dev/null || true; fi\n") - } - script.WriteString("printf '%s\\n' \"$rc\" >\"$EXIT_FILE\"\n") - script.WriteString("if [ \"$rc\" -eq 0 ]; then printf 'exited\\n' >\"$STATUS_FILE\"; else printf 'failed\\n' >\"$STATUS_FILE\"; fi\n") - script.WriteString("exit \"$rc\"\n") - return script.String() -} - -// InspectScript reads the on-guest state files for sessionID and prints a -// key=value block parseable by ParseState. -func InspectScript(sessionID string) string { - var script strings.Builder - script.WriteString("set -euo pipefail\n") - fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(StateDir(sessionID))) - script.WriteString("status=''\n") - script.WriteString("pid=''\n") - script.WriteString("exit_code=''\n") - script.WriteString("last_error=''\n") - script.WriteString("alive=false\n") - script.WriteString("[ -f \"$DIR/status\" ] && status=\"$(cat \"$DIR/status\")\"\n") - script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n") - script.WriteString("[ -f \"$DIR/exit_code\" ] && exit_code=\"$(cat \"$DIR/exit_code\")\"\n") - script.WriteString("[ -f \"$DIR/error\" ] && last_error=\"$(cat \"$DIR/error\")\"\n") - script.WriteString("if [ -n \"$pid\" ] && kill -0 \"$pid\" 2>/dev/null; then alive=true; fi\n") - script.WriteString("printf 'status=%s\\n' \"$status\"\n") - script.WriteString("printf 'pid=%s\\n' \"$pid\"\n") - script.WriteString("printf 'exit=%s\\n' \"$exit_code\"\n") - script.WriteString("printf 'alive=%s\\n' \"$alive\"\n") - script.WriteString("printf 'error=%s\\n' \"$last_error\"\n") - return script.String() -} - -// SignalScript sends signal to sessionID's runner and monitor processes. -func SignalScript(sessionID, signal string) string { - var script strings.Builder - script.WriteString("set -euo pipefail\n") - fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(StateDir(sessionID))) - fmt.Fprintf(&script, "SIGNAL=%s\n", ShellQuote(signal)) - script.WriteString("pid=''\n") - script.WriteString("monitor=''\n") - script.WriteString("keepalive=''\n") - script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n") - script.WriteString("[ -f \"$DIR/monitor_pid\" ] && monitor=\"$(cat \"$DIR/monitor_pid\")\"\n") - script.WriteString("[ -f \"$DIR/stdin_keepalive.pid\" ] && keepalive=\"$(cat \"$DIR/stdin_keepalive.pid\")\"\n") - script.WriteString("printf 'stopping\\n' >\"$DIR/status\"\n") - script.WriteString("if [ -n \"$pid\" ]; then kill -${SIGNAL} \"$pid\" 2>/dev/null || true; fi\n") - script.WriteString("if [ -n \"$monitor\" ]; then kill -${SIGNAL} \"$monitor\" 2>/dev/null || true; fi\n") - script.WriteString("if [ -n \"$keepalive\" ]; then kill -${SIGNAL} \"$keepalive\" 2>/dev/null || true; fi\n") - return script.String() -} - -// CWDPreflightScript verifies cwd exists on the guest. -func CWDPreflightScript(cwd string) string { - var script strings.Builder - script.WriteString("set -euo pipefail\n") - fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(DefaultCWD(cwd))) - script.WriteString("if [ ! -d \"$DIR\" ]; then echo \"missing cwd: $DIR\"; exit 1; fi\n") - return script.String() -} - -// CommandPreflightScript verifies each command is resolvable on the guest. -func CommandPreflightScript(commands []string) string { - var script strings.Builder - script.WriteString("set -euo pipefail\n") - script.WriteString("check_command() {\n") - script.WriteString(" cmd=\"$1\"\n") - script.WriteString(" case \"$cmd\" in\n") - script.WriteString(" */*) [ -x \"$cmd\" ] || { echo \"missing command: $cmd\"; exit 1; } ;;\n") - script.WriteString(" *) command -v \"$cmd\" >/dev/null 2>&1 || { echo \"missing command: $cmd\"; exit 1; } ;;\n") - script.WriteString(" esac\n") - script.WriteString("}\n") - for _, command := range commands { - fmt.Fprintf(&script, "check_command %s\n", ShellQuote(command)) - } - return script.String() -} - -// AttachInputCommand returns the guest command that creates/opens the stdin -// fifo for sessionID and cats attach-side bytes into it. -func AttachInputCommand(sessionID string) string { - path := StdinPipePath(sessionID) - return "bash -lc " + ShellQuote(fmt.Sprintf("set -euo pipefail\n[ -p %s ] || mkfifo -m 600 %s\nexec cat > %s\n", ShellQuote(path), ShellQuote(path), ShellQuote(path))) -} - -// AttachTailCommand returns the guest command that tails a log file and -// streams new content back to the attach bridge. -func AttachTailCommand(path string) string { - return "bash -lc " + ShellQuote(fmt.Sprintf("set -euo pipefail\ntouch %s\nexec tail -n 0 -F %s 2>/dev/null\n", ShellQuote(path), ShellQuote(path))) -} - -// EnvLines returns deterministic `export KEY=value` lines for the session -// launcher, ordered by key. -func EnvLines(values map[string]string) []string { - if len(values) == 0 { - return nil - } - keys := make([]string, 0, len(values)) - for key := range values { - keys = append(keys, key) - } - sort.Strings(keys) - lines := make([]string, 0, len(keys)) - for _, key := range keys { - lines = append(lines, "export "+key+"="+ShellQuote(values[key])) - } - return lines -} - -// -- State snapshot helpers ------------------------------------------------- - -// ParseState decodes the key=value output produced by InspectScript. -func ParseState(raw string) (StateSnapshot, error) { - var snapshot StateSnapshot - scanner := bufio.NewScanner(strings.NewReader(raw)) - for scanner.Scan() { - line := scanner.Text() - key, value, ok := strings.Cut(line, "=") - if !ok { - continue - } - switch strings.TrimSpace(key) { - case "status": - snapshot.Status = strings.TrimSpace(value) - case "pid": - if pid, err := strconv.Atoi(strings.TrimSpace(value)); err == nil { - snapshot.GuestPID = pid - } - case "exit": - if exitCode, err := strconv.Atoi(strings.TrimSpace(value)); err == nil { - snapshot.ExitCode = &exitCode - } - case "alive": - snapshot.Alive = strings.TrimSpace(value) == "true" - case "error": - snapshot.LastError = strings.TrimSpace(value) - } - } - return snapshot, scanner.Err() -} - -// InspectStateFromDir reads the state files directly from stateDir (used -// when the guest is offline and we can mount the work disk from the host). -func InspectStateFromDir(stateDir string) (StateSnapshot, error) { - var snapshot StateSnapshot - statusData, _ := os.ReadFile(filepath.Join(stateDir, "status")) - snapshot.Status = strings.TrimSpace(string(statusData)) - pidData, _ := os.ReadFile(filepath.Join(stateDir, "pid")) - if pidValue, err := strconv.Atoi(strings.TrimSpace(string(pidData))); err == nil { - snapshot.GuestPID = pidValue - } - exitData, _ := os.ReadFile(filepath.Join(stateDir, "exit_code")) - if exitValue, err := strconv.Atoi(strings.TrimSpace(string(exitData))); err == nil { - snapshot.ExitCode = &exitValue - } - errorData, _ := os.ReadFile(filepath.Join(stateDir, "error")) - snapshot.LastError = strings.TrimSpace(string(errorData)) - if snapshot.GuestPID != 0 { - snapshot.Alive = ProcessAlive(snapshot.GuestPID) - } - return snapshot, nil -} - -// ApplyStateSnapshot mutates sess in place to reflect snapshot. vmRunning -// captures whether the VM is currently up so stale in-flight sessions can be -// failed when the VM is gone. -func ApplyStateSnapshot(sess *model.GuestSession, snapshot StateSnapshot, vmRunning bool) { - if sess == nil { - return - } - if snapshot.GuestPID != 0 { - sess.GuestPID = snapshot.GuestPID - } - if snapshot.LastError != "" { - sess.LastError = snapshot.LastError - } - if snapshot.ExitCode != nil { - sess.ExitCode = snapshot.ExitCode - sess.Attachable = false - sess.Reattachable = false - if sess.StartedAt.IsZero() { - sess.StartedAt = model.Now() - } - if sess.EndedAt.IsZero() { - sess.EndedAt = model.Now() - } - if *snapshot.ExitCode == 0 { - sess.Status = model.GuestSessionStatusExited - } else { - sess.Status = model.GuestSessionStatusFailed - } - return - } - if snapshot.Alive { - if sess.StartedAt.IsZero() { - sess.StartedAt = model.Now() - } - sess.Status = model.GuestSessionStatusRunning - return - } - if !vmRunning && (sess.Status == model.GuestSessionStatusStarting || sess.Status == model.GuestSessionStatusRunning || sess.Status == model.GuestSessionStatusStopping) { - sess.Status = model.GuestSessionStatusFailed - sess.Attachable = false - sess.Reattachable = false - if sess.LastError == "" { - sess.LastError = "vm is not running" - } - if sess.EndedAt.IsZero() { - sess.EndedAt = model.Now() - } - return - } - if snapshot.Status == string(model.GuestSessionStatusRunning) { - if sess.StartedAt.IsZero() { - sess.StartedAt = model.Now() - } - sess.Status = model.GuestSessionStatusRunning - } - if sess.Status == model.GuestSessionStatusRunning && sess.StdinMode == model.GuestSessionStdinPipe { - sess.Attachable = true - sess.Reattachable = true - if sess.AttachBackend == "" { - sess.AttachBackend = AttachBackendSSHBridge - } - if sess.AttachMode == "" { - sess.AttachMode = AttachModeExclusive - } - } -} - -// StateChanged reports whether any materially observable field differs -// between before and after, guiding whether to persist an update. -func StateChanged(before, after model.GuestSession) bool { - if before.Status != after.Status || before.GuestPID != after.GuestPID || before.LastError != after.LastError || before.Attachable != after.Attachable || before.Reattachable != after.Reattachable || before.AttachBackend != after.AttachBackend || before.AttachMode != after.AttachMode || before.LaunchStage != after.LaunchStage || before.LaunchMessage != after.LaunchMessage || before.LaunchRawLog != after.LaunchRawLog { - return true - } - if before.StartedAt != after.StartedAt || before.EndedAt != after.EndedAt { - return true - } - switch { - case before.ExitCode == nil && after.ExitCode == nil: - return false - case before.ExitCode == nil || after.ExitCode == nil: - return true - default: - return *before.ExitCode != *after.ExitCode - } -} - -// -- Launch helpers --------------------------------------------------------- - -// DefaultName returns a friendly session name: caller-provided if non-empty, -// otherwise `-`. -func DefaultName(id, command, explicit string) string { - if trimmed := strings.TrimSpace(explicit); trimmed != "" { - return trimmed - } - base := filepath.Base(strings.TrimSpace(command)) - if base == "." || base == string(filepath.Separator) || base == "" { - base = "session" - } - return base + "-" + system.ShortID(id) -} - -// DefaultCWD returns value if non-empty, else /root. -func DefaultCWD(value string) string { - if trimmed := strings.TrimSpace(value); trimmed != "" { - return trimmed - } - return "/root" -} - -// FailLaunch annotates sess as launch-failed with stage/message/raw log and -// returns it for persistence. -func FailLaunch(sess model.GuestSession, stage, message, rawLog string) model.GuestSession { - now := model.Now() - sess.Status = model.GuestSessionStatusFailed - sess.LastError = strings.TrimSpace(message) - sess.Attachable = false - sess.Reattachable = false - sess.LaunchStage = strings.TrimSpace(stage) - sess.LaunchMessage = strings.TrimSpace(message) - sess.LaunchRawLog = strings.TrimSpace(rawLog) - sess.UpdatedAt = now - sess.EndedAt = now - return sess -} - -// NormalizeRequiredCommands returns a de-duplicated, order-preserving list -// of required commands, with the session command first. -func NormalizeRequiredCommands(command string, extras []string) []string { - ordered := make([]string, 0, len(extras)+1) - seen := map[string]struct{}{} - appendValue := func(value string) { - trimmed := strings.TrimSpace(value) - if trimmed == "" { - return - } - if _, ok := seen[trimmed]; ok { - return - } - seen[trimmed] = struct{}{} - ordered = append(ordered, trimmed) - } - appendValue(command) - for _, extra := range extras { - appendValue(extra) - } - return ordered -} - -// -- Small utilities -------------------------------------------------------- - -// ShellQuote returns value single-quoted for bash, escaping embedded quotes. -func ShellQuote(value string) string { - return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'" -} - -// ExitCode extracts the exit status from an ssh.ExitError, returning -// (0, true) for nil errors. -func ExitCode(err error) (int, bool) { - if err == nil { - return 0, true - } - var exitErr *ssh.ExitError - if errors.As(err, &exitErr) { - return exitErr.ExitStatus(), true - } - return 0, false -} - -// CloneStringMap returns a shallow copy of values, or nil if empty. -func CloneStringMap(values map[string]string) map[string]string { - if len(values) == 0 { - return nil - } - cloned := make(map[string]string, len(values)) - for key, value := range values { - cloned[key] = value - } - return cloned -} - -// TailFileContent returns the last N lines of a file, or "" if the file is -// missing. -func TailFileContent(path string, lines int) (string, error) { - data, err := os.ReadFile(path) - if err != nil { - if os.IsNotExist(err) { - return "", nil - } - return "", err - } - if lines <= 0 { - return string(data), nil - } - parts := strings.Split(string(data), "\n") - if len(parts) <= lines { - return string(data), nil - } - return strings.Join(parts[len(parts)-lines-1:], "\n"), nil -} - -// ProcessAlive returns true if the process with pid exists. The syscallKill -// override is exposed for tests that need to simulate alive/dead processes. -func ProcessAlive(pid int) bool { - if pid <= 0 { - return false - } - return syscallKill(pid, syscall.Signal(0)) == nil -} - -// syscallKill is a test seam: tests replace it to stub process-alive checks. -var syscallKill = func(pid int, signal os.Signal) error { - proc, err := os.FindProcess(pid) - if err != nil { - return err - } - return proc.Signal(signal) -} - -// FormatStepError wraps err with an action label and trimmed on-guest log. -func FormatStepError(action string, err error, log string) error { - log = strings.TrimSpace(log) - if log == "" { - return fmt.Errorf("%s: %w", action, err) - } - return fmt.Errorf("%s: %w: %s", action, err, log) -} diff --git a/internal/daemon/session/session_test.go b/internal/daemon/session/session_test.go deleted file mode 100644 index ec093f2..0000000 --- a/internal/daemon/session/session_test.go +++ /dev/null @@ -1,440 +0,0 @@ -package session - -import ( - "errors" - "fmt" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "banger/internal/model" - - "golang.org/x/crypto/ssh" -) - -func TestRelativeStateDir(t *testing.T) { - got := RelativeStateDir("abc") - if strings.HasPrefix(got, "/root/") { - t.Fatalf("RelativeStateDir(%q) = %q, should strip /root/ prefix", "abc", got) - } - if !strings.Contains(got, "abc") { - t.Fatalf("missing session id in %q", got) - } - absolute := StateDir("abc") - if got != strings.TrimPrefix(absolute, "/root/") { - t.Fatalf("relative = %q, want %q", got, strings.TrimPrefix(absolute, "/root/")) - } -} - -func TestDefaultCWD(t *testing.T) { - if DefaultCWD("") != "/root" { - t.Error("empty should return /root") - } - if DefaultCWD(" ") != "/root" { - t.Error("whitespace should return /root") - } - if DefaultCWD("/work") != "/work" { - t.Error("explicit should pass through") - } -} - -func TestShellQuote(t *testing.T) { - if got := ShellQuote(""); got != "''" { - t.Errorf("empty: got %q, want ''", got) - } - if got := ShellQuote("x"); got != "'x'" { - t.Errorf("plain: got %q", got) - } - if got := ShellQuote("it's"); got != `'it'"'"'s'` { - t.Errorf("apostrophe: got %q", got) - } -} - -func TestExitCode(t *testing.T) { - if code, ok := ExitCode(nil); !ok || code != 0 { - t.Errorf("nil err: got (%d, %v), want (0, true)", code, ok) - } - // Build an ssh.ExitError using its real type — can't hand-construct, - // so wrap via errors.As check with a stub. - raw := &ssh.ExitError{} - if _, ok := ExitCode(raw); !ok { - t.Error("ssh.ExitError: ok should be true") - } - if _, ok := ExitCode(errors.New("bare error")); ok { - t.Error("bare error: ok should be false") - } -} - -func TestCloneStringMap(t *testing.T) { - if CloneStringMap(nil) != nil { - t.Error("nil in → nil out") - } - if CloneStringMap(map[string]string{}) != nil { - t.Error("empty in → nil out") - } - src := map[string]string{"a": "1", "b": "2"} - cloned := CloneStringMap(src) - if len(cloned) != 2 { - t.Fatalf("len = %d, want 2", len(cloned)) - } - cloned["a"] = "changed" - if src["a"] != "1" { - t.Error("mutating clone leaked back to source") - } -} - -func TestTailFileContent(t *testing.T) { - // Missing file → empty, no error. - got, err := TailFileContent(filepath.Join(t.TempDir(), "missing"), 10) - if err != nil || got != "" { - t.Errorf("missing: got (%q, %v), want ('', nil)", got, err) - } - - path := filepath.Join(t.TempDir(), "log") - lines := "one\ntwo\nthree\nfour\nfive" - if err := os.WriteFile(path, []byte(lines), 0o600); err != nil { - t.Fatalf("WriteFile: %v", err) - } - - full, err := TailFileContent(path, 0) - if err != nil || full != lines { - t.Errorf("0 lines: got (%q, %v), want (%q, nil)", full, err, lines) - } - - // Request more lines than exist → full content. - all, err := TailFileContent(path, 999) - if err != nil || all != lines { - t.Errorf("999 lines: got %q", all) - } - - last2, err := TailFileContent(path, 2) - if err != nil { - t.Fatalf("2 lines: %v", err) - } - if !strings.Contains(last2, "five") { - t.Errorf("2 lines missing last line: %q", last2) - } -} - -func TestProcessAlive(t *testing.T) { - if ProcessAlive(0) { - t.Error("pid 0 should not be alive") - } - if ProcessAlive(-1) { - t.Error("negative pid should not be alive") - } - // Swap the syscall seam. - original := syscallKill - t.Cleanup(func() { syscallKill = original }) - - syscallKill = func(pid int, signal os.Signal) error { return nil } - if !ProcessAlive(42) { - t.Error("syscallKill=nil should report alive") - } - - syscallKill = func(pid int, signal os.Signal) error { return fmt.Errorf("no such process") } - if ProcessAlive(42) { - t.Error("syscallKill error should report dead") - } -} - -func TestFormatStepError(t *testing.T) { - base := errors.New("boom") - err := FormatStepError("prepare", base, "") - if !errors.Is(err, base) { - t.Error("FormatStepError should wrap the base error") - } - if !strings.Contains(err.Error(), "prepare") { - t.Errorf("missing action: %v", err) - } - - errWithLog := FormatStepError("prepare", base, " log line\n") - if !strings.Contains(errWithLog.Error(), "log line") { - t.Errorf("missing log: %v", errWithLog) - } -} - -func TestParseStateHappyPath(t *testing.T) { - raw := `status=running -pid=123 -exit= -alive=true -error= -` - snap, err := ParseState(raw) - if err != nil { - t.Fatalf("ParseState: %v", err) - } - if snap.Status != "running" { - t.Errorf("Status = %q", snap.Status) - } - if snap.GuestPID != 123 { - t.Errorf("GuestPID = %d", snap.GuestPID) - } - if snap.ExitCode != nil { - t.Errorf("ExitCode should be nil when empty, got %v", snap.ExitCode) - } - if !snap.Alive { - t.Error("Alive should be true") - } -} - -func TestParseStateWithExit(t *testing.T) { - raw := `status=exited -pid=123 -exit=7 -alive=false -error=something bad -` - snap, err := ParseState(raw) - if err != nil { - t.Fatalf("ParseState: %v", err) - } - if snap.ExitCode == nil || *snap.ExitCode != 7 { - t.Errorf("ExitCode = %v, want 7", snap.ExitCode) - } - if snap.LastError != "something bad" { - t.Errorf("LastError = %q", snap.LastError) - } - if snap.Alive { - t.Error("Alive should be false") - } -} - -func TestParseStateIgnoresMalformedLines(t *testing.T) { - raw := "no-equals-here\nstatus=ok\n" - snap, err := ParseState(raw) - if err != nil { - t.Fatalf("ParseState: %v", err) - } - if snap.Status != "ok" { - t.Errorf("Status = %q, want ok", snap.Status) - } -} - -func TestInspectStateFromDir(t *testing.T) { - dir := t.TempDir() - writeFile := func(name, content string) { - if err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0o600); err != nil { - t.Fatalf("WriteFile(%s): %v", name, err) - } - } - writeFile("status", "running\n") - writeFile("pid", "42\n") - writeFile("exit_code", "0\n") - writeFile("error", "\n") - - original := syscallKill - t.Cleanup(func() { syscallKill = original }) - syscallKill = func(pid int, signal os.Signal) error { return nil } - - snap, err := InspectStateFromDir(dir) - if err != nil { - t.Fatalf("InspectStateFromDir: %v", err) - } - if snap.Status != "running" { - t.Errorf("Status = %q", snap.Status) - } - if snap.GuestPID != 42 { - t.Errorf("GuestPID = %d", snap.GuestPID) - } - if snap.ExitCode == nil || *snap.ExitCode != 0 { - t.Errorf("ExitCode = %v, want 0", snap.ExitCode) - } - if !snap.Alive { - t.Error("Alive should reflect syscallKill result (true)") - } -} - -func TestInspectStateFromDirMissingFiles(t *testing.T) { - snap, err := InspectStateFromDir(t.TempDir()) - if err != nil { - t.Fatalf("InspectStateFromDir (empty): %v", err) - } - if snap.Status != "" || snap.GuestPID != 0 || snap.ExitCode != nil { - t.Errorf("empty dir: snap = %+v", snap) - } -} - -func TestApplyStateSnapshotNilReceiver(t *testing.T) { - ApplyStateSnapshot(nil, StateSnapshot{}, true) // should not panic -} - -func TestApplyStateSnapshotExitedSuccess(t *testing.T) { - exit := 0 - sess := &model.GuestSession{Status: model.GuestSessionStatusRunning, Attachable: true, Reattachable: true} - ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true) - if sess.Status != model.GuestSessionStatusExited { - t.Errorf("Status = %q, want exited", sess.Status) - } - if sess.Attachable || sess.Reattachable { - t.Error("attach flags should be cleared on exit") - } - if sess.EndedAt.IsZero() { - t.Error("EndedAt should be set") - } -} - -func TestApplyStateSnapshotExitedFailure(t *testing.T) { - exit := 2 - sess := &model.GuestSession{Status: model.GuestSessionStatusRunning} - ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true) - if sess.Status != model.GuestSessionStatusFailed { - t.Errorf("Status = %q, want failed", sess.Status) - } -} - -func TestApplyStateSnapshotVMGone(t *testing.T) { - sess := &model.GuestSession{Status: model.GuestSessionStatusRunning} - ApplyStateSnapshot(sess, StateSnapshot{Alive: false}, false) - if sess.Status != model.GuestSessionStatusFailed { - t.Errorf("Status = %q, want failed", sess.Status) - } - if sess.LastError == "" { - t.Error("LastError should be populated when VM is gone") - } -} - -func TestApplyStateSnapshotRunningStatusSetsAttachableForPipe(t *testing.T) { - // When the guest-side status file reports "running" (Alive=false from - // kill -0 may still fail transiently), ApplyStateSnapshot transitions - // the session to running and sets attach flags for pipe-mode. - sess := &model.GuestSession{ - Status: model.GuestSessionStatusStarting, - StdinMode: model.GuestSessionStdinPipe, - } - ApplyStateSnapshot(sess, StateSnapshot{Status: string(model.GuestSessionStatusRunning), GuestPID: 11}, true) - if sess.Status != model.GuestSessionStatusRunning { - t.Errorf("Status = %q, want running", sess.Status) - } - if !sess.Attachable || !sess.Reattachable { - t.Error("pipe-mode running session should be attachable + reattachable") - } - if sess.AttachBackend != AttachBackendSSHBridge { - t.Errorf("AttachBackend = %q, want %q", sess.AttachBackend, AttachBackendSSHBridge) - } -} - -func TestApplyStateSnapshotAliveEarlyReturn(t *testing.T) { - // Alive-true returns immediately after setting status; no attach - // flags set on this path (by design — attach metadata only attaches - // to status-driven transitions). - sess := &model.GuestSession{ - Status: model.GuestSessionStatusStarting, - StdinMode: model.GuestSessionStdinPipe, - } - ApplyStateSnapshot(sess, StateSnapshot{Alive: true, GuestPID: 11}, true) - if sess.Status != model.GuestSessionStatusRunning { - t.Errorf("Status = %q, want running", sess.Status) - } - if sess.StartedAt.IsZero() { - t.Error("StartedAt should have been set") - } -} - -func TestStateChanged(t *testing.T) { - base := model.GuestSession{Status: model.GuestSessionStatusRunning, GuestPID: 10} - - // Identical → no change. - if StateChanged(base, base) { - t.Error("identical states should not be considered changed") - } - - // Status change. - changed := base - changed.Status = model.GuestSessionStatusExited - if !StateChanged(base, changed) { - t.Error("status change should be detected") - } - - // ExitCode change from nil → value. - exit := 3 - changed = base - changed.ExitCode = &exit - if !StateChanged(base, changed) { - t.Error("exit-code appearing should be detected") - } - - // Both have the same exit code → no change. - a := base - a.ExitCode = &exit - b := base - b.ExitCode = &exit - if StateChanged(a, b) { - t.Error("matching exit codes should not trigger change") - } - - // Different exit codes. - other := 5 - b.ExitCode = &other - if !StateChanged(a, b) { - t.Error("differing exit codes should be detected") - } - - // Timestamp change. - changed = base - changed.StartedAt = time.Now() - if !StateChanged(base, changed) { - t.Error("StartedAt change should be detected") - } -} - -func TestFailLaunch(t *testing.T) { - in := model.GuestSession{Status: model.GuestSessionStatusStarting, Attachable: true} - out := FailLaunch(in, "provision", " ssh did not come up ", " raw output\n") - if out.Status != model.GuestSessionStatusFailed { - t.Errorf("Status = %q, want failed", out.Status) - } - if out.LastError != "ssh did not come up" { - t.Errorf("LastError = %q (not trimmed?)", out.LastError) - } - if out.LaunchStage != "provision" || out.LaunchMessage != "ssh did not come up" { - t.Errorf("launch fields not set: %+v", out) - } - if out.LaunchRawLog != "raw output" { - t.Errorf("rawLog = %q (not trimmed?)", out.LaunchRawLog) - } - if out.Attachable { - t.Error("Attachable should be cleared") - } -} - -func TestNormalizeRequiredCommands(t *testing.T) { - got := NormalizeRequiredCommands("pi", []string{"pi", "git", "", "git", " ", "make"}) - want := []string{"pi", "git", "make"} - if len(got) != len(want) { - t.Fatalf("len = %d, want %d (%v)", len(got), len(want), got) - } - for i, v := range want { - if got[i] != v { - t.Errorf("position %d: got %q, want %q", i, got[i], v) - } - } -} - -func TestInspectScriptContainsAllStateFiles(t *testing.T) { - script := InspectScript("sess-abc") - for _, key := range []string{"status", "pid", "exit_code", "error", "alive"} { - if !strings.Contains(script, key) { - t.Errorf("script missing %q:\n%s", key, script) - } - } - if !strings.Contains(script, "sess-abc") { - t.Error("script missing session id") - } -} - -func TestSignalScriptIncludesSignalAndDirPaths(t *testing.T) { - script := SignalScript("sess-x", "TERM") - if !strings.Contains(script, "TERM") { - t.Error("missing signal") - } - if !strings.Contains(script, "sess-x") { - t.Error("missing session id") - } - if !strings.Contains(script, "monitor_pid") || !strings.Contains(script, "stdin_keepalive") { - t.Errorf("expected both monitor + stdin_keepalive kills, got:\n%s", script) - } -} diff --git a/internal/daemon/session_attach.go b/internal/daemon/session_attach.go deleted file mode 100644 index 6c83da4..0000000 --- a/internal/daemon/session_attach.go +++ /dev/null @@ -1,224 +0,0 @@ -package daemon - -import ( - "context" - "errors" - "fmt" - "io" - "net" - "os" - "path/filepath" - "time" - - "banger/internal/api" - sess "banger/internal/daemon/session" - "banger/internal/guest" - "banger/internal/model" - "banger/internal/sessionstream" -) - -func (d *Daemon) BeginGuestSessionAttach(ctx context.Context, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) { - vm, err := d.FindVM(ctx, params.VMIDOrName) - if err != nil { - return api.GuestSessionAttachBeginResult{}, err - } - session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName) - if err != nil { - return api.GuestSessionAttachBeginResult{}, err - } - session, _ = d.refreshGuestSession(ctx, vm, session) - if !session.Attachable { - return api.GuestSessionAttachBeginResult{}, errors.New("session is not attachable") - } - controller := &guestSessionController{} - if !d.claimGuestSessionController(session.ID, controller) { - return api.GuestSessionAttachBeginResult{}, errors.New("session already has an active attach") - } - attachID, err := model.NewID() - if err != nil { - d.clearGuestSessionController(session.ID) - return api.GuestSessionAttachBeginResult{}, err - } - socketPath := filepath.Join(d.layout.RuntimeDir, "guest-session-attach-"+attachID[:12]+".sock") - _ = os.Remove(socketPath) - listener, err := net.Listen("unix", socketPath) - if err != nil { - d.clearGuestSessionController(session.ID) - return api.GuestSessionAttachBeginResult{}, err - } - if err := os.Chmod(socketPath, 0o600); err != nil { - _ = listener.Close() - _ = os.Remove(socketPath) - d.clearGuestSessionController(session.ID) - return api.GuestSessionAttachBeginResult{}, err - } - go d.serveGuestSessionAttach(session, controller, attachID, socketPath, listener) - return api.GuestSessionAttachBeginResult{ - Session: session, - AttachID: attachID, - TransportKind: sess.TransportUnixSocket, - TransportTarget: socketPath, - SocketPath: socketPath, - StreamFormat: sessionstream.FormatV1, - }, nil -} - -func (d *Daemon) forwardGuestSessionOutput(_ string, controller *guestSessionController, channel byte, reader io.Reader) { - buffer := make([]byte, 32*1024) - for { - n, err := reader.Read(buffer) - if n > 0 { - controller.writeFrame(channel, buffer[:n]) - } - if err != nil { - if !errors.Is(err, io.EOF) { - controller.writeControl(sessionstream.ControlMessage{Type: "error", Error: err.Error()}) - } - return - } - } -} - -func (d *Daemon) waitForGuestSessionExit(id string, controller *guestSessionController, session model.GuestSession) { - err := controller.stream.Wait() - updated := session - updated.Attachable = false - now := model.Now() - updated.UpdatedAt = now - updated.EndedAt = now - if exitCode, ok := sess.ExitCode(err); ok { - updated.ExitCode = &exitCode - if exitCode == 0 { - updated.Status = model.GuestSessionStatusExited - } else { - updated.Status = model.GuestSessionStatusFailed - } - } - if err != nil && updated.LastError == "" { - updated.LastError = err.Error() - } - if vm, getErr := d.store.GetVMByID(context.Background(), updated.VMID); getErr == nil { - if refreshed, refreshErr := d.refreshGuestSession(context.Background(), vm, updated); refreshErr == nil { - updated = refreshed - } - } - _ = d.store.UpsertGuestSession(context.Background(), updated) - controller.writeControl(sessionstream.ControlMessage{Type: "exit", ExitCode: updated.ExitCode}) - _ = controller.close() - d.clearGuestSessionController(id) -} - -func (d *Daemon) serveGuestSessionAttach(session model.GuestSession, controller *guestSessionController, _ string, socketPath string, listener net.Listener) { - defer func() { - _ = listener.Close() - _ = os.Remove(socketPath) - _ = controller.close() - d.clearGuestSessionController(session.ID) - }() - conn, err := listener.Accept() - if err != nil { - return - } - defer conn.Close() - if err := controller.setAttach(conn); err != nil { - _ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()}) - return - } - defer controller.clearAttach(conn) - if err := d.attachGuestSessionBridge(session, controller); err != nil { - _ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()}) - return - } - for { - channel, payload, err := sessionstream.ReadFrame(conn) - if err != nil { - return - } - switch channel { - case sessionstream.ChannelStdin: - if controller.stdin == nil { - continue - } - if _, err := controller.stdin.Write(payload); err != nil { - _ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()}) - return - } - case sessionstream.ChannelControl: - message, err := sessionstream.ReadControl(payload) - if err != nil { - _ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()}) - return - } - if message.Type == "eof" && controller.stdin != nil { - _ = controller.stdin.Close() - } - } - } -} - -func (d *Daemon) attachGuestSessionBridge(session model.GuestSession, controller *guestSessionController) error { - vm, err := d.store.GetVMByID(context.Background(), session.VMID) - if err != nil { - return err - } - if !d.vmAlive(vm) { - return fmt.Errorf("vm %q is not running", vm.Name) - } - address := net.JoinHostPort(vm.Runtime.GuestIP, "22") - stdinStream, err := d.openGuestSessionAttachStream(address, sess.AttachInputCommand(session.ID)) - if err != nil { - return fmt.Errorf("open guest session stdin stream: %w", err) - } - stdoutStream, err := d.openGuestSessionAttachStream(address, sess.AttachTailCommand(session.StdoutLogPath)) - if err != nil { - _ = stdinStream.Close() - return fmt.Errorf("open guest session stdout stream: %w", err) - } - stderrStream, err := d.openGuestSessionAttachStream(address, sess.AttachTailCommand(session.StderrLogPath)) - if err != nil { - _ = stdinStream.Close() - _ = stdoutStream.Close() - return fmt.Errorf("open guest session stderr stream: %w", err) - } - controller.streams = append(controller.streams, stdinStream, stdoutStream, stderrStream) - controller.stdin = stdinStream.Stdin() - go d.forwardGuestSessionOutput(session.ID, controller, sessionstream.ChannelStdout, stdoutStream.Stdout()) - go d.forwardGuestSessionOutput(session.ID, controller, sessionstream.ChannelStderr, stderrStream.Stdout()) - go d.watchGuestSessionAttach(session.ID, controller, session) - return nil -} - -func (d *Daemon) openGuestSessionAttachStream(address, command string) (*guest.StreamSession, error) { - client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath, d.layout.KnownHostsPath) - if err != nil { - return nil, err - } - stream, err := client.StartCommand(context.Background(), command) - if err != nil { - _ = client.Close() - return nil, err - } - return stream, nil -} - -func (d *Daemon) watchGuestSessionAttach(id string, controller *guestSessionController, session model.GuestSession) { - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() - for range ticker.C { - vm, err := d.store.GetVMByID(context.Background(), session.VMID) - if err != nil { - controller.writeControl(sessionstream.ControlMessage{Type: "error", Error: err.Error()}) - _ = controller.close() - return - } - refreshed, err := d.refreshGuestSession(context.Background(), vm, session) - if err == nil { - session = refreshed - } - if session.Status == model.GuestSessionStatusExited || session.Status == model.GuestSessionStatusFailed { - controller.writeControl(sessionstream.ControlMessage{Type: "exit", ExitCode: session.ExitCode}) - _ = controller.close() - return - } - } -} diff --git a/internal/daemon/session_controller.go b/internal/daemon/session_controller.go deleted file mode 100644 index 1736f7b..0000000 --- a/internal/daemon/session_controller.go +++ /dev/null @@ -1,184 +0,0 @@ -package daemon - -import ( - "errors" - "io" - "net" - "sync" - - "banger/internal/guest" - "banger/internal/sessionstream" -) - -type guestSessionController struct { - stream *guest.StreamSession - streams []*guest.StreamSession - stdin io.WriteCloser - attachMu sync.Mutex - attach net.Conn - writeMu sync.Mutex - closeOnce sync.Once -} - -func (c *guestSessionController) setAttach(conn net.Conn) error { - c.attachMu.Lock() - defer c.attachMu.Unlock() - if c.attach != nil { - return errors.New("session already has an active attach") - } - c.attach = conn - return nil -} - -func (c *guestSessionController) clearAttach(conn net.Conn) { - c.attachMu.Lock() - defer c.attachMu.Unlock() - if c.attach == conn { - c.attach = nil - } -} - -func (c *guestSessionController) writeFrame(channel byte, payload []byte) { - c.attachMu.Lock() - conn := c.attach - c.attachMu.Unlock() - if conn == nil { - return - } - c.writeMu.Lock() - err := sessionstream.WriteFrame(conn, channel, payload) - c.writeMu.Unlock() - if err != nil { - _ = conn.Close() - c.clearAttach(conn) - } -} - -func (c *guestSessionController) writeControl(message sessionstream.ControlMessage) { - c.attachMu.Lock() - conn := c.attach - c.attachMu.Unlock() - if conn == nil { - return - } - c.writeMu.Lock() - err := sessionstream.WriteControl(conn, message) - c.writeMu.Unlock() - if err != nil { - _ = conn.Close() - c.clearAttach(conn) - } -} - -func (c *guestSessionController) close() error { - if c == nil { - return nil - } - var err error - c.closeOnce.Do(func() { - c.attachMu.Lock() - conn := c.attach - c.attach = nil - c.attachMu.Unlock() - if conn != nil { - err = errors.Join(err, conn.Close()) - } - if c.stdin != nil { - err = errors.Join(err, c.stdin.Close()) - } - if c.stream != nil { - err = errors.Join(err, c.stream.Close()) - } - for _, stream := range c.streams { - if stream != nil { - err = errors.Join(err, stream.Close()) - } - } - }) - return err -} - -// sessionRegistry owns the live guest-session controller map. Its lock is -// independent of Daemon.mu so guest-session lookups do not contend with -// unrelated daemon state. -type sessionRegistry struct { - mu sync.Mutex - byID map[string]*guestSessionController - closed bool -} - -func newSessionRegistry() sessionRegistry { - return sessionRegistry{byID: make(map[string]*guestSessionController)} -} - -func (r *sessionRegistry) set(id string, controller *guestSessionController) { - r.mu.Lock() - defer r.mu.Unlock() - if r.closed { - return - } - r.byID[id] = controller -} - -func (r *sessionRegistry) claim(id string, controller *guestSessionController) bool { - r.mu.Lock() - defer r.mu.Unlock() - if r.closed { - return false - } - if r.byID[id] != nil { - return false - } - r.byID[id] = controller - return true -} - -func (r *sessionRegistry) get(id string) *guestSessionController { - r.mu.Lock() - defer r.mu.Unlock() - return r.byID[id] -} - -func (r *sessionRegistry) clear(id string) *guestSessionController { - r.mu.Lock() - defer r.mu.Unlock() - controller := r.byID[id] - delete(r.byID, id) - return controller -} - -func (r *sessionRegistry) closeAll() error { - r.mu.Lock() - controllers := make([]*guestSessionController, 0, len(r.byID)) - for _, controller := range r.byID { - controllers = append(controllers, controller) - } - r.byID = nil - r.closed = true - r.mu.Unlock() - var err error - for _, controller := range controllers { - err = errors.Join(err, controller.close()) - } - return err -} - -func (d *Daemon) setGuestSessionController(id string, controller *guestSessionController) { - d.sessions.set(id, controller) -} - -func (d *Daemon) claimGuestSessionController(id string, controller *guestSessionController) bool { - return d.sessions.claim(id, controller) -} - -func (d *Daemon) getGuestSessionController(id string) *guestSessionController { - return d.sessions.get(id) -} - -func (d *Daemon) clearGuestSessionController(id string) *guestSessionController { - return d.sessions.clear(id) -} - -func (d *Daemon) closeGuestSessionControllers() error { - return d.sessions.closeAll() -} diff --git a/internal/daemon/session_lifecycle.go b/internal/daemon/session_lifecycle.go deleted file mode 100644 index beeaa07..0000000 --- a/internal/daemon/session_lifecycle.go +++ /dev/null @@ -1,213 +0,0 @@ -package daemon - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "strings" - "time" - - "banger/internal/api" - sess "banger/internal/daemon/session" - "banger/internal/guest" - "banger/internal/model" -) - -func (d *Daemon) StartGuestSession(ctx context.Context, params api.GuestSessionStartParams) (model.GuestSession, error) { - stdinMode := model.GuestSessionStdinMode(strings.TrimSpace(params.StdinMode)) - if stdinMode == "" { - stdinMode = model.GuestSessionStdinClosed - } - if stdinMode != model.GuestSessionStdinClosed && stdinMode != model.GuestSessionStdinPipe { - return model.GuestSession{}, fmt.Errorf("unsupported stdin mode %q", params.StdinMode) - } - if strings.TrimSpace(params.Command) == "" { - return model.GuestSession{}, errors.New("session command is required") - } - var created model.GuestSession - _, err := d.withVMLockByRef(ctx, params.VMIDOrName, func(vm model.VMRecord) (model.VMRecord, error) { - if !d.vmAlive(vm) { - return model.VMRecord{}, fmt.Errorf("vm %q is not running", vm.Name) - } - session, err := d.startGuestSessionLocked(ctx, vm, params, stdinMode) - if err != nil { - return model.VMRecord{}, err - } - created = session - return vm, nil - }) - return created, err -} - -func (d *Daemon) startGuestSessionLocked(ctx context.Context, vm model.VMRecord, params api.GuestSessionStartParams, stdinMode model.GuestSessionStdinMode) (model.GuestSession, error) { - id, err := model.NewID() - if err != nil { - return model.GuestSession{}, err - } - now := model.Now() - session := model.GuestSession{ - ID: id, - VMID: vm.ID, - Name: sess.DefaultName(id, params.Command, params.Name), - Backend: sess.BackendSSH, - Command: params.Command, - Args: append([]string(nil), params.Args...), - CWD: strings.TrimSpace(params.CWD), - Env: sess.CloneStringMap(params.Env), - StdinMode: stdinMode, - Status: model.GuestSessionStatusStarting, - GuestStateDir: sess.StateDir(id), - StdoutLogPath: sess.StdoutLogPath(id), - StderrLogPath: sess.StderrLogPath(id), - Tags: sess.CloneStringMap(params.Tags), - Attachable: stdinMode == model.GuestSessionStdinPipe, - Reattachable: stdinMode == model.GuestSessionStdinPipe, - CreatedAt: now, - UpdatedAt: now, - } - if session.Attachable { - session.AttachBackend = sess.AttachBackendSSHBridge - session.AttachMode = sess.AttachModeExclusive - } else { - session.AttachBackend = sess.AttachBackendNone - } - if err := d.store.UpsertGuestSession(ctx, session); err != nil { - return model.GuestSession{}, err - } - fail := func(stage, message, rawLog string) (model.GuestSession, error) { - session = sess.FailLaunch(session, stage, message, rawLog) - if err := d.store.UpsertGuestSession(ctx, session); err != nil { - return model.GuestSession{}, err - } - return session, nil - } - address := net.JoinHostPort(vm.Runtime.GuestIP, "22") - if err := d.waitForGuestSSH(ctx, address, 250*time.Millisecond); err != nil { - return fail("ssh_unavailable", fmt.Sprintf("guest ssh unavailable: %v", err), "") - } - client, err := d.dialGuest(ctx, address) - if err != nil { - return fail("dial_guest", fmt.Sprintf("dial guest ssh: %v", err), "") - } - defer client.Close() - var preflightLog bytes.Buffer - if err := client.RunScript(ctx, sess.CWDPreflightScript(session.CWD), &preflightLog); err != nil { - return fail("preflight_cwd", fmt.Sprintf("guest working directory is unavailable: %s", sess.DefaultCWD(session.CWD)), preflightLog.String()) - } - preflightLog.Reset() - requiredCommands := sess.NormalizeRequiredCommands(params.Command, params.RequiredCommands) - if err := client.RunScript(ctx, sess.CommandPreflightScript(requiredCommands), &preflightLog); err != nil { - return fail("preflight_command", fmt.Sprintf("required guest command is unavailable: %s", strings.TrimSpace(preflightLog.String())), preflightLog.String()) - } - var uploadLog bytes.Buffer - if err := client.UploadFile(ctx, sess.ScriptPath(id), 0o755, []byte(sess.Script(session)), &uploadLog); err != nil { - return fail("upload_script", "upload guest session script failed", uploadLog.String()) - } - var launchLog bytes.Buffer - launchScript := fmt.Sprintf("set -euo pipefail\nnohup bash %s >/dev/null 2>&1 > %s\nrm -f %s\n", - sess.ShellQuote(tmpPath), - sess.ShellQuote(sess.StdinPipePath(session.ID)), - sess.ShellQuote(tmpPath), - ) - var sendLog bytes.Buffer - if err := client.RunScript(ctx, sendScript, &sendLog); err != nil { - return api.GuestSessionSendResult{}, fmt.Errorf("send to session: %w: %s", err, strings.TrimSpace(sendLog.String())) - } - return api.GuestSessionSendResult{Session: session, BytesWritten: len(params.Payload)}, nil -} - -func (d *Daemon) readGuestSessionLog(ctx context.Context, vm model.VMRecord, session model.GuestSession, stream string, tailLines int) (string, error) { - if d.vmAlive(vm) { - client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath) - if err != nil { - return "", err - } - defer client.Close() - path := session.StdoutLogPath - if stream == "stderr" { - path = session.StderrLogPath - } - var output bytes.Buffer - script := fmt.Sprintf("set -euo pipefail\nif [ -f %s ]; then tail -n %d %s; fi\n", sess.ShellQuote(path), tailLines, sess.ShellQuote(path)) - if err := client.RunScript(ctx, script, &output); err != nil { - return "", sess.FormatStepError("read guest session log", err, output.String()) - } - return output.String(), nil - } - runner := d.runner - if runner == nil { - runner = system.NewRunner() - } - workMount, cleanup, err := system.MountTempDir(ctx, runner, vm.Runtime.WorkDiskPath, false) - if err != nil { - return "", err - } - defer cleanup() - logPath := filepath.Join(workMount, sess.RelativeStateDir(session.ID), stream+".log") - return sess.TailFileContent(logPath, tailLines) -} diff --git a/internal/daemon/workspace.go b/internal/daemon/workspace.go index e285c94..553d2b6 100644 --- a/internal/daemon/workspace.go +++ b/internal/daemon/workspace.go @@ -10,7 +10,6 @@ import ( "time" "banger/internal/api" - sess "banger/internal/daemon/session" ws "banger/internal/daemon/workspace" "banger/internal/model" ) @@ -114,9 +113,9 @@ func exportScript(guestPath, diffRef, diffFlag string) string { "git read-tree %s --index-output=\"$tmp_idx\"\n"+ "GIT_INDEX_FILE=\"$tmp_idx\" git add -A\n"+ "GIT_INDEX_FILE=\"$tmp_idx\" git diff --cached %s %s\n", - sess.ShellQuote(guestPath), - sess.ShellQuote(diffRef), - sess.ShellQuote(diffRef), + ws.ShellQuote(guestPath), + ws.ShellQuote(diffRef), + ws.ShellQuote(diffRef), diffFlag, ) } @@ -189,9 +188,9 @@ func (d *Daemon) prepareVMWorkspaceGuestIO(ctx context.Context, vm model.VMRecor } if readOnly { var chmodLog bytes.Buffer - chmodScript := fmt.Sprintf("set -euo pipefail\nchmod -R a-w %s\n", sess.ShellQuote(guestPath)) + chmodScript := fmt.Sprintf("set -euo pipefail\nchmod -R a-w %s\n", ws.ShellQuote(guestPath)) if err := client.RunScript(ctx, chmodScript, &chmodLog); err != nil { - return model.WorkspacePrepareResult{}, sess.FormatStepError("set workspace readonly", err, chmodLog.String()) + return model.WorkspacePrepareResult{}, ws.FormatStepError("set workspace readonly", err, chmodLog.String()) } } return model.WorkspacePrepareResult{ diff --git a/internal/daemon/workspace/util.go b/internal/daemon/workspace/util.go new file mode 100644 index 0000000..9f99b2f --- /dev/null +++ b/internal/daemon/workspace/util.go @@ -0,0 +1,20 @@ +package workspace + +import ( + "fmt" + "strings" +) + +// ShellQuote returns value single-quoted for bash, escaping embedded quotes. +func ShellQuote(value string) string { + return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'" +} + +// FormatStepError wraps err with an action label and trimmed on-guest log. +func FormatStepError(action string, err error, log string) error { + log = strings.TrimSpace(log) + if log == "" { + return fmt.Errorf("%s: %w", action, err) + } + return fmt.Errorf("%s: %w: %s", action, err, log) +} diff --git a/internal/daemon/workspace/workspace.go b/internal/daemon/workspace/workspace.go index 30c1973..1e78af3 100644 --- a/internal/daemon/workspace/workspace.go +++ b/internal/daemon/workspace/workspace.go @@ -18,7 +18,6 @@ import ( "sort" "strings" - sess "banger/internal/daemon/session" "banger/internal/model" "banger/internal/system" ) @@ -146,13 +145,13 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g switch mode { case model.WorkspacePrepareModeFullCopy: var copyLog bytes.Buffer - command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath), sess.ShellQuote(guestPath), sess.ShellQuote(guestPath)) + command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath), ShellQuote(guestPath), ShellQuote(guestPath)) if err := client.StreamTar(ctx, spec.RepoRoot, command, ©Log); err != nil { - return sess.FormatStepError("copy full workspace", err, copyLog.String()) + return FormatStepError("copy full workspace", err, copyLog.String()) } var finalizeLog bytes.Buffer if err := client.RunScript(ctx, FinalizeScript(spec, guestPath, mode), &finalizeLog); err != nil { - return sess.FormatStepError("finalize full workspace", err, finalizeLog.String()) + return FormatStepError("finalize full workspace", err, finalizeLog.String()) } return nil case model.WorkspacePrepareModeMetadataOnly, model.WorkspacePrepareModeShallowOverlay: @@ -162,21 +161,21 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g } defer cleanup() var copyLog bytes.Buffer - command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath), sess.ShellQuote(guestPath), sess.ShellQuote(guestPath)) + command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath), ShellQuote(guestPath), ShellQuote(guestPath)) if err := client.StreamTar(ctx, repoCopyDir, command, ©Log); err != nil { - return sess.FormatStepError("copy guest git metadata", err, copyLog.String()) + return FormatStepError("copy guest git metadata", err, copyLog.String()) } var scriptLog bytes.Buffer if err := client.RunScript(ctx, FinalizeScript(spec, guestPath, mode), &scriptLog); err != nil { - return sess.FormatStepError("prepare guest checkout", err, scriptLog.String()) + return FormatStepError("prepare guest checkout", err, scriptLog.String()) } if mode == model.WorkspacePrepareModeMetadataOnly { return nil } var overlayLog bytes.Buffer - command = fmt.Sprintf("tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath)) + command = fmt.Sprintf("tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath)) if err := client.StreamTarEntries(ctx, spec.RepoRoot, spec.OverlayPaths, command, &overlayLog); err != nil { - return sess.FormatStepError("overlay workspace working tree", err, overlayLog.String()) + return FormatStepError("overlay workspace working tree", err, overlayLog.String()) } return nil default: @@ -190,22 +189,22 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g func FinalizeScript(spec RepoSpec, guestPath string, mode model.WorkspacePrepareMode) string { var script strings.Builder script.WriteString("set -euo pipefail\n") - fmt.Fprintf(&script, "DIR=%s\n", sess.ShellQuote(guestPath)) + fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(guestPath)) script.WriteString("git config --global --add safe.directory \"$DIR\"\n") if mode != model.WorkspacePrepareModeFullCopy { script.WriteString("find \"$DIR\" -mindepth 1 -maxdepth 1 ! -name .git -exec rm -rf {} +\n") } switch { case strings.TrimSpace(spec.BranchName) != "": - fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", sess.ShellQuote(spec.BranchName), sess.ShellQuote(spec.BaseCommit)) + fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", ShellQuote(spec.BranchName), ShellQuote(spec.BaseCommit)) case strings.TrimSpace(spec.CurrentBranch) != "": - fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", sess.ShellQuote(spec.CurrentBranch), sess.ShellQuote(spec.HeadCommit)) + fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", ShellQuote(spec.CurrentBranch), ShellQuote(spec.HeadCommit)) default: - fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", sess.ShellQuote(spec.HeadCommit)) + fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", ShellQuote(spec.HeadCommit)) } if strings.TrimSpace(spec.GitUserName) != "" && strings.TrimSpace(spec.GitUserEmail) != "" { - fmt.Fprintf(&script, "git -C \"$DIR\" config user.name %s\n", sess.ShellQuote(spec.GitUserName)) - fmt.Fprintf(&script, "git -C \"$DIR\" config user.email %s\n", sess.ShellQuote(spec.GitUserEmail)) + fmt.Fprintf(&script, "git -C \"$DIR\" config user.name %s\n", ShellQuote(spec.GitUserName)) + fmt.Fprintf(&script, "git -C \"$DIR\" config user.email %s\n", ShellQuote(spec.GitUserEmail)) } return script.String() } diff --git a/internal/model/types.go b/internal/model/types.go index 2eb0b45..64011dd 100644 --- a/internal/model/types.go +++ b/internal/model/types.go @@ -34,23 +34,6 @@ const ( VMStateError VMState = "error" ) -type GuestSessionStatus string - -const ( - GuestSessionStatusStarting GuestSessionStatus = "starting" - GuestSessionStatusRunning GuestSessionStatus = "running" - GuestSessionStatusExited GuestSessionStatus = "exited" - GuestSessionStatusFailed GuestSessionStatus = "failed" - GuestSessionStatusStopping GuestSessionStatus = "stopping" -) - -type GuestSessionStdinMode string - -const ( - GuestSessionStdinClosed GuestSessionStdinMode = "closed" - GuestSessionStdinPipe GuestSessionStdinMode = "pipe" -) - type DaemonConfig struct { LogLevel string FirecrackerBin string @@ -176,37 +159,6 @@ type VMSetRequest struct { NATEnabled *bool } -type GuestSession struct { - ID string `json:"id"` - VMID string `json:"vm_id"` - Name string `json:"name"` - Backend string `json:"backend"` - AttachBackend string `json:"attach_backend,omitempty"` - AttachMode string `json:"attach_mode,omitempty"` - Command string `json:"command"` - Args []string `json:"args,omitempty"` - CWD string `json:"cwd,omitempty"` - Env map[string]string `json:"env,omitempty"` - StdinMode GuestSessionStdinMode `json:"stdin_mode,omitempty"` - Status GuestSessionStatus `json:"status"` - ExitCode *int `json:"exit_code,omitempty"` - GuestPID int `json:"guest_pid,omitempty"` - GuestStateDir string `json:"guest_state_dir,omitempty"` - StdoutLogPath string `json:"stdout_log_path,omitempty"` - StderrLogPath string `json:"stderr_log_path,omitempty"` - Tags map[string]string `json:"tags,omitempty"` - LastError string `json:"last_error,omitempty"` - Attachable bool `json:"attachable"` - Reattachable bool `json:"reattachable"` - LaunchStage string `json:"launch_stage,omitempty"` - LaunchMessage string `json:"launch_message,omitempty"` - LaunchRawLog string `json:"launch_raw_log,omitempty"` - CreatedAt time.Time `json:"created_at"` - StartedAt time.Time `json:"started_at,omitempty"` - UpdatedAt time.Time `json:"updated_at"` - EndedAt time.Time `json:"ended_at,omitempty"` -} - type WorkspacePrepareMode string const ( diff --git a/internal/sessionstream/sessionstream.go b/internal/sessionstream/sessionstream.go deleted file mode 100644 index 7167f43..0000000 --- a/internal/sessionstream/sessionstream.go +++ /dev/null @@ -1,76 +0,0 @@ -package sessionstream - -import ( - "encoding/binary" - "encoding/json" - "fmt" - "io" -) - -const ( - ChannelStdin byte = 0x01 - ChannelStdout byte = 0x02 - ChannelStderr byte = 0x03 - ChannelControl byte = 0x04 - FormatV1 = "stdio_mux_v1" -) - -type ControlMessage struct { - Type string `json:"type"` - ExitCode *int `json:"exit_code,omitempty"` - Error string `json:"error,omitempty"` -} - -func WriteFrame(w io.Writer, channel byte, payload []byte) error { - var header [5]byte - header[0] = channel - binary.BigEndian.PutUint32(header[1:], uint32(len(payload))) - if _, err := w.Write(header[:]); err != nil { - return err - } - if len(payload) == 0 { - return nil - } - _, err := w.Write(payload) - return err -} - -func ReadFrame(r io.Reader) (byte, []byte, error) { - var header [5]byte - if _, err := io.ReadFull(r, header[:]); err != nil { - return 0, nil, err - } - length := binary.BigEndian.Uint32(header[1:]) - payload := make([]byte, length) - if _, err := io.ReadFull(r, payload); err != nil { - return 0, nil, err - } - return header[0], payload, nil -} - -func WriteControl(w io.Writer, message ControlMessage) error { - payload, err := json.Marshal(message) - if err != nil { - return err - } - return WriteFrame(w, ChannelControl, payload) -} - -func ReadControl(payload []byte) (ControlMessage, error) { - var message ControlMessage - if err := json.Unmarshal(payload, &message); err != nil { - return ControlMessage{}, err - } - return message, nil -} - -func ReadNextControl(r io.Reader) (ControlMessage, error) { - channel, payload, err := ReadFrame(r) - if err != nil { - return ControlMessage{}, err - } - if channel != ChannelControl { - return ControlMessage{}, fmt.Errorf("unexpected channel %d", channel) - } - return ReadControl(payload) -} diff --git a/internal/sessionstream/sessionstream_test.go b/internal/sessionstream/sessionstream_test.go deleted file mode 100644 index aca7446..0000000 --- a/internal/sessionstream/sessionstream_test.go +++ /dev/null @@ -1,117 +0,0 @@ -package sessionstream - -import ( - "bytes" - "errors" - "io" - "testing" -) - -func TestWriteReadFrameRoundtrip(t *testing.T) { - cases := []struct { - name string - channel byte - payload []byte - }{ - {"stdout_bytes", ChannelStdout, []byte("hello world")}, - {"stderr_bytes", ChannelStderr, []byte{0x00, 0xff, 0x7f}}, - {"empty_payload", ChannelStdin, nil}, - } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - if err := WriteFrame(&buf, tc.channel, tc.payload); err != nil { - t.Fatalf("WriteFrame: %v", err) - } - ch, got, err := ReadFrame(&buf) - if err != nil { - t.Fatalf("ReadFrame: %v", err) - } - if ch != tc.channel { - t.Fatalf("channel = %d, want %d", ch, tc.channel) - } - if !bytes.Equal(got, tc.payload) && !(len(got) == 0 && len(tc.payload) == 0) { - t.Fatalf("payload = %q, want %q", got, tc.payload) - } - }) - } -} - -type shortWriter struct { - failAfter int - written int -} - -func (s *shortWriter) Write(p []byte) (int, error) { - s.written += len(p) - if s.written > s.failAfter { - return 0, io.ErrShortWrite - } - return len(p), nil -} - -func TestWriteFrameWriterError(t *testing.T) { - w := &shortWriter{failAfter: 2} - err := WriteFrame(w, ChannelStdout, []byte("payload")) - if err == nil { - t.Fatal("expected error from short writer") - } -} - -func TestReadFrameTruncated(t *testing.T) { - _, _, err := ReadFrame(bytes.NewReader([]byte{0x02, 0x00})) - if !errors.Is(err, io.ErrUnexpectedEOF) && err == nil { - t.Fatalf("expected EOF-ish error, got %v", err) - } - - // Header OK, but payload truncated. - var buf bytes.Buffer - buf.Write([]byte{ChannelStdout, 0x00, 0x00, 0x00, 0x05}) - buf.Write([]byte("ab")) - if _, _, err := ReadFrame(&buf); err == nil { - t.Fatal("expected truncated payload error") - } -} - -func TestControlRoundtrip(t *testing.T) { - code := 42 - msg := ControlMessage{Type: "exit", ExitCode: &code} - - var buf bytes.Buffer - if err := WriteControl(&buf, msg); err != nil { - t.Fatalf("WriteControl: %v", err) - } - - got, err := ReadNextControl(&buf) - if err != nil { - t.Fatalf("ReadNextControl: %v", err) - } - if got.Type != "exit" { - t.Fatalf("type = %q, want exit", got.Type) - } - if got.ExitCode == nil || *got.ExitCode != 42 { - t.Fatalf("exit_code = %v, want 42", got.ExitCode) - } -} - -func TestReadControlBadJSON(t *testing.T) { - if _, err := ReadControl([]byte("{not json")); err == nil { - t.Fatal("expected JSON error") - } -} - -func TestReadNextControlWrongChannel(t *testing.T) { - var buf bytes.Buffer - if err := WriteFrame(&buf, ChannelStdout, []byte("not a control frame")); err != nil { - t.Fatalf("WriteFrame: %v", err) - } - if _, err := ReadNextControl(&buf); err == nil { - t.Fatal("expected error for non-control channel") - } -} - -func TestFormatConstant(t *testing.T) { - if FormatV1 != "stdio_mux_v1" { - t.Fatalf("FormatV1 = %q, want stdio_mux_v1", FormatV1) - } -} diff --git a/internal/store/guest_session_test.go b/internal/store/guest_session_test.go deleted file mode 100644 index eff1477..0000000 --- a/internal/store/guest_session_test.go +++ /dev/null @@ -1,214 +0,0 @@ -package store - -import ( - "context" - "database/sql" - "errors" - "fmt" - "reflect" - "testing" - "time" - - "banger/internal/model" -) - -func sampleGuestSession(id, vmID, name string) model.GuestSession { - now := fixedTime() - exit := 7 - return model.GuestSession{ - ID: id, - VMID: vmID, - Name: name, - Backend: "ssh", - AttachBackend: "vsock", - AttachMode: "rpc", - Command: "pi", - Args: []string{"--mode", "rpc"}, - CWD: "/root/repo", - Env: map[string]string{"FOO": "bar"}, - StdinMode: model.GuestSessionStdinMode("pipe"), - Status: model.GuestSessionStatus("exited"), - ExitCode: &exit, - GuestPID: 1234, - GuestStateDir: "/tmp/guest-" + id, - StdoutLogPath: "/tmp/" + id + ".stdout", - StderrLogPath: "/tmp/" + id + ".stderr", - Tags: map[string]string{"role": "planner"}, - LastError: "", - Attachable: true, - Reattachable: true, - LaunchStage: "started", - LaunchMessage: "ok", - LaunchRawLog: "boot log...", - CreatedAt: now, - StartedAt: now, - UpdatedAt: now, - EndedAt: now.Add(time.Minute), - } -} - -// openTestStoreWithVMs opens a fresh store seeded with the given VM IDs so -// guest_sessions FK constraints are satisfied. Each VM gets a minimal -// image it references. -func openTestStoreWithVMs(t *testing.T, vmIDs ...string) *Store { - t.Helper() - ctx := context.Background() - store := openTestStore(t) - - image := sampleImage("stub-image") - if err := store.UpsertImage(ctx, image); err != nil { - t.Fatalf("UpsertImage: %v", err) - } - for i, id := range vmIDs { - vm := sampleVM(id, image.ID, fmt.Sprintf("172.16.0.%d", i+2)) - vm.ID = id - if err := store.UpsertVM(ctx, vm); err != nil { - t.Fatalf("UpsertVM(%s): %v", id, err) - } - } - return store -} - -func TestGuestSessionUpsertAndGetByID(t *testing.T) { - t.Parallel() - ctx := context.Background() - store := openTestStoreWithVMs(t, "vm-1") - - session := sampleGuestSession("sess-1", "vm-1", "planner") - if err := store.UpsertGuestSession(ctx, session); err != nil { - t.Fatalf("UpsertGuestSession: %v", err) - } - - got, err := store.GetGuestSessionByID(ctx, "sess-1") - if err != nil { - t.Fatalf("GetGuestSessionByID: %v", err) - } - if !reflect.DeepEqual(got, session) { - t.Fatalf("round-trip mismatch:\n got %+v\n want %+v", got, session) - } -} - -func TestGuestSessionUpsertIsIdempotent(t *testing.T) { - t.Parallel() - ctx := context.Background() - store := openTestStoreWithVMs(t, "vm-1") - - session := sampleGuestSession("sess-1", "vm-1", "planner") - if err := store.UpsertGuestSession(ctx, session); err != nil { - t.Fatalf("UpsertGuestSession (first): %v", err) - } - - // Mutate + re-upsert → existing row updated. - session.Command = "pi --other" - session.Status = model.GuestSessionStatus("running") - session.ExitCode = nil - if err := store.UpsertGuestSession(ctx, session); err != nil { - t.Fatalf("UpsertGuestSession (second): %v", err) - } - - got, err := store.GetGuestSessionByID(ctx, "sess-1") - if err != nil { - t.Fatalf("GetGuestSessionByID: %v", err) - } - if got.Command != "pi --other" { - t.Errorf("command = %q, want 'pi --other'", got.Command) - } - if got.Status != model.GuestSessionStatus("running") { - t.Errorf("status = %q, want running", got.Status) - } - if got.ExitCode != nil { - t.Errorf("ExitCode = %v, want nil after clearing", got.ExitCode) - } -} - -func TestGetGuestSessionByIDOrName(t *testing.T) { - t.Parallel() - ctx := context.Background() - store := openTestStoreWithVMs(t, "vm-1") - - session := sampleGuestSession("sess-1", "vm-1", "planner") - if err := store.UpsertGuestSession(ctx, session); err != nil { - t.Fatalf("UpsertGuestSession: %v", err) - } - - byID, err := store.GetGuestSession(ctx, "vm-1", "sess-1") - if err != nil { - t.Fatalf("GetGuestSession by ID: %v", err) - } - if byID.ID != "sess-1" { - t.Errorf("by-ID: got %q, want sess-1", byID.ID) - } - - byName, err := store.GetGuestSession(ctx, "vm-1", "planner") - if err != nil { - t.Fatalf("GetGuestSession by name: %v", err) - } - if byName.Name != "planner" { - t.Errorf("by-name: got %q, want planner", byName.Name) - } - - // Scoped to the VM. - if _, err := store.GetGuestSession(ctx, "vm-unknown", "sess-1"); !errors.Is(err, sql.ErrNoRows) { - t.Errorf("wrong-vm lookup = %v, want sql.ErrNoRows", err) - } -} - -func TestListGuestSessionsByVMOrdersByCreatedAt(t *testing.T) { - t.Parallel() - ctx := context.Background() - store := openTestStoreWithVMs(t, "vm-1", "vm-2") - - base := fixedTime() - first := sampleGuestSession("sess-early", "vm-1", "first") - first.CreatedAt = base - second := sampleGuestSession("sess-late", "vm-1", "second") - second.CreatedAt = base.Add(time.Hour) - other := sampleGuestSession("sess-other", "vm-2", "other") - - for _, s := range []model.GuestSession{second, first, other} { - if err := store.UpsertGuestSession(ctx, s); err != nil { - t.Fatalf("UpsertGuestSession: %v", err) - } - } - - sessions, err := store.ListGuestSessionsByVM(ctx, "vm-1") - if err != nil { - t.Fatalf("ListGuestSessionsByVM: %v", err) - } - if len(sessions) != 2 { - t.Fatalf("len = %d, want 2 (vm-1 only)", len(sessions)) - } - if sessions[0].ID != "sess-early" || sessions[1].ID != "sess-late" { - t.Fatalf("order: got %q, %q; want sess-early, sess-late", sessions[0].ID, sessions[1].ID) - } - - empty, err := store.ListGuestSessionsByVM(ctx, "vm-unknown") - if err != nil { - t.Fatalf("ListGuestSessionsByVM (unknown vm): %v", err) - } - if len(empty) != 0 { - t.Fatalf("unknown vm sessions = %+v, want empty", empty) - } -} - -func TestDeleteGuestSession(t *testing.T) { - t.Parallel() - ctx := context.Background() - store := openTestStoreWithVMs(t, "vm-1") - - session := sampleGuestSession("sess-1", "vm-1", "planner") - if err := store.UpsertGuestSession(ctx, session); err != nil { - t.Fatalf("UpsertGuestSession: %v", err) - } - if err := store.DeleteGuestSession(ctx, "sess-1"); err != nil { - t.Fatalf("DeleteGuestSession: %v", err) - } - if _, err := store.GetGuestSessionByID(ctx, "sess-1"); !errors.Is(err, sql.ErrNoRows) { - t.Fatalf("after delete err = %v, want sql.ErrNoRows", err) - } - - // Deleting something that doesn't exist is a no-op (matches SQL DELETE semantics). - if err := store.DeleteGuestSession(ctx, "sess-nope"); err != nil { - t.Fatalf("DeleteGuestSession on missing row: %v", err) - } -} diff --git a/internal/store/store.go b/internal/store/store.go index ca73a1d..6ac5a31 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -99,32 +99,6 @@ func (s *Store) migrate() error { stats_json TEXT NOT NULL DEFAULT '{}', FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE RESTRICT );`, - `CREATE TABLE IF NOT EXISTS guest_sessions ( - id TEXT PRIMARY KEY, - vm_id TEXT NOT NULL, - name TEXT NOT NULL, - backend TEXT NOT NULL, - command TEXT NOT NULL, - args_json TEXT NOT NULL DEFAULT '[]', - cwd TEXT, - env_json TEXT NOT NULL DEFAULT '{}', - stdin_mode TEXT NOT NULL, - status TEXT NOT NULL, - exit_code INTEGER, - guest_pid INTEGER NOT NULL DEFAULT 0, - guest_state_dir TEXT, - stdout_log_path TEXT, - stderr_log_path TEXT, - tags_json TEXT NOT NULL DEFAULT '{}', - last_error TEXT, - attachable INTEGER NOT NULL DEFAULT 0, - created_at TEXT NOT NULL, - started_at TEXT, - updated_at TEXT NOT NULL, - ended_at TEXT, - UNIQUE(vm_id, name), - FOREIGN KEY(vm_id) REFERENCES vms(id) ON DELETE CASCADE - );`, } for _, stmt := range stmts { if _, err := s.db.Exec(stmt); err != nil { @@ -137,18 +111,6 @@ func (s *Store) migrate() error { if err := ensureColumnExists(s.db, "images", "seeded_ssh_public_key_fingerprint", "TEXT"); err != nil { return err } - for _, spec := range []struct{ table, column, typ string }{ - {"guest_sessions", "attach_backend", "TEXT"}, - {"guest_sessions", "attach_mode", "TEXT"}, - {"guest_sessions", "reattachable", "INTEGER NOT NULL DEFAULT 0"}, - {"guest_sessions", "launch_stage", "TEXT"}, - {"guest_sessions", "launch_message", "TEXT"}, - {"guest_sessions", "launch_raw_log", "TEXT"}, - } { - if err := ensureColumnExists(s.db, spec.table, spec.column, spec.typ); err != nil { - return err - } - } return nil } @@ -336,122 +298,6 @@ func (s *Store) FindVMsUsingImage(ctx context.Context, imageID string) ([]model. return vms, rows.Err() } -func (s *Store) UpsertGuestSession(ctx context.Context, session model.GuestSession) error { - s.writeMu.Lock() - defer s.writeMu.Unlock() - argsJSON, err := json.Marshal(session.Args) - if err != nil { - return err - } - envJSON, err := json.Marshal(session.Env) - if err != nil { - return err - } - tagsJSON, err := json.Marshal(session.Tags) - if err != nil { - return err - } - const query = ` - INSERT INTO guest_sessions ( - id, vm_id, name, backend, attach_backend, attach_mode, command, args_json, cwd, env_json, stdin_mode, status, - exit_code, guest_pid, guest_state_dir, stdout_log_path, stderr_log_path, tags_json, - last_error, attachable, reattachable, launch_stage, launch_message, launch_raw_log, - created_at, started_at, updated_at, ended_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(id) DO UPDATE SET - vm_id=excluded.vm_id, - name=excluded.name, - backend=excluded.backend, - attach_backend=excluded.attach_backend, - attach_mode=excluded.attach_mode, - command=excluded.command, - args_json=excluded.args_json, - cwd=excluded.cwd, - env_json=excluded.env_json, - stdin_mode=excluded.stdin_mode, - status=excluded.status, - exit_code=excluded.exit_code, - guest_pid=excluded.guest_pid, - guest_state_dir=excluded.guest_state_dir, - stdout_log_path=excluded.stdout_log_path, - stderr_log_path=excluded.stderr_log_path, - tags_json=excluded.tags_json, - last_error=excluded.last_error, - attachable=excluded.attachable, - reattachable=excluded.reattachable, - launch_stage=excluded.launch_stage, - launch_message=excluded.launch_message, - launch_raw_log=excluded.launch_raw_log, - started_at=excluded.started_at, - updated_at=excluded.updated_at, - ended_at=excluded.ended_at` - _, err = s.db.ExecContext(ctx, query, - session.ID, - session.VMID, - session.Name, - session.Backend, - session.AttachBackend, - session.AttachMode, - session.Command, - string(argsJSON), - session.CWD, - string(envJSON), - string(session.StdinMode), - string(session.Status), - nullableInt(session.ExitCode), - session.GuestPID, - session.GuestStateDir, - session.StdoutLogPath, - session.StderrLogPath, - string(tagsJSON), - session.LastError, - boolToInt(session.Attachable), - boolToInt(session.Reattachable), - session.LaunchStage, - session.LaunchMessage, - session.LaunchRawLog, - session.CreatedAt.Format(time.RFC3339), - nullableTimeString(session.StartedAt), - session.UpdatedAt.Format(time.RFC3339), - nullableTimeString(session.EndedAt), - ) - return err -} - -func (s *Store) GetGuestSessionByID(ctx context.Context, id string) (model.GuestSession, error) { - row := s.db.QueryRowContext(ctx, guestSessionSelectSQL+" WHERE id = ?", id) - return scanGuestSessionRow(row) -} - -func (s *Store) GetGuestSession(ctx context.Context, vmID, idOrName string) (model.GuestSession, error) { - row := s.db.QueryRowContext(ctx, guestSessionSelectSQL+" WHERE vm_id = ? AND (id = ? OR name = ?)", vmID, idOrName, idOrName) - return scanGuestSessionRow(row) -} - -func (s *Store) ListGuestSessionsByVM(ctx context.Context, vmID string) ([]model.GuestSession, error) { - rows, err := s.db.QueryContext(ctx, guestSessionSelectSQL+" WHERE vm_id = ? ORDER BY created_at ASC", vmID) - if err != nil { - return nil, err - } - defer rows.Close() - var sessions []model.GuestSession - for rows.Next() { - session, err := scanGuestSession(rows) - if err != nil { - return nil, err - } - sessions = append(sessions, session) - } - return sessions, rows.Err() -} - -func (s *Store) DeleteGuestSession(ctx context.Context, id string) error { - s.writeMu.Lock() - defer s.writeMu.Unlock() - _, err := s.db.ExecContext(ctx, "DELETE FROM guest_sessions WHERE id = ?", id) - return err -} - func (s *Store) NextGuestIP(ctx context.Context, bridgeIPPrefix string) (string, error) { used := map[string]struct{}{} rows, err := s.db.QueryContext(ctx, "SELECT guest_ip FROM vms") @@ -622,113 +468,6 @@ func boolToInt(value bool) int { return 0 } -const guestSessionSelectSQL = ` -SELECT id, vm_id, name, backend, attach_backend, attach_mode, command, args_json, cwd, env_json, stdin_mode, status, - exit_code, guest_pid, guest_state_dir, stdout_log_path, stderr_log_path, tags_json, - last_error, attachable, reattachable, launch_stage, launch_message, launch_raw_log, - created_at, started_at, updated_at, ended_at -FROM guest_sessions` - -func scanGuestSession(rows scanner) (model.GuestSession, error) { - return scanGuestSessionRow(rows) -} - -func scanGuestSessionRow(row scanner) (model.GuestSession, error) { - var session model.GuestSession - var ( - argsJSON string - envJSON string - tagsJSON string - stdinMode string - status string - exitCode sql.NullInt64 - startedAt sql.NullString - endedAt sql.NullString - attachable int - reattachable int - createdRaw string - updatedRaw string - ) - err := row.Scan( - &session.ID, - &session.VMID, - &session.Name, - &session.Backend, - &session.AttachBackend, - &session.AttachMode, - &session.Command, - &argsJSON, - &session.CWD, - &envJSON, - &stdinMode, - &status, - &exitCode, - &session.GuestPID, - &session.GuestStateDir, - &session.StdoutLogPath, - &session.StderrLogPath, - &tagsJSON, - &session.LastError, - &attachable, - &reattachable, - &session.LaunchStage, - &session.LaunchMessage, - &session.LaunchRawLog, - &createdRaw, - &startedAt, - &updatedRaw, - &endedAt, - ) - if err != nil { - return session, err - } - session.StdinMode = model.GuestSessionStdinMode(stdinMode) - session.Status = model.GuestSessionStatus(status) - session.Attachable = attachable == 1 - session.Reattachable = reattachable == 1 - if argsJSON != "" { - if err := json.Unmarshal([]byte(argsJSON), &session.Args); err != nil { - return session, err - } - } - if envJSON != "" { - if err := json.Unmarshal([]byte(envJSON), &session.Env); err != nil { - return session, err - } - } - if tagsJSON != "" { - if err := json.Unmarshal([]byte(tagsJSON), &session.Tags); err != nil { - return session, err - } - } - if exitCode.Valid { - value := int(exitCode.Int64) - session.ExitCode = &value - } - var parseErr error - session.CreatedAt, parseErr = time.Parse(time.RFC3339, createdRaw) - if parseErr != nil { - return session, parseErr - } - session.UpdatedAt, parseErr = time.Parse(time.RFC3339, updatedRaw) - if parseErr != nil { - return session, parseErr - } - if startedAt.Valid && startedAt.String != "" { - session.StartedAt, parseErr = time.Parse(time.RFC3339, startedAt.String) - if parseErr != nil { - return session, parseErr - } - } - if endedAt.Valid && endedAt.String != "" { - session.EndedAt, parseErr = time.Parse(time.RFC3339, endedAt.String) - if parseErr != nil { - return session, parseErr - } - } - return session, nil -} - func nullableTimeString(value time.Time) any { if value.IsZero() { return nil