remove vm session feature
Cuts the daemon-managed guest-session machinery (start/list/show/
logs/stop/kill/attach/send). The feature shipped aimed at agent-
orchestration workflows (programmatic stdin piping into a long-lived
guest process) that aren't driving any concrete user today, and the
~2.3K LOC of daemon surface area — attach bridge, FIFO keepalive,
controller registry, sessionstream framing, SQLite persistence — was
locking in an API we'd have to keep through v0.1.0.
Anything session-flavoured that people actually need today can be
done with `vm ssh + tmux` or `vm run -- cmd`.
Deleted:
- internal/cli/commands_vm_session.go
- internal/daemon/{guest_sessions,session_lifecycle,session_attach,session_stream,session_controller}.go
- internal/daemon/session/ (guest-session helpers package)
- internal/sessionstream/ (framing package)
- internal/daemon/guest_sessions_test.go
- internal/store/guest_session_test.go
- GuestSession* types from internal/{api,model}
- Store UpsertGuestSession/GetGuestSession/ListGuestSessionsByVM/DeleteGuestSession + scanner helpers
- guest.session.* RPC dispatch entries
- 5 CLI session tests, 2 completion tests, 2 printer tests
Extracted:
- ShellQuote + FormatStepError lifted to internal/daemon/workspace/util.go
(only non-session consumer); workspace package now self-contained
- internal/daemon/guest_ssh.go keeps guestSSHClient + dialGuest +
waitForGuestSSH — still used by workspace prepare/export
- internal/daemon/fake_firecracker_test.go preserves the test helper
that used to live in guest_sessions_test.go
Store schema: CREATE TABLE guest_sessions and its column migrations
removed. Existing dev DBs keep an orphan table (harmless, pre-v0.1.0).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
c42fcbe012
commit
2b6437d1b4
34 changed files with 194 additions and 4031 deletions
|
|
@ -122,70 +122,6 @@ type VMPortsResult struct {
|
|||
Ports []VMPort `json:"ports"`
|
||||
}
|
||||
|
||||
type GuestSessionStartParams struct {
|
||||
VMIDOrName string `json:"vm_id_or_name"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
CWD string `json:"cwd,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
StdinMode string `json:"stdin_mode,omitempty"`
|
||||
Tags map[string]string `json:"tags,omitempty"`
|
||||
RequiredCommands []string `json:"required_commands,omitempty"`
|
||||
}
|
||||
|
||||
type GuestSessionRefParams struct {
|
||||
VMIDOrName string `json:"vm_id_or_name"`
|
||||
SessionIDOrName string `json:"session_id_or_name"`
|
||||
}
|
||||
|
||||
type GuestSessionLogsParams struct {
|
||||
VMIDOrName string `json:"vm_id_or_name"`
|
||||
SessionIDOrName string `json:"session_id_or_name"`
|
||||
Stream string `json:"stream,omitempty"`
|
||||
TailLines int `json:"tail_lines,omitempty"`
|
||||
}
|
||||
|
||||
type GuestSessionAttachBeginParams struct {
|
||||
VMIDOrName string `json:"vm_id_or_name"`
|
||||
SessionIDOrName string `json:"session_id_or_name"`
|
||||
}
|
||||
|
||||
type GuestSessionListResult struct {
|
||||
Sessions []model.GuestSession `json:"sessions"`
|
||||
}
|
||||
|
||||
type GuestSessionShowResult struct {
|
||||
Session model.GuestSession `json:"session"`
|
||||
}
|
||||
|
||||
type GuestSessionLogsResult struct {
|
||||
Session model.GuestSession `json:"session"`
|
||||
Stream string `json:"stream"`
|
||||
Path string `json:"path,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type GuestSessionAttachBeginResult struct {
|
||||
Session model.GuestSession `json:"session"`
|
||||
AttachID string `json:"attach_id"`
|
||||
TransportKind string `json:"transport_kind"`
|
||||
TransportTarget string `json:"transport_target"`
|
||||
SocketPath string `json:"socket_path,omitempty"`
|
||||
StreamFormat string `json:"stream_format"`
|
||||
}
|
||||
|
||||
type GuestSessionSendParams struct {
|
||||
VMIDOrName string `json:"vm_id_or_name"`
|
||||
SessionIDOrName string `json:"session_id_or_name"`
|
||||
Payload []byte `json:"payload"`
|
||||
}
|
||||
|
||||
type GuestSessionSendResult struct {
|
||||
Session model.GuestSession `json:"session"`
|
||||
BytesWritten int `json:"bytes_written"`
|
||||
}
|
||||
|
||||
type WorkspaceExportParams struct {
|
||||
IDOrName string `json:"id_or_name"`
|
||||
GuestPath string `json:"guest_path,omitempty"`
|
||||
|
|
|
|||
|
|
@ -46,7 +46,6 @@ func TestListCommandsHaveLsAlias(t *testing.T) {
|
|||
{"vm", "list"},
|
||||
{"image", "list"},
|
||||
{"kernel", "list"},
|
||||
{"vm", "session", "list"},
|
||||
}
|
||||
for _, path := range cases {
|
||||
t.Run(path[len(path)-1], func(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -1918,34 +1918,9 @@ func (c *testVMRunGuestClient) StreamTarEntries(ctx context.Context, sourceDir s
|
|||
return nil
|
||||
}
|
||||
|
||||
func TestVMSessionSendCommandExists(t *testing.T) {
|
||||
root := NewBangerCommand()
|
||||
vm, _, err := root.Find([]string{"vm"})
|
||||
if err != nil {
|
||||
t.Fatalf("find vm: %v", err)
|
||||
}
|
||||
session, _, err := vm.Find([]string{"session"})
|
||||
if err != nil {
|
||||
t.Fatalf("find session: %v", err)
|
||||
}
|
||||
if _, _, err := session.Find([]string{"send"}); err != nil {
|
||||
t.Fatalf("find session send: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMSessionSendRejectsWrongArgCount(t *testing.T) {
|
||||
cmd := NewBangerCommand()
|
||||
cmd.SetArgs([]string{"vm", "session", "send", "only-one-arg"})
|
||||
err := cmd.Execute()
|
||||
if err == nil || !strings.Contains(err.Error(), "usage: banger vm session send") {
|
||||
t.Fatalf("Execute() error = %v, want send usage error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// stubEnsureDaemonForSend isolates XDG dirs and installs a daemon-ping
|
||||
// fake onto the caller's *deps so `ensureDaemon` short-circuits without
|
||||
// trying to spawn bangerd. `vm session send` uses this to avoid needing
|
||||
// a built binary on disk.
|
||||
// trying to spawn bangerd.
|
||||
func stubEnsureDaemonForSend(t *testing.T, d *deps) {
|
||||
t.Helper()
|
||||
t.Setenv("XDG_CONFIG_HOME", filepath.Join(t.TempDir(), "config"))
|
||||
|
|
@ -1956,98 +1931,6 @@ func stubEnsureDaemonForSend(t *testing.T, d *deps) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestVMSessionSendWithMessageFlag(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubEnsureDaemonForSend(t, d)
|
||||
|
||||
var capturedParams api.GuestSessionSendParams
|
||||
d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
|
||||
capturedParams = params
|
||||
return api.GuestSessionSendResult{
|
||||
Session: model.GuestSession{ID: "sess-id", Name: "planner"},
|
||||
BytesWritten: len(params.Payload),
|
||||
}, nil
|
||||
}
|
||||
|
||||
cmd := d.newRootCommand()
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner", "--message", `{"type":"abort"}`})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
wantPayload := []byte(`{"type":"abort"}` + "\n")
|
||||
if string(capturedParams.Payload) != string(wantPayload) {
|
||||
t.Fatalf("payload = %q, want %q", capturedParams.Payload, wantPayload)
|
||||
}
|
||||
if capturedParams.VMIDOrName != "devbox" {
|
||||
t.Fatalf("VMIDOrName = %q, want %q", capturedParams.VMIDOrName, "devbox")
|
||||
}
|
||||
if capturedParams.SessionIDOrName != "planner" {
|
||||
t.Fatalf("SessionIDOrName = %q, want %q", capturedParams.SessionIDOrName, "planner")
|
||||
}
|
||||
if !strings.Contains(out.String(), "17") {
|
||||
t.Fatalf("output = %q, want bytes_written in output", out.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubEnsureDaemonForSend(t, d)
|
||||
|
||||
var capturedPayload []byte
|
||||
d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
|
||||
capturedPayload = params.Payload
|
||||
return api.GuestSessionSendResult{
|
||||
Session: model.GuestSession{Name: "s"},
|
||||
BytesWritten: len(params.Payload),
|
||||
}, nil
|
||||
}
|
||||
|
||||
cmd := d.newRootCommand()
|
||||
cmd.SetOut(io.Discard)
|
||||
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "s", "--message", "{\"type\":\"abort\"}\n"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
// Must not double-append newline.
|
||||
if capturedPayload[len(capturedPayload)-1] != '\n' {
|
||||
t.Fatalf("payload missing trailing newline: %q", capturedPayload)
|
||||
}
|
||||
if len(capturedPayload) > 0 && capturedPayload[len(capturedPayload)-2] == '\n' {
|
||||
t.Fatalf("payload has double trailing newline: %q", capturedPayload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMSessionSendFromStdin(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubEnsureDaemonForSend(t, d)
|
||||
|
||||
var capturedPayload []byte
|
||||
d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
|
||||
capturedPayload = params.Payload
|
||||
return api.GuestSessionSendResult{
|
||||
Session: model.GuestSession{Name: "planner"},
|
||||
BytesWritten: len(params.Payload),
|
||||
}, nil
|
||||
}
|
||||
|
||||
stdinPayload := `{"type":"steer","message":"Focus on src/"}` + "\n"
|
||||
cmd := d.newRootCommand()
|
||||
cmd.SetOut(io.Discard)
|
||||
cmd.SetIn(strings.NewReader(stdinPayload))
|
||||
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
if string(capturedPayload) != stdinPayload {
|
||||
t.Fatalf("payload = %q, want %q", capturedPayload, stdinPayload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMWorkspaceExportCommandExists(t *testing.T) {
|
||||
root := NewBangerCommand()
|
||||
vm, _, err := root.Find([]string{"vm"})
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ func (d *deps) newVMCommand() *cobra.Command {
|
|||
d.newVMSetCommand(),
|
||||
d.newVMSSHCommand(),
|
||||
d.newVMWorkspaceCommand(),
|
||||
d.newVMSessionCommand(),
|
||||
d.newVMLogsCommand(),
|
||||
d.newVMStatsCommand(),
|
||||
d.newVMPortsCommand(),
|
||||
|
|
|
|||
|
|
@ -1,370 +0,0 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"banger/internal/api"
|
||||
"banger/internal/model"
|
||||
"banger/internal/sessionstream"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func (d *deps) newVMSessionCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "session",
|
||||
Short: "Manage long-lived guest commands inside a VM",
|
||||
Long: "Start, inspect, stop, and attach to daemon-managed guest commands. Pipe-mode sessions expose live stdio for interactive protocols. Attach is exclusive and currently uses a same-host local bridge.",
|
||||
RunE: helpNoArgs,
|
||||
}
|
||||
cmd.AddCommand(
|
||||
d.newVMSessionStartCommand(),
|
||||
d.newVMSessionListCommand(),
|
||||
d.newVMSessionShowCommand(),
|
||||
d.newVMSessionLogsCommand(),
|
||||
d.newVMSessionStopCommand(),
|
||||
d.newVMSessionKillCommand(),
|
||||
d.newVMSessionAttachCommand(),
|
||||
d.newVMSessionSendCommand(),
|
||||
)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (d *deps) newVMSessionStartCommand() *cobra.Command {
|
||||
var name string
|
||||
var cwd string
|
||||
var stdinMode string
|
||||
var envPairs []string
|
||||
var tagPairs []string
|
||||
var requiredCommands []string
|
||||
cmd := &cobra.Command{
|
||||
Use: "start <id-or-name> <command> [args...]",
|
||||
Short: "Start a managed guest command",
|
||||
Long: "Start a daemon-managed guest command. The daemon verifies that the guest working directory exists and that the requested command is present in guest PATH before launch. Use --stdin-mode pipe when you need live attach.",
|
||||
Args: minArgsUsage(2, "usage: banger vm session start <id-or-name> [flags] -- <command> [args...]"),
|
||||
ValidArgsFunction: d.completeVMNameOnlyAtPos0,
|
||||
Example: strings.TrimSpace(`
|
||||
banger vm session start devbox --name planner --cwd /root/repo --stdin-mode pipe --require-command git -- pi --mode rpc --no-session
|
||||
banger vm session start devbox --name shell --stdin-mode pipe -- bash -lc 'exec bash'
|
||||
`),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
layout, _, err := d.ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
env, err := parseKeyValuePairs(envPairs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tags, err := parseKeyValuePairs(tagPairs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := d.guestSessionStart(cmd.Context(), layout.SocketPath, api.GuestSessionStartParams{
|
||||
VMIDOrName: args[0],
|
||||
Name: name,
|
||||
Command: args[1],
|
||||
Args: append([]string(nil), args[2:]...),
|
||||
CWD: cwd,
|
||||
Env: env,
|
||||
StdinMode: stdinMode,
|
||||
Tags: tags,
|
||||
RequiredCommands: append([]string(nil), requiredCommands...),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := printGuestSessionSummary(cmd.OutOrStdout(), result.Session); err != nil {
|
||||
return err
|
||||
}
|
||||
if result.Session.Status == model.GuestSessionStatusFailed && strings.TrimSpace(result.Session.LaunchMessage) != "" {
|
||||
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "warning: session failed at %s: %s\n", result.Session.LaunchStage, result.Session.LaunchMessage)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cmd.Flags().StringVar(&name, "name", "", "session name")
|
||||
cmd.Flags().StringVar(&cwd, "cwd", "", "guest working directory; must already exist")
|
||||
cmd.Flags().StringVar(&stdinMode, "stdin-mode", string(model.GuestSessionStdinClosed), "stdin mode: closed or pipe (pipe enables attach)")
|
||||
cmd.Flags().StringArrayVar(&envPairs, "env", nil, "environment entry in KEY=VALUE form")
|
||||
cmd.Flags().StringArrayVar(&tagPairs, "tag", nil, "session tag in KEY=VALUE form")
|
||||
cmd.Flags().StringArrayVar(&requiredCommands, "require-command", nil, "extra guest command that must exist in PATH before launch; repeatable")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (d *deps) newVMSessionListCommand() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "list <id-or-name>",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List managed guest commands for a VM",
|
||||
Args: exactArgsUsage(1, "usage: banger vm session list <id-or-name>"),
|
||||
ValidArgsFunction: d.completeVMNameOnlyAtPos0,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
layout, _, err := d.ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := d.guestSessionList(cmd.Context(), layout.SocketPath, args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return printGuestSessionTable(cmd.OutOrStdout(), result.Sessions)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (d *deps) newVMSessionShowCommand() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "show <id-or-name> <session>",
|
||||
Short: "Show managed guest command details",
|
||||
Args: exactArgsUsage(2, "usage: banger vm session show <id-or-name> <session>"),
|
||||
ValidArgsFunction: d.completeSessionNames,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
layout, _, err := d.ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := d.guestSessionGet(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return printJSON(cmd.OutOrStdout(), result.Session)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (d *deps) newVMSessionLogsCommand() *cobra.Command {
|
||||
var stream string
|
||||
var tailLines int
|
||||
cmd := &cobra.Command{
|
||||
Use: "logs <id-or-name> <session>",
|
||||
Short: "Show stdout or stderr for a guest session",
|
||||
Args: exactArgsUsage(2, "usage: banger vm session logs [--stream stdout|stderr] [-n LINES] <id-or-name> <session>"),
|
||||
ValidArgsFunction: d.completeSessionNames,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
layout, _, err := d.ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := d.guestSessionLogs(cmd.Context(), layout.SocketPath, api.GuestSessionLogsParams{VMIDOrName: args[0], SessionIDOrName: args[1], Stream: stream, TailLines: tailLines})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprint(cmd.OutOrStdout(), result.Content)
|
||||
return err
|
||||
},
|
||||
}
|
||||
cmd.Flags().StringVar(&stream, "stream", "stdout", "log stream to read")
|
||||
cmd.Flags().IntVarP(&tailLines, "lines", "n", 200, "number of lines to tail")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (d *deps) newVMSessionStopCommand() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "stop <id-or-name> <session>",
|
||||
Short: "Send SIGTERM to a guest session",
|
||||
Args: exactArgsUsage(2, "usage: banger vm session stop <id-or-name> <session>"),
|
||||
ValidArgsFunction: d.completeSessionNames,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
layout, _, err := d.ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := d.guestSessionStop(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return printGuestSessionSummary(cmd.OutOrStdout(), result.Session)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (d *deps) newVMSessionKillCommand() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "kill <id-or-name> <session>",
|
||||
Short: "Send SIGKILL to a guest session",
|
||||
Args: exactArgsUsage(2, "usage: banger vm session kill <id-or-name> <session>"),
|
||||
ValidArgsFunction: d.completeSessionNames,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
layout, _, err := d.ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := d.guestSessionKill(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return printGuestSessionSummary(cmd.OutOrStdout(), result.Session)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (d *deps) newVMSessionAttachCommand() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "attach <id-or-name> <session>",
|
||||
Short: "Attach local stdio to an attachable guest session",
|
||||
Long: "Attach local stdio to a pipe-mode session through a daemon-created local Unix socket bridge. Only one active attach is allowed at a time, and the client must run on the same host as the daemon.",
|
||||
Args: exactArgsUsage(2, "usage: banger vm session attach <id-or-name> <session>"),
|
||||
ValidArgsFunction: d.completeSessionNames,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
layout, _, err := d.ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := d.guestSessionAttachBegin(cmd.Context(), layout.SocketPath, api.GuestSessionAttachBeginParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
socketPath := strings.TrimSpace(result.SocketPath)
|
||||
if socketPath == "" && result.TransportKind == "unix_socket" {
|
||||
socketPath = strings.TrimSpace(result.TransportTarget)
|
||||
}
|
||||
return runGuestSessionAttach(cmd.Context(), cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), socketPath)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (d *deps) newVMSessionSendCommand() *cobra.Command {
|
||||
var message string
|
||||
cmd := &cobra.Command{
|
||||
Use: "send <id-or-name> <session>",
|
||||
Short: "Write bytes to a running guest session's stdin pipe",
|
||||
Long: "Write a payload to the stdin pipe of a running pipe-mode guest session without holding the exclusive attach. Use --message for an inline JSONL string, or pipe bytes via stdin when --message is omitted. A trailing newline is appended to --message values that lack one.",
|
||||
Args: exactArgsUsage(2, "usage: banger vm session send <id-or-name> <session> [--message '<json>']"),
|
||||
ValidArgsFunction: d.completeSessionNames,
|
||||
Example: strings.TrimSpace(`
|
||||
banger vm session send devbox planner --message '{"type":"abort"}'
|
||||
banger vm session send devbox planner --message '{"type":"steer","message":"Focus on src/"}'
|
||||
echo '{"type":"prompt","prompt":"Summarize."}' | banger vm session send devbox planner
|
||||
`),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
layout, _, err := d.ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var payload []byte
|
||||
if message != "" {
|
||||
payload = []byte(message)
|
||||
if len(payload) > 0 && payload[len(payload)-1] != '\n' {
|
||||
payload = append(payload, '\n')
|
||||
}
|
||||
} else {
|
||||
payload, err = io.ReadAll(cmd.InOrStdin())
|
||||
if err != nil {
|
||||
return fmt.Errorf("read stdin: %w", err)
|
||||
}
|
||||
}
|
||||
result, err := d.guestSessionSend(cmd.Context(), layout.SocketPath, api.GuestSessionSendParams{
|
||||
VMIDOrName: args[0],
|
||||
SessionIDOrName: args[1],
|
||||
Payload: payload,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintf(cmd.OutOrStdout(), "sent %d bytes to session %s\n", result.BytesWritten, result.Session.Name)
|
||||
return err
|
||||
},
|
||||
}
|
||||
cmd.Flags().StringVar(&message, "message", "", "JSONL message to send; a trailing newline is appended if absent")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func parseKeyValuePairs(values []string) (map[string]string, error) {
|
||||
if len(values) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
result := make(map[string]string, len(values))
|
||||
for _, value := range values {
|
||||
key, raw, ok := strings.Cut(value, "=")
|
||||
if !ok || strings.TrimSpace(key) == "" {
|
||||
return nil, fmt.Errorf("invalid key=value entry %q", value)
|
||||
}
|
||||
result[strings.TrimSpace(key)] = raw
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func runGuestSessionAttach(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, socketPath string) error {
|
||||
conn, err := (&net.Dialer{}).DialContext(ctx, "unix", socketPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
writeErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
writeErrCh <- streamGuestSessionAttachInput(conn, stdin)
|
||||
}()
|
||||
for {
|
||||
channel, payload, err := sessionstream.ReadFrame(conn)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
switch channel {
|
||||
case sessionstream.ChannelStdout:
|
||||
if _, err := stdout.Write(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
case sessionstream.ChannelStderr:
|
||||
if _, err := stderr.Write(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
case sessionstream.ChannelControl:
|
||||
message, err := sessionstream.ReadControl(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch message.Type {
|
||||
case "exit":
|
||||
if message.ExitCode != nil && *message.ExitCode != 0 {
|
||||
return fmt.Errorf("guest session exited with code %d", *message.ExitCode)
|
||||
}
|
||||
return nil
|
||||
case "error":
|
||||
if strings.TrimSpace(message.Error) == "" {
|
||||
return errors.New("guest session attach failed")
|
||||
}
|
||||
return errors.New(message.Error)
|
||||
}
|
||||
}
|
||||
select {
|
||||
case err := <-writeErrCh:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func streamGuestSessionAttachInput(conn net.Conn, stdin io.Reader) error {
|
||||
if stdin == nil {
|
||||
return sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "eof"})
|
||||
}
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := stdin.Read(buffer)
|
||||
if n > 0 {
|
||||
if writeErr := sessionstream.WriteFrame(conn, sessionstream.ChannelStdin, buffer[:n]); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "eof"})
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -21,9 +21,9 @@ import (
|
|||
// - Fail silently. Completion is advisory; any error path returns an
|
||||
// empty suggestion list rather than propagating to the user.
|
||||
|
||||
// defaultCompletionLister + defaultCompletionSessionLister back the
|
||||
// corresponding *deps fields; tests inject their own fakes via the
|
||||
// struct instead of mutating package-level vars.
|
||||
// defaultCompletionLister backs the *deps.completionLister field;
|
||||
// tests inject their own fake via the struct instead of mutating
|
||||
// package-level vars.
|
||||
func defaultCompletionLister(ctx context.Context, socketPath, method string) ([]string, error) {
|
||||
switch method {
|
||||
case "vm.list":
|
||||
|
|
@ -66,20 +66,6 @@ func defaultCompletionLister(ctx context.Context, socketPath, method string) ([]
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func defaultCompletionSessionLister(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) {
|
||||
result, err := rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: vmIDOrName})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
names := make([]string, 0, len(result.Sessions))
|
||||
for _, session := range result.Sessions {
|
||||
if session.Name != "" {
|
||||
names = append(names, session.Name)
|
||||
}
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
// daemonSocketForCompletion returns the socket path IFF the daemon is
|
||||
// already running. Returns "", false when no daemon is up — completion
|
||||
// callers use this as the bail signal.
|
||||
|
|
@ -177,25 +163,3 @@ func (d *deps) completeKernelNames(cmd *cobra.Command, args []string, toComplete
|
|||
}
|
||||
return filterPrefix(names, args, toComplete), cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
|
||||
// completeSessionNames handles `... <vm> <session>` commands: pos 0
|
||||
// completes VMs, pos 1 completes sessions owned by args[0], pos 2+ is
|
||||
// silent.
|
||||
func (d *deps) completeSessionNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
switch len(args) {
|
||||
case 0:
|
||||
return d.completeVMNames(cmd, args, toComplete)
|
||||
case 1:
|
||||
socket, ok := d.daemonSocketForCompletion(cmd.Context())
|
||||
if !ok {
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
names, err := d.completionSessionLister(cmd.Context(), socket, args[0])
|
||||
if err != nil {
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
return filterPrefix(names, nil, toComplete), cobra.ShellCompDirectiveNoFileComp
|
||||
default:
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,10 +19,7 @@ func stubCompletionSeams(
|
|||
d *deps,
|
||||
pingErr error,
|
||||
names map[string][]string,
|
||||
listErr error,
|
||||
sessions map[string][]string,
|
||||
sessionErr error,
|
||||
) {
|
||||
listErr error) {
|
||||
t.Helper()
|
||||
|
||||
d.daemonPing = func(ctx context.Context, socketPath string) (api.PingResult, error) {
|
||||
|
|
@ -37,12 +34,6 @@ func stubCompletionSeams(
|
|||
}
|
||||
return names[method], nil
|
||||
}
|
||||
d.completionSessionLister = func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) {
|
||||
if sessionErr != nil {
|
||||
return nil, sessionErr
|
||||
}
|
||||
return sessions[vmIDOrName], nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterPrefix(t *testing.T) {
|
||||
|
|
@ -82,7 +73,7 @@ func testCmdWithCtx() *cobra.Command {
|
|||
|
||||
func TestCompleteVMNamesHappyPath(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil)
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil)
|
||||
|
||||
got, directive := d.completeVMNames(testCmdWithCtx(), nil, "")
|
||||
if directive != cobra.ShellCompDirectiveNoFileComp {
|
||||
|
|
@ -95,7 +86,7 @@ func TestCompleteVMNamesHappyPath(t *testing.T) {
|
|||
|
||||
func TestCompleteVMNamesDaemonDown(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubCompletionSeams(t, d, errors.New("connection refused"), nil, nil, nil, nil)
|
||||
stubCompletionSeams(t, d, errors.New("connection refused"), nil, nil)
|
||||
|
||||
got, directive := d.completeVMNames(testCmdWithCtx(), nil, "")
|
||||
if len(got) != 0 {
|
||||
|
|
@ -108,7 +99,7 @@ func TestCompleteVMNamesDaemonDown(t *testing.T) {
|
|||
|
||||
func TestCompleteVMNamesRPCError(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubCompletionSeams(t, d, nil, nil, errors.New("rpc failed"), nil, nil)
|
||||
stubCompletionSeams(t, d, nil, nil, errors.New("rpc failed"))
|
||||
|
||||
got, _ := d.completeVMNames(testCmdWithCtx(), nil, "")
|
||||
if len(got) != 0 {
|
||||
|
|
@ -118,7 +109,7 @@ func TestCompleteVMNamesRPCError(t *testing.T) {
|
|||
|
||||
func TestCompleteVMNamesExcludesAlreadyEntered(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil)
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil)
|
||||
|
||||
got, _ := d.completeVMNames(testCmdWithCtx(), []string{"alpha"}, "")
|
||||
want := []string{"beta", "gamma"}
|
||||
|
|
@ -129,7 +120,7 @@ func TestCompleteVMNamesExcludesAlreadyEntered(t *testing.T) {
|
|||
|
||||
func TestCompleteVMNamesPrefixFilter(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "alphabet"}}, nil, nil, nil)
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "alphabet"}}, nil)
|
||||
|
||||
got, _ := d.completeVMNames(testCmdWithCtx(), nil, "alp")
|
||||
want := []string{"alpha", "alphabet"}
|
||||
|
|
@ -140,7 +131,7 @@ func TestCompleteVMNamesPrefixFilter(t *testing.T) {
|
|||
|
||||
func TestCompleteVMNameOnlyAtPos0(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha"}}, nil, nil, nil)
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha"}}, nil)
|
||||
|
||||
atPos0, _ := d.completeVMNameOnlyAtPos0(testCmdWithCtx(), nil, "")
|
||||
if len(atPos0) != 1 || atPos0[0] != "alpha" {
|
||||
|
|
@ -155,7 +146,7 @@ func TestCompleteVMNameOnlyAtPos0(t *testing.T) {
|
|||
|
||||
func TestCompleteImageNames(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"debian-bookworm", "alpine"}}, nil, nil, nil)
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"debian-bookworm", "alpine"}}, nil)
|
||||
|
||||
got, _ := d.completeImageNames(testCmdWithCtx(), nil, "")
|
||||
if !reflect.DeepEqual(got, []string{"debian-bookworm", "alpine"}) {
|
||||
|
|
@ -165,7 +156,7 @@ func TestCompleteImageNames(t *testing.T) {
|
|||
|
||||
func TestCompleteKernelNames(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"kernel.list": {"generic-6.12"}}, nil, nil, nil)
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"kernel.list": {"generic-6.12"}}, nil)
|
||||
|
||||
got, _ := d.completeKernelNames(testCmdWithCtx(), nil, "")
|
||||
if len(got) != 1 || got[0] != "generic-6.12" {
|
||||
|
|
@ -175,58 +166,10 @@ func TestCompleteKernelNames(t *testing.T) {
|
|||
|
||||
func TestCompleteImageNameOnlyAtPos0SilentAfterFirst(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"alpine"}}, nil, nil, nil)
|
||||
stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"alpine"}}, nil)
|
||||
|
||||
after, _ := d.completeImageNameOnlyAtPos0(testCmdWithCtx(), []string{"alpine"}, "")
|
||||
if len(after) != 0 {
|
||||
t.Errorf("expected silence at pos 1+, got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteSessionNames(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubCompletionSeams(t, d,
|
||||
nil,
|
||||
map[string][]string{"vm.list": {"devbox"}},
|
||||
nil,
|
||||
map[string][]string{"devbox": {"planner", "worker"}},
|
||||
nil,
|
||||
)
|
||||
|
||||
// Position 0 → VMs.
|
||||
vms, _ := d.completeSessionNames(testCmdWithCtx(), nil, "")
|
||||
if len(vms) != 1 || vms[0] != "devbox" {
|
||||
t.Errorf("pos 0: got %v", vms)
|
||||
}
|
||||
|
||||
// Position 1 → sessions scoped to args[0].
|
||||
sessions, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "")
|
||||
if !reflect.DeepEqual(sessions, []string{"planner", "worker"}) {
|
||||
t.Errorf("pos 1: got %v", sessions)
|
||||
}
|
||||
|
||||
// Position 1 with prefix filter.
|
||||
filtered, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "wor")
|
||||
if len(filtered) != 1 || filtered[0] != "worker" {
|
||||
t.Errorf("pos 1 prefix: got %v", filtered)
|
||||
}
|
||||
|
||||
// Position 2+ silent.
|
||||
past, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox", "planner"}, "")
|
||||
if len(past) != 0 {
|
||||
t.Errorf("pos 2+: got %v", past)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteSessionNamesDaemonDown(t *testing.T) {
|
||||
d := defaultDeps()
|
||||
stubCompletionSeams(t, d, errors.New("down"), nil, nil, nil, nil)
|
||||
|
||||
got, directive := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "")
|
||||
if len(got) != 0 {
|
||||
t.Errorf("expected no suggestions when daemon down, got %v", got)
|
||||
}
|
||||
if directive != cobra.ShellCompDirectiveNoFileComp {
|
||||
t.Errorf("directive = %d, want NoFileComp", directive)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,36 +31,27 @@ import (
|
|||
// validators) stay package-level because they hold no references to
|
||||
// external systems.
|
||||
type deps struct {
|
||||
bangerdPath func() (string, error)
|
||||
daemonExePath func(pid int) string
|
||||
doctor func(ctx context.Context) (system.Report, error)
|
||||
sshExec func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error
|
||||
hostCommandOutput func(ctx context.Context, name string, args ...string) ([]byte, error)
|
||||
vmHealth func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error)
|
||||
vmSSH func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error)
|
||||
vmDelete func(ctx context.Context, socketPath, idOrName string) error
|
||||
vmList func(ctx context.Context, socketPath string) (api.VMListResult, error)
|
||||
daemonPing func(ctx context.Context, socketPath string) (api.PingResult, error)
|
||||
vmCreateBegin func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error)
|
||||
vmCreateStatus func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error)
|
||||
vmCreateCancel func(ctx context.Context, socketPath, operationID string) error
|
||||
vmPorts func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error)
|
||||
vmWorkspacePrepare func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error)
|
||||
vmWorkspaceExport func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error)
|
||||
guestSessionStart func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error)
|
||||
guestSessionGet func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error)
|
||||
guestSessionList func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error)
|
||||
guestSessionStop func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error)
|
||||
guestSessionKill func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error)
|
||||
guestSessionLogs func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error)
|
||||
guestSessionAttachBegin func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error)
|
||||
guestSessionSend func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error)
|
||||
guestWaitForSSH func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error
|
||||
guestDial func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error)
|
||||
buildVMRunToolingPlan func(ctx context.Context, repoRoot string) toolingplan.Plan
|
||||
cwd func() (string, error)
|
||||
completionLister func(ctx context.Context, socketPath, method string) ([]string, error)
|
||||
completionSessionLister func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error)
|
||||
bangerdPath func() (string, error)
|
||||
daemonExePath func(pid int) string
|
||||
doctor func(ctx context.Context) (system.Report, error)
|
||||
sshExec func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error
|
||||
hostCommandOutput func(ctx context.Context, name string, args ...string) ([]byte, error)
|
||||
vmHealth func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error)
|
||||
vmSSH func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error)
|
||||
vmDelete func(ctx context.Context, socketPath, idOrName string) error
|
||||
vmList func(ctx context.Context, socketPath string) (api.VMListResult, error)
|
||||
daemonPing func(ctx context.Context, socketPath string) (api.PingResult, error)
|
||||
vmCreateBegin func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error)
|
||||
vmCreateStatus func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error)
|
||||
vmCreateCancel func(ctx context.Context, socketPath, operationID string) error
|
||||
vmPorts func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error)
|
||||
vmWorkspacePrepare func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error)
|
||||
vmWorkspaceExport func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error)
|
||||
guestWaitForSSH func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error
|
||||
guestDial func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error)
|
||||
buildVMRunToolingPlan func(ctx context.Context, repoRoot string) toolingplan.Plan
|
||||
cwd func() (string, error)
|
||||
completionLister func(ctx context.Context, socketPath, method string) ([]string, error)
|
||||
}
|
||||
|
||||
func defaultDeps() *deps {
|
||||
|
|
@ -125,30 +116,6 @@ func defaultDeps() *deps {
|
|||
vmWorkspaceExport: func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
|
||||
return rpc.Call[api.WorkspaceExportResult](ctx, socketPath, "vm.workspace.export", params)
|
||||
},
|
||||
guestSessionStart: func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error) {
|
||||
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.start", params)
|
||||
},
|
||||
guestSessionGet: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
|
||||
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.get", params)
|
||||
},
|
||||
guestSessionList: func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error) {
|
||||
return rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: idOrName})
|
||||
},
|
||||
guestSessionStop: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
|
||||
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.stop", params)
|
||||
},
|
||||
guestSessionKill: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
|
||||
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.kill", params)
|
||||
},
|
||||
guestSessionLogs: func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) {
|
||||
return rpc.Call[api.GuestSessionLogsResult](ctx, socketPath, "guest.session.logs", params)
|
||||
},
|
||||
guestSessionAttachBegin: func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) {
|
||||
return rpc.Call[api.GuestSessionAttachBeginResult](ctx, socketPath, "guest.session.attach.begin", params)
|
||||
},
|
||||
guestSessionSend: func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
|
||||
return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params)
|
||||
},
|
||||
guestWaitForSSH: func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
|
||||
knownHosts, _ := bangerKnownHostsPath()
|
||||
return guest.WaitForSSH(ctx, address, privateKeyPath, knownHosts, interval)
|
||||
|
|
@ -157,9 +124,8 @@ func defaultDeps() *deps {
|
|||
knownHosts, _ := bangerKnownHostsPath()
|
||||
return guest.Dial(ctx, address, privateKeyPath, knownHosts)
|
||||
},
|
||||
buildVMRunToolingPlan: toolingplan.Build,
|
||||
cwd: os.Getwd,
|
||||
completionLister: defaultCompletionLister,
|
||||
completionSessionLister: defaultCompletionSessionLister,
|
||||
buildVMRunToolingPlan: toolingplan.Build,
|
||||
cwd: os.Getwd,
|
||||
completionLister: defaultCompletionLister,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import (
|
|||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
|
@ -52,37 +51,6 @@ func TestDashIfEmpty(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestParseKeyValuePairs(t *testing.T) {
|
||||
t.Run("nil when empty", func(t *testing.T) {
|
||||
got, err := parseKeyValuePairs(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("got %v, want nil", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("parses entries", func(t *testing.T) {
|
||||
got, err := parseKeyValuePairs([]string{"a=1", " b = two", "c=x=y"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
want := map[string]string{"a": "1", "b": " two", "c": "x=y"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("got %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects malformed entries", func(t *testing.T) {
|
||||
for _, bad := range []string{"noequals", "=noKey", " =v"} {
|
||||
if _, err := parseKeyValuePairs([]string{bad}); err == nil {
|
||||
t.Errorf("expected error for %q", bad)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExitCodeErrorError(t *testing.T) {
|
||||
e := ExitCodeError{Code: 42}
|
||||
got := e.Error()
|
||||
|
|
@ -234,38 +202,6 @@ func TestPrintKernelCatalogTable(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPrintGuestSessionTable(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
sessions := []model.GuestSession{
|
||||
{ID: "abcdef0123456789", Name: "planner", Status: "running", Command: "pi", CWD: "/root/repo", Attachable: true},
|
||||
{ID: "short", Name: "once", Status: "exited", Command: "true", CWD: "/tmp", Attachable: false},
|
||||
}
|
||||
if err := printGuestSessionTable(&buf, sessions); err != nil {
|
||||
t.Fatalf("printGuestSessionTable: %v", err)
|
||||
}
|
||||
got := buf.String()
|
||||
for _, want := range []string{"ID", "NAME", "planner", "once", "yes", "no", "pi"} {
|
||||
if !strings.Contains(got, want) {
|
||||
t.Errorf("output missing %q:\n%s", want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintGuestSessionSummary(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
session := model.GuestSession{
|
||||
ID: "id1", Name: "s", Status: "exited", Command: "true", CWD: "/root",
|
||||
}
|
||||
if err := printGuestSessionSummary(&buf, session); err != nil {
|
||||
t.Fatalf("printGuestSessionSummary: %v", err)
|
||||
}
|
||||
got := buf.String()
|
||||
fields := strings.Split(strings.TrimRight(got, "\n"), "\t")
|
||||
if len(fields) != 5 {
|
||||
t.Fatalf("expected 5 tab-separated fields, got %d: %q", len(fields), got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintJSON(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
if err := printJSON(&buf, map[string]int{"a": 1, "b": 2}); err != nil {
|
||||
|
|
@ -340,10 +276,6 @@ type failWriter struct{}
|
|||
func (failWriter) Write([]byte) (int, error) { return 0, fmt.Errorf("boom") }
|
||||
|
||||
func TestPrintersPropagateWriteErrors(t *testing.T) {
|
||||
sessions := []model.GuestSession{{ID: "id", Name: "n"}}
|
||||
if err := printGuestSessionTable(failWriter{}, sessions); err == nil {
|
||||
t.Error("expected write error from printGuestSessionTable")
|
||||
}
|
||||
kernels := []api.KernelEntry{{Name: "k"}}
|
||||
if err := printKernelListTable(failWriter{}, kernels); err == nil {
|
||||
t.Error("expected write error from printKernelListTable")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package cli
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
|
@ -276,30 +275,6 @@ func printKernelCatalogTable(out anyWriter, entries []api.KernelCatalogEntry) er
|
|||
return w.Flush()
|
||||
}
|
||||
|
||||
// -- guest session printers -----------------------------------------
|
||||
|
||||
func printGuestSessionSummary(out anyWriter, session model.GuestSession) error {
|
||||
_, err := fmt.Fprintf(out, "%s\t%s\t%s\t%s\t%s\n", session.ID, session.Name, session.Status, session.Command, session.CWD)
|
||||
return err
|
||||
}
|
||||
|
||||
func printGuestSessionTable(out io.Writer, sessions []model.GuestSession) error {
|
||||
tw := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0)
|
||||
if _, err := fmt.Fprintln(tw, "ID\tNAME\tSTATUS\tATTACH\tCOMMAND\tCWD"); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, session := range sessions {
|
||||
attach := "no"
|
||||
if session.Attachable {
|
||||
attach = "yes"
|
||||
}
|
||||
if _, err := fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n", shortID(session.ID), session.Name, session.Status, attach, session.Command, session.CWD); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tw.Flush()
|
||||
}
|
||||
|
||||
// -- doctor printer -------------------------------------------------
|
||||
|
||||
func printDoctorReport(out anyWriter, report system.Report) error {
|
||||
|
|
|
|||
|
|
@ -32,13 +32,11 @@ owning types:
|
|||
- `createOps opstate.Registry[*vmCreateOperationState]` — in-flight VM
|
||||
create operations; owns its own lock.
|
||||
- `tapPool tapPool` — TAP interface pool; owns its own lock.
|
||||
- `sessions sessionRegistry` — active guest session controllers; owns
|
||||
its own lock.
|
||||
- `listener`, `vmDNS` — networking.
|
||||
- `vmCaps` — registered VM capability hooks.
|
||||
- `pullAndFlatten`, `finalizePulledRootfs`, `bundleFetch`,
|
||||
`requestHandler`, `guestWaitForSSH`, `guestDial`,
|
||||
`waitForGuestSessionReady` — injectable seams used by tests.
|
||||
`workspaceInspectRepo`, `workspaceImport` — injectable seams used by tests.
|
||||
|
||||
## Subpackages
|
||||
|
||||
|
|
@ -53,11 +51,9 @@ state beyond small test seams.
|
|||
| `internal/daemon/dmsnap` | Device-mapper COW snapshot create/cleanup/remove. |
|
||||
| `internal/daemon/fcproc` | Firecracker process primitives (bridge, tap, binary, PID, kill, wait). |
|
||||
| `internal/daemon/imagemgr` | Image subsystem pure helpers: validators, staging, build script gen. |
|
||||
| `internal/daemon/session` | Guest-session helpers: state paths, scripts, parsing, utilities. |
|
||||
| `internal/daemon/workspace` | Workspace helpers: git inspection, copy prep, guest import script. |
|
||||
|
||||
`workspace` imports `session` for `ShellQuote` and `FormatStepError`; all
|
||||
other subpackages are leaves (no other intra-daemon subpackage imports).
|
||||
All subpackages are leaves — no intra-daemon subpackage imports another.
|
||||
|
||||
## Lock ordering
|
||||
|
||||
|
|
@ -73,9 +69,8 @@ time. `workspace.prepare` acquires `vmLocks[id]` just long enough to
|
|||
validate VM state, releases it, then acquires `workspaceLocks[id]`
|
||||
for the guest I/O phase.
|
||||
|
||||
Subsystem-local locks (`tapPool.mu`, `sessionRegistry.mu`,
|
||||
`opstate.Registry` mu, `guestSessionController.attachMu` /
|
||||
`writeMu`) are leaves. They do not contend with each other.
|
||||
Subsystem-local locks (`tapPool.mu`, `opstate.Registry` mu) are leaves.
|
||||
They do not contend with each other.
|
||||
|
||||
Notes:
|
||||
|
||||
|
|
|
|||
|
|
@ -51,24 +51,22 @@ type Daemon struct {
|
|||
// See internal/daemon/vm_handles.go — persistent durable state
|
||||
// lives in the store, this is rebuildable from a per-VM
|
||||
// handles.json scratch file and OS inspection.
|
||||
handles *handleCache
|
||||
sessions sessionRegistry
|
||||
tapPool tapPool
|
||||
closing chan struct{}
|
||||
once sync.Once
|
||||
pid int
|
||||
listener net.Listener
|
||||
vmDNS *vmdns.Server
|
||||
vmCaps []vmCapability
|
||||
pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error)
|
||||
finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error
|
||||
bundleFetch func(ctx context.Context, destDir string, entry imagecat.CatEntry) (imagecat.Manifest, error)
|
||||
requestHandler func(context.Context, rpc.Request) rpc.Response
|
||||
guestWaitForSSH func(context.Context, string, string, time.Duration) error
|
||||
guestDial func(context.Context, string, string) (guestSSHClient, error)
|
||||
waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error)
|
||||
workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error)
|
||||
workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error
|
||||
handles *handleCache
|
||||
tapPool tapPool
|
||||
closing chan struct{}
|
||||
once sync.Once
|
||||
pid int
|
||||
listener net.Listener
|
||||
vmDNS *vmdns.Server
|
||||
vmCaps []vmCapability
|
||||
pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error)
|
||||
finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error
|
||||
bundleFetch func(ctx context.Context, destDir string, entry imagecat.CatEntry) (imagecat.Manifest, error)
|
||||
requestHandler func(context.Context, rpc.Request) rpc.Response
|
||||
guestWaitForSSH func(context.Context, string, string, time.Duration) error
|
||||
guestDial func(context.Context, string, string) (guestSSHClient, error)
|
||||
workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error)
|
||||
workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error
|
||||
}
|
||||
|
||||
func Open(ctx context.Context) (d *Daemon, err error) {
|
||||
|
|
@ -93,15 +91,14 @@ func Open(ctx context.Context) (d *Daemon, err error) {
|
|||
return nil, err
|
||||
}
|
||||
d = &Daemon{
|
||||
layout: layout,
|
||||
config: cfg,
|
||||
store: db,
|
||||
runner: system.NewRunner(),
|
||||
logger: logger,
|
||||
closing: make(chan struct{}),
|
||||
pid: os.Getpid(),
|
||||
handles: newHandleCache(),
|
||||
sessions: newSessionRegistry(),
|
||||
layout: layout,
|
||||
config: cfg,
|
||||
store: db,
|
||||
runner: system.NewRunner(),
|
||||
logger: logger,
|
||||
closing: make(chan struct{}),
|
||||
pid: os.Getpid(),
|
||||
handles: newHandleCache(),
|
||||
}
|
||||
// From here on, every failure path must run Close() so the host
|
||||
// state we touched (DNS listener goroutine, resolvectl routing,
|
||||
|
|
@ -144,7 +141,7 @@ func (d *Daemon) Close() error {
|
|||
if d.listener != nil {
|
||||
_ = d.listener.Close()
|
||||
}
|
||||
err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.closeGuestSessionControllers(), d.store.Close())
|
||||
err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.store.Close())
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
|
@ -424,62 +421,6 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
|
|||
}
|
||||
result, err := d.ExportVMWorkspace(ctx, params)
|
||||
return marshalResultOrError(result, err)
|
||||
case "guest.session.start":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionStartParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
session, err := d.StartGuestSession(ctx, params)
|
||||
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
|
||||
case "guest.session.get":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
session, err := d.GetGuestSession(ctx, params)
|
||||
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
|
||||
case "guest.session.list":
|
||||
params, err := rpc.DecodeParams[api.VMRefParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
sessions, err := d.ListGuestSessions(ctx, params)
|
||||
return marshalResultOrError(api.GuestSessionListResult{Sessions: sessions}, err)
|
||||
case "guest.session.stop":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
session, err := d.StopGuestSession(ctx, params)
|
||||
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
|
||||
case "guest.session.kill":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
session, err := d.KillGuestSession(ctx, params)
|
||||
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
|
||||
case "guest.session.logs":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionLogsParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
result, err := d.GuestSessionLogs(ctx, params)
|
||||
return marshalResultOrError(result, err)
|
||||
case "guest.session.attach.begin":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionAttachBeginParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
result, err := d.BeginGuestSessionAttach(ctx, params)
|
||||
return marshalResultOrError(result, err)
|
||||
case "guest.session.send":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionSendParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
result, err := d.SendToGuestSession(ctx, params)
|
||||
return marshalResultOrError(result, err)
|
||||
case "image.list":
|
||||
images, err := d.store.ListImages(ctx)
|
||||
return marshalResultOrError(api.ImageListResult{Images: images}, err)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
// Package daemon hosts the Banger daemon process.
|
||||
//
|
||||
// The daemon exposes a JSON-RPC endpoint over a Unix socket. It owns VM
|
||||
// lifecycle, image management, guest sessions, host networking bootstrap,
|
||||
// and state persistence via internal/store.
|
||||
// lifecycle, image management, host networking bootstrap, and state
|
||||
// persistence via internal/store.
|
||||
//
|
||||
// The package is organised into cohesive groups. Pure stateless helpers for
|
||||
// each group have been lifted into subpackages; orchestrator methods
|
||||
|
|
@ -18,13 +18,9 @@
|
|||
// internal/daemon/imagemgr Image subsystem helpers: path validation,
|
||||
// artifact staging, guest provisioning script
|
||||
// generator, metadata.
|
||||
// internal/daemon/session Guest-session helpers: state paths, runner
|
||||
// / inspect / signal scripts, state snapshot
|
||||
// parsing, launch helpers, ShellQuote,
|
||||
// FormatStepError.
|
||||
// internal/daemon/workspace Workspace helpers: git repo inspection,
|
||||
// shallow copy prep, guest-side import,
|
||||
// finalize script generation.
|
||||
// finalize script generation, shell quoting.
|
||||
//
|
||||
// VM lifecycle (in this package):
|
||||
//
|
||||
|
|
@ -50,11 +46,7 @@
|
|||
//
|
||||
// Guest interaction (in this package):
|
||||
//
|
||||
// guest_sessions.go dialGuest, waitForGuestSSH, refresh/inspect
|
||||
// session_lifecycle.go Start/Stop/Kill/Get/List/signal orchestrators
|
||||
// session_attach.go BeginGuestSessionAttach + bridge/forward/watch
|
||||
// session_stream.go GuestSessionLogs, SendToGuestSession
|
||||
// session_controller.go guestSessionController, sessionRegistry
|
||||
// guest_ssh.go guestSSHClient, dialGuest, waitForGuestSSH
|
||||
// ssh_client_config.go daemon-managed SSH client key material
|
||||
// workspace.go ExportVMWorkspace, PrepareVMWorkspace
|
||||
//
|
||||
|
|
@ -73,10 +65,8 @@
|
|||
//
|
||||
// Lock ordering:
|
||||
//
|
||||
// vmLocks[id] → {createVMMu, imageOpsMu} → subsystem-local locks
|
||||
// vmLocks[id] → workspaceLocks[id] → {createVMMu, imageOpsMu} → subsystem-local locks
|
||||
//
|
||||
// Subsystem-local locks live on their owning type (tapPool.mu,
|
||||
// sessionRegistry.mu, opstate.Registry mu, guestSessionController.attachMu /
|
||||
// writeMu) and do not contend with each other. See ARCHITECTURE.md for
|
||||
// details.
|
||||
// Subsystem-local locks (tapPool.mu, opstate.Registry mu) are leaves and
|
||||
// do not contend with each other. See ARCHITECTURE.md for details.
|
||||
package daemon
|
||||
|
|
|
|||
26
internal/daemon/fake_firecracker_test.go
Normal file
26
internal/daemon/fake_firecracker_test.go
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// startFakeFirecracker launches a bash sleep-loop rewritten to match
|
||||
// the firecracker command line a real process would expose, so
|
||||
// reconcile / handle-cache paths that grep /proc/<pid>/cmdline accept
|
||||
// it as a firecracker process. Killed on test cleanup.
|
||||
func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd {
|
||||
t.Helper()
|
||||
cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("start fake firecracker: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
_, _ = cmd.Process.Wait()
|
||||
}
|
||||
})
|
||||
return cmd
|
||||
}
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"banger/internal/daemon/session"
|
||||
"banger/internal/guest"
|
||||
"banger/internal/model"
|
||||
"banger/internal/system"
|
||||
)
|
||||
|
||||
type guestSSHClient interface {
|
||||
Close() error
|
||||
RunScript(context.Context, string, io.Writer) error
|
||||
RunScriptOutput(context.Context, string) ([]byte, error)
|
||||
UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error
|
||||
StreamTar(context.Context, string, string, io.Writer) error
|
||||
StreamTarEntries(context.Context, string, []string, string, io.Writer) error
|
||||
}
|
||||
|
||||
func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval time.Duration) error {
|
||||
if d != nil && d.guestWaitForSSH != nil {
|
||||
return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
|
||||
}
|
||||
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval)
|
||||
}
|
||||
|
||||
func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) {
|
||||
if d != nil && d.guestDial != nil {
|
||||
return d.guestDial(ctx, address, d.config.SSHKeyPath)
|
||||
}
|
||||
return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
}
|
||||
|
||||
func (d *Daemon) waitForGuestSessionReadyHook(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
|
||||
if d != nil && d.waitForGuestSessionReady != nil {
|
||||
return d.waitForGuestSessionReady(ctx, vm, s)
|
||||
}
|
||||
return d.waitForGuestSessionReadyDefault(ctx, vm, s)
|
||||
}
|
||||
|
||||
func (d *Daemon) waitForGuestSessionReadyDefault(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
|
||||
for {
|
||||
updated, err := d.refreshGuestSession(ctx, vm, s)
|
||||
if err == nil {
|
||||
s = updated
|
||||
if s.GuestPID != 0 || s.ExitCode != nil || s.Status == model.GuestSessionStatusRunning || s.Status == model.GuestSessionStatusFailed || s.Status == model.GuestSessionStatusExited {
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return s, ctx.Err()
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) refreshGuestSession(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
|
||||
if s.Status != model.GuestSessionStatusStarting && s.Status != model.GuestSessionStatusRunning && s.Status != model.GuestSessionStatusStopping {
|
||||
return s, nil
|
||||
}
|
||||
snapshot, err := d.inspectGuestSessionState(ctx, vm, s)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
original := s
|
||||
session.ApplyStateSnapshot(&s, snapshot, d.vmAlive(vm))
|
||||
if session.StateChanged(original, s) {
|
||||
s.UpdatedAt = model.Now()
|
||||
if err := d.store.UpsertGuestSession(ctx, s); err != nil {
|
||||
return s, err
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) inspectGuestSessionState(ctx context.Context, vm model.VMRecord, s model.GuestSession) (session.StateSnapshot, error) {
|
||||
if d.vmAlive(vm) {
|
||||
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
if err != nil {
|
||||
return session.StateSnapshot{}, err
|
||||
}
|
||||
defer client.Close()
|
||||
var output bytes.Buffer
|
||||
if err := client.RunScript(ctx, session.InspectScript(s.ID), &output); err != nil {
|
||||
return session.StateSnapshot{}, session.FormatStepError("inspect guest session state", err, output.String())
|
||||
}
|
||||
return session.ParseState(output.String())
|
||||
}
|
||||
return d.inspectGuestSessionStateFromWorkDisk(ctx, vm, s.ID)
|
||||
}
|
||||
|
||||
func (d *Daemon) inspectGuestSessionStateFromWorkDisk(ctx context.Context, vm model.VMRecord, sessionID string) (session.StateSnapshot, error) {
|
||||
runner := d.runner
|
||||
if runner == nil {
|
||||
runner = system.NewRunner()
|
||||
}
|
||||
workMount, cleanup, err := system.MountTempDir(ctx, runner, vm.Runtime.WorkDiskPath, false)
|
||||
if err != nil {
|
||||
return session.StateSnapshot{}, err
|
||||
}
|
||||
defer cleanup()
|
||||
stateDir := filepath.Join(workMount, session.RelativeStateDir(sessionID))
|
||||
return session.InspectStateFromDir(stateDir)
|
||||
}
|
||||
|
||||
func (d *Daemon) findGuestSession(ctx context.Context, vmID, idOrName string) (model.GuestSession, error) {
|
||||
if strings.TrimSpace(idOrName) == "" {
|
||||
return model.GuestSession{}, errors.New("session id or name is required")
|
||||
}
|
||||
if s, err := d.store.GetGuestSession(ctx, vmID, idOrName); err == nil {
|
||||
return s, nil
|
||||
}
|
||||
sessions, err := d.store.ListGuestSessionsByVM(ctx, vmID)
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
matches := make([]model.GuestSession, 0, 1)
|
||||
for _, s := range sessions {
|
||||
if strings.HasPrefix(s.ID, idOrName) || strings.HasPrefix(s.Name, idOrName) {
|
||||
matches = append(matches, s)
|
||||
}
|
||||
}
|
||||
switch len(matches) {
|
||||
case 0:
|
||||
return model.GuestSession{}, fmt.Errorf("session %q not found", idOrName)
|
||||
case 1:
|
||||
return matches[0], nil
|
||||
default:
|
||||
return model.GuestSession{}, fmt.Errorf("multiple sessions match %q", idOrName)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,492 +0,0 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"banger/internal/api"
|
||||
sess "banger/internal/daemon/session"
|
||||
"banger/internal/model"
|
||||
"banger/internal/store"
|
||||
)
|
||||
|
||||
type fakeGuestSSHClient struct {
|
||||
t *testing.T
|
||||
existingDirs map[string]bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) Close() error {
|
||||
f.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error {
|
||||
f.t.Helper()
|
||||
switch {
|
||||
case strings.Contains(script, `\n`):
|
||||
return fmt.Errorf("script still contains escaped newline literals: %q", script)
|
||||
case strings.Contains(script, `echo "missing cwd: $DIR"`):
|
||||
if strings.Contains(script, "DIR='/root/repo'\n") && f.existingDirs["/root/repo"] {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("missing cwd")
|
||||
case strings.Contains(script, "check_command() {"):
|
||||
return nil
|
||||
case strings.Contains(script, `git config --global --add safe.directory "$DIR"`):
|
||||
if strings.Contains(script, "DIR='/root/repo'\n") {
|
||||
f.existingDirs["/root/repo"] = true
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("workspace finalize used unexpected guest path")
|
||||
case strings.Contains(script, "chmod -R a-w"):
|
||||
if f.existingDirs["/root/repo"] {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("workspace path missing during readonly chmod")
|
||||
case strings.Contains(script, "nohup bash "):
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) UploadFile(_ context.Context, _ string, _ os.FileMode, _ []byte, _ io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) StreamTar(_ context.Context, _ string, command string, _ io.Writer) error {
|
||||
if strings.Contains(command, "/root/repo") {
|
||||
f.existingDirs["/root/repo"] = true
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unexpected StreamTar command: %s", command)
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, command string, _ io.Writer) error {
|
||||
if strings.Contains(command, "/root/repo") {
|
||||
f.existingDirs["/root/repo"] = true
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unexpected StreamTarEntries command: %s", command)
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_HappyPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
apiSock := filepath.Join(t.TempDir(), "fc.sock")
|
||||
firecracker := startFakeFirecracker(t, apiSock)
|
||||
|
||||
vm := testVM("sendbox", "image-send", "172.16.0.88")
|
||||
vm.State = model.VMStateRunning
|
||||
vm.Runtime.State = model.VMStateRunning
|
||||
vm.Runtime.APISockPath = apiSock
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
fake := &recordingGuestSSHClient{}
|
||||
d := newSendTestDaemon(t, db, fake)
|
||||
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
|
||||
|
||||
payload := []byte(`{"type":"abort"}` + "\n")
|
||||
result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: payload,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendToGuestSession: %v", err)
|
||||
}
|
||||
if result.BytesWritten != len(payload) {
|
||||
t.Fatalf("BytesWritten = %d, want %d", result.BytesWritten, len(payload))
|
||||
}
|
||||
if result.Session.ID != session.ID {
|
||||
t.Fatalf("Session.ID = %q, want %q", result.Session.ID, session.ID)
|
||||
}
|
||||
if len(fake.uploadedFiles) != 1 {
|
||||
t.Fatalf("UploadFile call count = %d, want 1", len(fake.uploadedFiles))
|
||||
}
|
||||
for path, data := range fake.uploadedFiles {
|
||||
if !strings.HasPrefix(path, "/tmp/banger-send-") {
|
||||
t.Fatalf("upload path = %q, want /tmp/banger-send-... prefix", path)
|
||||
}
|
||||
if string(data) != string(payload) {
|
||||
t.Fatalf("upload data = %q, want %q", data, payload)
|
||||
}
|
||||
}
|
||||
if len(fake.ranScripts) != 1 {
|
||||
t.Fatalf("RunScript call count = %d, want 1", len(fake.ranScripts))
|
||||
}
|
||||
script := fake.ranScripts[0]
|
||||
pipePath := sess.StdinPipePath(session.ID)
|
||||
if !strings.Contains(script, "cat ") {
|
||||
t.Fatalf("send script missing cat command: %q", script)
|
||||
}
|
||||
if !strings.Contains(script, pipePath) {
|
||||
t.Fatalf("send script missing pipe path %q: %q", pipePath, script)
|
||||
}
|
||||
if !strings.Contains(script, "rm -f ") {
|
||||
t.Fatalf("send script missing rm cleanup: %q", script)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_EmptyPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
apiSock := filepath.Join(t.TempDir(), "fc.sock")
|
||||
firecracker := startFakeFirecracker(t, apiSock)
|
||||
|
||||
vm := testVM("sendbox-empty", "image-send", "172.16.0.89")
|
||||
vm.State = model.VMStateRunning
|
||||
vm.Runtime.State = model.VMStateRunning
|
||||
vm.Runtime.APISockPath = apiSock
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
fake := &recordingGuestSSHClient{}
|
||||
d := newSendTestDaemon(t, db, fake)
|
||||
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
|
||||
|
||||
result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: nil,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendToGuestSession(empty): %v", err)
|
||||
}
|
||||
if result.BytesWritten != 0 {
|
||||
t.Fatalf("BytesWritten = %d, want 0", result.BytesWritten)
|
||||
}
|
||||
if fake.dialCount != 0 {
|
||||
t.Fatalf("SSH dial count = %d, want 0 for empty payload", fake.dialCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_NotPipeMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
vm := testVM("sendbox-closed", "image-send", "172.16.0.90")
|
||||
vm.State = model.VMStateRunning
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinClosed, model.GuestSessionStatusRunning)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
d := &Daemon{store: db}
|
||||
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: []byte("hello\n"),
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "stdin pipe") {
|
||||
t.Fatalf("error = %v, want 'stdin pipe' error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_SessionNotRunning(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
vm := testVM("sendbox-failed", "image-send", "172.16.0.91")
|
||||
vm.State = model.VMStateRunning
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusFailed)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
d := &Daemon{store: db}
|
||||
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: []byte("hello\n"),
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "not running") {
|
||||
t.Fatalf("error = %v, want 'not running' error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_VMNotRunning(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
vm := testVM("sendbox-stopped", "image-send", "172.16.0.92")
|
||||
vm.State = model.VMStateStopped
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
d := &Daemon{store: db}
|
||||
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: []byte("hello\n"),
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "not running") {
|
||||
t.Fatalf("error = %v, want 'not running' error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// recordingGuestSSHClient captures UploadFile and RunScript calls for send tests.
|
||||
type recordingGuestSSHClient struct {
|
||||
dialCount int
|
||||
uploadedFiles map[string][]byte
|
||||
ranScripts []string
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) Close() error { return nil }
|
||||
|
||||
func (r *recordingGuestSSHClient) UploadFile(_ context.Context, path string, _ os.FileMode, data []byte, _ io.Writer) error {
|
||||
if r.uploadedFiles == nil {
|
||||
r.uploadedFiles = make(map[string][]byte)
|
||||
}
|
||||
copy := make([]byte, len(data))
|
||||
_ = copy[:len(data):len(data)]
|
||||
for i, b := range data {
|
||||
copy[i] = b
|
||||
}
|
||||
r.uploadedFiles[path] = copy
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error {
|
||||
r.ranScripts = append(r.ranScripts, script)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) StreamTar(_ context.Context, _ string, _ string, _ io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, _ string, _ io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newSendTestDaemon(t *testing.T, db *store.Store, fake *recordingGuestSSHClient) *Daemon {
|
||||
t.Helper()
|
||||
d := &Daemon{
|
||||
store: db,
|
||||
config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")},
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
}
|
||||
d.guestDial = func(_ context.Context, _ string, _ string) (guestSSHClient, error) {
|
||||
fake.dialCount++
|
||||
return fake, nil
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func testGuestSession(vmID string, stdinMode model.GuestSessionStdinMode, status model.GuestSessionStatus) model.GuestSession {
|
||||
now := model.Now()
|
||||
id := vmID + "-sess-id"
|
||||
return model.GuestSession{
|
||||
ID: id,
|
||||
VMID: vmID,
|
||||
Name: vmID + "-sess",
|
||||
Backend: sess.BackendSSH,
|
||||
Command: "pi",
|
||||
Args: []string{"--mode", "rpc"},
|
||||
CWD: "/root/repo",
|
||||
StdinMode: stdinMode,
|
||||
Status: status,
|
||||
GuestStateDir: sess.StateDir(id),
|
||||
StdoutLogPath: sess.StdoutLogPath(id),
|
||||
StderrLogPath: sess.StderrLogPath(id),
|
||||
Attachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning,
|
||||
Reattachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd {
|
||||
t.Helper()
|
||||
cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("start fake firecracker: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
_, _ = cmd.Process.Wait()
|
||||
}
|
||||
})
|
||||
return cmd
|
||||
}
|
||||
|
||||
func TestGuestSessionPreflightScriptsUseRealNewlines(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cwdScript := sess.CWDPreflightScript("/root/repo")
|
||||
if strings.Contains(cwdScript, `\n`) {
|
||||
t.Fatalf("cwd preflight script still contains escaped newline literals: %q", cwdScript)
|
||||
}
|
||||
if !strings.Contains(cwdScript, "\n") {
|
||||
t.Fatalf("cwd preflight script should contain real newlines: %q", cwdScript)
|
||||
}
|
||||
|
||||
commandScript := sess.CommandPreflightScript([]string{"git", "pi"})
|
||||
if strings.Contains(commandScript, `\n`) {
|
||||
t.Fatalf("command preflight script still contains escaped newline literals: %q", commandScript)
|
||||
}
|
||||
if !strings.Contains(commandScript, "\n") {
|
||||
t.Fatalf("command preflight script should contain real newlines: %q", commandScript)
|
||||
}
|
||||
|
||||
attachInput := sess.AttachInputCommand("session-id")
|
||||
if strings.Contains(attachInput, `\n`) {
|
||||
t.Fatalf("attach input command still contains escaped newline literals: %q", attachInput)
|
||||
}
|
||||
|
||||
attachTail := sess.AttachTailCommand("/tmp/stdout.log")
|
||||
if strings.Contains(attachTail, `\n`) {
|
||||
t.Fatalf("attach tail command still contains escaped newline literals: %q", attachTail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareWorkspaceThenStartGuestSessionPassesCWDPreflight(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
repoRoot := filepath.Join(t.TempDir(), "repo")
|
||||
if err := os.MkdirAll(repoRoot, 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(repoRoot): %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(repoRoot, "README.md"), []byte("hello\n"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(README.md): %v", err)
|
||||
}
|
||||
runGit := func(args ...string) {
|
||||
t.Helper()
|
||||
cmd := exec.Command("git", append([]string{"-C", repoRoot}, args...)...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("git %s: %v\n%s", strings.Join(args, " "), err, output)
|
||||
}
|
||||
}
|
||||
runGit("init", "-b", "main")
|
||||
runGit("config", "user.name", "Test User")
|
||||
runGit("config", "user.email", "test@example.com")
|
||||
runGit("add", ".")
|
||||
runGit("commit", "-m", "initial")
|
||||
|
||||
apiSock := filepath.Join(t.TempDir(), "fc.sock")
|
||||
firecracker := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
|
||||
if err := firecracker.Start(); err != nil {
|
||||
t.Fatalf("start fake firecracker: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if firecracker.Process != nil {
|
||||
_ = firecracker.Process.Kill()
|
||||
_, _ = firecracker.Process.Wait()
|
||||
}
|
||||
})
|
||||
|
||||
vm := testVM("pi-devbox", "image-pi", "172.16.0.77")
|
||||
vm.State = model.VMStateRunning
|
||||
vm.Runtime.State = model.VMStateRunning
|
||||
vm.Runtime.APISockPath = apiSock
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
fakeClient := &fakeGuestSSHClient{t: t, existingDirs: map[string]bool{}}
|
||||
d := &Daemon{
|
||||
store: db,
|
||||
config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")},
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
}
|
||||
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
|
||||
d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
|
||||
d.guestDial = func(context.Context, string, string) (guestSSHClient, error) { return fakeClient, nil }
|
||||
d.waitForGuestSessionReady = func(_ context.Context, _ model.VMRecord, session model.GuestSession) (model.GuestSession, error) {
|
||||
now := model.Now()
|
||||
session.Status = model.GuestSessionStatusRunning
|
||||
session.GuestPID = 4242
|
||||
session.StartedAt = now
|
||||
session.UpdatedAt = now
|
||||
session.Attachable = session.StdinMode == model.GuestSessionStdinPipe
|
||||
session.Reattachable = session.StdinMode == model.GuestSessionStdinPipe
|
||||
return session, nil
|
||||
}
|
||||
|
||||
workspace, err := d.PrepareVMWorkspace(ctx, api.VMWorkspacePrepareParams{
|
||||
IDOrName: vm.Name,
|
||||
SourcePath: repoRoot,
|
||||
GuestPath: "/root/repo",
|
||||
ReadOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PrepareVMWorkspace: %v", err)
|
||||
}
|
||||
if workspace.GuestPath != "/root/repo" {
|
||||
t.Fatalf("PrepareVMWorkspace guest path = %q, want /root/repo", workspace.GuestPath)
|
||||
}
|
||||
if !fakeClient.existingDirs["/root/repo"] {
|
||||
t.Fatalf("workspace prepare did not mark /root/repo as present in fake guest")
|
||||
}
|
||||
|
||||
session, err := d.StartGuestSession(ctx, api.GuestSessionStartParams{
|
||||
VMIDOrName: vm.Name,
|
||||
Name: "testpi",
|
||||
Command: "pi",
|
||||
Args: []string{"--mode", "rpc", "--no-session"},
|
||||
CWD: "/root/repo",
|
||||
StdinMode: string(model.GuestSessionStdinPipe),
|
||||
RequiredCommands: []string{"git"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartGuestSession: %v", err)
|
||||
}
|
||||
if session.Status != model.GuestSessionStatusRunning {
|
||||
t.Fatalf("session status = %q, want %q", session.Status, model.GuestSessionStatusRunning)
|
||||
}
|
||||
if session.LaunchStage != "" {
|
||||
t.Fatalf("session launch stage = %q, want empty", session.LaunchStage)
|
||||
}
|
||||
if session.LaunchMessage != "" {
|
||||
t.Fatalf("session launch message = %q, want empty", session.LaunchMessage)
|
||||
}
|
||||
if session.GuestPID == 0 {
|
||||
t.Fatalf("session guest pid = 0, want non-zero")
|
||||
}
|
||||
if !session.Attachable {
|
||||
t.Fatalf("session should be attachable for pipe stdin mode")
|
||||
}
|
||||
}
|
||||
35
internal/daemon/guest_ssh.go
Normal file
35
internal/daemon/guest_ssh.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"banger/internal/guest"
|
||||
)
|
||||
|
||||
// guestSSHClient is the narrow guest-SSH surface the daemon uses for
|
||||
// workspace prepare / export and ad-hoc guest interactions.
|
||||
type guestSSHClient interface {
|
||||
Close() error
|
||||
RunScript(context.Context, string, io.Writer) error
|
||||
RunScriptOutput(context.Context, string) ([]byte, error)
|
||||
UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error
|
||||
StreamTar(context.Context, string, string, io.Writer) error
|
||||
StreamTarEntries(context.Context, string, []string, string, io.Writer) error
|
||||
}
|
||||
|
||||
func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval time.Duration) error {
|
||||
if d != nil && d.guestWaitForSSH != nil {
|
||||
return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
|
||||
}
|
||||
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval)
|
||||
}
|
||||
|
||||
func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) {
|
||||
if d != nil && d.guestDial != nil {
|
||||
return d.guestDial(ctx, address, d.config.SSHKeyPath)
|
||||
}
|
||||
return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
}
|
||||
|
|
@ -26,10 +26,9 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) {
|
|||
name: "only store + closing channel (early failure)",
|
||||
build: func(t *testing.T) *Daemon {
|
||||
return &Daemon{
|
||||
store: openDaemonStore(t),
|
||||
closing: make(chan struct{}),
|
||||
sessions: newSessionRegistry(),
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
store: openDaemonStore(t),
|
||||
closing: make(chan struct{}),
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
}
|
||||
},
|
||||
verify: func(t *testing.T, d *Daemon) {
|
||||
|
|
@ -49,11 +48,10 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) {
|
|||
t.Fatalf("vmdns.New: %v", err)
|
||||
}
|
||||
return &Daemon{
|
||||
store: openDaemonStore(t),
|
||||
closing: make(chan struct{}),
|
||||
sessions: newSessionRegistry(),
|
||||
vmDNS: server,
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
store: openDaemonStore(t),
|
||||
closing: make(chan struct{}),
|
||||
vmDNS: server,
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
}
|
||||
},
|
||||
verify: func(t *testing.T, d *Daemon) {
|
||||
|
|
@ -86,11 +84,10 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) {
|
|||
// returns and also calls Close afterwards, both paths must survive.
|
||||
func TestCloseIdempotentUnderConcurrency(t *testing.T) {
|
||||
d := &Daemon{
|
||||
store: openDaemonStore(t),
|
||||
closing: make(chan struct{}),
|
||||
sessions: newSessionRegistry(),
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
config: model.DaemonConfig{BridgeName: ""},
|
||||
store: openDaemonStore(t),
|
||||
closing: make(chan struct{}),
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
config: model.DaemonConfig{BridgeName: ""},
|
||||
}
|
||||
|
||||
var count atomic.Int32
|
||||
|
|
|
|||
|
|
@ -1,521 +0,0 @@
|
|||
// Package session contains the pure helpers of the guest-session subsystem:
|
||||
// bash script generators, on-guest state path helpers, state snapshot
|
||||
// parsing, and small utilities like ShellQuote and FormatStepError.
|
||||
//
|
||||
// The orchestrator methods (StartGuestSession, BeginGuestSessionAttach,
|
||||
// etc.) stay on *daemon.Daemon and compose these helpers.
|
||||
package session
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"banger/internal/model"
|
||||
"banger/internal/system"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Constants shared between orchestration and helpers.
|
||||
const (
|
||||
BackendSSH = "ssh"
|
||||
AttachBackendNone = "none"
|
||||
AttachBackendSSHBridge = "ssh_rehydratable"
|
||||
AttachModeExclusive = "exclusive"
|
||||
TransportUnixSocket = "unix_socket"
|
||||
StateRoot = "/root/.local/state/banger/sessions"
|
||||
LogTailLineDefault = 200
|
||||
)
|
||||
|
||||
// StateSnapshot is the decoded per-session state as read from the guest.
|
||||
type StateSnapshot struct {
|
||||
Status string
|
||||
GuestPID int
|
||||
ExitCode *int
|
||||
Alive bool
|
||||
LastError string
|
||||
}
|
||||
|
||||
// -- Guest filesystem paths -------------------------------------------------
|
||||
|
||||
func StateDir(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateRoot, id))
|
||||
}
|
||||
|
||||
func RelativeStateDir(id string) string {
|
||||
return strings.TrimPrefix(StateDir(id), "/root/")
|
||||
}
|
||||
|
||||
func ScriptPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "run.sh")) }
|
||||
func PIDPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "pid")) }
|
||||
func MonitorPIDPath(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateDir(id), "monitor_pid"))
|
||||
}
|
||||
func ExitCodePath(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateDir(id), "exit_code"))
|
||||
}
|
||||
func StdinPipePath(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateDir(id), "stdin.pipe"))
|
||||
}
|
||||
func StdinKeepalivePIDPath(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateDir(id), "stdin_keepalive.pid"))
|
||||
}
|
||||
func StatusPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "status")) }
|
||||
func ErrorPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "error")) }
|
||||
func StdoutLogPath(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateDir(id), "stdout.log"))
|
||||
}
|
||||
func StderrLogPath(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateDir(id), "stderr.log"))
|
||||
}
|
||||
|
||||
// -- Script generators ------------------------------------------------------
|
||||
|
||||
// Script returns the bash runner installed into the guest for session. It
|
||||
// sets up state/log paths, optional stdin fifo, and wait-loop around the
|
||||
// user command.
|
||||
func Script(sess model.GuestSession) string {
|
||||
var script strings.Builder
|
||||
script.WriteString("set -euo pipefail\n")
|
||||
fmt.Fprintf(&script, "STATE_DIR=%s\n", ShellQuote(sess.GuestStateDir))
|
||||
fmt.Fprintf(&script, "STDOUT_LOG=%s\n", ShellQuote(sess.StdoutLogPath))
|
||||
fmt.Fprintf(&script, "STDERR_LOG=%s\n", ShellQuote(sess.StderrLogPath))
|
||||
fmt.Fprintf(&script, "PID_FILE=%s\n", ShellQuote(PIDPath(sess.ID)))
|
||||
fmt.Fprintf(&script, "MONITOR_PID_FILE=%s\n", ShellQuote(MonitorPIDPath(sess.ID)))
|
||||
fmt.Fprintf(&script, "EXIT_FILE=%s\n", ShellQuote(ExitCodePath(sess.ID)))
|
||||
fmt.Fprintf(&script, "STATUS_FILE=%s\n", ShellQuote(StatusPath(sess.ID)))
|
||||
fmt.Fprintf(&script, "ERROR_FILE=%s\n", ShellQuote(ErrorPath(sess.ID)))
|
||||
fmt.Fprintf(&script, "STDIN_PIPE=%s\n", ShellQuote(StdinPipePath(sess.ID)))
|
||||
fmt.Fprintf(&script, "STDIN_KEEPALIVE_PID_FILE=%s\n", ShellQuote(StdinKeepalivePIDPath(sess.ID)))
|
||||
fmt.Fprintf(&script, "SESSION_CWD=%s\n", ShellQuote(DefaultCWD(sess.CWD)))
|
||||
script.WriteString("mkdir -p \"$STATE_DIR\"\n")
|
||||
script.WriteString(": >\"$STDOUT_LOG\"\n")
|
||||
script.WriteString(": >\"$STDERR_LOG\"\n")
|
||||
script.WriteString("rm -f \"$EXIT_FILE\" \"$ERROR_FILE\" \"$STDIN_KEEPALIVE_PID_FILE\"\n")
|
||||
if sess.StdinMode == model.GuestSessionStdinPipe {
|
||||
script.WriteString("rm -f \"$STDIN_PIPE\"\n")
|
||||
script.WriteString("mkfifo -m 600 \"$STDIN_PIPE\"\n")
|
||||
}
|
||||
script.WriteString("printf '%s\\n' \"${BASHPID:-$$}\" >\"$MONITOR_PID_FILE\"\n")
|
||||
script.WriteString("printf 'starting\\n' >\"$STATUS_FILE\"\n")
|
||||
script.WriteString("cd \"$SESSION_CWD\"\n")
|
||||
script.WriteString("exec > >(tee -a \"$STDOUT_LOG\") 2> >(tee -a \"$STDERR_LOG\" >&2)\n")
|
||||
for _, line := range EnvLines(sess.Env) {
|
||||
script.WriteString(line)
|
||||
script.WriteByte('\n')
|
||||
}
|
||||
script.WriteString("COMMAND=(")
|
||||
for _, value := range append([]string{sess.Command}, sess.Args...) {
|
||||
script.WriteByte(' ')
|
||||
script.WriteString(ShellQuote(value))
|
||||
}
|
||||
script.WriteString(" )\n")
|
||||
if sess.StdinMode == model.GuestSessionStdinPipe {
|
||||
script.WriteString("( while :; do sleep 3600; done ) >\"$STDIN_PIPE\" &\n")
|
||||
script.WriteString("keepalive=$!\n")
|
||||
script.WriteString("printf '%s\\n' \"$keepalive\" >\"$STDIN_KEEPALIVE_PID_FILE\"\n")
|
||||
script.WriteString("\"${COMMAND[@]}\" <\"$STDIN_PIPE\" &\n")
|
||||
} else {
|
||||
script.WriteString("\"${COMMAND[@]}\" &\n")
|
||||
}
|
||||
script.WriteString("child=$!\n")
|
||||
script.WriteString("printf '%s\\n' \"$child\" >\"$PID_FILE\"\n")
|
||||
script.WriteString("printf 'running\\n' >\"$STATUS_FILE\"\n")
|
||||
script.WriteString("wait \"$child\"\n")
|
||||
script.WriteString("rc=$?\n")
|
||||
if sess.StdinMode == model.GuestSessionStdinPipe {
|
||||
script.WriteString("if [ -f \"$STDIN_KEEPALIVE_PID_FILE\" ]; then kill \"$(cat \"$STDIN_KEEPALIVE_PID_FILE\")\" 2>/dev/null || true; fi\n")
|
||||
}
|
||||
script.WriteString("printf '%s\\n' \"$rc\" >\"$EXIT_FILE\"\n")
|
||||
script.WriteString("if [ \"$rc\" -eq 0 ]; then printf 'exited\\n' >\"$STATUS_FILE\"; else printf 'failed\\n' >\"$STATUS_FILE\"; fi\n")
|
||||
script.WriteString("exit \"$rc\"\n")
|
||||
return script.String()
|
||||
}
|
||||
|
||||
// InspectScript reads the on-guest state files for sessionID and prints a
|
||||
// key=value block parseable by ParseState.
|
||||
func InspectScript(sessionID string) string {
|
||||
var script strings.Builder
|
||||
script.WriteString("set -euo pipefail\n")
|
||||
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(StateDir(sessionID)))
|
||||
script.WriteString("status=''\n")
|
||||
script.WriteString("pid=''\n")
|
||||
script.WriteString("exit_code=''\n")
|
||||
script.WriteString("last_error=''\n")
|
||||
script.WriteString("alive=false\n")
|
||||
script.WriteString("[ -f \"$DIR/status\" ] && status=\"$(cat \"$DIR/status\")\"\n")
|
||||
script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n")
|
||||
script.WriteString("[ -f \"$DIR/exit_code\" ] && exit_code=\"$(cat \"$DIR/exit_code\")\"\n")
|
||||
script.WriteString("[ -f \"$DIR/error\" ] && last_error=\"$(cat \"$DIR/error\")\"\n")
|
||||
script.WriteString("if [ -n \"$pid\" ] && kill -0 \"$pid\" 2>/dev/null; then alive=true; fi\n")
|
||||
script.WriteString("printf 'status=%s\\n' \"$status\"\n")
|
||||
script.WriteString("printf 'pid=%s\\n' \"$pid\"\n")
|
||||
script.WriteString("printf 'exit=%s\\n' \"$exit_code\"\n")
|
||||
script.WriteString("printf 'alive=%s\\n' \"$alive\"\n")
|
||||
script.WriteString("printf 'error=%s\\n' \"$last_error\"\n")
|
||||
return script.String()
|
||||
}
|
||||
|
||||
// SignalScript sends signal to sessionID's runner and monitor processes.
|
||||
func SignalScript(sessionID, signal string) string {
|
||||
var script strings.Builder
|
||||
script.WriteString("set -euo pipefail\n")
|
||||
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(StateDir(sessionID)))
|
||||
fmt.Fprintf(&script, "SIGNAL=%s\n", ShellQuote(signal))
|
||||
script.WriteString("pid=''\n")
|
||||
script.WriteString("monitor=''\n")
|
||||
script.WriteString("keepalive=''\n")
|
||||
script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n")
|
||||
script.WriteString("[ -f \"$DIR/monitor_pid\" ] && monitor=\"$(cat \"$DIR/monitor_pid\")\"\n")
|
||||
script.WriteString("[ -f \"$DIR/stdin_keepalive.pid\" ] && keepalive=\"$(cat \"$DIR/stdin_keepalive.pid\")\"\n")
|
||||
script.WriteString("printf 'stopping\\n' >\"$DIR/status\"\n")
|
||||
script.WriteString("if [ -n \"$pid\" ]; then kill -${SIGNAL} \"$pid\" 2>/dev/null || true; fi\n")
|
||||
script.WriteString("if [ -n \"$monitor\" ]; then kill -${SIGNAL} \"$monitor\" 2>/dev/null || true; fi\n")
|
||||
script.WriteString("if [ -n \"$keepalive\" ]; then kill -${SIGNAL} \"$keepalive\" 2>/dev/null || true; fi\n")
|
||||
return script.String()
|
||||
}
|
||||
|
||||
// CWDPreflightScript verifies cwd exists on the guest.
|
||||
func CWDPreflightScript(cwd string) string {
|
||||
var script strings.Builder
|
||||
script.WriteString("set -euo pipefail\n")
|
||||
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(DefaultCWD(cwd)))
|
||||
script.WriteString("if [ ! -d \"$DIR\" ]; then echo \"missing cwd: $DIR\"; exit 1; fi\n")
|
||||
return script.String()
|
||||
}
|
||||
|
||||
// CommandPreflightScript verifies each command is resolvable on the guest.
|
||||
func CommandPreflightScript(commands []string) string {
|
||||
var script strings.Builder
|
||||
script.WriteString("set -euo pipefail\n")
|
||||
script.WriteString("check_command() {\n")
|
||||
script.WriteString(" cmd=\"$1\"\n")
|
||||
script.WriteString(" case \"$cmd\" in\n")
|
||||
script.WriteString(" */*) [ -x \"$cmd\" ] || { echo \"missing command: $cmd\"; exit 1; } ;;\n")
|
||||
script.WriteString(" *) command -v \"$cmd\" >/dev/null 2>&1 || { echo \"missing command: $cmd\"; exit 1; } ;;\n")
|
||||
script.WriteString(" esac\n")
|
||||
script.WriteString("}\n")
|
||||
for _, command := range commands {
|
||||
fmt.Fprintf(&script, "check_command %s\n", ShellQuote(command))
|
||||
}
|
||||
return script.String()
|
||||
}
|
||||
|
||||
// AttachInputCommand returns the guest command that creates/opens the stdin
|
||||
// fifo for sessionID and cats attach-side bytes into it.
|
||||
func AttachInputCommand(sessionID string) string {
|
||||
path := StdinPipePath(sessionID)
|
||||
return "bash -lc " + ShellQuote(fmt.Sprintf("set -euo pipefail\n[ -p %s ] || mkfifo -m 600 %s\nexec cat > %s\n", ShellQuote(path), ShellQuote(path), ShellQuote(path)))
|
||||
}
|
||||
|
||||
// AttachTailCommand returns the guest command that tails a log file and
|
||||
// streams new content back to the attach bridge.
|
||||
func AttachTailCommand(path string) string {
|
||||
return "bash -lc " + ShellQuote(fmt.Sprintf("set -euo pipefail\ntouch %s\nexec tail -n 0 -F %s 2>/dev/null\n", ShellQuote(path), ShellQuote(path)))
|
||||
}
|
||||
|
||||
// EnvLines returns deterministic `export KEY=value` lines for the session
|
||||
// launcher, ordered by key.
|
||||
func EnvLines(values map[string]string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
lines := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
lines = append(lines, "export "+key+"="+ShellQuote(values[key]))
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
// -- State snapshot helpers -------------------------------------------------
|
||||
|
||||
// ParseState decodes the key=value output produced by InspectScript.
|
||||
func ParseState(raw string) (StateSnapshot, error) {
|
||||
var snapshot StateSnapshot
|
||||
scanner := bufio.NewScanner(strings.NewReader(raw))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
key, value, ok := strings.Cut(line, "=")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch strings.TrimSpace(key) {
|
||||
case "status":
|
||||
snapshot.Status = strings.TrimSpace(value)
|
||||
case "pid":
|
||||
if pid, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
|
||||
snapshot.GuestPID = pid
|
||||
}
|
||||
case "exit":
|
||||
if exitCode, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
|
||||
snapshot.ExitCode = &exitCode
|
||||
}
|
||||
case "alive":
|
||||
snapshot.Alive = strings.TrimSpace(value) == "true"
|
||||
case "error":
|
||||
snapshot.LastError = strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return snapshot, scanner.Err()
|
||||
}
|
||||
|
||||
// InspectStateFromDir reads the state files directly from stateDir (used
|
||||
// when the guest is offline and we can mount the work disk from the host).
|
||||
func InspectStateFromDir(stateDir string) (StateSnapshot, error) {
|
||||
var snapshot StateSnapshot
|
||||
statusData, _ := os.ReadFile(filepath.Join(stateDir, "status"))
|
||||
snapshot.Status = strings.TrimSpace(string(statusData))
|
||||
pidData, _ := os.ReadFile(filepath.Join(stateDir, "pid"))
|
||||
if pidValue, err := strconv.Atoi(strings.TrimSpace(string(pidData))); err == nil {
|
||||
snapshot.GuestPID = pidValue
|
||||
}
|
||||
exitData, _ := os.ReadFile(filepath.Join(stateDir, "exit_code"))
|
||||
if exitValue, err := strconv.Atoi(strings.TrimSpace(string(exitData))); err == nil {
|
||||
snapshot.ExitCode = &exitValue
|
||||
}
|
||||
errorData, _ := os.ReadFile(filepath.Join(stateDir, "error"))
|
||||
snapshot.LastError = strings.TrimSpace(string(errorData))
|
||||
if snapshot.GuestPID != 0 {
|
||||
snapshot.Alive = ProcessAlive(snapshot.GuestPID)
|
||||
}
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
// ApplyStateSnapshot mutates sess in place to reflect snapshot. vmRunning
|
||||
// captures whether the VM is currently up so stale in-flight sessions can be
|
||||
// failed when the VM is gone.
|
||||
func ApplyStateSnapshot(sess *model.GuestSession, snapshot StateSnapshot, vmRunning bool) {
|
||||
if sess == nil {
|
||||
return
|
||||
}
|
||||
if snapshot.GuestPID != 0 {
|
||||
sess.GuestPID = snapshot.GuestPID
|
||||
}
|
||||
if snapshot.LastError != "" {
|
||||
sess.LastError = snapshot.LastError
|
||||
}
|
||||
if snapshot.ExitCode != nil {
|
||||
sess.ExitCode = snapshot.ExitCode
|
||||
sess.Attachable = false
|
||||
sess.Reattachable = false
|
||||
if sess.StartedAt.IsZero() {
|
||||
sess.StartedAt = model.Now()
|
||||
}
|
||||
if sess.EndedAt.IsZero() {
|
||||
sess.EndedAt = model.Now()
|
||||
}
|
||||
if *snapshot.ExitCode == 0 {
|
||||
sess.Status = model.GuestSessionStatusExited
|
||||
} else {
|
||||
sess.Status = model.GuestSessionStatusFailed
|
||||
}
|
||||
return
|
||||
}
|
||||
if snapshot.Alive {
|
||||
if sess.StartedAt.IsZero() {
|
||||
sess.StartedAt = model.Now()
|
||||
}
|
||||
sess.Status = model.GuestSessionStatusRunning
|
||||
return
|
||||
}
|
||||
if !vmRunning && (sess.Status == model.GuestSessionStatusStarting || sess.Status == model.GuestSessionStatusRunning || sess.Status == model.GuestSessionStatusStopping) {
|
||||
sess.Status = model.GuestSessionStatusFailed
|
||||
sess.Attachable = false
|
||||
sess.Reattachable = false
|
||||
if sess.LastError == "" {
|
||||
sess.LastError = "vm is not running"
|
||||
}
|
||||
if sess.EndedAt.IsZero() {
|
||||
sess.EndedAt = model.Now()
|
||||
}
|
||||
return
|
||||
}
|
||||
if snapshot.Status == string(model.GuestSessionStatusRunning) {
|
||||
if sess.StartedAt.IsZero() {
|
||||
sess.StartedAt = model.Now()
|
||||
}
|
||||
sess.Status = model.GuestSessionStatusRunning
|
||||
}
|
||||
if sess.Status == model.GuestSessionStatusRunning && sess.StdinMode == model.GuestSessionStdinPipe {
|
||||
sess.Attachable = true
|
||||
sess.Reattachable = true
|
||||
if sess.AttachBackend == "" {
|
||||
sess.AttachBackend = AttachBackendSSHBridge
|
||||
}
|
||||
if sess.AttachMode == "" {
|
||||
sess.AttachMode = AttachModeExclusive
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StateChanged reports whether any materially observable field differs
|
||||
// between before and after, guiding whether to persist an update.
|
||||
func StateChanged(before, after model.GuestSession) bool {
|
||||
if before.Status != after.Status || before.GuestPID != after.GuestPID || before.LastError != after.LastError || before.Attachable != after.Attachable || before.Reattachable != after.Reattachable || before.AttachBackend != after.AttachBackend || before.AttachMode != after.AttachMode || before.LaunchStage != after.LaunchStage || before.LaunchMessage != after.LaunchMessage || before.LaunchRawLog != after.LaunchRawLog {
|
||||
return true
|
||||
}
|
||||
if before.StartedAt != after.StartedAt || before.EndedAt != after.EndedAt {
|
||||
return true
|
||||
}
|
||||
switch {
|
||||
case before.ExitCode == nil && after.ExitCode == nil:
|
||||
return false
|
||||
case before.ExitCode == nil || after.ExitCode == nil:
|
||||
return true
|
||||
default:
|
||||
return *before.ExitCode != *after.ExitCode
|
||||
}
|
||||
}
|
||||
|
||||
// -- Launch helpers ---------------------------------------------------------
|
||||
|
||||
// DefaultName returns a friendly session name: caller-provided if non-empty,
|
||||
// otherwise `<command-base>-<short-id>`.
|
||||
func DefaultName(id, command, explicit string) string {
|
||||
if trimmed := strings.TrimSpace(explicit); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
base := filepath.Base(strings.TrimSpace(command))
|
||||
if base == "." || base == string(filepath.Separator) || base == "" {
|
||||
base = "session"
|
||||
}
|
||||
return base + "-" + system.ShortID(id)
|
||||
}
|
||||
|
||||
// DefaultCWD returns value if non-empty, else /root.
|
||||
func DefaultCWD(value string) string {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
return "/root"
|
||||
}
|
||||
|
||||
// FailLaunch annotates sess as launch-failed with stage/message/raw log and
|
||||
// returns it for persistence.
|
||||
func FailLaunch(sess model.GuestSession, stage, message, rawLog string) model.GuestSession {
|
||||
now := model.Now()
|
||||
sess.Status = model.GuestSessionStatusFailed
|
||||
sess.LastError = strings.TrimSpace(message)
|
||||
sess.Attachable = false
|
||||
sess.Reattachable = false
|
||||
sess.LaunchStage = strings.TrimSpace(stage)
|
||||
sess.LaunchMessage = strings.TrimSpace(message)
|
||||
sess.LaunchRawLog = strings.TrimSpace(rawLog)
|
||||
sess.UpdatedAt = now
|
||||
sess.EndedAt = now
|
||||
return sess
|
||||
}
|
||||
|
||||
// NormalizeRequiredCommands returns a de-duplicated, order-preserving list
|
||||
// of required commands, with the session command first.
|
||||
func NormalizeRequiredCommands(command string, extras []string) []string {
|
||||
ordered := make([]string, 0, len(extras)+1)
|
||||
seen := map[string]struct{}{}
|
||||
appendValue := func(value string) {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[trimmed]; ok {
|
||||
return
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
ordered = append(ordered, trimmed)
|
||||
}
|
||||
appendValue(command)
|
||||
for _, extra := range extras {
|
||||
appendValue(extra)
|
||||
}
|
||||
return ordered
|
||||
}
|
||||
|
||||
// -- Small utilities --------------------------------------------------------
|
||||
|
||||
// ShellQuote returns value single-quoted for bash, escaping embedded quotes.
|
||||
func ShellQuote(value string) string {
|
||||
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
|
||||
}
|
||||
|
||||
// ExitCode extracts the exit status from an ssh.ExitError, returning
|
||||
// (0, true) for nil errors.
|
||||
func ExitCode(err error) (int, bool) {
|
||||
if err == nil {
|
||||
return 0, true
|
||||
}
|
||||
var exitErr *ssh.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
return exitErr.ExitStatus(), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// CloneStringMap returns a shallow copy of values, or nil if empty.
|
||||
func CloneStringMap(values map[string]string) map[string]string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make(map[string]string, len(values))
|
||||
for key, value := range values {
|
||||
cloned[key] = value
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
// TailFileContent returns the last N lines of a file, or "" if the file is
|
||||
// missing.
|
||||
func TailFileContent(path string, lines int) (string, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if lines <= 0 {
|
||||
return string(data), nil
|
||||
}
|
||||
parts := strings.Split(string(data), "\n")
|
||||
if len(parts) <= lines {
|
||||
return string(data), nil
|
||||
}
|
||||
return strings.Join(parts[len(parts)-lines-1:], "\n"), nil
|
||||
}
|
||||
|
||||
// ProcessAlive returns true if the process with pid exists. The syscallKill
|
||||
// override is exposed for tests that need to simulate alive/dead processes.
|
||||
func ProcessAlive(pid int) bool {
|
||||
if pid <= 0 {
|
||||
return false
|
||||
}
|
||||
return syscallKill(pid, syscall.Signal(0)) == nil
|
||||
}
|
||||
|
||||
// syscallKill is a test seam: tests replace it to stub process-alive checks.
|
||||
var syscallKill = func(pid int, signal os.Signal) error {
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return proc.Signal(signal)
|
||||
}
|
||||
|
||||
// FormatStepError wraps err with an action label and trimmed on-guest log.
|
||||
func FormatStepError(action string, err error, log string) error {
|
||||
log = strings.TrimSpace(log)
|
||||
if log == "" {
|
||||
return fmt.Errorf("%s: %w", action, err)
|
||||
}
|
||||
return fmt.Errorf("%s: %w: %s", action, err, log)
|
||||
}
|
||||
|
|
@ -1,440 +0,0 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"banger/internal/model"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestRelativeStateDir(t *testing.T) {
|
||||
got := RelativeStateDir("abc")
|
||||
if strings.HasPrefix(got, "/root/") {
|
||||
t.Fatalf("RelativeStateDir(%q) = %q, should strip /root/ prefix", "abc", got)
|
||||
}
|
||||
if !strings.Contains(got, "abc") {
|
||||
t.Fatalf("missing session id in %q", got)
|
||||
}
|
||||
absolute := StateDir("abc")
|
||||
if got != strings.TrimPrefix(absolute, "/root/") {
|
||||
t.Fatalf("relative = %q, want %q", got, strings.TrimPrefix(absolute, "/root/"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCWD(t *testing.T) {
|
||||
if DefaultCWD("") != "/root" {
|
||||
t.Error("empty should return /root")
|
||||
}
|
||||
if DefaultCWD(" ") != "/root" {
|
||||
t.Error("whitespace should return /root")
|
||||
}
|
||||
if DefaultCWD("/work") != "/work" {
|
||||
t.Error("explicit should pass through")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellQuote(t *testing.T) {
|
||||
if got := ShellQuote(""); got != "''" {
|
||||
t.Errorf("empty: got %q, want ''", got)
|
||||
}
|
||||
if got := ShellQuote("x"); got != "'x'" {
|
||||
t.Errorf("plain: got %q", got)
|
||||
}
|
||||
if got := ShellQuote("it's"); got != `'it'"'"'s'` {
|
||||
t.Errorf("apostrophe: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitCode(t *testing.T) {
|
||||
if code, ok := ExitCode(nil); !ok || code != 0 {
|
||||
t.Errorf("nil err: got (%d, %v), want (0, true)", code, ok)
|
||||
}
|
||||
// Build an ssh.ExitError using its real type — can't hand-construct,
|
||||
// so wrap via errors.As check with a stub.
|
||||
raw := &ssh.ExitError{}
|
||||
if _, ok := ExitCode(raw); !ok {
|
||||
t.Error("ssh.ExitError: ok should be true")
|
||||
}
|
||||
if _, ok := ExitCode(errors.New("bare error")); ok {
|
||||
t.Error("bare error: ok should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneStringMap(t *testing.T) {
|
||||
if CloneStringMap(nil) != nil {
|
||||
t.Error("nil in → nil out")
|
||||
}
|
||||
if CloneStringMap(map[string]string{}) != nil {
|
||||
t.Error("empty in → nil out")
|
||||
}
|
||||
src := map[string]string{"a": "1", "b": "2"}
|
||||
cloned := CloneStringMap(src)
|
||||
if len(cloned) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(cloned))
|
||||
}
|
||||
cloned["a"] = "changed"
|
||||
if src["a"] != "1" {
|
||||
t.Error("mutating clone leaked back to source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTailFileContent(t *testing.T) {
|
||||
// Missing file → empty, no error.
|
||||
got, err := TailFileContent(filepath.Join(t.TempDir(), "missing"), 10)
|
||||
if err != nil || got != "" {
|
||||
t.Errorf("missing: got (%q, %v), want ('', nil)", got, err)
|
||||
}
|
||||
|
||||
path := filepath.Join(t.TempDir(), "log")
|
||||
lines := "one\ntwo\nthree\nfour\nfive"
|
||||
if err := os.WriteFile(path, []byte(lines), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
|
||||
full, err := TailFileContent(path, 0)
|
||||
if err != nil || full != lines {
|
||||
t.Errorf("0 lines: got (%q, %v), want (%q, nil)", full, err, lines)
|
||||
}
|
||||
|
||||
// Request more lines than exist → full content.
|
||||
all, err := TailFileContent(path, 999)
|
||||
if err != nil || all != lines {
|
||||
t.Errorf("999 lines: got %q", all)
|
||||
}
|
||||
|
||||
last2, err := TailFileContent(path, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("2 lines: %v", err)
|
||||
}
|
||||
if !strings.Contains(last2, "five") {
|
||||
t.Errorf("2 lines missing last line: %q", last2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessAlive(t *testing.T) {
|
||||
if ProcessAlive(0) {
|
||||
t.Error("pid 0 should not be alive")
|
||||
}
|
||||
if ProcessAlive(-1) {
|
||||
t.Error("negative pid should not be alive")
|
||||
}
|
||||
// Swap the syscall seam.
|
||||
original := syscallKill
|
||||
t.Cleanup(func() { syscallKill = original })
|
||||
|
||||
syscallKill = func(pid int, signal os.Signal) error { return nil }
|
||||
if !ProcessAlive(42) {
|
||||
t.Error("syscallKill=nil should report alive")
|
||||
}
|
||||
|
||||
syscallKill = func(pid int, signal os.Signal) error { return fmt.Errorf("no such process") }
|
||||
if ProcessAlive(42) {
|
||||
t.Error("syscallKill error should report dead")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatStepError(t *testing.T) {
|
||||
base := errors.New("boom")
|
||||
err := FormatStepError("prepare", base, "")
|
||||
if !errors.Is(err, base) {
|
||||
t.Error("FormatStepError should wrap the base error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "prepare") {
|
||||
t.Errorf("missing action: %v", err)
|
||||
}
|
||||
|
||||
errWithLog := FormatStepError("prepare", base, " log line\n")
|
||||
if !strings.Contains(errWithLog.Error(), "log line") {
|
||||
t.Errorf("missing log: %v", errWithLog)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStateHappyPath(t *testing.T) {
|
||||
raw := `status=running
|
||||
pid=123
|
||||
exit=
|
||||
alive=true
|
||||
error=
|
||||
`
|
||||
snap, err := ParseState(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseState: %v", err)
|
||||
}
|
||||
if snap.Status != "running" {
|
||||
t.Errorf("Status = %q", snap.Status)
|
||||
}
|
||||
if snap.GuestPID != 123 {
|
||||
t.Errorf("GuestPID = %d", snap.GuestPID)
|
||||
}
|
||||
if snap.ExitCode != nil {
|
||||
t.Errorf("ExitCode should be nil when empty, got %v", snap.ExitCode)
|
||||
}
|
||||
if !snap.Alive {
|
||||
t.Error("Alive should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStateWithExit(t *testing.T) {
|
||||
raw := `status=exited
|
||||
pid=123
|
||||
exit=7
|
||||
alive=false
|
||||
error=something bad
|
||||
`
|
||||
snap, err := ParseState(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseState: %v", err)
|
||||
}
|
||||
if snap.ExitCode == nil || *snap.ExitCode != 7 {
|
||||
t.Errorf("ExitCode = %v, want 7", snap.ExitCode)
|
||||
}
|
||||
if snap.LastError != "something bad" {
|
||||
t.Errorf("LastError = %q", snap.LastError)
|
||||
}
|
||||
if snap.Alive {
|
||||
t.Error("Alive should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStateIgnoresMalformedLines(t *testing.T) {
|
||||
raw := "no-equals-here\nstatus=ok\n"
|
||||
snap, err := ParseState(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseState: %v", err)
|
||||
}
|
||||
if snap.Status != "ok" {
|
||||
t.Errorf("Status = %q, want ok", snap.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectStateFromDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeFile := func(name, content string) {
|
||||
if err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile(%s): %v", name, err)
|
||||
}
|
||||
}
|
||||
writeFile("status", "running\n")
|
||||
writeFile("pid", "42\n")
|
||||
writeFile("exit_code", "0\n")
|
||||
writeFile("error", "\n")
|
||||
|
||||
original := syscallKill
|
||||
t.Cleanup(func() { syscallKill = original })
|
||||
syscallKill = func(pid int, signal os.Signal) error { return nil }
|
||||
|
||||
snap, err := InspectStateFromDir(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("InspectStateFromDir: %v", err)
|
||||
}
|
||||
if snap.Status != "running" {
|
||||
t.Errorf("Status = %q", snap.Status)
|
||||
}
|
||||
if snap.GuestPID != 42 {
|
||||
t.Errorf("GuestPID = %d", snap.GuestPID)
|
||||
}
|
||||
if snap.ExitCode == nil || *snap.ExitCode != 0 {
|
||||
t.Errorf("ExitCode = %v, want 0", snap.ExitCode)
|
||||
}
|
||||
if !snap.Alive {
|
||||
t.Error("Alive should reflect syscallKill result (true)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectStateFromDirMissingFiles(t *testing.T) {
|
||||
snap, err := InspectStateFromDir(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("InspectStateFromDir (empty): %v", err)
|
||||
}
|
||||
if snap.Status != "" || snap.GuestPID != 0 || snap.ExitCode != nil {
|
||||
t.Errorf("empty dir: snap = %+v", snap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotNilReceiver(t *testing.T) {
|
||||
ApplyStateSnapshot(nil, StateSnapshot{}, true) // should not panic
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotExitedSuccess(t *testing.T) {
|
||||
exit := 0
|
||||
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning, Attachable: true, Reattachable: true}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true)
|
||||
if sess.Status != model.GuestSessionStatusExited {
|
||||
t.Errorf("Status = %q, want exited", sess.Status)
|
||||
}
|
||||
if sess.Attachable || sess.Reattachable {
|
||||
t.Error("attach flags should be cleared on exit")
|
||||
}
|
||||
if sess.EndedAt.IsZero() {
|
||||
t.Error("EndedAt should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotExitedFailure(t *testing.T) {
|
||||
exit := 2
|
||||
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true)
|
||||
if sess.Status != model.GuestSessionStatusFailed {
|
||||
t.Errorf("Status = %q, want failed", sess.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotVMGone(t *testing.T) {
|
||||
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{Alive: false}, false)
|
||||
if sess.Status != model.GuestSessionStatusFailed {
|
||||
t.Errorf("Status = %q, want failed", sess.Status)
|
||||
}
|
||||
if sess.LastError == "" {
|
||||
t.Error("LastError should be populated when VM is gone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotRunningStatusSetsAttachableForPipe(t *testing.T) {
|
||||
// When the guest-side status file reports "running" (Alive=false from
|
||||
// kill -0 may still fail transiently), ApplyStateSnapshot transitions
|
||||
// the session to running and sets attach flags for pipe-mode.
|
||||
sess := &model.GuestSession{
|
||||
Status: model.GuestSessionStatusStarting,
|
||||
StdinMode: model.GuestSessionStdinPipe,
|
||||
}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{Status: string(model.GuestSessionStatusRunning), GuestPID: 11}, true)
|
||||
if sess.Status != model.GuestSessionStatusRunning {
|
||||
t.Errorf("Status = %q, want running", sess.Status)
|
||||
}
|
||||
if !sess.Attachable || !sess.Reattachable {
|
||||
t.Error("pipe-mode running session should be attachable + reattachable")
|
||||
}
|
||||
if sess.AttachBackend != AttachBackendSSHBridge {
|
||||
t.Errorf("AttachBackend = %q, want %q", sess.AttachBackend, AttachBackendSSHBridge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotAliveEarlyReturn(t *testing.T) {
|
||||
// Alive-true returns immediately after setting status; no attach
|
||||
// flags set on this path (by design — attach metadata only attaches
|
||||
// to status-driven transitions).
|
||||
sess := &model.GuestSession{
|
||||
Status: model.GuestSessionStatusStarting,
|
||||
StdinMode: model.GuestSessionStdinPipe,
|
||||
}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{Alive: true, GuestPID: 11}, true)
|
||||
if sess.Status != model.GuestSessionStatusRunning {
|
||||
t.Errorf("Status = %q, want running", sess.Status)
|
||||
}
|
||||
if sess.StartedAt.IsZero() {
|
||||
t.Error("StartedAt should have been set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateChanged(t *testing.T) {
|
||||
base := model.GuestSession{Status: model.GuestSessionStatusRunning, GuestPID: 10}
|
||||
|
||||
// Identical → no change.
|
||||
if StateChanged(base, base) {
|
||||
t.Error("identical states should not be considered changed")
|
||||
}
|
||||
|
||||
// Status change.
|
||||
changed := base
|
||||
changed.Status = model.GuestSessionStatusExited
|
||||
if !StateChanged(base, changed) {
|
||||
t.Error("status change should be detected")
|
||||
}
|
||||
|
||||
// ExitCode change from nil → value.
|
||||
exit := 3
|
||||
changed = base
|
||||
changed.ExitCode = &exit
|
||||
if !StateChanged(base, changed) {
|
||||
t.Error("exit-code appearing should be detected")
|
||||
}
|
||||
|
||||
// Both have the same exit code → no change.
|
||||
a := base
|
||||
a.ExitCode = &exit
|
||||
b := base
|
||||
b.ExitCode = &exit
|
||||
if StateChanged(a, b) {
|
||||
t.Error("matching exit codes should not trigger change")
|
||||
}
|
||||
|
||||
// Different exit codes.
|
||||
other := 5
|
||||
b.ExitCode = &other
|
||||
if !StateChanged(a, b) {
|
||||
t.Error("differing exit codes should be detected")
|
||||
}
|
||||
|
||||
// Timestamp change.
|
||||
changed = base
|
||||
changed.StartedAt = time.Now()
|
||||
if !StateChanged(base, changed) {
|
||||
t.Error("StartedAt change should be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailLaunch(t *testing.T) {
|
||||
in := model.GuestSession{Status: model.GuestSessionStatusStarting, Attachable: true}
|
||||
out := FailLaunch(in, "provision", " ssh did not come up ", " raw output\n")
|
||||
if out.Status != model.GuestSessionStatusFailed {
|
||||
t.Errorf("Status = %q, want failed", out.Status)
|
||||
}
|
||||
if out.LastError != "ssh did not come up" {
|
||||
t.Errorf("LastError = %q (not trimmed?)", out.LastError)
|
||||
}
|
||||
if out.LaunchStage != "provision" || out.LaunchMessage != "ssh did not come up" {
|
||||
t.Errorf("launch fields not set: %+v", out)
|
||||
}
|
||||
if out.LaunchRawLog != "raw output" {
|
||||
t.Errorf("rawLog = %q (not trimmed?)", out.LaunchRawLog)
|
||||
}
|
||||
if out.Attachable {
|
||||
t.Error("Attachable should be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeRequiredCommands(t *testing.T) {
|
||||
got := NormalizeRequiredCommands("pi", []string{"pi", "git", "", "git", " ", "make"})
|
||||
want := []string{"pi", "git", "make"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("len = %d, want %d (%v)", len(got), len(want), got)
|
||||
}
|
||||
for i, v := range want {
|
||||
if got[i] != v {
|
||||
t.Errorf("position %d: got %q, want %q", i, got[i], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectScriptContainsAllStateFiles(t *testing.T) {
|
||||
script := InspectScript("sess-abc")
|
||||
for _, key := range []string{"status", "pid", "exit_code", "error", "alive"} {
|
||||
if !strings.Contains(script, key) {
|
||||
t.Errorf("script missing %q:\n%s", key, script)
|
||||
}
|
||||
}
|
||||
if !strings.Contains(script, "sess-abc") {
|
||||
t.Error("script missing session id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalScriptIncludesSignalAndDirPaths(t *testing.T) {
|
||||
script := SignalScript("sess-x", "TERM")
|
||||
if !strings.Contains(script, "TERM") {
|
||||
t.Error("missing signal")
|
||||
}
|
||||
if !strings.Contains(script, "sess-x") {
|
||||
t.Error("missing session id")
|
||||
}
|
||||
if !strings.Contains(script, "monitor_pid") || !strings.Contains(script, "stdin_keepalive") {
|
||||
t.Errorf("expected both monitor + stdin_keepalive kills, got:\n%s", script)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,224 +0,0 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"banger/internal/api"
|
||||
sess "banger/internal/daemon/session"
|
||||
"banger/internal/guest"
|
||||
"banger/internal/model"
|
||||
"banger/internal/sessionstream"
|
||||
)
|
||||
|
||||
func (d *Daemon) BeginGuestSessionAttach(ctx context.Context, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) {
|
||||
vm, err := d.FindVM(ctx, params.VMIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionAttachBeginResult{}, err
|
||||
}
|
||||
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionAttachBeginResult{}, err
|
||||
}
|
||||
session, _ = d.refreshGuestSession(ctx, vm, session)
|
||||
if !session.Attachable {
|
||||
return api.GuestSessionAttachBeginResult{}, errors.New("session is not attachable")
|
||||
}
|
||||
controller := &guestSessionController{}
|
||||
if !d.claimGuestSessionController(session.ID, controller) {
|
||||
return api.GuestSessionAttachBeginResult{}, errors.New("session already has an active attach")
|
||||
}
|
||||
attachID, err := model.NewID()
|
||||
if err != nil {
|
||||
d.clearGuestSessionController(session.ID)
|
||||
return api.GuestSessionAttachBeginResult{}, err
|
||||
}
|
||||
socketPath := filepath.Join(d.layout.RuntimeDir, "guest-session-attach-"+attachID[:12]+".sock")
|
||||
_ = os.Remove(socketPath)
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
d.clearGuestSessionController(session.ID)
|
||||
return api.GuestSessionAttachBeginResult{}, err
|
||||
}
|
||||
if err := os.Chmod(socketPath, 0o600); err != nil {
|
||||
_ = listener.Close()
|
||||
_ = os.Remove(socketPath)
|
||||
d.clearGuestSessionController(session.ID)
|
||||
return api.GuestSessionAttachBeginResult{}, err
|
||||
}
|
||||
go d.serveGuestSessionAttach(session, controller, attachID, socketPath, listener)
|
||||
return api.GuestSessionAttachBeginResult{
|
||||
Session: session,
|
||||
AttachID: attachID,
|
||||
TransportKind: sess.TransportUnixSocket,
|
||||
TransportTarget: socketPath,
|
||||
SocketPath: socketPath,
|
||||
StreamFormat: sessionstream.FormatV1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) forwardGuestSessionOutput(_ string, controller *guestSessionController, channel byte, reader io.Reader) {
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := reader.Read(buffer)
|
||||
if n > 0 {
|
||||
controller.writeFrame(channel, buffer[:n])
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
controller.writeControl(sessionstream.ControlMessage{Type: "error", Error: err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) waitForGuestSessionExit(id string, controller *guestSessionController, session model.GuestSession) {
|
||||
err := controller.stream.Wait()
|
||||
updated := session
|
||||
updated.Attachable = false
|
||||
now := model.Now()
|
||||
updated.UpdatedAt = now
|
||||
updated.EndedAt = now
|
||||
if exitCode, ok := sess.ExitCode(err); ok {
|
||||
updated.ExitCode = &exitCode
|
||||
if exitCode == 0 {
|
||||
updated.Status = model.GuestSessionStatusExited
|
||||
} else {
|
||||
updated.Status = model.GuestSessionStatusFailed
|
||||
}
|
||||
}
|
||||
if err != nil && updated.LastError == "" {
|
||||
updated.LastError = err.Error()
|
||||
}
|
||||
if vm, getErr := d.store.GetVMByID(context.Background(), updated.VMID); getErr == nil {
|
||||
if refreshed, refreshErr := d.refreshGuestSession(context.Background(), vm, updated); refreshErr == nil {
|
||||
updated = refreshed
|
||||
}
|
||||
}
|
||||
_ = d.store.UpsertGuestSession(context.Background(), updated)
|
||||
controller.writeControl(sessionstream.ControlMessage{Type: "exit", ExitCode: updated.ExitCode})
|
||||
_ = controller.close()
|
||||
d.clearGuestSessionController(id)
|
||||
}
|
||||
|
||||
func (d *Daemon) serveGuestSessionAttach(session model.GuestSession, controller *guestSessionController, _ string, socketPath string, listener net.Listener) {
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
_ = os.Remove(socketPath)
|
||||
_ = controller.close()
|
||||
d.clearGuestSessionController(session.ID)
|
||||
}()
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
if err := controller.setAttach(conn); err != nil {
|
||||
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
|
||||
return
|
||||
}
|
||||
defer controller.clearAttach(conn)
|
||||
if err := d.attachGuestSessionBridge(session, controller); err != nil {
|
||||
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
|
||||
return
|
||||
}
|
||||
for {
|
||||
channel, payload, err := sessionstream.ReadFrame(conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch channel {
|
||||
case sessionstream.ChannelStdin:
|
||||
if controller.stdin == nil {
|
||||
continue
|
||||
}
|
||||
if _, err := controller.stdin.Write(payload); err != nil {
|
||||
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
|
||||
return
|
||||
}
|
||||
case sessionstream.ChannelControl:
|
||||
message, err := sessionstream.ReadControl(payload)
|
||||
if err != nil {
|
||||
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
|
||||
return
|
||||
}
|
||||
if message.Type == "eof" && controller.stdin != nil {
|
||||
_ = controller.stdin.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) attachGuestSessionBridge(session model.GuestSession, controller *guestSessionController) error {
|
||||
vm, err := d.store.GetVMByID(context.Background(), session.VMID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !d.vmAlive(vm) {
|
||||
return fmt.Errorf("vm %q is not running", vm.Name)
|
||||
}
|
||||
address := net.JoinHostPort(vm.Runtime.GuestIP, "22")
|
||||
stdinStream, err := d.openGuestSessionAttachStream(address, sess.AttachInputCommand(session.ID))
|
||||
if err != nil {
|
||||
return fmt.Errorf("open guest session stdin stream: %w", err)
|
||||
}
|
||||
stdoutStream, err := d.openGuestSessionAttachStream(address, sess.AttachTailCommand(session.StdoutLogPath))
|
||||
if err != nil {
|
||||
_ = stdinStream.Close()
|
||||
return fmt.Errorf("open guest session stdout stream: %w", err)
|
||||
}
|
||||
stderrStream, err := d.openGuestSessionAttachStream(address, sess.AttachTailCommand(session.StderrLogPath))
|
||||
if err != nil {
|
||||
_ = stdinStream.Close()
|
||||
_ = stdoutStream.Close()
|
||||
return fmt.Errorf("open guest session stderr stream: %w", err)
|
||||
}
|
||||
controller.streams = append(controller.streams, stdinStream, stdoutStream, stderrStream)
|
||||
controller.stdin = stdinStream.Stdin()
|
||||
go d.forwardGuestSessionOutput(session.ID, controller, sessionstream.ChannelStdout, stdoutStream.Stdout())
|
||||
go d.forwardGuestSessionOutput(session.ID, controller, sessionstream.ChannelStderr, stderrStream.Stdout())
|
||||
go d.watchGuestSessionAttach(session.ID, controller, session)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Daemon) openGuestSessionAttachStream(address, command string) (*guest.StreamSession, error) {
|
||||
client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream, err := client.StartCommand(context.Background(), command)
|
||||
if err != nil {
|
||||
_ = client.Close()
|
||||
return nil, err
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) watchGuestSessionAttach(id string, controller *guestSessionController, session model.GuestSession) {
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
vm, err := d.store.GetVMByID(context.Background(), session.VMID)
|
||||
if err != nil {
|
||||
controller.writeControl(sessionstream.ControlMessage{Type: "error", Error: err.Error()})
|
||||
_ = controller.close()
|
||||
return
|
||||
}
|
||||
refreshed, err := d.refreshGuestSession(context.Background(), vm, session)
|
||||
if err == nil {
|
||||
session = refreshed
|
||||
}
|
||||
if session.Status == model.GuestSessionStatusExited || session.Status == model.GuestSessionStatusFailed {
|
||||
controller.writeControl(sessionstream.ControlMessage{Type: "exit", ExitCode: session.ExitCode})
|
||||
_ = controller.close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"banger/internal/guest"
|
||||
"banger/internal/sessionstream"
|
||||
)
|
||||
|
||||
type guestSessionController struct {
|
||||
stream *guest.StreamSession
|
||||
streams []*guest.StreamSession
|
||||
stdin io.WriteCloser
|
||||
attachMu sync.Mutex
|
||||
attach net.Conn
|
||||
writeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (c *guestSessionController) setAttach(conn net.Conn) error {
|
||||
c.attachMu.Lock()
|
||||
defer c.attachMu.Unlock()
|
||||
if c.attach != nil {
|
||||
return errors.New("session already has an active attach")
|
||||
}
|
||||
c.attach = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *guestSessionController) clearAttach(conn net.Conn) {
|
||||
c.attachMu.Lock()
|
||||
defer c.attachMu.Unlock()
|
||||
if c.attach == conn {
|
||||
c.attach = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *guestSessionController) writeFrame(channel byte, payload []byte) {
|
||||
c.attachMu.Lock()
|
||||
conn := c.attach
|
||||
c.attachMu.Unlock()
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
c.writeMu.Lock()
|
||||
err := sessionstream.WriteFrame(conn, channel, payload)
|
||||
c.writeMu.Unlock()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
c.clearAttach(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *guestSessionController) writeControl(message sessionstream.ControlMessage) {
|
||||
c.attachMu.Lock()
|
||||
conn := c.attach
|
||||
c.attachMu.Unlock()
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
c.writeMu.Lock()
|
||||
err := sessionstream.WriteControl(conn, message)
|
||||
c.writeMu.Unlock()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
c.clearAttach(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *guestSessionController) close() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
c.closeOnce.Do(func() {
|
||||
c.attachMu.Lock()
|
||||
conn := c.attach
|
||||
c.attach = nil
|
||||
c.attachMu.Unlock()
|
||||
if conn != nil {
|
||||
err = errors.Join(err, conn.Close())
|
||||
}
|
||||
if c.stdin != nil {
|
||||
err = errors.Join(err, c.stdin.Close())
|
||||
}
|
||||
if c.stream != nil {
|
||||
err = errors.Join(err, c.stream.Close())
|
||||
}
|
||||
for _, stream := range c.streams {
|
||||
if stream != nil {
|
||||
err = errors.Join(err, stream.Close())
|
||||
}
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// sessionRegistry owns the live guest-session controller map. Its lock is
|
||||
// independent of Daemon.mu so guest-session lookups do not contend with
|
||||
// unrelated daemon state.
|
||||
type sessionRegistry struct {
|
||||
mu sync.Mutex
|
||||
byID map[string]*guestSessionController
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newSessionRegistry() sessionRegistry {
|
||||
return sessionRegistry{byID: make(map[string]*guestSessionController)}
|
||||
}
|
||||
|
||||
func (r *sessionRegistry) set(id string, controller *guestSessionController) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.closed {
|
||||
return
|
||||
}
|
||||
r.byID[id] = controller
|
||||
}
|
||||
|
||||
func (r *sessionRegistry) claim(id string, controller *guestSessionController) bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.closed {
|
||||
return false
|
||||
}
|
||||
if r.byID[id] != nil {
|
||||
return false
|
||||
}
|
||||
r.byID[id] = controller
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *sessionRegistry) get(id string) *guestSessionController {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.byID[id]
|
||||
}
|
||||
|
||||
func (r *sessionRegistry) clear(id string) *guestSessionController {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
controller := r.byID[id]
|
||||
delete(r.byID, id)
|
||||
return controller
|
||||
}
|
||||
|
||||
func (r *sessionRegistry) closeAll() error {
|
||||
r.mu.Lock()
|
||||
controllers := make([]*guestSessionController, 0, len(r.byID))
|
||||
for _, controller := range r.byID {
|
||||
controllers = append(controllers, controller)
|
||||
}
|
||||
r.byID = nil
|
||||
r.closed = true
|
||||
r.mu.Unlock()
|
||||
var err error
|
||||
for _, controller := range controllers {
|
||||
err = errors.Join(err, controller.close())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Daemon) setGuestSessionController(id string, controller *guestSessionController) {
|
||||
d.sessions.set(id, controller)
|
||||
}
|
||||
|
||||
func (d *Daemon) claimGuestSessionController(id string, controller *guestSessionController) bool {
|
||||
return d.sessions.claim(id, controller)
|
||||
}
|
||||
|
||||
func (d *Daemon) getGuestSessionController(id string) *guestSessionController {
|
||||
return d.sessions.get(id)
|
||||
}
|
||||
|
||||
func (d *Daemon) clearGuestSessionController(id string) *guestSessionController {
|
||||
return d.sessions.clear(id)
|
||||
}
|
||||
|
||||
func (d *Daemon) closeGuestSessionControllers() error {
|
||||
return d.sessions.closeAll()
|
||||
}
|
||||
|
|
@ -1,213 +0,0 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"banger/internal/api"
|
||||
sess "banger/internal/daemon/session"
|
||||
"banger/internal/guest"
|
||||
"banger/internal/model"
|
||||
)
|
||||
|
||||
func (d *Daemon) StartGuestSession(ctx context.Context, params api.GuestSessionStartParams) (model.GuestSession, error) {
|
||||
stdinMode := model.GuestSessionStdinMode(strings.TrimSpace(params.StdinMode))
|
||||
if stdinMode == "" {
|
||||
stdinMode = model.GuestSessionStdinClosed
|
||||
}
|
||||
if stdinMode != model.GuestSessionStdinClosed && stdinMode != model.GuestSessionStdinPipe {
|
||||
return model.GuestSession{}, fmt.Errorf("unsupported stdin mode %q", params.StdinMode)
|
||||
}
|
||||
if strings.TrimSpace(params.Command) == "" {
|
||||
return model.GuestSession{}, errors.New("session command is required")
|
||||
}
|
||||
var created model.GuestSession
|
||||
_, err := d.withVMLockByRef(ctx, params.VMIDOrName, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
if !d.vmAlive(vm) {
|
||||
return model.VMRecord{}, fmt.Errorf("vm %q is not running", vm.Name)
|
||||
}
|
||||
session, err := d.startGuestSessionLocked(ctx, vm, params, stdinMode)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
created = session
|
||||
return vm, nil
|
||||
})
|
||||
return created, err
|
||||
}
|
||||
|
||||
func (d *Daemon) startGuestSessionLocked(ctx context.Context, vm model.VMRecord, params api.GuestSessionStartParams, stdinMode model.GuestSessionStdinMode) (model.GuestSession, error) {
|
||||
id, err := model.NewID()
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
now := model.Now()
|
||||
session := model.GuestSession{
|
||||
ID: id,
|
||||
VMID: vm.ID,
|
||||
Name: sess.DefaultName(id, params.Command, params.Name),
|
||||
Backend: sess.BackendSSH,
|
||||
Command: params.Command,
|
||||
Args: append([]string(nil), params.Args...),
|
||||
CWD: strings.TrimSpace(params.CWD),
|
||||
Env: sess.CloneStringMap(params.Env),
|
||||
StdinMode: stdinMode,
|
||||
Status: model.GuestSessionStatusStarting,
|
||||
GuestStateDir: sess.StateDir(id),
|
||||
StdoutLogPath: sess.StdoutLogPath(id),
|
||||
StderrLogPath: sess.StderrLogPath(id),
|
||||
Tags: sess.CloneStringMap(params.Tags),
|
||||
Attachable: stdinMode == model.GuestSessionStdinPipe,
|
||||
Reattachable: stdinMode == model.GuestSessionStdinPipe,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if session.Attachable {
|
||||
session.AttachBackend = sess.AttachBackendSSHBridge
|
||||
session.AttachMode = sess.AttachModeExclusive
|
||||
} else {
|
||||
session.AttachBackend = sess.AttachBackendNone
|
||||
}
|
||||
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
fail := func(stage, message, rawLog string) (model.GuestSession, error) {
|
||||
session = sess.FailLaunch(session, stage, message, rawLog)
|
||||
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
address := net.JoinHostPort(vm.Runtime.GuestIP, "22")
|
||||
if err := d.waitForGuestSSH(ctx, address, 250*time.Millisecond); err != nil {
|
||||
return fail("ssh_unavailable", fmt.Sprintf("guest ssh unavailable: %v", err), "")
|
||||
}
|
||||
client, err := d.dialGuest(ctx, address)
|
||||
if err != nil {
|
||||
return fail("dial_guest", fmt.Sprintf("dial guest ssh: %v", err), "")
|
||||
}
|
||||
defer client.Close()
|
||||
var preflightLog bytes.Buffer
|
||||
if err := client.RunScript(ctx, sess.CWDPreflightScript(session.CWD), &preflightLog); err != nil {
|
||||
return fail("preflight_cwd", fmt.Sprintf("guest working directory is unavailable: %s", sess.DefaultCWD(session.CWD)), preflightLog.String())
|
||||
}
|
||||
preflightLog.Reset()
|
||||
requiredCommands := sess.NormalizeRequiredCommands(params.Command, params.RequiredCommands)
|
||||
if err := client.RunScript(ctx, sess.CommandPreflightScript(requiredCommands), &preflightLog); err != nil {
|
||||
return fail("preflight_command", fmt.Sprintf("required guest command is unavailable: %s", strings.TrimSpace(preflightLog.String())), preflightLog.String())
|
||||
}
|
||||
var uploadLog bytes.Buffer
|
||||
if err := client.UploadFile(ctx, sess.ScriptPath(id), 0o755, []byte(sess.Script(session)), &uploadLog); err != nil {
|
||||
return fail("upload_script", "upload guest session script failed", uploadLog.String())
|
||||
}
|
||||
var launchLog bytes.Buffer
|
||||
launchScript := fmt.Sprintf("set -euo pipefail\nnohup bash %s >/dev/null 2>&1 </dev/null &\ndisown || true\n", sess.ShellQuote(sess.ScriptPath(id)))
|
||||
if err := client.RunScript(ctx, launchScript, &launchLog); err != nil {
|
||||
return fail("launch", "launch guest session failed", launchLog.String())
|
||||
}
|
||||
readyCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
updated, err := d.waitForGuestSessionReadyHook(readyCtx, vm, session)
|
||||
if err != nil {
|
||||
return fail("ready_wait", "guest session did not report ready state", err.Error())
|
||||
}
|
||||
session = updated
|
||||
if session.Status == model.GuestSessionStatusStarting {
|
||||
session.Status = model.GuestSessionStatusRunning
|
||||
session.StartedAt = model.Now()
|
||||
session.UpdatedAt = model.Now()
|
||||
}
|
||||
session.LaunchStage = ""
|
||||
session.LaunchMessage = ""
|
||||
session.LaunchRawLog = ""
|
||||
session.LastError = ""
|
||||
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) GetGuestSession(ctx context.Context, params api.GuestSessionRefParams) (model.GuestSession, error) {
|
||||
vm, err := d.FindVM(ctx, params.VMIDOrName)
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
return d.refreshGuestSession(ctx, vm, session)
|
||||
}
|
||||
|
||||
func (d *Daemon) ListGuestSessions(ctx context.Context, params api.VMRefParams) ([]model.GuestSession, error) {
|
||||
vm, err := d.FindVM(ctx, params.IDOrName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions, err := d.store.ListGuestSessionsByVM(ctx, vm.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for index := range sessions {
|
||||
refreshed, refreshErr := d.refreshGuestSession(ctx, vm, sessions[index])
|
||||
if refreshErr == nil {
|
||||
sessions[index] = refreshed
|
||||
}
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) StopGuestSession(ctx context.Context, params api.GuestSessionRefParams) (model.GuestSession, error) {
|
||||
return d.signalGuestSession(ctx, params, "TERM")
|
||||
}
|
||||
|
||||
func (d *Daemon) KillGuestSession(ctx context.Context, params api.GuestSessionRefParams) (model.GuestSession, error) {
|
||||
return d.signalGuestSession(ctx, params, "KILL")
|
||||
}
|
||||
|
||||
func (d *Daemon) signalGuestSession(ctx context.Context, params api.GuestSessionRefParams, signal string) (model.GuestSession, error) {
|
||||
vm, err := d.FindVM(ctx, params.VMIDOrName)
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
session, _ = d.refreshGuestSession(ctx, vm, session)
|
||||
if session.Status == model.GuestSessionStatusExited || session.Status == model.GuestSessionStatusFailed {
|
||||
return session, nil
|
||||
}
|
||||
if !d.vmAlive(vm) {
|
||||
session.Status = model.GuestSessionStatusFailed
|
||||
session.LastError = "vm is not running"
|
||||
now := model.Now()
|
||||
session.UpdatedAt = now
|
||||
session.EndedAt = now
|
||||
session.Attachable = false
|
||||
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
defer client.Close()
|
||||
var log bytes.Buffer
|
||||
if err := client.RunScript(ctx, sess.SignalScript(session.ID, signal), &log); err != nil {
|
||||
return model.GuestSession{}, sess.FormatStepError("signal guest session", err, log.String())
|
||||
}
|
||||
session.Status = model.GuestSessionStatusStopping
|
||||
session.UpdatedAt = model.Now()
|
||||
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
|
@ -1,120 +0,0 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"banger/internal/api"
|
||||
sess "banger/internal/daemon/session"
|
||||
"banger/internal/guest"
|
||||
"banger/internal/model"
|
||||
"banger/internal/system"
|
||||
)
|
||||
|
||||
func (d *Daemon) GuestSessionLogs(ctx context.Context, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) {
|
||||
vm, err := d.FindVM(ctx, params.VMIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionLogsResult{}, err
|
||||
}
|
||||
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionLogsResult{}, err
|
||||
}
|
||||
streamName := strings.TrimSpace(params.Stream)
|
||||
if streamName == "" {
|
||||
streamName = "stdout"
|
||||
}
|
||||
tailLines := params.TailLines
|
||||
if tailLines <= 0 {
|
||||
tailLines = sess.LogTailLineDefault
|
||||
}
|
||||
path := session.StdoutLogPath
|
||||
if streamName == "stderr" {
|
||||
path = session.StderrLogPath
|
||||
}
|
||||
content, err := d.readGuestSessionLog(ctx, vm, session, streamName, tailLines)
|
||||
if err != nil {
|
||||
return api.GuestSessionLogsResult{}, err
|
||||
}
|
||||
return api.GuestSessionLogsResult{Session: session, Stream: streamName, Path: path, Content: content}, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) SendToGuestSession(ctx context.Context, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
|
||||
vm, err := d.FindVM(ctx, params.VMIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionSendResult{}, err
|
||||
}
|
||||
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionSendResult{}, err
|
||||
}
|
||||
if session.StdinMode != model.GuestSessionStdinPipe {
|
||||
return api.GuestSessionSendResult{}, errors.New("session does not have a stdin pipe")
|
||||
}
|
||||
if session.Status != model.GuestSessionStatusRunning {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("session is not running (status=%s)", session.Status)
|
||||
}
|
||||
if !d.vmAlive(vm) {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("vm %q is not running", vm.Name)
|
||||
}
|
||||
if len(params.Payload) == 0 {
|
||||
return api.GuestSessionSendResult{Session: session}, nil
|
||||
}
|
||||
client, err := d.dialGuest(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"))
|
||||
if err != nil {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("dial guest: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
tmpPath := fmt.Sprintf("/tmp/banger-send-%s.bin", session.ID[:8])
|
||||
var uploadLog bytes.Buffer
|
||||
if err := client.UploadFile(ctx, tmpPath, 0o600, params.Payload, &uploadLog); err != nil {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("upload payload: %w", err)
|
||||
}
|
||||
sendScript := fmt.Sprintf(
|
||||
"set -euo pipefail\ncat %s >> %s\nrm -f %s\n",
|
||||
sess.ShellQuote(tmpPath),
|
||||
sess.ShellQuote(sess.StdinPipePath(session.ID)),
|
||||
sess.ShellQuote(tmpPath),
|
||||
)
|
||||
var sendLog bytes.Buffer
|
||||
if err := client.RunScript(ctx, sendScript, &sendLog); err != nil {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("send to session: %w: %s", err, strings.TrimSpace(sendLog.String()))
|
||||
}
|
||||
return api.GuestSessionSendResult{Session: session, BytesWritten: len(params.Payload)}, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) readGuestSessionLog(ctx context.Context, vm model.VMRecord, session model.GuestSession, stream string, tailLines int) (string, error) {
|
||||
if d.vmAlive(vm) {
|
||||
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer client.Close()
|
||||
path := session.StdoutLogPath
|
||||
if stream == "stderr" {
|
||||
path = session.StderrLogPath
|
||||
}
|
||||
var output bytes.Buffer
|
||||
script := fmt.Sprintf("set -euo pipefail\nif [ -f %s ]; then tail -n %d %s; fi\n", sess.ShellQuote(path), tailLines, sess.ShellQuote(path))
|
||||
if err := client.RunScript(ctx, script, &output); err != nil {
|
||||
return "", sess.FormatStepError("read guest session log", err, output.String())
|
||||
}
|
||||
return output.String(), nil
|
||||
}
|
||||
runner := d.runner
|
||||
if runner == nil {
|
||||
runner = system.NewRunner()
|
||||
}
|
||||
workMount, cleanup, err := system.MountTempDir(ctx, runner, vm.Runtime.WorkDiskPath, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer cleanup()
|
||||
logPath := filepath.Join(workMount, sess.RelativeStateDir(session.ID), stream+".log")
|
||||
return sess.TailFileContent(logPath, tailLines)
|
||||
}
|
||||
|
|
@ -10,7 +10,6 @@ import (
|
|||
"time"
|
||||
|
||||
"banger/internal/api"
|
||||
sess "banger/internal/daemon/session"
|
||||
ws "banger/internal/daemon/workspace"
|
||||
"banger/internal/model"
|
||||
)
|
||||
|
|
@ -114,9 +113,9 @@ func exportScript(guestPath, diffRef, diffFlag string) string {
|
|||
"git read-tree %s --index-output=\"$tmp_idx\"\n"+
|
||||
"GIT_INDEX_FILE=\"$tmp_idx\" git add -A\n"+
|
||||
"GIT_INDEX_FILE=\"$tmp_idx\" git diff --cached %s %s\n",
|
||||
sess.ShellQuote(guestPath),
|
||||
sess.ShellQuote(diffRef),
|
||||
sess.ShellQuote(diffRef),
|
||||
ws.ShellQuote(guestPath),
|
||||
ws.ShellQuote(diffRef),
|
||||
ws.ShellQuote(diffRef),
|
||||
diffFlag,
|
||||
)
|
||||
}
|
||||
|
|
@ -189,9 +188,9 @@ func (d *Daemon) prepareVMWorkspaceGuestIO(ctx context.Context, vm model.VMRecor
|
|||
}
|
||||
if readOnly {
|
||||
var chmodLog bytes.Buffer
|
||||
chmodScript := fmt.Sprintf("set -euo pipefail\nchmod -R a-w %s\n", sess.ShellQuote(guestPath))
|
||||
chmodScript := fmt.Sprintf("set -euo pipefail\nchmod -R a-w %s\n", ws.ShellQuote(guestPath))
|
||||
if err := client.RunScript(ctx, chmodScript, &chmodLog); err != nil {
|
||||
return model.WorkspacePrepareResult{}, sess.FormatStepError("set workspace readonly", err, chmodLog.String())
|
||||
return model.WorkspacePrepareResult{}, ws.FormatStepError("set workspace readonly", err, chmodLog.String())
|
||||
}
|
||||
}
|
||||
return model.WorkspacePrepareResult{
|
||||
|
|
|
|||
20
internal/daemon/workspace/util.go
Normal file
20
internal/daemon/workspace/util.go
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
package workspace
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ShellQuote returns value single-quoted for bash, escaping embedded quotes.
|
||||
func ShellQuote(value string) string {
|
||||
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
|
||||
}
|
||||
|
||||
// FormatStepError wraps err with an action label and trimmed on-guest log.
|
||||
func FormatStepError(action string, err error, log string) error {
|
||||
log = strings.TrimSpace(log)
|
||||
if log == "" {
|
||||
return fmt.Errorf("%s: %w", action, err)
|
||||
}
|
||||
return fmt.Errorf("%s: %w: %s", action, err, log)
|
||||
}
|
||||
|
|
@ -18,7 +18,6 @@ import (
|
|||
"sort"
|
||||
"strings"
|
||||
|
||||
sess "banger/internal/daemon/session"
|
||||
"banger/internal/model"
|
||||
"banger/internal/system"
|
||||
)
|
||||
|
|
@ -146,13 +145,13 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g
|
|||
switch mode {
|
||||
case model.WorkspacePrepareModeFullCopy:
|
||||
var copyLog bytes.Buffer
|
||||
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath), sess.ShellQuote(guestPath), sess.ShellQuote(guestPath))
|
||||
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath), ShellQuote(guestPath), ShellQuote(guestPath))
|
||||
if err := client.StreamTar(ctx, spec.RepoRoot, command, ©Log); err != nil {
|
||||
return sess.FormatStepError("copy full workspace", err, copyLog.String())
|
||||
return FormatStepError("copy full workspace", err, copyLog.String())
|
||||
}
|
||||
var finalizeLog bytes.Buffer
|
||||
if err := client.RunScript(ctx, FinalizeScript(spec, guestPath, mode), &finalizeLog); err != nil {
|
||||
return sess.FormatStepError("finalize full workspace", err, finalizeLog.String())
|
||||
return FormatStepError("finalize full workspace", err, finalizeLog.String())
|
||||
}
|
||||
return nil
|
||||
case model.WorkspacePrepareModeMetadataOnly, model.WorkspacePrepareModeShallowOverlay:
|
||||
|
|
@ -162,21 +161,21 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g
|
|||
}
|
||||
defer cleanup()
|
||||
var copyLog bytes.Buffer
|
||||
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath), sess.ShellQuote(guestPath), sess.ShellQuote(guestPath))
|
||||
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath), ShellQuote(guestPath), ShellQuote(guestPath))
|
||||
if err := client.StreamTar(ctx, repoCopyDir, command, ©Log); err != nil {
|
||||
return sess.FormatStepError("copy guest git metadata", err, copyLog.String())
|
||||
return FormatStepError("copy guest git metadata", err, copyLog.String())
|
||||
}
|
||||
var scriptLog bytes.Buffer
|
||||
if err := client.RunScript(ctx, FinalizeScript(spec, guestPath, mode), &scriptLog); err != nil {
|
||||
return sess.FormatStepError("prepare guest checkout", err, scriptLog.String())
|
||||
return FormatStepError("prepare guest checkout", err, scriptLog.String())
|
||||
}
|
||||
if mode == model.WorkspacePrepareModeMetadataOnly {
|
||||
return nil
|
||||
}
|
||||
var overlayLog bytes.Buffer
|
||||
command = fmt.Sprintf("tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath))
|
||||
command = fmt.Sprintf("tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath))
|
||||
if err := client.StreamTarEntries(ctx, spec.RepoRoot, spec.OverlayPaths, command, &overlayLog); err != nil {
|
||||
return sess.FormatStepError("overlay workspace working tree", err, overlayLog.String())
|
||||
return FormatStepError("overlay workspace working tree", err, overlayLog.String())
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
|
|
@ -190,22 +189,22 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g
|
|||
func FinalizeScript(spec RepoSpec, guestPath string, mode model.WorkspacePrepareMode) string {
|
||||
var script strings.Builder
|
||||
script.WriteString("set -euo pipefail\n")
|
||||
fmt.Fprintf(&script, "DIR=%s\n", sess.ShellQuote(guestPath))
|
||||
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(guestPath))
|
||||
script.WriteString("git config --global --add safe.directory \"$DIR\"\n")
|
||||
if mode != model.WorkspacePrepareModeFullCopy {
|
||||
script.WriteString("find \"$DIR\" -mindepth 1 -maxdepth 1 ! -name .git -exec rm -rf {} +\n")
|
||||
}
|
||||
switch {
|
||||
case strings.TrimSpace(spec.BranchName) != "":
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", sess.ShellQuote(spec.BranchName), sess.ShellQuote(spec.BaseCommit))
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", ShellQuote(spec.BranchName), ShellQuote(spec.BaseCommit))
|
||||
case strings.TrimSpace(spec.CurrentBranch) != "":
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", sess.ShellQuote(spec.CurrentBranch), sess.ShellQuote(spec.HeadCommit))
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", ShellQuote(spec.CurrentBranch), ShellQuote(spec.HeadCommit))
|
||||
default:
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", sess.ShellQuote(spec.HeadCommit))
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", ShellQuote(spec.HeadCommit))
|
||||
}
|
||||
if strings.TrimSpace(spec.GitUserName) != "" && strings.TrimSpace(spec.GitUserEmail) != "" {
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" config user.name %s\n", sess.ShellQuote(spec.GitUserName))
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" config user.email %s\n", sess.ShellQuote(spec.GitUserEmail))
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" config user.name %s\n", ShellQuote(spec.GitUserName))
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" config user.email %s\n", ShellQuote(spec.GitUserEmail))
|
||||
}
|
||||
return script.String()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,23 +34,6 @@ const (
|
|||
VMStateError VMState = "error"
|
||||
)
|
||||
|
||||
type GuestSessionStatus string
|
||||
|
||||
const (
|
||||
GuestSessionStatusStarting GuestSessionStatus = "starting"
|
||||
GuestSessionStatusRunning GuestSessionStatus = "running"
|
||||
GuestSessionStatusExited GuestSessionStatus = "exited"
|
||||
GuestSessionStatusFailed GuestSessionStatus = "failed"
|
||||
GuestSessionStatusStopping GuestSessionStatus = "stopping"
|
||||
)
|
||||
|
||||
type GuestSessionStdinMode string
|
||||
|
||||
const (
|
||||
GuestSessionStdinClosed GuestSessionStdinMode = "closed"
|
||||
GuestSessionStdinPipe GuestSessionStdinMode = "pipe"
|
||||
)
|
||||
|
||||
type DaemonConfig struct {
|
||||
LogLevel string
|
||||
FirecrackerBin string
|
||||
|
|
@ -176,37 +159,6 @@ type VMSetRequest struct {
|
|||
NATEnabled *bool
|
||||
}
|
||||
|
||||
type GuestSession struct {
|
||||
ID string `json:"id"`
|
||||
VMID string `json:"vm_id"`
|
||||
Name string `json:"name"`
|
||||
Backend string `json:"backend"`
|
||||
AttachBackend string `json:"attach_backend,omitempty"`
|
||||
AttachMode string `json:"attach_mode,omitempty"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
CWD string `json:"cwd,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
StdinMode GuestSessionStdinMode `json:"stdin_mode,omitempty"`
|
||||
Status GuestSessionStatus `json:"status"`
|
||||
ExitCode *int `json:"exit_code,omitempty"`
|
||||
GuestPID int `json:"guest_pid,omitempty"`
|
||||
GuestStateDir string `json:"guest_state_dir,omitempty"`
|
||||
StdoutLogPath string `json:"stdout_log_path,omitempty"`
|
||||
StderrLogPath string `json:"stderr_log_path,omitempty"`
|
||||
Tags map[string]string `json:"tags,omitempty"`
|
||||
LastError string `json:"last_error,omitempty"`
|
||||
Attachable bool `json:"attachable"`
|
||||
Reattachable bool `json:"reattachable"`
|
||||
LaunchStage string `json:"launch_stage,omitempty"`
|
||||
LaunchMessage string `json:"launch_message,omitempty"`
|
||||
LaunchRawLog string `json:"launch_raw_log,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt time.Time `json:"started_at,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
EndedAt time.Time `json:"ended_at,omitempty"`
|
||||
}
|
||||
|
||||
type WorkspacePrepareMode string
|
||||
|
||||
const (
|
||||
|
|
|
|||
|
|
@ -1,76 +0,0 @@
|
|||
package sessionstream
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelStdin byte = 0x01
|
||||
ChannelStdout byte = 0x02
|
||||
ChannelStderr byte = 0x03
|
||||
ChannelControl byte = 0x04
|
||||
FormatV1 = "stdio_mux_v1"
|
||||
)
|
||||
|
||||
type ControlMessage struct {
|
||||
Type string `json:"type"`
|
||||
ExitCode *int `json:"exit_code,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func WriteFrame(w io.Writer, channel byte, payload []byte) error {
|
||||
var header [5]byte
|
||||
header[0] = channel
|
||||
binary.BigEndian.PutUint32(header[1:], uint32(len(payload)))
|
||||
if _, err := w.Write(header[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := w.Write(payload)
|
||||
return err
|
||||
}
|
||||
|
||||
func ReadFrame(r io.Reader) (byte, []byte, error) {
|
||||
var header [5]byte
|
||||
if _, err := io.ReadFull(r, header[:]); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
length := binary.BigEndian.Uint32(header[1:])
|
||||
payload := make([]byte, length)
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
return header[0], payload, nil
|
||||
}
|
||||
|
||||
func WriteControl(w io.Writer, message ControlMessage) error {
|
||||
payload, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return WriteFrame(w, ChannelControl, payload)
|
||||
}
|
||||
|
||||
func ReadControl(payload []byte) (ControlMessage, error) {
|
||||
var message ControlMessage
|
||||
if err := json.Unmarshal(payload, &message); err != nil {
|
||||
return ControlMessage{}, err
|
||||
}
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func ReadNextControl(r io.Reader) (ControlMessage, error) {
|
||||
channel, payload, err := ReadFrame(r)
|
||||
if err != nil {
|
||||
return ControlMessage{}, err
|
||||
}
|
||||
if channel != ChannelControl {
|
||||
return ControlMessage{}, fmt.Errorf("unexpected channel %d", channel)
|
||||
}
|
||||
return ReadControl(payload)
|
||||
}
|
||||
|
|
@ -1,117 +0,0 @@
|
|||
package sessionstream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWriteReadFrameRoundtrip(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
channel byte
|
||||
payload []byte
|
||||
}{
|
||||
{"stdout_bytes", ChannelStdout, []byte("hello world")},
|
||||
{"stderr_bytes", ChannelStderr, []byte{0x00, 0xff, 0x7f}},
|
||||
{"empty_payload", ChannelStdin, nil},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
if err := WriteFrame(&buf, tc.channel, tc.payload); err != nil {
|
||||
t.Fatalf("WriteFrame: %v", err)
|
||||
}
|
||||
ch, got, err := ReadFrame(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrame: %v", err)
|
||||
}
|
||||
if ch != tc.channel {
|
||||
t.Fatalf("channel = %d, want %d", ch, tc.channel)
|
||||
}
|
||||
if !bytes.Equal(got, tc.payload) && !(len(got) == 0 && len(tc.payload) == 0) {
|
||||
t.Fatalf("payload = %q, want %q", got, tc.payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type shortWriter struct {
|
||||
failAfter int
|
||||
written int
|
||||
}
|
||||
|
||||
func (s *shortWriter) Write(p []byte) (int, error) {
|
||||
s.written += len(p)
|
||||
if s.written > s.failAfter {
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func TestWriteFrameWriterError(t *testing.T) {
|
||||
w := &shortWriter{failAfter: 2}
|
||||
err := WriteFrame(w, ChannelStdout, []byte("payload"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error from short writer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFrameTruncated(t *testing.T) {
|
||||
_, _, err := ReadFrame(bytes.NewReader([]byte{0x02, 0x00}))
|
||||
if !errors.Is(err, io.ErrUnexpectedEOF) && err == nil {
|
||||
t.Fatalf("expected EOF-ish error, got %v", err)
|
||||
}
|
||||
|
||||
// Header OK, but payload truncated.
|
||||
var buf bytes.Buffer
|
||||
buf.Write([]byte{ChannelStdout, 0x00, 0x00, 0x00, 0x05})
|
||||
buf.Write([]byte("ab"))
|
||||
if _, _, err := ReadFrame(&buf); err == nil {
|
||||
t.Fatal("expected truncated payload error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlRoundtrip(t *testing.T) {
|
||||
code := 42
|
||||
msg := ControlMessage{Type: "exit", ExitCode: &code}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := WriteControl(&buf, msg); err != nil {
|
||||
t.Fatalf("WriteControl: %v", err)
|
||||
}
|
||||
|
||||
got, err := ReadNextControl(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadNextControl: %v", err)
|
||||
}
|
||||
if got.Type != "exit" {
|
||||
t.Fatalf("type = %q, want exit", got.Type)
|
||||
}
|
||||
if got.ExitCode == nil || *got.ExitCode != 42 {
|
||||
t.Fatalf("exit_code = %v, want 42", got.ExitCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadControlBadJSON(t *testing.T) {
|
||||
if _, err := ReadControl([]byte("{not json")); err == nil {
|
||||
t.Fatal("expected JSON error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadNextControlWrongChannel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
if err := WriteFrame(&buf, ChannelStdout, []byte("not a control frame")); err != nil {
|
||||
t.Fatalf("WriteFrame: %v", err)
|
||||
}
|
||||
if _, err := ReadNextControl(&buf); err == nil {
|
||||
t.Fatal("expected error for non-control channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatConstant(t *testing.T) {
|
||||
if FormatV1 != "stdio_mux_v1" {
|
||||
t.Fatalf("FormatV1 = %q, want stdio_mux_v1", FormatV1)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,214 +0,0 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"banger/internal/model"
|
||||
)
|
||||
|
||||
func sampleGuestSession(id, vmID, name string) model.GuestSession {
|
||||
now := fixedTime()
|
||||
exit := 7
|
||||
return model.GuestSession{
|
||||
ID: id,
|
||||
VMID: vmID,
|
||||
Name: name,
|
||||
Backend: "ssh",
|
||||
AttachBackend: "vsock",
|
||||
AttachMode: "rpc",
|
||||
Command: "pi",
|
||||
Args: []string{"--mode", "rpc"},
|
||||
CWD: "/root/repo",
|
||||
Env: map[string]string{"FOO": "bar"},
|
||||
StdinMode: model.GuestSessionStdinMode("pipe"),
|
||||
Status: model.GuestSessionStatus("exited"),
|
||||
ExitCode: &exit,
|
||||
GuestPID: 1234,
|
||||
GuestStateDir: "/tmp/guest-" + id,
|
||||
StdoutLogPath: "/tmp/" + id + ".stdout",
|
||||
StderrLogPath: "/tmp/" + id + ".stderr",
|
||||
Tags: map[string]string{"role": "planner"},
|
||||
LastError: "",
|
||||
Attachable: true,
|
||||
Reattachable: true,
|
||||
LaunchStage: "started",
|
||||
LaunchMessage: "ok",
|
||||
LaunchRawLog: "boot log...",
|
||||
CreatedAt: now,
|
||||
StartedAt: now,
|
||||
UpdatedAt: now,
|
||||
EndedAt: now.Add(time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
// openTestStoreWithVMs opens a fresh store seeded with the given VM IDs so
|
||||
// guest_sessions FK constraints are satisfied. Each VM gets a minimal
|
||||
// image it references.
|
||||
func openTestStoreWithVMs(t *testing.T, vmIDs ...string) *Store {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
store := openTestStore(t)
|
||||
|
||||
image := sampleImage("stub-image")
|
||||
if err := store.UpsertImage(ctx, image); err != nil {
|
||||
t.Fatalf("UpsertImage: %v", err)
|
||||
}
|
||||
for i, id := range vmIDs {
|
||||
vm := sampleVM(id, image.ID, fmt.Sprintf("172.16.0.%d", i+2))
|
||||
vm.ID = id
|
||||
if err := store.UpsertVM(ctx, vm); err != nil {
|
||||
t.Fatalf("UpsertVM(%s): %v", id, err)
|
||||
}
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func TestGuestSessionUpsertAndGetByID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
store := openTestStoreWithVMs(t, "vm-1")
|
||||
|
||||
session := sampleGuestSession("sess-1", "vm-1", "planner")
|
||||
if err := store.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
got, err := store.GetGuestSessionByID(ctx, "sess-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetGuestSessionByID: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, session) {
|
||||
t.Fatalf("round-trip mismatch:\n got %+v\n want %+v", got, session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuestSessionUpsertIsIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
store := openTestStoreWithVMs(t, "vm-1")
|
||||
|
||||
session := sampleGuestSession("sess-1", "vm-1", "planner")
|
||||
if err := store.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession (first): %v", err)
|
||||
}
|
||||
|
||||
// Mutate + re-upsert → existing row updated.
|
||||
session.Command = "pi --other"
|
||||
session.Status = model.GuestSessionStatus("running")
|
||||
session.ExitCode = nil
|
||||
if err := store.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession (second): %v", err)
|
||||
}
|
||||
|
||||
got, err := store.GetGuestSessionByID(ctx, "sess-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetGuestSessionByID: %v", err)
|
||||
}
|
||||
if got.Command != "pi --other" {
|
||||
t.Errorf("command = %q, want 'pi --other'", got.Command)
|
||||
}
|
||||
if got.Status != model.GuestSessionStatus("running") {
|
||||
t.Errorf("status = %q, want running", got.Status)
|
||||
}
|
||||
if got.ExitCode != nil {
|
||||
t.Errorf("ExitCode = %v, want nil after clearing", got.ExitCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGuestSessionByIDOrName(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
store := openTestStoreWithVMs(t, "vm-1")
|
||||
|
||||
session := sampleGuestSession("sess-1", "vm-1", "planner")
|
||||
if err := store.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
byID, err := store.GetGuestSession(ctx, "vm-1", "sess-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetGuestSession by ID: %v", err)
|
||||
}
|
||||
if byID.ID != "sess-1" {
|
||||
t.Errorf("by-ID: got %q, want sess-1", byID.ID)
|
||||
}
|
||||
|
||||
byName, err := store.GetGuestSession(ctx, "vm-1", "planner")
|
||||
if err != nil {
|
||||
t.Fatalf("GetGuestSession by name: %v", err)
|
||||
}
|
||||
if byName.Name != "planner" {
|
||||
t.Errorf("by-name: got %q, want planner", byName.Name)
|
||||
}
|
||||
|
||||
// Scoped to the VM.
|
||||
if _, err := store.GetGuestSession(ctx, "vm-unknown", "sess-1"); !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Errorf("wrong-vm lookup = %v, want sql.ErrNoRows", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListGuestSessionsByVMOrdersByCreatedAt(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
store := openTestStoreWithVMs(t, "vm-1", "vm-2")
|
||||
|
||||
base := fixedTime()
|
||||
first := sampleGuestSession("sess-early", "vm-1", "first")
|
||||
first.CreatedAt = base
|
||||
second := sampleGuestSession("sess-late", "vm-1", "second")
|
||||
second.CreatedAt = base.Add(time.Hour)
|
||||
other := sampleGuestSession("sess-other", "vm-2", "other")
|
||||
|
||||
for _, s := range []model.GuestSession{second, first, other} {
|
||||
if err := store.UpsertGuestSession(ctx, s); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
sessions, err := store.ListGuestSessionsByVM(ctx, "vm-1")
|
||||
if err != nil {
|
||||
t.Fatalf("ListGuestSessionsByVM: %v", err)
|
||||
}
|
||||
if len(sessions) != 2 {
|
||||
t.Fatalf("len = %d, want 2 (vm-1 only)", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != "sess-early" || sessions[1].ID != "sess-late" {
|
||||
t.Fatalf("order: got %q, %q; want sess-early, sess-late", sessions[0].ID, sessions[1].ID)
|
||||
}
|
||||
|
||||
empty, err := store.ListGuestSessionsByVM(ctx, "vm-unknown")
|
||||
if err != nil {
|
||||
t.Fatalf("ListGuestSessionsByVM (unknown vm): %v", err)
|
||||
}
|
||||
if len(empty) != 0 {
|
||||
t.Fatalf("unknown vm sessions = %+v, want empty", empty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteGuestSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
store := openTestStoreWithVMs(t, "vm-1")
|
||||
|
||||
session := sampleGuestSession("sess-1", "vm-1", "planner")
|
||||
if err := store.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
if err := store.DeleteGuestSession(ctx, "sess-1"); err != nil {
|
||||
t.Fatalf("DeleteGuestSession: %v", err)
|
||||
}
|
||||
if _, err := store.GetGuestSessionByID(ctx, "sess-1"); !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("after delete err = %v, want sql.ErrNoRows", err)
|
||||
}
|
||||
|
||||
// Deleting something that doesn't exist is a no-op (matches SQL DELETE semantics).
|
||||
if err := store.DeleteGuestSession(ctx, "sess-nope"); err != nil {
|
||||
t.Fatalf("DeleteGuestSession on missing row: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -99,32 +99,6 @@ func (s *Store) migrate() error {
|
|||
stats_json TEXT NOT NULL DEFAULT '{}',
|
||||
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE RESTRICT
|
||||
);`,
|
||||
`CREATE TABLE IF NOT EXISTS guest_sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
vm_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
backend TEXT NOT NULL,
|
||||
command TEXT NOT NULL,
|
||||
args_json TEXT NOT NULL DEFAULT '[]',
|
||||
cwd TEXT,
|
||||
env_json TEXT NOT NULL DEFAULT '{}',
|
||||
stdin_mode TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
exit_code INTEGER,
|
||||
guest_pid INTEGER NOT NULL DEFAULT 0,
|
||||
guest_state_dir TEXT,
|
||||
stdout_log_path TEXT,
|
||||
stderr_log_path TEXT,
|
||||
tags_json TEXT NOT NULL DEFAULT '{}',
|
||||
last_error TEXT,
|
||||
attachable INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
updated_at TEXT NOT NULL,
|
||||
ended_at TEXT,
|
||||
UNIQUE(vm_id, name),
|
||||
FOREIGN KEY(vm_id) REFERENCES vms(id) ON DELETE CASCADE
|
||||
);`,
|
||||
}
|
||||
for _, stmt := range stmts {
|
||||
if _, err := s.db.Exec(stmt); err != nil {
|
||||
|
|
@ -137,18 +111,6 @@ func (s *Store) migrate() error {
|
|||
if err := ensureColumnExists(s.db, "images", "seeded_ssh_public_key_fingerprint", "TEXT"); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, spec := range []struct{ table, column, typ string }{
|
||||
{"guest_sessions", "attach_backend", "TEXT"},
|
||||
{"guest_sessions", "attach_mode", "TEXT"},
|
||||
{"guest_sessions", "reattachable", "INTEGER NOT NULL DEFAULT 0"},
|
||||
{"guest_sessions", "launch_stage", "TEXT"},
|
||||
{"guest_sessions", "launch_message", "TEXT"},
|
||||
{"guest_sessions", "launch_raw_log", "TEXT"},
|
||||
} {
|
||||
if err := ensureColumnExists(s.db, spec.table, spec.column, spec.typ); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -336,122 +298,6 @@ func (s *Store) FindVMsUsingImage(ctx context.Context, imageID string) ([]model.
|
|||
return vms, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) UpsertGuestSession(ctx context.Context, session model.GuestSession) error {
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
argsJSON, err := json.Marshal(session.Args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
envJSON, err := json.Marshal(session.Env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tagsJSON, err := json.Marshal(session.Tags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
const query = `
|
||||
INSERT INTO guest_sessions (
|
||||
id, vm_id, name, backend, attach_backend, attach_mode, command, args_json, cwd, env_json, stdin_mode, status,
|
||||
exit_code, guest_pid, guest_state_dir, stdout_log_path, stderr_log_path, tags_json,
|
||||
last_error, attachable, reattachable, launch_stage, launch_message, launch_raw_log,
|
||||
created_at, started_at, updated_at, ended_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
vm_id=excluded.vm_id,
|
||||
name=excluded.name,
|
||||
backend=excluded.backend,
|
||||
attach_backend=excluded.attach_backend,
|
||||
attach_mode=excluded.attach_mode,
|
||||
command=excluded.command,
|
||||
args_json=excluded.args_json,
|
||||
cwd=excluded.cwd,
|
||||
env_json=excluded.env_json,
|
||||
stdin_mode=excluded.stdin_mode,
|
||||
status=excluded.status,
|
||||
exit_code=excluded.exit_code,
|
||||
guest_pid=excluded.guest_pid,
|
||||
guest_state_dir=excluded.guest_state_dir,
|
||||
stdout_log_path=excluded.stdout_log_path,
|
||||
stderr_log_path=excluded.stderr_log_path,
|
||||
tags_json=excluded.tags_json,
|
||||
last_error=excluded.last_error,
|
||||
attachable=excluded.attachable,
|
||||
reattachable=excluded.reattachable,
|
||||
launch_stage=excluded.launch_stage,
|
||||
launch_message=excluded.launch_message,
|
||||
launch_raw_log=excluded.launch_raw_log,
|
||||
started_at=excluded.started_at,
|
||||
updated_at=excluded.updated_at,
|
||||
ended_at=excluded.ended_at`
|
||||
_, err = s.db.ExecContext(ctx, query,
|
||||
session.ID,
|
||||
session.VMID,
|
||||
session.Name,
|
||||
session.Backend,
|
||||
session.AttachBackend,
|
||||
session.AttachMode,
|
||||
session.Command,
|
||||
string(argsJSON),
|
||||
session.CWD,
|
||||
string(envJSON),
|
||||
string(session.StdinMode),
|
||||
string(session.Status),
|
||||
nullableInt(session.ExitCode),
|
||||
session.GuestPID,
|
||||
session.GuestStateDir,
|
||||
session.StdoutLogPath,
|
||||
session.StderrLogPath,
|
||||
string(tagsJSON),
|
||||
session.LastError,
|
||||
boolToInt(session.Attachable),
|
||||
boolToInt(session.Reattachable),
|
||||
session.LaunchStage,
|
||||
session.LaunchMessage,
|
||||
session.LaunchRawLog,
|
||||
session.CreatedAt.Format(time.RFC3339),
|
||||
nullableTimeString(session.StartedAt),
|
||||
session.UpdatedAt.Format(time.RFC3339),
|
||||
nullableTimeString(session.EndedAt),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) GetGuestSessionByID(ctx context.Context, id string) (model.GuestSession, error) {
|
||||
row := s.db.QueryRowContext(ctx, guestSessionSelectSQL+" WHERE id = ?", id)
|
||||
return scanGuestSessionRow(row)
|
||||
}
|
||||
|
||||
func (s *Store) GetGuestSession(ctx context.Context, vmID, idOrName string) (model.GuestSession, error) {
|
||||
row := s.db.QueryRowContext(ctx, guestSessionSelectSQL+" WHERE vm_id = ? AND (id = ? OR name = ?)", vmID, idOrName, idOrName)
|
||||
return scanGuestSessionRow(row)
|
||||
}
|
||||
|
||||
func (s *Store) ListGuestSessionsByVM(ctx context.Context, vmID string) ([]model.GuestSession, error) {
|
||||
rows, err := s.db.QueryContext(ctx, guestSessionSelectSQL+" WHERE vm_id = ? ORDER BY created_at ASC", vmID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var sessions []model.GuestSession
|
||||
for rows.Next() {
|
||||
session, err := scanGuestSession(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
return sessions, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) DeleteGuestSession(ctx context.Context, id string) error {
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
_, err := s.db.ExecContext(ctx, "DELETE FROM guest_sessions WHERE id = ?", id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) NextGuestIP(ctx context.Context, bridgeIPPrefix string) (string, error) {
|
||||
used := map[string]struct{}{}
|
||||
rows, err := s.db.QueryContext(ctx, "SELECT guest_ip FROM vms")
|
||||
|
|
@ -622,113 +468,6 @@ func boolToInt(value bool) int {
|
|||
return 0
|
||||
}
|
||||
|
||||
const guestSessionSelectSQL = `
|
||||
SELECT id, vm_id, name, backend, attach_backend, attach_mode, command, args_json, cwd, env_json, stdin_mode, status,
|
||||
exit_code, guest_pid, guest_state_dir, stdout_log_path, stderr_log_path, tags_json,
|
||||
last_error, attachable, reattachable, launch_stage, launch_message, launch_raw_log,
|
||||
created_at, started_at, updated_at, ended_at
|
||||
FROM guest_sessions`
|
||||
|
||||
func scanGuestSession(rows scanner) (model.GuestSession, error) {
|
||||
return scanGuestSessionRow(rows)
|
||||
}
|
||||
|
||||
func scanGuestSessionRow(row scanner) (model.GuestSession, error) {
|
||||
var session model.GuestSession
|
||||
var (
|
||||
argsJSON string
|
||||
envJSON string
|
||||
tagsJSON string
|
||||
stdinMode string
|
||||
status string
|
||||
exitCode sql.NullInt64
|
||||
startedAt sql.NullString
|
||||
endedAt sql.NullString
|
||||
attachable int
|
||||
reattachable int
|
||||
createdRaw string
|
||||
updatedRaw string
|
||||
)
|
||||
err := row.Scan(
|
||||
&session.ID,
|
||||
&session.VMID,
|
||||
&session.Name,
|
||||
&session.Backend,
|
||||
&session.AttachBackend,
|
||||
&session.AttachMode,
|
||||
&session.Command,
|
||||
&argsJSON,
|
||||
&session.CWD,
|
||||
&envJSON,
|
||||
&stdinMode,
|
||||
&status,
|
||||
&exitCode,
|
||||
&session.GuestPID,
|
||||
&session.GuestStateDir,
|
||||
&session.StdoutLogPath,
|
||||
&session.StderrLogPath,
|
||||
&tagsJSON,
|
||||
&session.LastError,
|
||||
&attachable,
|
||||
&reattachable,
|
||||
&session.LaunchStage,
|
||||
&session.LaunchMessage,
|
||||
&session.LaunchRawLog,
|
||||
&createdRaw,
|
||||
&startedAt,
|
||||
&updatedRaw,
|
||||
&endedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return session, err
|
||||
}
|
||||
session.StdinMode = model.GuestSessionStdinMode(stdinMode)
|
||||
session.Status = model.GuestSessionStatus(status)
|
||||
session.Attachable = attachable == 1
|
||||
session.Reattachable = reattachable == 1
|
||||
if argsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(argsJSON), &session.Args); err != nil {
|
||||
return session, err
|
||||
}
|
||||
}
|
||||
if envJSON != "" {
|
||||
if err := json.Unmarshal([]byte(envJSON), &session.Env); err != nil {
|
||||
return session, err
|
||||
}
|
||||
}
|
||||
if tagsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(tagsJSON), &session.Tags); err != nil {
|
||||
return session, err
|
||||
}
|
||||
}
|
||||
if exitCode.Valid {
|
||||
value := int(exitCode.Int64)
|
||||
session.ExitCode = &value
|
||||
}
|
||||
var parseErr error
|
||||
session.CreatedAt, parseErr = time.Parse(time.RFC3339, createdRaw)
|
||||
if parseErr != nil {
|
||||
return session, parseErr
|
||||
}
|
||||
session.UpdatedAt, parseErr = time.Parse(time.RFC3339, updatedRaw)
|
||||
if parseErr != nil {
|
||||
return session, parseErr
|
||||
}
|
||||
if startedAt.Valid && startedAt.String != "" {
|
||||
session.StartedAt, parseErr = time.Parse(time.RFC3339, startedAt.String)
|
||||
if parseErr != nil {
|
||||
return session, parseErr
|
||||
}
|
||||
}
|
||||
if endedAt.Valid && endedAt.String != "" {
|
||||
session.EndedAt, parseErr = time.Parse(time.RFC3339, endedAt.String)
|
||||
if parseErr != nil {
|
||||
return session, parseErr
|
||||
}
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func nullableTimeString(value time.Time) any {
|
||||
if value.IsZero() {
|
||||
return nil
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue