remove vm session feature

Cuts the daemon-managed guest-session machinery (start/list/show/
logs/stop/kill/attach/send). The feature shipped aimed at agent-
orchestration workflows (programmatic stdin piping into a long-lived
guest process) that aren't driving any concrete user today, and the
~2.3K LOC of daemon surface area — attach bridge, FIFO keepalive,
controller registry, sessionstream framing, SQLite persistence — was
locking in an API we'd have to keep through v0.1.0.

Anything session-flavoured that people actually need today can be
done with `vm ssh + tmux` or `vm run -- cmd`.

Deleted:
- internal/cli/commands_vm_session.go
- internal/daemon/{guest_sessions,session_lifecycle,session_attach,session_stream,session_controller}.go
- internal/daemon/session/ (guest-session helpers package)
- internal/sessionstream/ (framing package)
- internal/daemon/guest_sessions_test.go
- internal/store/guest_session_test.go
- GuestSession* types from internal/{api,model}
- Store UpsertGuestSession/GetGuestSession/ListGuestSessionsByVM/DeleteGuestSession + scanner helpers
- guest.session.* RPC dispatch entries
- 5 CLI session tests, 2 completion tests, 2 printer tests

Extracted:
- ShellQuote + FormatStepError lifted to internal/daemon/workspace/util.go
  (only non-session consumer); workspace package now self-contained
- internal/daemon/guest_ssh.go keeps guestSSHClient + dialGuest +
  waitForGuestSSH — still used by workspace prepare/export
- internal/daemon/fake_firecracker_test.go preserves the test helper
  that used to live in guest_sessions_test.go

Store schema: CREATE TABLE guest_sessions and its column migrations
removed. Existing dev DBs keep an orphan table (harmless, pre-v0.1.0).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Thales Maciel 2026-04-20 12:47:58 -03:00
parent c42fcbe012
commit 2b6437d1b4
No known key found for this signature in database
GPG key ID: 33112E6833C34679
34 changed files with 194 additions and 4031 deletions

View file

@ -49,7 +49,7 @@ User data stays in place — the target prints the paths so you can
`banger` ships completion scripts for bash, zsh, fish, and
powershell. Tab-completion covers subcommands, flags, and live
resource names (VM, image, kernel, session) looked up from the
resource names (VM, image, kernel) looked up from the
daemon. With the daemon down, resource completion silently
returns nothing — no file-completion fallback.
@ -179,7 +179,7 @@ ones you want.
## Advanced
The common path is `vm run`. Power-user flows (`vm create`, OCI pull
for arbitrary images, `image register`, long-lived sessions) are
for arbitrary images, `image register`, manual workspace prepare) are
documented in [`docs/advanced.md`](docs/advanced.md).
## Security

View file

@ -60,33 +60,19 @@ disk, or pass `--kernel /abs/path/vmlinux` for a one-off kernel.
For reproducible custom images, write a Dockerfile and publish it to
an image catalog. See [`docs/image-catalog.md`](image-catalog.md).
## Workspace + session primitives
## Workspace primitive
Long-lived guest commands managed by the daemon, attachable over a
local Unix socket bridge. Useful for agent/background processes that
need to survive SSH disconnects.
`vm run ./repo` (see README) handles the common case. For a manual
flow against an already-running VM, `vm workspace prepare`
materialises a local git checkout into the guest:
```bash
banger vm workspace prepare <vm> ./other-repo --guest-path /root/repo
banger vm session start <vm> --name planner --cwd /root/repo \
--stdin-mode pipe -- pi --mode rpc
banger vm session attach <vm> planner
banger vm session logs <vm> planner --stream stderr
banger vm session stop <vm> planner
```
Details:
- `vm workspace prepare` materialises a local git checkout into a
running VM. Default guest path `/root/repo`; default mode is a
shallow metadata copy plus tracked and untracked non-ignored
overlay.
- `vm session start` launches a daemon-managed long-lived guest
command. The daemon preflights that the guest `cwd` exists and the
command is on guest `PATH` before launch. Use `--stdin-mode pipe`
when you need live `attach`.
- `vm session attach` is exclusive and same-host only. Pipe-mode
sessions survive daemon restarts.
Default guest path is `/root/repo`; default mode is a shallow metadata
copy plus tracked and untracked non-ignored overlay. For repositories
with submodules, pass `--mode full_copy`.
## Inspecting boot failures

View file

@ -122,70 +122,6 @@ type VMPortsResult struct {
Ports []VMPort `json:"ports"`
}
type GuestSessionStartParams struct {
VMIDOrName string `json:"vm_id_or_name"`
Name string `json:"name,omitempty"`
Command string `json:"command"`
Args []string `json:"args,omitempty"`
CWD string `json:"cwd,omitempty"`
Env map[string]string `json:"env,omitempty"`
StdinMode string `json:"stdin_mode,omitempty"`
Tags map[string]string `json:"tags,omitempty"`
RequiredCommands []string `json:"required_commands,omitempty"`
}
type GuestSessionRefParams struct {
VMIDOrName string `json:"vm_id_or_name"`
SessionIDOrName string `json:"session_id_or_name"`
}
type GuestSessionLogsParams struct {
VMIDOrName string `json:"vm_id_or_name"`
SessionIDOrName string `json:"session_id_or_name"`
Stream string `json:"stream,omitempty"`
TailLines int `json:"tail_lines,omitempty"`
}
type GuestSessionAttachBeginParams struct {
VMIDOrName string `json:"vm_id_or_name"`
SessionIDOrName string `json:"session_id_or_name"`
}
type GuestSessionListResult struct {
Sessions []model.GuestSession `json:"sessions"`
}
type GuestSessionShowResult struct {
Session model.GuestSession `json:"session"`
}
type GuestSessionLogsResult struct {
Session model.GuestSession `json:"session"`
Stream string `json:"stream"`
Path string `json:"path,omitempty"`
Content string `json:"content,omitempty"`
}
type GuestSessionAttachBeginResult struct {
Session model.GuestSession `json:"session"`
AttachID string `json:"attach_id"`
TransportKind string `json:"transport_kind"`
TransportTarget string `json:"transport_target"`
SocketPath string `json:"socket_path,omitempty"`
StreamFormat string `json:"stream_format"`
}
type GuestSessionSendParams struct {
VMIDOrName string `json:"vm_id_or_name"`
SessionIDOrName string `json:"session_id_or_name"`
Payload []byte `json:"payload"`
}
type GuestSessionSendResult struct {
Session model.GuestSession `json:"session"`
BytesWritten int `json:"bytes_written"`
}
type WorkspaceExportParams struct {
IDOrName string `json:"id_or_name"`
GuestPath string `json:"guest_path,omitempty"`

View file

@ -46,7 +46,6 @@ func TestListCommandsHaveLsAlias(t *testing.T) {
{"vm", "list"},
{"image", "list"},
{"kernel", "list"},
{"vm", "session", "list"},
}
for _, path := range cases {
t.Run(path[len(path)-1], func(t *testing.T) {

View file

@ -1918,34 +1918,9 @@ func (c *testVMRunGuestClient) StreamTarEntries(ctx context.Context, sourceDir s
return nil
}
func TestVMSessionSendCommandExists(t *testing.T) {
root := NewBangerCommand()
vm, _, err := root.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
session, _, err := vm.Find([]string{"session"})
if err != nil {
t.Fatalf("find session: %v", err)
}
if _, _, err := session.Find([]string{"send"}); err != nil {
t.Fatalf("find session send: %v", err)
}
}
func TestVMSessionSendRejectsWrongArgCount(t *testing.T) {
cmd := NewBangerCommand()
cmd.SetArgs([]string{"vm", "session", "send", "only-one-arg"})
err := cmd.Execute()
if err == nil || !strings.Contains(err.Error(), "usage: banger vm session send") {
t.Fatalf("Execute() error = %v, want send usage error", err)
}
}
// stubEnsureDaemonForSend isolates XDG dirs and installs a daemon-ping
// fake onto the caller's *deps so `ensureDaemon` short-circuits without
// trying to spawn bangerd. `vm session send` uses this to avoid needing
// a built binary on disk.
// trying to spawn bangerd.
func stubEnsureDaemonForSend(t *testing.T, d *deps) {
t.Helper()
t.Setenv("XDG_CONFIG_HOME", filepath.Join(t.TempDir(), "config"))
@ -1956,98 +1931,6 @@ func stubEnsureDaemonForSend(t *testing.T, d *deps) {
}
}
func TestVMSessionSendWithMessageFlag(t *testing.T) {
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
var capturedParams api.GuestSessionSendParams
d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
capturedParams = params
return api.GuestSessionSendResult{
Session: model.GuestSession{ID: "sess-id", Name: "planner"},
BytesWritten: len(params.Payload),
}, nil
}
cmd := d.newRootCommand()
var out bytes.Buffer
cmd.SetOut(&out)
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner", "--message", `{"type":"abort"}`})
if err := cmd.Execute(); err != nil {
t.Fatalf("Execute: %v", err)
}
wantPayload := []byte(`{"type":"abort"}` + "\n")
if string(capturedParams.Payload) != string(wantPayload) {
t.Fatalf("payload = %q, want %q", capturedParams.Payload, wantPayload)
}
if capturedParams.VMIDOrName != "devbox" {
t.Fatalf("VMIDOrName = %q, want %q", capturedParams.VMIDOrName, "devbox")
}
if capturedParams.SessionIDOrName != "planner" {
t.Fatalf("SessionIDOrName = %q, want %q", capturedParams.SessionIDOrName, "planner")
}
if !strings.Contains(out.String(), "17") {
t.Fatalf("output = %q, want bytes_written in output", out.String())
}
}
func TestVMSessionSendMessageAlreadyHasNewline(t *testing.T) {
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
var capturedPayload []byte
d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
capturedPayload = params.Payload
return api.GuestSessionSendResult{
Session: model.GuestSession{Name: "s"},
BytesWritten: len(params.Payload),
}, nil
}
cmd := d.newRootCommand()
cmd.SetOut(io.Discard)
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "s", "--message", "{\"type\":\"abort\"}\n"})
if err := cmd.Execute(); err != nil {
t.Fatalf("Execute: %v", err)
}
// Must not double-append newline.
if capturedPayload[len(capturedPayload)-1] != '\n' {
t.Fatalf("payload missing trailing newline: %q", capturedPayload)
}
if len(capturedPayload) > 0 && capturedPayload[len(capturedPayload)-2] == '\n' {
t.Fatalf("payload has double trailing newline: %q", capturedPayload)
}
}
func TestVMSessionSendFromStdin(t *testing.T) {
d := defaultDeps()
stubEnsureDaemonForSend(t, d)
var capturedPayload []byte
d.guestSessionSend = func(_ context.Context, _ string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
capturedPayload = params.Payload
return api.GuestSessionSendResult{
Session: model.GuestSession{Name: "planner"},
BytesWritten: len(params.Payload),
}, nil
}
stdinPayload := `{"type":"steer","message":"Focus on src/"}` + "\n"
cmd := d.newRootCommand()
cmd.SetOut(io.Discard)
cmd.SetIn(strings.NewReader(stdinPayload))
cmd.SetArgs([]string{"vm", "session", "send", "devbox", "planner"})
if err := cmd.Execute(); err != nil {
t.Fatalf("Execute: %v", err)
}
if string(capturedPayload) != stdinPayload {
t.Fatalf("payload = %q, want %q", capturedPayload, stdinPayload)
}
}
func TestVMWorkspaceExportCommandExists(t *testing.T) {
root := NewBangerCommand()
vm, _, err := root.Find([]string{"vm"})

View file

@ -42,7 +42,6 @@ func (d *deps) newVMCommand() *cobra.Command {
d.newVMSetCommand(),
d.newVMSSHCommand(),
d.newVMWorkspaceCommand(),
d.newVMSessionCommand(),
d.newVMLogsCommand(),
d.newVMStatsCommand(),
d.newVMPortsCommand(),

View file

@ -1,370 +0,0 @@
package cli
import (
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"banger/internal/api"
"banger/internal/model"
"banger/internal/sessionstream"
"github.com/spf13/cobra"
)
func (d *deps) newVMSessionCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "session",
Short: "Manage long-lived guest commands inside a VM",
Long: "Start, inspect, stop, and attach to daemon-managed guest commands. Pipe-mode sessions expose live stdio for interactive protocols. Attach is exclusive and currently uses a same-host local bridge.",
RunE: helpNoArgs,
}
cmd.AddCommand(
d.newVMSessionStartCommand(),
d.newVMSessionListCommand(),
d.newVMSessionShowCommand(),
d.newVMSessionLogsCommand(),
d.newVMSessionStopCommand(),
d.newVMSessionKillCommand(),
d.newVMSessionAttachCommand(),
d.newVMSessionSendCommand(),
)
return cmd
}
func (d *deps) newVMSessionStartCommand() *cobra.Command {
var name string
var cwd string
var stdinMode string
var envPairs []string
var tagPairs []string
var requiredCommands []string
cmd := &cobra.Command{
Use: "start <id-or-name> <command> [args...]",
Short: "Start a managed guest command",
Long: "Start a daemon-managed guest command. The daemon verifies that the guest working directory exists and that the requested command is present in guest PATH before launch. Use --stdin-mode pipe when you need live attach.",
Args: minArgsUsage(2, "usage: banger vm session start <id-or-name> [flags] -- <command> [args...]"),
ValidArgsFunction: d.completeVMNameOnlyAtPos0,
Example: strings.TrimSpace(`
banger vm session start devbox --name planner --cwd /root/repo --stdin-mode pipe --require-command git -- pi --mode rpc --no-session
banger vm session start devbox --name shell --stdin-mode pipe -- bash -lc 'exec bash'
`),
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
env, err := parseKeyValuePairs(envPairs)
if err != nil {
return err
}
tags, err := parseKeyValuePairs(tagPairs)
if err != nil {
return err
}
result, err := d.guestSessionStart(cmd.Context(), layout.SocketPath, api.GuestSessionStartParams{
VMIDOrName: args[0],
Name: name,
Command: args[1],
Args: append([]string(nil), args[2:]...),
CWD: cwd,
Env: env,
StdinMode: stdinMode,
Tags: tags,
RequiredCommands: append([]string(nil), requiredCommands...),
})
if err != nil {
return err
}
if err := printGuestSessionSummary(cmd.OutOrStdout(), result.Session); err != nil {
return err
}
if result.Session.Status == model.GuestSessionStatusFailed && strings.TrimSpace(result.Session.LaunchMessage) != "" {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "warning: session failed at %s: %s\n", result.Session.LaunchStage, result.Session.LaunchMessage)
}
return nil
},
}
cmd.Flags().StringVar(&name, "name", "", "session name")
cmd.Flags().StringVar(&cwd, "cwd", "", "guest working directory; must already exist")
cmd.Flags().StringVar(&stdinMode, "stdin-mode", string(model.GuestSessionStdinClosed), "stdin mode: closed or pipe (pipe enables attach)")
cmd.Flags().StringArrayVar(&envPairs, "env", nil, "environment entry in KEY=VALUE form")
cmd.Flags().StringArrayVar(&tagPairs, "tag", nil, "session tag in KEY=VALUE form")
cmd.Flags().StringArrayVar(&requiredCommands, "require-command", nil, "extra guest command that must exist in PATH before launch; repeatable")
return cmd
}
func (d *deps) newVMSessionListCommand() *cobra.Command {
return &cobra.Command{
Use: "list <id-or-name>",
Aliases: []string{"ls"},
Short: "List managed guest commands for a VM",
Args: exactArgsUsage(1, "usage: banger vm session list <id-or-name>"),
ValidArgsFunction: d.completeVMNameOnlyAtPos0,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := d.guestSessionList(cmd.Context(), layout.SocketPath, args[0])
if err != nil {
return err
}
return printGuestSessionTable(cmd.OutOrStdout(), result.Sessions)
},
}
}
func (d *deps) newVMSessionShowCommand() *cobra.Command {
return &cobra.Command{
Use: "show <id-or-name> <session>",
Short: "Show managed guest command details",
Args: exactArgsUsage(2, "usage: banger vm session show <id-or-name> <session>"),
ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := d.guestSessionGet(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
if err != nil {
return err
}
return printJSON(cmd.OutOrStdout(), result.Session)
},
}
}
func (d *deps) newVMSessionLogsCommand() *cobra.Command {
var stream string
var tailLines int
cmd := &cobra.Command{
Use: "logs <id-or-name> <session>",
Short: "Show stdout or stderr for a guest session",
Args: exactArgsUsage(2, "usage: banger vm session logs [--stream stdout|stderr] [-n LINES] <id-or-name> <session>"),
ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := d.guestSessionLogs(cmd.Context(), layout.SocketPath, api.GuestSessionLogsParams{VMIDOrName: args[0], SessionIDOrName: args[1], Stream: stream, TailLines: tailLines})
if err != nil {
return err
}
_, err = fmt.Fprint(cmd.OutOrStdout(), result.Content)
return err
},
}
cmd.Flags().StringVar(&stream, "stream", "stdout", "log stream to read")
cmd.Flags().IntVarP(&tailLines, "lines", "n", 200, "number of lines to tail")
return cmd
}
func (d *deps) newVMSessionStopCommand() *cobra.Command {
return &cobra.Command{
Use: "stop <id-or-name> <session>",
Short: "Send SIGTERM to a guest session",
Args: exactArgsUsage(2, "usage: banger vm session stop <id-or-name> <session>"),
ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := d.guestSessionStop(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
if err != nil {
return err
}
return printGuestSessionSummary(cmd.OutOrStdout(), result.Session)
},
}
}
func (d *deps) newVMSessionKillCommand() *cobra.Command {
return &cobra.Command{
Use: "kill <id-or-name> <session>",
Short: "Send SIGKILL to a guest session",
Args: exactArgsUsage(2, "usage: banger vm session kill <id-or-name> <session>"),
ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := d.guestSessionKill(cmd.Context(), layout.SocketPath, api.GuestSessionRefParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
if err != nil {
return err
}
return printGuestSessionSummary(cmd.OutOrStdout(), result.Session)
},
}
}
func (d *deps) newVMSessionAttachCommand() *cobra.Command {
return &cobra.Command{
Use: "attach <id-or-name> <session>",
Short: "Attach local stdio to an attachable guest session",
Long: "Attach local stdio to a pipe-mode session through a daemon-created local Unix socket bridge. Only one active attach is allowed at a time, and the client must run on the same host as the daemon.",
Args: exactArgsUsage(2, "usage: banger vm session attach <id-or-name> <session>"),
ValidArgsFunction: d.completeSessionNames,
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
result, err := d.guestSessionAttachBegin(cmd.Context(), layout.SocketPath, api.GuestSessionAttachBeginParams{VMIDOrName: args[0], SessionIDOrName: args[1]})
if err != nil {
return err
}
socketPath := strings.TrimSpace(result.SocketPath)
if socketPath == "" && result.TransportKind == "unix_socket" {
socketPath = strings.TrimSpace(result.TransportTarget)
}
return runGuestSessionAttach(cmd.Context(), cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), socketPath)
},
}
}
func (d *deps) newVMSessionSendCommand() *cobra.Command {
var message string
cmd := &cobra.Command{
Use: "send <id-or-name> <session>",
Short: "Write bytes to a running guest session's stdin pipe",
Long: "Write a payload to the stdin pipe of a running pipe-mode guest session without holding the exclusive attach. Use --message for an inline JSONL string, or pipe bytes via stdin when --message is omitted. A trailing newline is appended to --message values that lack one.",
Args: exactArgsUsage(2, "usage: banger vm session send <id-or-name> <session> [--message '<json>']"),
ValidArgsFunction: d.completeSessionNames,
Example: strings.TrimSpace(`
banger vm session send devbox planner --message '{"type":"abort"}'
banger vm session send devbox planner --message '{"type":"steer","message":"Focus on src/"}'
echo '{"type":"prompt","prompt":"Summarize."}' | banger vm session send devbox planner
`),
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := d.ensureDaemon(cmd.Context())
if err != nil {
return err
}
var payload []byte
if message != "" {
payload = []byte(message)
if len(payload) > 0 && payload[len(payload)-1] != '\n' {
payload = append(payload, '\n')
}
} else {
payload, err = io.ReadAll(cmd.InOrStdin())
if err != nil {
return fmt.Errorf("read stdin: %w", err)
}
}
result, err := d.guestSessionSend(cmd.Context(), layout.SocketPath, api.GuestSessionSendParams{
VMIDOrName: args[0],
SessionIDOrName: args[1],
Payload: payload,
})
if err != nil {
return err
}
_, err = fmt.Fprintf(cmd.OutOrStdout(), "sent %d bytes to session %s\n", result.BytesWritten, result.Session.Name)
return err
},
}
cmd.Flags().StringVar(&message, "message", "", "JSONL message to send; a trailing newline is appended if absent")
return cmd
}
func parseKeyValuePairs(values []string) (map[string]string, error) {
if len(values) == 0 {
return nil, nil
}
result := make(map[string]string, len(values))
for _, value := range values {
key, raw, ok := strings.Cut(value, "=")
if !ok || strings.TrimSpace(key) == "" {
return nil, fmt.Errorf("invalid key=value entry %q", value)
}
result[strings.TrimSpace(key)] = raw
}
return result, nil
}
func runGuestSessionAttach(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, socketPath string) error {
conn, err := (&net.Dialer{}).DialContext(ctx, "unix", socketPath)
if err != nil {
return err
}
defer conn.Close()
writeErrCh := make(chan error, 1)
go func() {
writeErrCh <- streamGuestSessionAttachInput(conn, stdin)
}()
for {
channel, payload, err := sessionstream.ReadFrame(conn)
if err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
if errors.Is(err, io.EOF) {
return nil
}
return err
}
switch channel {
case sessionstream.ChannelStdout:
if _, err := stdout.Write(payload); err != nil {
return err
}
case sessionstream.ChannelStderr:
if _, err := stderr.Write(payload); err != nil {
return err
}
case sessionstream.ChannelControl:
message, err := sessionstream.ReadControl(payload)
if err != nil {
return err
}
switch message.Type {
case "exit":
if message.ExitCode != nil && *message.ExitCode != 0 {
return fmt.Errorf("guest session exited with code %d", *message.ExitCode)
}
return nil
case "error":
if strings.TrimSpace(message.Error) == "" {
return errors.New("guest session attach failed")
}
return errors.New(message.Error)
}
}
select {
case err := <-writeErrCh:
if err != nil {
return err
}
default:
}
}
}
func streamGuestSessionAttachInput(conn net.Conn, stdin io.Reader) error {
if stdin == nil {
return sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "eof"})
}
buffer := make([]byte, 32*1024)
for {
n, err := stdin.Read(buffer)
if n > 0 {
if writeErr := sessionstream.WriteFrame(conn, sessionstream.ChannelStdin, buffer[:n]); writeErr != nil {
return writeErr
}
}
if err != nil {
if errors.Is(err, io.EOF) {
return sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "eof"})
}
return err
}
}
}

View file

@ -21,9 +21,9 @@ import (
// - Fail silently. Completion is advisory; any error path returns an
// empty suggestion list rather than propagating to the user.
// defaultCompletionLister + defaultCompletionSessionLister back the
// corresponding *deps fields; tests inject their own fakes via the
// struct instead of mutating package-level vars.
// defaultCompletionLister backs the *deps.completionLister field;
// tests inject their own fake via the struct instead of mutating
// package-level vars.
func defaultCompletionLister(ctx context.Context, socketPath, method string) ([]string, error) {
switch method {
case "vm.list":
@ -66,20 +66,6 @@ func defaultCompletionLister(ctx context.Context, socketPath, method string) ([]
return nil, nil
}
func defaultCompletionSessionLister(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) {
result, err := rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: vmIDOrName})
if err != nil {
return nil, err
}
names := make([]string, 0, len(result.Sessions))
for _, session := range result.Sessions {
if session.Name != "" {
names = append(names, session.Name)
}
}
return names, nil
}
// daemonSocketForCompletion returns the socket path IFF the daemon is
// already running. Returns "", false when no daemon is up — completion
// callers use this as the bail signal.
@ -177,25 +163,3 @@ func (d *deps) completeKernelNames(cmd *cobra.Command, args []string, toComplete
}
return filterPrefix(names, args, toComplete), cobra.ShellCompDirectiveNoFileComp
}
// completeSessionNames handles `... <vm> <session>` commands: pos 0
// completes VMs, pos 1 completes sessions owned by args[0], pos 2+ is
// silent.
func (d *deps) completeSessionNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) {
case 0:
return d.completeVMNames(cmd, args, toComplete)
case 1:
socket, ok := d.daemonSocketForCompletion(cmd.Context())
if !ok {
return nil, cobra.ShellCompDirectiveNoFileComp
}
names, err := d.completionSessionLister(cmd.Context(), socket, args[0])
if err != nil {
return nil, cobra.ShellCompDirectiveNoFileComp
}
return filterPrefix(names, nil, toComplete), cobra.ShellCompDirectiveNoFileComp
default:
return nil, cobra.ShellCompDirectiveNoFileComp
}
}

View file

@ -19,10 +19,7 @@ func stubCompletionSeams(
d *deps,
pingErr error,
names map[string][]string,
listErr error,
sessions map[string][]string,
sessionErr error,
) {
listErr error) {
t.Helper()
d.daemonPing = func(ctx context.Context, socketPath string) (api.PingResult, error) {
@ -37,12 +34,6 @@ func stubCompletionSeams(
}
return names[method], nil
}
d.completionSessionLister = func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error) {
if sessionErr != nil {
return nil, sessionErr
}
return sessions[vmIDOrName], nil
}
}
func TestFilterPrefix(t *testing.T) {
@ -82,7 +73,7 @@ func testCmdWithCtx() *cobra.Command {
func TestCompleteVMNamesHappyPath(t *testing.T) {
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil)
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil)
got, directive := d.completeVMNames(testCmdWithCtx(), nil, "")
if directive != cobra.ShellCompDirectiveNoFileComp {
@ -95,7 +86,7 @@ func TestCompleteVMNamesHappyPath(t *testing.T) {
func TestCompleteVMNamesDaemonDown(t *testing.T) {
d := defaultDeps()
stubCompletionSeams(t, d, errors.New("connection refused"), nil, nil, nil, nil)
stubCompletionSeams(t, d, errors.New("connection refused"), nil, nil)
got, directive := d.completeVMNames(testCmdWithCtx(), nil, "")
if len(got) != 0 {
@ -108,7 +99,7 @@ func TestCompleteVMNamesDaemonDown(t *testing.T) {
func TestCompleteVMNamesRPCError(t *testing.T) {
d := defaultDeps()
stubCompletionSeams(t, d, nil, nil, errors.New("rpc failed"), nil, nil)
stubCompletionSeams(t, d, nil, nil, errors.New("rpc failed"))
got, _ := d.completeVMNames(testCmdWithCtx(), nil, "")
if len(got) != 0 {
@ -118,7 +109,7 @@ func TestCompleteVMNamesRPCError(t *testing.T) {
func TestCompleteVMNamesExcludesAlreadyEntered(t *testing.T) {
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil, nil, nil)
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "gamma"}}, nil)
got, _ := d.completeVMNames(testCmdWithCtx(), []string{"alpha"}, "")
want := []string{"beta", "gamma"}
@ -129,7 +120,7 @@ func TestCompleteVMNamesExcludesAlreadyEntered(t *testing.T) {
func TestCompleteVMNamesPrefixFilter(t *testing.T) {
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "alphabet"}}, nil, nil, nil)
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha", "beta", "alphabet"}}, nil)
got, _ := d.completeVMNames(testCmdWithCtx(), nil, "alp")
want := []string{"alpha", "alphabet"}
@ -140,7 +131,7 @@ func TestCompleteVMNamesPrefixFilter(t *testing.T) {
func TestCompleteVMNameOnlyAtPos0(t *testing.T) {
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha"}}, nil, nil, nil)
stubCompletionSeams(t, d, nil, map[string][]string{"vm.list": {"alpha"}}, nil)
atPos0, _ := d.completeVMNameOnlyAtPos0(testCmdWithCtx(), nil, "")
if len(atPos0) != 1 || atPos0[0] != "alpha" {
@ -155,7 +146,7 @@ func TestCompleteVMNameOnlyAtPos0(t *testing.T) {
func TestCompleteImageNames(t *testing.T) {
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"debian-bookworm", "alpine"}}, nil, nil, nil)
stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"debian-bookworm", "alpine"}}, nil)
got, _ := d.completeImageNames(testCmdWithCtx(), nil, "")
if !reflect.DeepEqual(got, []string{"debian-bookworm", "alpine"}) {
@ -165,7 +156,7 @@ func TestCompleteImageNames(t *testing.T) {
func TestCompleteKernelNames(t *testing.T) {
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"kernel.list": {"generic-6.12"}}, nil, nil, nil)
stubCompletionSeams(t, d, nil, map[string][]string{"kernel.list": {"generic-6.12"}}, nil)
got, _ := d.completeKernelNames(testCmdWithCtx(), nil, "")
if len(got) != 1 || got[0] != "generic-6.12" {
@ -175,58 +166,10 @@ func TestCompleteKernelNames(t *testing.T) {
func TestCompleteImageNameOnlyAtPos0SilentAfterFirst(t *testing.T) {
d := defaultDeps()
stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"alpine"}}, nil, nil, nil)
stubCompletionSeams(t, d, nil, map[string][]string{"image.list": {"alpine"}}, nil)
after, _ := d.completeImageNameOnlyAtPos0(testCmdWithCtx(), []string{"alpine"}, "")
if len(after) != 0 {
t.Errorf("expected silence at pos 1+, got %v", after)
}
}
func TestCompleteSessionNames(t *testing.T) {
d := defaultDeps()
stubCompletionSeams(t, d,
nil,
map[string][]string{"vm.list": {"devbox"}},
nil,
map[string][]string{"devbox": {"planner", "worker"}},
nil,
)
// Position 0 → VMs.
vms, _ := d.completeSessionNames(testCmdWithCtx(), nil, "")
if len(vms) != 1 || vms[0] != "devbox" {
t.Errorf("pos 0: got %v", vms)
}
// Position 1 → sessions scoped to args[0].
sessions, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "")
if !reflect.DeepEqual(sessions, []string{"planner", "worker"}) {
t.Errorf("pos 1: got %v", sessions)
}
// Position 1 with prefix filter.
filtered, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "wor")
if len(filtered) != 1 || filtered[0] != "worker" {
t.Errorf("pos 1 prefix: got %v", filtered)
}
// Position 2+ silent.
past, _ := d.completeSessionNames(testCmdWithCtx(), []string{"devbox", "planner"}, "")
if len(past) != 0 {
t.Errorf("pos 2+: got %v", past)
}
}
func TestCompleteSessionNamesDaemonDown(t *testing.T) {
d := defaultDeps()
stubCompletionSeams(t, d, errors.New("down"), nil, nil, nil, nil)
got, directive := d.completeSessionNames(testCmdWithCtx(), []string{"devbox"}, "")
if len(got) != 0 {
t.Errorf("expected no suggestions when daemon down, got %v", got)
}
if directive != cobra.ShellCompDirectiveNoFileComp {
t.Errorf("directive = %d, want NoFileComp", directive)
}
}

View file

@ -31,36 +31,27 @@ import (
// validators) stay package-level because they hold no references to
// external systems.
type deps struct {
bangerdPath func() (string, error)
daemonExePath func(pid int) string
doctor func(ctx context.Context) (system.Report, error)
sshExec func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error
hostCommandOutput func(ctx context.Context, name string, args ...string) ([]byte, error)
vmHealth func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error)
vmSSH func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error)
vmDelete func(ctx context.Context, socketPath, idOrName string) error
vmList func(ctx context.Context, socketPath string) (api.VMListResult, error)
daemonPing func(ctx context.Context, socketPath string) (api.PingResult, error)
vmCreateBegin func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error)
vmCreateStatus func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error)
vmCreateCancel func(ctx context.Context, socketPath, operationID string) error
vmPorts func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error)
vmWorkspacePrepare func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error)
vmWorkspaceExport func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error)
guestSessionStart func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error)
guestSessionGet func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error)
guestSessionList func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error)
guestSessionStop func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error)
guestSessionKill func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error)
guestSessionLogs func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error)
guestSessionAttachBegin func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error)
guestSessionSend func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error)
guestWaitForSSH func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error
guestDial func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error)
buildVMRunToolingPlan func(ctx context.Context, repoRoot string) toolingplan.Plan
cwd func() (string, error)
completionLister func(ctx context.Context, socketPath, method string) ([]string, error)
completionSessionLister func(ctx context.Context, socketPath, vmIDOrName string) ([]string, error)
bangerdPath func() (string, error)
daemonExePath func(pid int) string
doctor func(ctx context.Context) (system.Report, error)
sshExec func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error
hostCommandOutput func(ctx context.Context, name string, args ...string) ([]byte, error)
vmHealth func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error)
vmSSH func(ctx context.Context, socketPath, idOrName string) (api.VMSSHResult, error)
vmDelete func(ctx context.Context, socketPath, idOrName string) error
vmList func(ctx context.Context, socketPath string) (api.VMListResult, error)
daemonPing func(ctx context.Context, socketPath string) (api.PingResult, error)
vmCreateBegin func(ctx context.Context, socketPath string, params api.VMCreateParams) (api.VMCreateBeginResult, error)
vmCreateStatus func(ctx context.Context, socketPath, operationID string) (api.VMCreateStatusResult, error)
vmCreateCancel func(ctx context.Context, socketPath, operationID string) error
vmPorts func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error)
vmWorkspacePrepare func(ctx context.Context, socketPath string, params api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error)
vmWorkspaceExport func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error)
guestWaitForSSH func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error
guestDial func(ctx context.Context, address, privateKeyPath string) (vmRunGuestClient, error)
buildVMRunToolingPlan func(ctx context.Context, repoRoot string) toolingplan.Plan
cwd func() (string, error)
completionLister func(ctx context.Context, socketPath, method string) ([]string, error)
}
func defaultDeps() *deps {
@ -125,30 +116,6 @@ func defaultDeps() *deps {
vmWorkspaceExport: func(ctx context.Context, socketPath string, params api.WorkspaceExportParams) (api.WorkspaceExportResult, error) {
return rpc.Call[api.WorkspaceExportResult](ctx, socketPath, "vm.workspace.export", params)
},
guestSessionStart: func(ctx context.Context, socketPath string, params api.GuestSessionStartParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.start", params)
},
guestSessionGet: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.get", params)
},
guestSessionList: func(ctx context.Context, socketPath, idOrName string) (api.GuestSessionListResult, error) {
return rpc.Call[api.GuestSessionListResult](ctx, socketPath, "guest.session.list", api.VMRefParams{IDOrName: idOrName})
},
guestSessionStop: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.stop", params)
},
guestSessionKill: func(ctx context.Context, socketPath string, params api.GuestSessionRefParams) (api.GuestSessionShowResult, error) {
return rpc.Call[api.GuestSessionShowResult](ctx, socketPath, "guest.session.kill", params)
},
guestSessionLogs: func(ctx context.Context, socketPath string, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) {
return rpc.Call[api.GuestSessionLogsResult](ctx, socketPath, "guest.session.logs", params)
},
guestSessionAttachBegin: func(ctx context.Context, socketPath string, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) {
return rpc.Call[api.GuestSessionAttachBeginResult](ctx, socketPath, "guest.session.attach.begin", params)
},
guestSessionSend: func(ctx context.Context, socketPath string, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
return rpc.Call[api.GuestSessionSendResult](ctx, socketPath, "guest.session.send", params)
},
guestWaitForSSH: func(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
knownHosts, _ := bangerKnownHostsPath()
return guest.WaitForSSH(ctx, address, privateKeyPath, knownHosts, interval)
@ -157,9 +124,8 @@ func defaultDeps() *deps {
knownHosts, _ := bangerKnownHostsPath()
return guest.Dial(ctx, address, privateKeyPath, knownHosts)
},
buildVMRunToolingPlan: toolingplan.Build,
cwd: os.Getwd,
completionLister: defaultCompletionLister,
completionSessionLister: defaultCompletionSessionLister,
buildVMRunToolingPlan: toolingplan.Build,
cwd: os.Getwd,
completionLister: defaultCompletionLister,
}
}

View file

@ -4,7 +4,6 @@ import (
"bytes"
"errors"
"fmt"
"reflect"
"strings"
"testing"
@ -52,37 +51,6 @@ func TestDashIfEmpty(t *testing.T) {
}
}
func TestParseKeyValuePairs(t *testing.T) {
t.Run("nil when empty", func(t *testing.T) {
got, err := parseKeyValuePairs(nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != nil {
t.Fatalf("got %v, want nil", got)
}
})
t.Run("parses entries", func(t *testing.T) {
got, err := parseKeyValuePairs([]string{"a=1", " b = two", "c=x=y"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := map[string]string{"a": "1", "b": " two", "c": "x=y"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("got %v, want %v", got, want)
}
})
t.Run("rejects malformed entries", func(t *testing.T) {
for _, bad := range []string{"noequals", "=noKey", " =v"} {
if _, err := parseKeyValuePairs([]string{bad}); err == nil {
t.Errorf("expected error for %q", bad)
}
}
})
}
func TestExitCodeErrorError(t *testing.T) {
e := ExitCodeError{Code: 42}
got := e.Error()
@ -234,38 +202,6 @@ func TestPrintKernelCatalogTable(t *testing.T) {
}
}
func TestPrintGuestSessionTable(t *testing.T) {
var buf bytes.Buffer
sessions := []model.GuestSession{
{ID: "abcdef0123456789", Name: "planner", Status: "running", Command: "pi", CWD: "/root/repo", Attachable: true},
{ID: "short", Name: "once", Status: "exited", Command: "true", CWD: "/tmp", Attachable: false},
}
if err := printGuestSessionTable(&buf, sessions); err != nil {
t.Fatalf("printGuestSessionTable: %v", err)
}
got := buf.String()
for _, want := range []string{"ID", "NAME", "planner", "once", "yes", "no", "pi"} {
if !strings.Contains(got, want) {
t.Errorf("output missing %q:\n%s", want, got)
}
}
}
func TestPrintGuestSessionSummary(t *testing.T) {
var buf bytes.Buffer
session := model.GuestSession{
ID: "id1", Name: "s", Status: "exited", Command: "true", CWD: "/root",
}
if err := printGuestSessionSummary(&buf, session); err != nil {
t.Fatalf("printGuestSessionSummary: %v", err)
}
got := buf.String()
fields := strings.Split(strings.TrimRight(got, "\n"), "\t")
if len(fields) != 5 {
t.Fatalf("expected 5 tab-separated fields, got %d: %q", len(fields), got)
}
}
func TestPrintJSON(t *testing.T) {
var buf bytes.Buffer
if err := printJSON(&buf, map[string]int{"a": 1, "b": 2}); err != nil {
@ -340,10 +276,6 @@ type failWriter struct{}
func (failWriter) Write([]byte) (int, error) { return 0, fmt.Errorf("boom") }
func TestPrintersPropagateWriteErrors(t *testing.T) {
sessions := []model.GuestSession{{ID: "id", Name: "n"}}
if err := printGuestSessionTable(failWriter{}, sessions); err == nil {
t.Error("expected write error from printGuestSessionTable")
}
kernels := []api.KernelEntry{{Name: "k"}}
if err := printKernelListTable(failWriter{}, kernels); err == nil {
t.Error("expected write error from printKernelListTable")

View file

@ -3,7 +3,6 @@ package cli
import (
"encoding/json"
"fmt"
"io"
"os"
"sort"
"strings"
@ -276,30 +275,6 @@ func printKernelCatalogTable(out anyWriter, entries []api.KernelCatalogEntry) er
return w.Flush()
}
// -- guest session printers -----------------------------------------
func printGuestSessionSummary(out anyWriter, session model.GuestSession) error {
_, err := fmt.Fprintf(out, "%s\t%s\t%s\t%s\t%s\n", session.ID, session.Name, session.Status, session.Command, session.CWD)
return err
}
func printGuestSessionTable(out io.Writer, sessions []model.GuestSession) error {
tw := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0)
if _, err := fmt.Fprintln(tw, "ID\tNAME\tSTATUS\tATTACH\tCOMMAND\tCWD"); err != nil {
return err
}
for _, session := range sessions {
attach := "no"
if session.Attachable {
attach = "yes"
}
if _, err := fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n", shortID(session.ID), session.Name, session.Status, attach, session.Command, session.CWD); err != nil {
return err
}
}
return tw.Flush()
}
// -- doctor printer -------------------------------------------------
func printDoctorReport(out anyWriter, report system.Report) error {

View file

@ -32,13 +32,11 @@ owning types:
- `createOps opstate.Registry[*vmCreateOperationState]` — in-flight VM
create operations; owns its own lock.
- `tapPool tapPool` — TAP interface pool; owns its own lock.
- `sessions sessionRegistry` — active guest session controllers; owns
its own lock.
- `listener`, `vmDNS` — networking.
- `vmCaps` — registered VM capability hooks.
- `pullAndFlatten`, `finalizePulledRootfs`, `bundleFetch`,
`requestHandler`, `guestWaitForSSH`, `guestDial`,
`waitForGuestSessionReady` — injectable seams used by tests.
`workspaceInspectRepo`, `workspaceImport` — injectable seams used by tests.
## Subpackages
@ -53,11 +51,9 @@ state beyond small test seams.
| `internal/daemon/dmsnap` | Device-mapper COW snapshot create/cleanup/remove. |
| `internal/daemon/fcproc` | Firecracker process primitives (bridge, tap, binary, PID, kill, wait). |
| `internal/daemon/imagemgr` | Image subsystem pure helpers: validators, staging, build script gen. |
| `internal/daemon/session` | Guest-session helpers: state paths, scripts, parsing, utilities. |
| `internal/daemon/workspace` | Workspace helpers: git inspection, copy prep, guest import script. |
`workspace` imports `session` for `ShellQuote` and `FormatStepError`; all
other subpackages are leaves (no other intra-daemon subpackage imports).
All subpackages are leaves — no intra-daemon subpackage imports another.
## Lock ordering
@ -73,9 +69,8 @@ time. `workspace.prepare` acquires `vmLocks[id]` just long enough to
validate VM state, releases it, then acquires `workspaceLocks[id]`
for the guest I/O phase.
Subsystem-local locks (`tapPool.mu`, `sessionRegistry.mu`,
`opstate.Registry` mu, `guestSessionController.attachMu` /
`writeMu`) are leaves. They do not contend with each other.
Subsystem-local locks (`tapPool.mu`, `opstate.Registry` mu) are leaves.
They do not contend with each other.
Notes:

View file

@ -51,24 +51,22 @@ type Daemon struct {
// See internal/daemon/vm_handles.go — persistent durable state
// lives in the store, this is rebuildable from a per-VM
// handles.json scratch file and OS inspection.
handles *handleCache
sessions sessionRegistry
tapPool tapPool
closing chan struct{}
once sync.Once
pid int
listener net.Listener
vmDNS *vmdns.Server
vmCaps []vmCapability
pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error)
finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error
bundleFetch func(ctx context.Context, destDir string, entry imagecat.CatEntry) (imagecat.Manifest, error)
requestHandler func(context.Context, rpc.Request) rpc.Response
guestWaitForSSH func(context.Context, string, string, time.Duration) error
guestDial func(context.Context, string, string) (guestSSHClient, error)
waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error)
workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error)
workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error
handles *handleCache
tapPool tapPool
closing chan struct{}
once sync.Once
pid int
listener net.Listener
vmDNS *vmdns.Server
vmCaps []vmCapability
pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error)
finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error
bundleFetch func(ctx context.Context, destDir string, entry imagecat.CatEntry) (imagecat.Manifest, error)
requestHandler func(context.Context, rpc.Request) rpc.Response
guestWaitForSSH func(context.Context, string, string, time.Duration) error
guestDial func(context.Context, string, string) (guestSSHClient, error)
workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error)
workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error
}
func Open(ctx context.Context) (d *Daemon, err error) {
@ -93,15 +91,14 @@ func Open(ctx context.Context) (d *Daemon, err error) {
return nil, err
}
d = &Daemon{
layout: layout,
config: cfg,
store: db,
runner: system.NewRunner(),
logger: logger,
closing: make(chan struct{}),
pid: os.Getpid(),
handles: newHandleCache(),
sessions: newSessionRegistry(),
layout: layout,
config: cfg,
store: db,
runner: system.NewRunner(),
logger: logger,
closing: make(chan struct{}),
pid: os.Getpid(),
handles: newHandleCache(),
}
// From here on, every failure path must run Close() so the host
// state we touched (DNS listener goroutine, resolvectl routing,
@ -144,7 +141,7 @@ func (d *Daemon) Close() error {
if d.listener != nil {
_ = d.listener.Close()
}
err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.closeGuestSessionControllers(), d.store.Close())
err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.store.Close())
})
return err
}
@ -424,62 +421,6 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
}
result, err := d.ExportVMWorkspace(ctx, params)
return marshalResultOrError(result, err)
case "guest.session.start":
params, err := rpc.DecodeParams[api.GuestSessionStartParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.StartGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.get":
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.GetGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.list":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
sessions, err := d.ListGuestSessions(ctx, params)
return marshalResultOrError(api.GuestSessionListResult{Sessions: sessions}, err)
case "guest.session.stop":
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.StopGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.kill":
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.KillGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.logs":
params, err := rpc.DecodeParams[api.GuestSessionLogsParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.GuestSessionLogs(ctx, params)
return marshalResultOrError(result, err)
case "guest.session.attach.begin":
params, err := rpc.DecodeParams[api.GuestSessionAttachBeginParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.BeginGuestSessionAttach(ctx, params)
return marshalResultOrError(result, err)
case "guest.session.send":
params, err := rpc.DecodeParams[api.GuestSessionSendParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.SendToGuestSession(ctx, params)
return marshalResultOrError(result, err)
case "image.list":
images, err := d.store.ListImages(ctx)
return marshalResultOrError(api.ImageListResult{Images: images}, err)

View file

@ -1,8 +1,8 @@
// Package daemon hosts the Banger daemon process.
//
// The daemon exposes a JSON-RPC endpoint over a Unix socket. It owns VM
// lifecycle, image management, guest sessions, host networking bootstrap,
// and state persistence via internal/store.
// lifecycle, image management, host networking bootstrap, and state
// persistence via internal/store.
//
// The package is organised into cohesive groups. Pure stateless helpers for
// each group have been lifted into subpackages; orchestrator methods
@ -18,13 +18,9 @@
// internal/daemon/imagemgr Image subsystem helpers: path validation,
// artifact staging, guest provisioning script
// generator, metadata.
// internal/daemon/session Guest-session helpers: state paths, runner
// / inspect / signal scripts, state snapshot
// parsing, launch helpers, ShellQuote,
// FormatStepError.
// internal/daemon/workspace Workspace helpers: git repo inspection,
// shallow copy prep, guest-side import,
// finalize script generation.
// finalize script generation, shell quoting.
//
// VM lifecycle (in this package):
//
@ -50,11 +46,7 @@
//
// Guest interaction (in this package):
//
// guest_sessions.go dialGuest, waitForGuestSSH, refresh/inspect
// session_lifecycle.go Start/Stop/Kill/Get/List/signal orchestrators
// session_attach.go BeginGuestSessionAttach + bridge/forward/watch
// session_stream.go GuestSessionLogs, SendToGuestSession
// session_controller.go guestSessionController, sessionRegistry
// guest_ssh.go guestSSHClient, dialGuest, waitForGuestSSH
// ssh_client_config.go daemon-managed SSH client key material
// workspace.go ExportVMWorkspace, PrepareVMWorkspace
//
@ -73,10 +65,8 @@
//
// Lock ordering:
//
// vmLocks[id] → {createVMMu, imageOpsMu} → subsystem-local locks
// vmLocks[id] → workspaceLocks[id] → {createVMMu, imageOpsMu} → subsystem-local locks
//
// Subsystem-local locks live on their owning type (tapPool.mu,
// sessionRegistry.mu, opstate.Registry mu, guestSessionController.attachMu /
// writeMu) and do not contend with each other. See ARCHITECTURE.md for
// details.
// Subsystem-local locks (tapPool.mu, opstate.Registry mu) are leaves and
// do not contend with each other. See ARCHITECTURE.md for details.
package daemon

View file

@ -0,0 +1,26 @@
package daemon
import (
"fmt"
"os/exec"
"testing"
)
// startFakeFirecracker launches a bash sleep-loop rewritten to match
// the firecracker command line a real process would expose, so
// reconcile / handle-cache paths that grep /proc/<pid>/cmdline accept
// it as a firecracker process. Killed on test cleanup.
func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd {
t.Helper()
cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
if err := cmd.Start(); err != nil {
t.Fatalf("start fake firecracker: %v", err)
}
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
}
})
return cmd
}

View file

@ -1,142 +0,0 @@
package daemon
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"time"
"banger/internal/daemon/session"
"banger/internal/guest"
"banger/internal/model"
"banger/internal/system"
)
type guestSSHClient interface {
Close() error
RunScript(context.Context, string, io.Writer) error
RunScriptOutput(context.Context, string) ([]byte, error)
UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error
StreamTar(context.Context, string, string, io.Writer) error
StreamTarEntries(context.Context, string, []string, string, io.Writer) error
}
func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval time.Duration) error {
if d != nil && d.guestWaitForSSH != nil {
return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
}
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval)
}
func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) {
if d != nil && d.guestDial != nil {
return d.guestDial(ctx, address, d.config.SSHKeyPath)
}
return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
}
func (d *Daemon) waitForGuestSessionReadyHook(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
if d != nil && d.waitForGuestSessionReady != nil {
return d.waitForGuestSessionReady(ctx, vm, s)
}
return d.waitForGuestSessionReadyDefault(ctx, vm, s)
}
func (d *Daemon) waitForGuestSessionReadyDefault(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
for {
updated, err := d.refreshGuestSession(ctx, vm, s)
if err == nil {
s = updated
if s.GuestPID != 0 || s.ExitCode != nil || s.Status == model.GuestSessionStatusRunning || s.Status == model.GuestSessionStatusFailed || s.Status == model.GuestSessionStatusExited {
return s, nil
}
}
select {
case <-ctx.Done():
return s, ctx.Err()
case <-time.After(100 * time.Millisecond):
}
}
}
func (d *Daemon) refreshGuestSession(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
if s.Status != model.GuestSessionStatusStarting && s.Status != model.GuestSessionStatusRunning && s.Status != model.GuestSessionStatusStopping {
return s, nil
}
snapshot, err := d.inspectGuestSessionState(ctx, vm, s)
if err != nil {
return s, err
}
original := s
session.ApplyStateSnapshot(&s, snapshot, d.vmAlive(vm))
if session.StateChanged(original, s) {
s.UpdatedAt = model.Now()
if err := d.store.UpsertGuestSession(ctx, s); err != nil {
return s, err
}
}
return s, nil
}
func (d *Daemon) inspectGuestSessionState(ctx context.Context, vm model.VMRecord, s model.GuestSession) (session.StateSnapshot, error) {
if d.vmAlive(vm) {
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return session.StateSnapshot{}, err
}
defer client.Close()
var output bytes.Buffer
if err := client.RunScript(ctx, session.InspectScript(s.ID), &output); err != nil {
return session.StateSnapshot{}, session.FormatStepError("inspect guest session state", err, output.String())
}
return session.ParseState(output.String())
}
return d.inspectGuestSessionStateFromWorkDisk(ctx, vm, s.ID)
}
func (d *Daemon) inspectGuestSessionStateFromWorkDisk(ctx context.Context, vm model.VMRecord, sessionID string) (session.StateSnapshot, error) {
runner := d.runner
if runner == nil {
runner = system.NewRunner()
}
workMount, cleanup, err := system.MountTempDir(ctx, runner, vm.Runtime.WorkDiskPath, false)
if err != nil {
return session.StateSnapshot{}, err
}
defer cleanup()
stateDir := filepath.Join(workMount, session.RelativeStateDir(sessionID))
return session.InspectStateFromDir(stateDir)
}
func (d *Daemon) findGuestSession(ctx context.Context, vmID, idOrName string) (model.GuestSession, error) {
if strings.TrimSpace(idOrName) == "" {
return model.GuestSession{}, errors.New("session id or name is required")
}
if s, err := d.store.GetGuestSession(ctx, vmID, idOrName); err == nil {
return s, nil
}
sessions, err := d.store.ListGuestSessionsByVM(ctx, vmID)
if err != nil {
return model.GuestSession{}, err
}
matches := make([]model.GuestSession, 0, 1)
for _, s := range sessions {
if strings.HasPrefix(s.ID, idOrName) || strings.HasPrefix(s.Name, idOrName) {
matches = append(matches, s)
}
}
switch len(matches) {
case 0:
return model.GuestSession{}, fmt.Errorf("session %q not found", idOrName)
case 1:
return matches[0], nil
default:
return model.GuestSession{}, fmt.Errorf("multiple sessions match %q", idOrName)
}
}

View file

@ -1,492 +0,0 @@
package daemon
import (
"context"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
"banger/internal/api"
sess "banger/internal/daemon/session"
"banger/internal/model"
"banger/internal/store"
)
type fakeGuestSSHClient struct {
t *testing.T
existingDirs map[string]bool
closed bool
}
func (f *fakeGuestSSHClient) Close() error {
f.closed = true
return nil
}
func (f *fakeGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error {
f.t.Helper()
switch {
case strings.Contains(script, `\n`):
return fmt.Errorf("script still contains escaped newline literals: %q", script)
case strings.Contains(script, `echo "missing cwd: $DIR"`):
if strings.Contains(script, "DIR='/root/repo'\n") && f.existingDirs["/root/repo"] {
return nil
}
return fmt.Errorf("missing cwd")
case strings.Contains(script, "check_command() {"):
return nil
case strings.Contains(script, `git config --global --add safe.directory "$DIR"`):
if strings.Contains(script, "DIR='/root/repo'\n") {
f.existingDirs["/root/repo"] = true
return nil
}
return fmt.Errorf("workspace finalize used unexpected guest path")
case strings.Contains(script, "chmod -R a-w"):
if f.existingDirs["/root/repo"] {
return nil
}
return fmt.Errorf("workspace path missing during readonly chmod")
case strings.Contains(script, "nohup bash "):
return nil
default:
return nil
}
}
func (f *fakeGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) {
return nil, nil
}
func (f *fakeGuestSSHClient) UploadFile(_ context.Context, _ string, _ os.FileMode, _ []byte, _ io.Writer) error {
return nil
}
func (f *fakeGuestSSHClient) StreamTar(_ context.Context, _ string, command string, _ io.Writer) error {
if strings.Contains(command, "/root/repo") {
f.existingDirs["/root/repo"] = true
return nil
}
return fmt.Errorf("unexpected StreamTar command: %s", command)
}
func (f *fakeGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, command string, _ io.Writer) error {
if strings.Contains(command, "/root/repo") {
f.existingDirs["/root/repo"] = true
return nil
}
return fmt.Errorf("unexpected StreamTarEntries command: %s", command)
}
func TestSendToGuestSession_HappyPath(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := openDaemonStore(t)
apiSock := filepath.Join(t.TempDir(), "fc.sock")
firecracker := startFakeFirecracker(t, apiSock)
vm := testVM("sendbox", "image-send", "172.16.0.88")
vm.State = model.VMStateRunning
vm.Runtime.State = model.VMStateRunning
vm.Runtime.APISockPath = apiSock
upsertDaemonVM(t, ctx, db, vm)
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
if err := db.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
fake := &recordingGuestSSHClient{}
d := newSendTestDaemon(t, db, fake)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
payload := []byte(`{"type":"abort"}` + "\n")
result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
VMIDOrName: vm.Name,
SessionIDOrName: session.Name,
Payload: payload,
})
if err != nil {
t.Fatalf("SendToGuestSession: %v", err)
}
if result.BytesWritten != len(payload) {
t.Fatalf("BytesWritten = %d, want %d", result.BytesWritten, len(payload))
}
if result.Session.ID != session.ID {
t.Fatalf("Session.ID = %q, want %q", result.Session.ID, session.ID)
}
if len(fake.uploadedFiles) != 1 {
t.Fatalf("UploadFile call count = %d, want 1", len(fake.uploadedFiles))
}
for path, data := range fake.uploadedFiles {
if !strings.HasPrefix(path, "/tmp/banger-send-") {
t.Fatalf("upload path = %q, want /tmp/banger-send-... prefix", path)
}
if string(data) != string(payload) {
t.Fatalf("upload data = %q, want %q", data, payload)
}
}
if len(fake.ranScripts) != 1 {
t.Fatalf("RunScript call count = %d, want 1", len(fake.ranScripts))
}
script := fake.ranScripts[0]
pipePath := sess.StdinPipePath(session.ID)
if !strings.Contains(script, "cat ") {
t.Fatalf("send script missing cat command: %q", script)
}
if !strings.Contains(script, pipePath) {
t.Fatalf("send script missing pipe path %q: %q", pipePath, script)
}
if !strings.Contains(script, "rm -f ") {
t.Fatalf("send script missing rm cleanup: %q", script)
}
}
func TestSendToGuestSession_EmptyPayload(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := openDaemonStore(t)
apiSock := filepath.Join(t.TempDir(), "fc.sock")
firecracker := startFakeFirecracker(t, apiSock)
vm := testVM("sendbox-empty", "image-send", "172.16.0.89")
vm.State = model.VMStateRunning
vm.Runtime.State = model.VMStateRunning
vm.Runtime.APISockPath = apiSock
upsertDaemonVM(t, ctx, db, vm)
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
if err := db.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
fake := &recordingGuestSSHClient{}
d := newSendTestDaemon(t, db, fake)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
VMIDOrName: vm.Name,
SessionIDOrName: session.Name,
Payload: nil,
})
if err != nil {
t.Fatalf("SendToGuestSession(empty): %v", err)
}
if result.BytesWritten != 0 {
t.Fatalf("BytesWritten = %d, want 0", result.BytesWritten)
}
if fake.dialCount != 0 {
t.Fatalf("SSH dial count = %d, want 0 for empty payload", fake.dialCount)
}
}
func TestSendToGuestSession_NotPipeMode(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := openDaemonStore(t)
vm := testVM("sendbox-closed", "image-send", "172.16.0.90")
vm.State = model.VMStateRunning
upsertDaemonVM(t, ctx, db, vm)
session := testGuestSession(vm.ID, model.GuestSessionStdinClosed, model.GuestSessionStatusRunning)
if err := db.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
d := &Daemon{store: db}
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
VMIDOrName: vm.Name,
SessionIDOrName: session.Name,
Payload: []byte("hello\n"),
})
if err == nil || !strings.Contains(err.Error(), "stdin pipe") {
t.Fatalf("error = %v, want 'stdin pipe' error", err)
}
}
func TestSendToGuestSession_SessionNotRunning(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := openDaemonStore(t)
vm := testVM("sendbox-failed", "image-send", "172.16.0.91")
vm.State = model.VMStateRunning
upsertDaemonVM(t, ctx, db, vm)
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusFailed)
if err := db.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
d := &Daemon{store: db}
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
VMIDOrName: vm.Name,
SessionIDOrName: session.Name,
Payload: []byte("hello\n"),
})
if err == nil || !strings.Contains(err.Error(), "not running") {
t.Fatalf("error = %v, want 'not running' error", err)
}
}
func TestSendToGuestSession_VMNotRunning(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := openDaemonStore(t)
vm := testVM("sendbox-stopped", "image-send", "172.16.0.92")
vm.State = model.VMStateStopped
upsertDaemonVM(t, ctx, db, vm)
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
if err := db.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
d := &Daemon{store: db}
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
VMIDOrName: vm.Name,
SessionIDOrName: session.Name,
Payload: []byte("hello\n"),
})
if err == nil || !strings.Contains(err.Error(), "not running") {
t.Fatalf("error = %v, want 'not running' error", err)
}
}
// recordingGuestSSHClient captures UploadFile and RunScript calls for send tests.
type recordingGuestSSHClient struct {
dialCount int
uploadedFiles map[string][]byte
ranScripts []string
}
func (r *recordingGuestSSHClient) Close() error { return nil }
func (r *recordingGuestSSHClient) UploadFile(_ context.Context, path string, _ os.FileMode, data []byte, _ io.Writer) error {
if r.uploadedFiles == nil {
r.uploadedFiles = make(map[string][]byte)
}
copy := make([]byte, len(data))
_ = copy[:len(data):len(data)]
for i, b := range data {
copy[i] = b
}
r.uploadedFiles[path] = copy
return nil
}
func (r *recordingGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error {
r.ranScripts = append(r.ranScripts, script)
return nil
}
func (r *recordingGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) {
return nil, nil
}
func (r *recordingGuestSSHClient) StreamTar(_ context.Context, _ string, _ string, _ io.Writer) error {
return nil
}
func (r *recordingGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, _ string, _ io.Writer) error {
return nil
}
func newSendTestDaemon(t *testing.T, db *store.Store, fake *recordingGuestSSHClient) *Daemon {
t.Helper()
d := &Daemon{
store: db,
config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")},
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
}
d.guestDial = func(_ context.Context, _ string, _ string) (guestSSHClient, error) {
fake.dialCount++
return fake, nil
}
return d
}
func testGuestSession(vmID string, stdinMode model.GuestSessionStdinMode, status model.GuestSessionStatus) model.GuestSession {
now := model.Now()
id := vmID + "-sess-id"
return model.GuestSession{
ID: id,
VMID: vmID,
Name: vmID + "-sess",
Backend: sess.BackendSSH,
Command: "pi",
Args: []string{"--mode", "rpc"},
CWD: "/root/repo",
StdinMode: stdinMode,
Status: status,
GuestStateDir: sess.StateDir(id),
StdoutLogPath: sess.StdoutLogPath(id),
StderrLogPath: sess.StderrLogPath(id),
Attachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning,
Reattachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning,
CreatedAt: now,
UpdatedAt: now,
}
}
func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd {
t.Helper()
cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
if err := cmd.Start(); err != nil {
t.Fatalf("start fake firecracker: %v", err)
}
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
}
})
return cmd
}
func TestGuestSessionPreflightScriptsUseRealNewlines(t *testing.T) {
t.Parallel()
cwdScript := sess.CWDPreflightScript("/root/repo")
if strings.Contains(cwdScript, `\n`) {
t.Fatalf("cwd preflight script still contains escaped newline literals: %q", cwdScript)
}
if !strings.Contains(cwdScript, "\n") {
t.Fatalf("cwd preflight script should contain real newlines: %q", cwdScript)
}
commandScript := sess.CommandPreflightScript([]string{"git", "pi"})
if strings.Contains(commandScript, `\n`) {
t.Fatalf("command preflight script still contains escaped newline literals: %q", commandScript)
}
if !strings.Contains(commandScript, "\n") {
t.Fatalf("command preflight script should contain real newlines: %q", commandScript)
}
attachInput := sess.AttachInputCommand("session-id")
if strings.Contains(attachInput, `\n`) {
t.Fatalf("attach input command still contains escaped newline literals: %q", attachInput)
}
attachTail := sess.AttachTailCommand("/tmp/stdout.log")
if strings.Contains(attachTail, `\n`) {
t.Fatalf("attach tail command still contains escaped newline literals: %q", attachTail)
}
}
func TestPrepareWorkspaceThenStartGuestSessionPassesCWDPreflight(t *testing.T) {
ctx := context.Background()
db := openDaemonStore(t)
repoRoot := filepath.Join(t.TempDir(), "repo")
if err := os.MkdirAll(repoRoot, 0o755); err != nil {
t.Fatalf("MkdirAll(repoRoot): %v", err)
}
if err := os.WriteFile(filepath.Join(repoRoot, "README.md"), []byte("hello\n"), 0o644); err != nil {
t.Fatalf("WriteFile(README.md): %v", err)
}
runGit := func(args ...string) {
t.Helper()
cmd := exec.Command("git", append([]string{"-C", repoRoot}, args...)...)
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("git %s: %v\n%s", strings.Join(args, " "), err, output)
}
}
runGit("init", "-b", "main")
runGit("config", "user.name", "Test User")
runGit("config", "user.email", "test@example.com")
runGit("add", ".")
runGit("commit", "-m", "initial")
apiSock := filepath.Join(t.TempDir(), "fc.sock")
firecracker := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
if err := firecracker.Start(); err != nil {
t.Fatalf("start fake firecracker: %v", err)
}
t.Cleanup(func() {
if firecracker.Process != nil {
_ = firecracker.Process.Kill()
_, _ = firecracker.Process.Wait()
}
})
vm := testVM("pi-devbox", "image-pi", "172.16.0.77")
vm.State = model.VMStateRunning
vm.Runtime.State = model.VMStateRunning
vm.Runtime.APISockPath = apiSock
upsertDaemonVM(t, ctx, db, vm)
fakeClient := &fakeGuestSSHClient{t: t, existingDirs: map[string]bool{}}
d := &Daemon{
store: db,
config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")},
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
}
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
d.guestDial = func(context.Context, string, string) (guestSSHClient, error) { return fakeClient, nil }
d.waitForGuestSessionReady = func(_ context.Context, _ model.VMRecord, session model.GuestSession) (model.GuestSession, error) {
now := model.Now()
session.Status = model.GuestSessionStatusRunning
session.GuestPID = 4242
session.StartedAt = now
session.UpdatedAt = now
session.Attachable = session.StdinMode == model.GuestSessionStdinPipe
session.Reattachable = session.StdinMode == model.GuestSessionStdinPipe
return session, nil
}
workspace, err := d.PrepareVMWorkspace(ctx, api.VMWorkspacePrepareParams{
IDOrName: vm.Name,
SourcePath: repoRoot,
GuestPath: "/root/repo",
ReadOnly: true,
})
if err != nil {
t.Fatalf("PrepareVMWorkspace: %v", err)
}
if workspace.GuestPath != "/root/repo" {
t.Fatalf("PrepareVMWorkspace guest path = %q, want /root/repo", workspace.GuestPath)
}
if !fakeClient.existingDirs["/root/repo"] {
t.Fatalf("workspace prepare did not mark /root/repo as present in fake guest")
}
session, err := d.StartGuestSession(ctx, api.GuestSessionStartParams{
VMIDOrName: vm.Name,
Name: "testpi",
Command: "pi",
Args: []string{"--mode", "rpc", "--no-session"},
CWD: "/root/repo",
StdinMode: string(model.GuestSessionStdinPipe),
RequiredCommands: []string{"git"},
})
if err != nil {
t.Fatalf("StartGuestSession: %v", err)
}
if session.Status != model.GuestSessionStatusRunning {
t.Fatalf("session status = %q, want %q", session.Status, model.GuestSessionStatusRunning)
}
if session.LaunchStage != "" {
t.Fatalf("session launch stage = %q, want empty", session.LaunchStage)
}
if session.LaunchMessage != "" {
t.Fatalf("session launch message = %q, want empty", session.LaunchMessage)
}
if session.GuestPID == 0 {
t.Fatalf("session guest pid = 0, want non-zero")
}
if !session.Attachable {
t.Fatalf("session should be attachable for pipe stdin mode")
}
}

View file

@ -0,0 +1,35 @@
package daemon
import (
"context"
"io"
"os"
"time"
"banger/internal/guest"
)
// guestSSHClient is the narrow guest-SSH surface the daemon uses for
// workspace prepare / export and ad-hoc guest interactions.
type guestSSHClient interface {
Close() error
RunScript(context.Context, string, io.Writer) error
RunScriptOutput(context.Context, string) ([]byte, error)
UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error
StreamTar(context.Context, string, string, io.Writer) error
StreamTarEntries(context.Context, string, []string, string, io.Writer) error
}
func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval time.Duration) error {
if d != nil && d.guestWaitForSSH != nil {
return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
}
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval)
}
func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) {
if d != nil && d.guestDial != nil {
return d.guestDial(ctx, address, d.config.SSHKeyPath)
}
return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
}

View file

@ -26,10 +26,9 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) {
name: "only store + closing channel (early failure)",
build: func(t *testing.T) *Daemon {
return &Daemon{
store: openDaemonStore(t),
closing: make(chan struct{}),
sessions: newSessionRegistry(),
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
store: openDaemonStore(t),
closing: make(chan struct{}),
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
}
},
verify: func(t *testing.T, d *Daemon) {
@ -49,11 +48,10 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) {
t.Fatalf("vmdns.New: %v", err)
}
return &Daemon{
store: openDaemonStore(t),
closing: make(chan struct{}),
sessions: newSessionRegistry(),
vmDNS: server,
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
store: openDaemonStore(t),
closing: make(chan struct{}),
vmDNS: server,
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
}
},
verify: func(t *testing.T, d *Daemon) {
@ -86,11 +84,10 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) {
// returns and also calls Close afterwards, both paths must survive.
func TestCloseIdempotentUnderConcurrency(t *testing.T) {
d := &Daemon{
store: openDaemonStore(t),
closing: make(chan struct{}),
sessions: newSessionRegistry(),
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
config: model.DaemonConfig{BridgeName: ""},
store: openDaemonStore(t),
closing: make(chan struct{}),
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
config: model.DaemonConfig{BridgeName: ""},
}
var count atomic.Int32

View file

@ -1,521 +0,0 @@
// Package session contains the pure helpers of the guest-session subsystem:
// bash script generators, on-guest state path helpers, state snapshot
// parsing, and small utilities like ShellQuote and FormatStepError.
//
// The orchestrator methods (StartGuestSession, BeginGuestSessionAttach,
// etc.) stay on *daemon.Daemon and compose these helpers.
package session
import (
"bufio"
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"syscall"
"banger/internal/model"
"banger/internal/system"
"golang.org/x/crypto/ssh"
)
// Constants shared between orchestration and helpers.
const (
BackendSSH = "ssh"
AttachBackendNone = "none"
AttachBackendSSHBridge = "ssh_rehydratable"
AttachModeExclusive = "exclusive"
TransportUnixSocket = "unix_socket"
StateRoot = "/root/.local/state/banger/sessions"
LogTailLineDefault = 200
)
// StateSnapshot is the decoded per-session state as read from the guest.
type StateSnapshot struct {
Status string
GuestPID int
ExitCode *int
Alive bool
LastError string
}
// -- Guest filesystem paths -------------------------------------------------
func StateDir(id string) string {
return filepath.ToSlash(filepath.Join(StateRoot, id))
}
func RelativeStateDir(id string) string {
return strings.TrimPrefix(StateDir(id), "/root/")
}
func ScriptPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "run.sh")) }
func PIDPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "pid")) }
func MonitorPIDPath(id string) string {
return filepath.ToSlash(filepath.Join(StateDir(id), "monitor_pid"))
}
func ExitCodePath(id string) string {
return filepath.ToSlash(filepath.Join(StateDir(id), "exit_code"))
}
func StdinPipePath(id string) string {
return filepath.ToSlash(filepath.Join(StateDir(id), "stdin.pipe"))
}
func StdinKeepalivePIDPath(id string) string {
return filepath.ToSlash(filepath.Join(StateDir(id), "stdin_keepalive.pid"))
}
func StatusPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "status")) }
func ErrorPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "error")) }
func StdoutLogPath(id string) string {
return filepath.ToSlash(filepath.Join(StateDir(id), "stdout.log"))
}
func StderrLogPath(id string) string {
return filepath.ToSlash(filepath.Join(StateDir(id), "stderr.log"))
}
// -- Script generators ------------------------------------------------------
// Script returns the bash runner installed into the guest for session. It
// sets up state/log paths, optional stdin fifo, and wait-loop around the
// user command.
func Script(sess model.GuestSession) string {
var script strings.Builder
script.WriteString("set -euo pipefail\n")
fmt.Fprintf(&script, "STATE_DIR=%s\n", ShellQuote(sess.GuestStateDir))
fmt.Fprintf(&script, "STDOUT_LOG=%s\n", ShellQuote(sess.StdoutLogPath))
fmt.Fprintf(&script, "STDERR_LOG=%s\n", ShellQuote(sess.StderrLogPath))
fmt.Fprintf(&script, "PID_FILE=%s\n", ShellQuote(PIDPath(sess.ID)))
fmt.Fprintf(&script, "MONITOR_PID_FILE=%s\n", ShellQuote(MonitorPIDPath(sess.ID)))
fmt.Fprintf(&script, "EXIT_FILE=%s\n", ShellQuote(ExitCodePath(sess.ID)))
fmt.Fprintf(&script, "STATUS_FILE=%s\n", ShellQuote(StatusPath(sess.ID)))
fmt.Fprintf(&script, "ERROR_FILE=%s\n", ShellQuote(ErrorPath(sess.ID)))
fmt.Fprintf(&script, "STDIN_PIPE=%s\n", ShellQuote(StdinPipePath(sess.ID)))
fmt.Fprintf(&script, "STDIN_KEEPALIVE_PID_FILE=%s\n", ShellQuote(StdinKeepalivePIDPath(sess.ID)))
fmt.Fprintf(&script, "SESSION_CWD=%s\n", ShellQuote(DefaultCWD(sess.CWD)))
script.WriteString("mkdir -p \"$STATE_DIR\"\n")
script.WriteString(": >\"$STDOUT_LOG\"\n")
script.WriteString(": >\"$STDERR_LOG\"\n")
script.WriteString("rm -f \"$EXIT_FILE\" \"$ERROR_FILE\" \"$STDIN_KEEPALIVE_PID_FILE\"\n")
if sess.StdinMode == model.GuestSessionStdinPipe {
script.WriteString("rm -f \"$STDIN_PIPE\"\n")
script.WriteString("mkfifo -m 600 \"$STDIN_PIPE\"\n")
}
script.WriteString("printf '%s\\n' \"${BASHPID:-$$}\" >\"$MONITOR_PID_FILE\"\n")
script.WriteString("printf 'starting\\n' >\"$STATUS_FILE\"\n")
script.WriteString("cd \"$SESSION_CWD\"\n")
script.WriteString("exec > >(tee -a \"$STDOUT_LOG\") 2> >(tee -a \"$STDERR_LOG\" >&2)\n")
for _, line := range EnvLines(sess.Env) {
script.WriteString(line)
script.WriteByte('\n')
}
script.WriteString("COMMAND=(")
for _, value := range append([]string{sess.Command}, sess.Args...) {
script.WriteByte(' ')
script.WriteString(ShellQuote(value))
}
script.WriteString(" )\n")
if sess.StdinMode == model.GuestSessionStdinPipe {
script.WriteString("( while :; do sleep 3600; done ) >\"$STDIN_PIPE\" &\n")
script.WriteString("keepalive=$!\n")
script.WriteString("printf '%s\\n' \"$keepalive\" >\"$STDIN_KEEPALIVE_PID_FILE\"\n")
script.WriteString("\"${COMMAND[@]}\" <\"$STDIN_PIPE\" &\n")
} else {
script.WriteString("\"${COMMAND[@]}\" &\n")
}
script.WriteString("child=$!\n")
script.WriteString("printf '%s\\n' \"$child\" >\"$PID_FILE\"\n")
script.WriteString("printf 'running\\n' >\"$STATUS_FILE\"\n")
script.WriteString("wait \"$child\"\n")
script.WriteString("rc=$?\n")
if sess.StdinMode == model.GuestSessionStdinPipe {
script.WriteString("if [ -f \"$STDIN_KEEPALIVE_PID_FILE\" ]; then kill \"$(cat \"$STDIN_KEEPALIVE_PID_FILE\")\" 2>/dev/null || true; fi\n")
}
script.WriteString("printf '%s\\n' \"$rc\" >\"$EXIT_FILE\"\n")
script.WriteString("if [ \"$rc\" -eq 0 ]; then printf 'exited\\n' >\"$STATUS_FILE\"; else printf 'failed\\n' >\"$STATUS_FILE\"; fi\n")
script.WriteString("exit \"$rc\"\n")
return script.String()
}
// InspectScript reads the on-guest state files for sessionID and prints a
// key=value block parseable by ParseState.
func InspectScript(sessionID string) string {
var script strings.Builder
script.WriteString("set -euo pipefail\n")
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(StateDir(sessionID)))
script.WriteString("status=''\n")
script.WriteString("pid=''\n")
script.WriteString("exit_code=''\n")
script.WriteString("last_error=''\n")
script.WriteString("alive=false\n")
script.WriteString("[ -f \"$DIR/status\" ] && status=\"$(cat \"$DIR/status\")\"\n")
script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n")
script.WriteString("[ -f \"$DIR/exit_code\" ] && exit_code=\"$(cat \"$DIR/exit_code\")\"\n")
script.WriteString("[ -f \"$DIR/error\" ] && last_error=\"$(cat \"$DIR/error\")\"\n")
script.WriteString("if [ -n \"$pid\" ] && kill -0 \"$pid\" 2>/dev/null; then alive=true; fi\n")
script.WriteString("printf 'status=%s\\n' \"$status\"\n")
script.WriteString("printf 'pid=%s\\n' \"$pid\"\n")
script.WriteString("printf 'exit=%s\\n' \"$exit_code\"\n")
script.WriteString("printf 'alive=%s\\n' \"$alive\"\n")
script.WriteString("printf 'error=%s\\n' \"$last_error\"\n")
return script.String()
}
// SignalScript sends signal to sessionID's runner and monitor processes.
func SignalScript(sessionID, signal string) string {
var script strings.Builder
script.WriteString("set -euo pipefail\n")
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(StateDir(sessionID)))
fmt.Fprintf(&script, "SIGNAL=%s\n", ShellQuote(signal))
script.WriteString("pid=''\n")
script.WriteString("monitor=''\n")
script.WriteString("keepalive=''\n")
script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n")
script.WriteString("[ -f \"$DIR/monitor_pid\" ] && monitor=\"$(cat \"$DIR/monitor_pid\")\"\n")
script.WriteString("[ -f \"$DIR/stdin_keepalive.pid\" ] && keepalive=\"$(cat \"$DIR/stdin_keepalive.pid\")\"\n")
script.WriteString("printf 'stopping\\n' >\"$DIR/status\"\n")
script.WriteString("if [ -n \"$pid\" ]; then kill -${SIGNAL} \"$pid\" 2>/dev/null || true; fi\n")
script.WriteString("if [ -n \"$monitor\" ]; then kill -${SIGNAL} \"$monitor\" 2>/dev/null || true; fi\n")
script.WriteString("if [ -n \"$keepalive\" ]; then kill -${SIGNAL} \"$keepalive\" 2>/dev/null || true; fi\n")
return script.String()
}
// CWDPreflightScript verifies cwd exists on the guest.
func CWDPreflightScript(cwd string) string {
var script strings.Builder
script.WriteString("set -euo pipefail\n")
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(DefaultCWD(cwd)))
script.WriteString("if [ ! -d \"$DIR\" ]; then echo \"missing cwd: $DIR\"; exit 1; fi\n")
return script.String()
}
// CommandPreflightScript verifies each command is resolvable on the guest.
func CommandPreflightScript(commands []string) string {
var script strings.Builder
script.WriteString("set -euo pipefail\n")
script.WriteString("check_command() {\n")
script.WriteString(" cmd=\"$1\"\n")
script.WriteString(" case \"$cmd\" in\n")
script.WriteString(" */*) [ -x \"$cmd\" ] || { echo \"missing command: $cmd\"; exit 1; } ;;\n")
script.WriteString(" *) command -v \"$cmd\" >/dev/null 2>&1 || { echo \"missing command: $cmd\"; exit 1; } ;;\n")
script.WriteString(" esac\n")
script.WriteString("}\n")
for _, command := range commands {
fmt.Fprintf(&script, "check_command %s\n", ShellQuote(command))
}
return script.String()
}
// AttachInputCommand returns the guest command that creates/opens the stdin
// fifo for sessionID and cats attach-side bytes into it.
func AttachInputCommand(sessionID string) string {
path := StdinPipePath(sessionID)
return "bash -lc " + ShellQuote(fmt.Sprintf("set -euo pipefail\n[ -p %s ] || mkfifo -m 600 %s\nexec cat > %s\n", ShellQuote(path), ShellQuote(path), ShellQuote(path)))
}
// AttachTailCommand returns the guest command that tails a log file and
// streams new content back to the attach bridge.
func AttachTailCommand(path string) string {
return "bash -lc " + ShellQuote(fmt.Sprintf("set -euo pipefail\ntouch %s\nexec tail -n 0 -F %s 2>/dev/null\n", ShellQuote(path), ShellQuote(path)))
}
// EnvLines returns deterministic `export KEY=value` lines for the session
// launcher, ordered by key.
func EnvLines(values map[string]string) []string {
if len(values) == 0 {
return nil
}
keys := make([]string, 0, len(values))
for key := range values {
keys = append(keys, key)
}
sort.Strings(keys)
lines := make([]string, 0, len(keys))
for _, key := range keys {
lines = append(lines, "export "+key+"="+ShellQuote(values[key]))
}
return lines
}
// -- State snapshot helpers -------------------------------------------------
// ParseState decodes the key=value output produced by InspectScript.
func ParseState(raw string) (StateSnapshot, error) {
var snapshot StateSnapshot
scanner := bufio.NewScanner(strings.NewReader(raw))
for scanner.Scan() {
line := scanner.Text()
key, value, ok := strings.Cut(line, "=")
if !ok {
continue
}
switch strings.TrimSpace(key) {
case "status":
snapshot.Status = strings.TrimSpace(value)
case "pid":
if pid, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
snapshot.GuestPID = pid
}
case "exit":
if exitCode, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
snapshot.ExitCode = &exitCode
}
case "alive":
snapshot.Alive = strings.TrimSpace(value) == "true"
case "error":
snapshot.LastError = strings.TrimSpace(value)
}
}
return snapshot, scanner.Err()
}
// InspectStateFromDir reads the state files directly from stateDir (used
// when the guest is offline and we can mount the work disk from the host).
func InspectStateFromDir(stateDir string) (StateSnapshot, error) {
var snapshot StateSnapshot
statusData, _ := os.ReadFile(filepath.Join(stateDir, "status"))
snapshot.Status = strings.TrimSpace(string(statusData))
pidData, _ := os.ReadFile(filepath.Join(stateDir, "pid"))
if pidValue, err := strconv.Atoi(strings.TrimSpace(string(pidData))); err == nil {
snapshot.GuestPID = pidValue
}
exitData, _ := os.ReadFile(filepath.Join(stateDir, "exit_code"))
if exitValue, err := strconv.Atoi(strings.TrimSpace(string(exitData))); err == nil {
snapshot.ExitCode = &exitValue
}
errorData, _ := os.ReadFile(filepath.Join(stateDir, "error"))
snapshot.LastError = strings.TrimSpace(string(errorData))
if snapshot.GuestPID != 0 {
snapshot.Alive = ProcessAlive(snapshot.GuestPID)
}
return snapshot, nil
}
// ApplyStateSnapshot mutates sess in place to reflect snapshot. vmRunning
// captures whether the VM is currently up so stale in-flight sessions can be
// failed when the VM is gone.
func ApplyStateSnapshot(sess *model.GuestSession, snapshot StateSnapshot, vmRunning bool) {
if sess == nil {
return
}
if snapshot.GuestPID != 0 {
sess.GuestPID = snapshot.GuestPID
}
if snapshot.LastError != "" {
sess.LastError = snapshot.LastError
}
if snapshot.ExitCode != nil {
sess.ExitCode = snapshot.ExitCode
sess.Attachable = false
sess.Reattachable = false
if sess.StartedAt.IsZero() {
sess.StartedAt = model.Now()
}
if sess.EndedAt.IsZero() {
sess.EndedAt = model.Now()
}
if *snapshot.ExitCode == 0 {
sess.Status = model.GuestSessionStatusExited
} else {
sess.Status = model.GuestSessionStatusFailed
}
return
}
if snapshot.Alive {
if sess.StartedAt.IsZero() {
sess.StartedAt = model.Now()
}
sess.Status = model.GuestSessionStatusRunning
return
}
if !vmRunning && (sess.Status == model.GuestSessionStatusStarting || sess.Status == model.GuestSessionStatusRunning || sess.Status == model.GuestSessionStatusStopping) {
sess.Status = model.GuestSessionStatusFailed
sess.Attachable = false
sess.Reattachable = false
if sess.LastError == "" {
sess.LastError = "vm is not running"
}
if sess.EndedAt.IsZero() {
sess.EndedAt = model.Now()
}
return
}
if snapshot.Status == string(model.GuestSessionStatusRunning) {
if sess.StartedAt.IsZero() {
sess.StartedAt = model.Now()
}
sess.Status = model.GuestSessionStatusRunning
}
if sess.Status == model.GuestSessionStatusRunning && sess.StdinMode == model.GuestSessionStdinPipe {
sess.Attachable = true
sess.Reattachable = true
if sess.AttachBackend == "" {
sess.AttachBackend = AttachBackendSSHBridge
}
if sess.AttachMode == "" {
sess.AttachMode = AttachModeExclusive
}
}
}
// StateChanged reports whether any materially observable field differs
// between before and after, guiding whether to persist an update.
func StateChanged(before, after model.GuestSession) bool {
if before.Status != after.Status || before.GuestPID != after.GuestPID || before.LastError != after.LastError || before.Attachable != after.Attachable || before.Reattachable != after.Reattachable || before.AttachBackend != after.AttachBackend || before.AttachMode != after.AttachMode || before.LaunchStage != after.LaunchStage || before.LaunchMessage != after.LaunchMessage || before.LaunchRawLog != after.LaunchRawLog {
return true
}
if before.StartedAt != after.StartedAt || before.EndedAt != after.EndedAt {
return true
}
switch {
case before.ExitCode == nil && after.ExitCode == nil:
return false
case before.ExitCode == nil || after.ExitCode == nil:
return true
default:
return *before.ExitCode != *after.ExitCode
}
}
// -- Launch helpers ---------------------------------------------------------
// DefaultName returns a friendly session name: caller-provided if non-empty,
// otherwise `<command-base>-<short-id>`.
func DefaultName(id, command, explicit string) string {
if trimmed := strings.TrimSpace(explicit); trimmed != "" {
return trimmed
}
base := filepath.Base(strings.TrimSpace(command))
if base == "." || base == string(filepath.Separator) || base == "" {
base = "session"
}
return base + "-" + system.ShortID(id)
}
// DefaultCWD returns value if non-empty, else /root.
func DefaultCWD(value string) string {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
return "/root"
}
// FailLaunch annotates sess as launch-failed with stage/message/raw log and
// returns it for persistence.
func FailLaunch(sess model.GuestSession, stage, message, rawLog string) model.GuestSession {
now := model.Now()
sess.Status = model.GuestSessionStatusFailed
sess.LastError = strings.TrimSpace(message)
sess.Attachable = false
sess.Reattachable = false
sess.LaunchStage = strings.TrimSpace(stage)
sess.LaunchMessage = strings.TrimSpace(message)
sess.LaunchRawLog = strings.TrimSpace(rawLog)
sess.UpdatedAt = now
sess.EndedAt = now
return sess
}
// NormalizeRequiredCommands returns a de-duplicated, order-preserving list
// of required commands, with the session command first.
func NormalizeRequiredCommands(command string, extras []string) []string {
ordered := make([]string, 0, len(extras)+1)
seen := map[string]struct{}{}
appendValue := func(value string) {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return
}
if _, ok := seen[trimmed]; ok {
return
}
seen[trimmed] = struct{}{}
ordered = append(ordered, trimmed)
}
appendValue(command)
for _, extra := range extras {
appendValue(extra)
}
return ordered
}
// -- Small utilities --------------------------------------------------------
// ShellQuote returns value single-quoted for bash, escaping embedded quotes.
func ShellQuote(value string) string {
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
}
// ExitCode extracts the exit status from an ssh.ExitError, returning
// (0, true) for nil errors.
func ExitCode(err error) (int, bool) {
if err == nil {
return 0, true
}
var exitErr *ssh.ExitError
if errors.As(err, &exitErr) {
return exitErr.ExitStatus(), true
}
return 0, false
}
// CloneStringMap returns a shallow copy of values, or nil if empty.
func CloneStringMap(values map[string]string) map[string]string {
if len(values) == 0 {
return nil
}
cloned := make(map[string]string, len(values))
for key, value := range values {
cloned[key] = value
}
return cloned
}
// TailFileContent returns the last N lines of a file, or "" if the file is
// missing.
func TailFileContent(path string, lines int) (string, error) {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return "", nil
}
return "", err
}
if lines <= 0 {
return string(data), nil
}
parts := strings.Split(string(data), "\n")
if len(parts) <= lines {
return string(data), nil
}
return strings.Join(parts[len(parts)-lines-1:], "\n"), nil
}
// ProcessAlive returns true if the process with pid exists. The syscallKill
// override is exposed for tests that need to simulate alive/dead processes.
func ProcessAlive(pid int) bool {
if pid <= 0 {
return false
}
return syscallKill(pid, syscall.Signal(0)) == nil
}
// syscallKill is a test seam: tests replace it to stub process-alive checks.
var syscallKill = func(pid int, signal os.Signal) error {
proc, err := os.FindProcess(pid)
if err != nil {
return err
}
return proc.Signal(signal)
}
// FormatStepError wraps err with an action label and trimmed on-guest log.
func FormatStepError(action string, err error, log string) error {
log = strings.TrimSpace(log)
if log == "" {
return fmt.Errorf("%s: %w", action, err)
}
return fmt.Errorf("%s: %w: %s", action, err, log)
}

View file

@ -1,440 +0,0 @@
package session
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
"banger/internal/model"
"golang.org/x/crypto/ssh"
)
func TestRelativeStateDir(t *testing.T) {
got := RelativeStateDir("abc")
if strings.HasPrefix(got, "/root/") {
t.Fatalf("RelativeStateDir(%q) = %q, should strip /root/ prefix", "abc", got)
}
if !strings.Contains(got, "abc") {
t.Fatalf("missing session id in %q", got)
}
absolute := StateDir("abc")
if got != strings.TrimPrefix(absolute, "/root/") {
t.Fatalf("relative = %q, want %q", got, strings.TrimPrefix(absolute, "/root/"))
}
}
func TestDefaultCWD(t *testing.T) {
if DefaultCWD("") != "/root" {
t.Error("empty should return /root")
}
if DefaultCWD(" ") != "/root" {
t.Error("whitespace should return /root")
}
if DefaultCWD("/work") != "/work" {
t.Error("explicit should pass through")
}
}
func TestShellQuote(t *testing.T) {
if got := ShellQuote(""); got != "''" {
t.Errorf("empty: got %q, want ''", got)
}
if got := ShellQuote("x"); got != "'x'" {
t.Errorf("plain: got %q", got)
}
if got := ShellQuote("it's"); got != `'it'"'"'s'` {
t.Errorf("apostrophe: got %q", got)
}
}
func TestExitCode(t *testing.T) {
if code, ok := ExitCode(nil); !ok || code != 0 {
t.Errorf("nil err: got (%d, %v), want (0, true)", code, ok)
}
// Build an ssh.ExitError using its real type — can't hand-construct,
// so wrap via errors.As check with a stub.
raw := &ssh.ExitError{}
if _, ok := ExitCode(raw); !ok {
t.Error("ssh.ExitError: ok should be true")
}
if _, ok := ExitCode(errors.New("bare error")); ok {
t.Error("bare error: ok should be false")
}
}
func TestCloneStringMap(t *testing.T) {
if CloneStringMap(nil) != nil {
t.Error("nil in → nil out")
}
if CloneStringMap(map[string]string{}) != nil {
t.Error("empty in → nil out")
}
src := map[string]string{"a": "1", "b": "2"}
cloned := CloneStringMap(src)
if len(cloned) != 2 {
t.Fatalf("len = %d, want 2", len(cloned))
}
cloned["a"] = "changed"
if src["a"] != "1" {
t.Error("mutating clone leaked back to source")
}
}
func TestTailFileContent(t *testing.T) {
// Missing file → empty, no error.
got, err := TailFileContent(filepath.Join(t.TempDir(), "missing"), 10)
if err != nil || got != "" {
t.Errorf("missing: got (%q, %v), want ('', nil)", got, err)
}
path := filepath.Join(t.TempDir(), "log")
lines := "one\ntwo\nthree\nfour\nfive"
if err := os.WriteFile(path, []byte(lines), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
full, err := TailFileContent(path, 0)
if err != nil || full != lines {
t.Errorf("0 lines: got (%q, %v), want (%q, nil)", full, err, lines)
}
// Request more lines than exist → full content.
all, err := TailFileContent(path, 999)
if err != nil || all != lines {
t.Errorf("999 lines: got %q", all)
}
last2, err := TailFileContent(path, 2)
if err != nil {
t.Fatalf("2 lines: %v", err)
}
if !strings.Contains(last2, "five") {
t.Errorf("2 lines missing last line: %q", last2)
}
}
func TestProcessAlive(t *testing.T) {
if ProcessAlive(0) {
t.Error("pid 0 should not be alive")
}
if ProcessAlive(-1) {
t.Error("negative pid should not be alive")
}
// Swap the syscall seam.
original := syscallKill
t.Cleanup(func() { syscallKill = original })
syscallKill = func(pid int, signal os.Signal) error { return nil }
if !ProcessAlive(42) {
t.Error("syscallKill=nil should report alive")
}
syscallKill = func(pid int, signal os.Signal) error { return fmt.Errorf("no such process") }
if ProcessAlive(42) {
t.Error("syscallKill error should report dead")
}
}
func TestFormatStepError(t *testing.T) {
base := errors.New("boom")
err := FormatStepError("prepare", base, "")
if !errors.Is(err, base) {
t.Error("FormatStepError should wrap the base error")
}
if !strings.Contains(err.Error(), "prepare") {
t.Errorf("missing action: %v", err)
}
errWithLog := FormatStepError("prepare", base, " log line\n")
if !strings.Contains(errWithLog.Error(), "log line") {
t.Errorf("missing log: %v", errWithLog)
}
}
func TestParseStateHappyPath(t *testing.T) {
raw := `status=running
pid=123
exit=
alive=true
error=
`
snap, err := ParseState(raw)
if err != nil {
t.Fatalf("ParseState: %v", err)
}
if snap.Status != "running" {
t.Errorf("Status = %q", snap.Status)
}
if snap.GuestPID != 123 {
t.Errorf("GuestPID = %d", snap.GuestPID)
}
if snap.ExitCode != nil {
t.Errorf("ExitCode should be nil when empty, got %v", snap.ExitCode)
}
if !snap.Alive {
t.Error("Alive should be true")
}
}
func TestParseStateWithExit(t *testing.T) {
raw := `status=exited
pid=123
exit=7
alive=false
error=something bad
`
snap, err := ParseState(raw)
if err != nil {
t.Fatalf("ParseState: %v", err)
}
if snap.ExitCode == nil || *snap.ExitCode != 7 {
t.Errorf("ExitCode = %v, want 7", snap.ExitCode)
}
if snap.LastError != "something bad" {
t.Errorf("LastError = %q", snap.LastError)
}
if snap.Alive {
t.Error("Alive should be false")
}
}
func TestParseStateIgnoresMalformedLines(t *testing.T) {
raw := "no-equals-here\nstatus=ok\n"
snap, err := ParseState(raw)
if err != nil {
t.Fatalf("ParseState: %v", err)
}
if snap.Status != "ok" {
t.Errorf("Status = %q, want ok", snap.Status)
}
}
func TestInspectStateFromDir(t *testing.T) {
dir := t.TempDir()
writeFile := func(name, content string) {
if err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0o600); err != nil {
t.Fatalf("WriteFile(%s): %v", name, err)
}
}
writeFile("status", "running\n")
writeFile("pid", "42\n")
writeFile("exit_code", "0\n")
writeFile("error", "\n")
original := syscallKill
t.Cleanup(func() { syscallKill = original })
syscallKill = func(pid int, signal os.Signal) error { return nil }
snap, err := InspectStateFromDir(dir)
if err != nil {
t.Fatalf("InspectStateFromDir: %v", err)
}
if snap.Status != "running" {
t.Errorf("Status = %q", snap.Status)
}
if snap.GuestPID != 42 {
t.Errorf("GuestPID = %d", snap.GuestPID)
}
if snap.ExitCode == nil || *snap.ExitCode != 0 {
t.Errorf("ExitCode = %v, want 0", snap.ExitCode)
}
if !snap.Alive {
t.Error("Alive should reflect syscallKill result (true)")
}
}
func TestInspectStateFromDirMissingFiles(t *testing.T) {
snap, err := InspectStateFromDir(t.TempDir())
if err != nil {
t.Fatalf("InspectStateFromDir (empty): %v", err)
}
if snap.Status != "" || snap.GuestPID != 0 || snap.ExitCode != nil {
t.Errorf("empty dir: snap = %+v", snap)
}
}
func TestApplyStateSnapshotNilReceiver(t *testing.T) {
ApplyStateSnapshot(nil, StateSnapshot{}, true) // should not panic
}
func TestApplyStateSnapshotExitedSuccess(t *testing.T) {
exit := 0
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning, Attachable: true, Reattachable: true}
ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true)
if sess.Status != model.GuestSessionStatusExited {
t.Errorf("Status = %q, want exited", sess.Status)
}
if sess.Attachable || sess.Reattachable {
t.Error("attach flags should be cleared on exit")
}
if sess.EndedAt.IsZero() {
t.Error("EndedAt should be set")
}
}
func TestApplyStateSnapshotExitedFailure(t *testing.T) {
exit := 2
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning}
ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true)
if sess.Status != model.GuestSessionStatusFailed {
t.Errorf("Status = %q, want failed", sess.Status)
}
}
func TestApplyStateSnapshotVMGone(t *testing.T) {
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning}
ApplyStateSnapshot(sess, StateSnapshot{Alive: false}, false)
if sess.Status != model.GuestSessionStatusFailed {
t.Errorf("Status = %q, want failed", sess.Status)
}
if sess.LastError == "" {
t.Error("LastError should be populated when VM is gone")
}
}
func TestApplyStateSnapshotRunningStatusSetsAttachableForPipe(t *testing.T) {
// When the guest-side status file reports "running" (Alive=false from
// kill -0 may still fail transiently), ApplyStateSnapshot transitions
// the session to running and sets attach flags for pipe-mode.
sess := &model.GuestSession{
Status: model.GuestSessionStatusStarting,
StdinMode: model.GuestSessionStdinPipe,
}
ApplyStateSnapshot(sess, StateSnapshot{Status: string(model.GuestSessionStatusRunning), GuestPID: 11}, true)
if sess.Status != model.GuestSessionStatusRunning {
t.Errorf("Status = %q, want running", sess.Status)
}
if !sess.Attachable || !sess.Reattachable {
t.Error("pipe-mode running session should be attachable + reattachable")
}
if sess.AttachBackend != AttachBackendSSHBridge {
t.Errorf("AttachBackend = %q, want %q", sess.AttachBackend, AttachBackendSSHBridge)
}
}
func TestApplyStateSnapshotAliveEarlyReturn(t *testing.T) {
// Alive-true returns immediately after setting status; no attach
// flags set on this path (by design — attach metadata only attaches
// to status-driven transitions).
sess := &model.GuestSession{
Status: model.GuestSessionStatusStarting,
StdinMode: model.GuestSessionStdinPipe,
}
ApplyStateSnapshot(sess, StateSnapshot{Alive: true, GuestPID: 11}, true)
if sess.Status != model.GuestSessionStatusRunning {
t.Errorf("Status = %q, want running", sess.Status)
}
if sess.StartedAt.IsZero() {
t.Error("StartedAt should have been set")
}
}
func TestStateChanged(t *testing.T) {
base := model.GuestSession{Status: model.GuestSessionStatusRunning, GuestPID: 10}
// Identical → no change.
if StateChanged(base, base) {
t.Error("identical states should not be considered changed")
}
// Status change.
changed := base
changed.Status = model.GuestSessionStatusExited
if !StateChanged(base, changed) {
t.Error("status change should be detected")
}
// ExitCode change from nil → value.
exit := 3
changed = base
changed.ExitCode = &exit
if !StateChanged(base, changed) {
t.Error("exit-code appearing should be detected")
}
// Both have the same exit code → no change.
a := base
a.ExitCode = &exit
b := base
b.ExitCode = &exit
if StateChanged(a, b) {
t.Error("matching exit codes should not trigger change")
}
// Different exit codes.
other := 5
b.ExitCode = &other
if !StateChanged(a, b) {
t.Error("differing exit codes should be detected")
}
// Timestamp change.
changed = base
changed.StartedAt = time.Now()
if !StateChanged(base, changed) {
t.Error("StartedAt change should be detected")
}
}
func TestFailLaunch(t *testing.T) {
in := model.GuestSession{Status: model.GuestSessionStatusStarting, Attachable: true}
out := FailLaunch(in, "provision", " ssh did not come up ", " raw output\n")
if out.Status != model.GuestSessionStatusFailed {
t.Errorf("Status = %q, want failed", out.Status)
}
if out.LastError != "ssh did not come up" {
t.Errorf("LastError = %q (not trimmed?)", out.LastError)
}
if out.LaunchStage != "provision" || out.LaunchMessage != "ssh did not come up" {
t.Errorf("launch fields not set: %+v", out)
}
if out.LaunchRawLog != "raw output" {
t.Errorf("rawLog = %q (not trimmed?)", out.LaunchRawLog)
}
if out.Attachable {
t.Error("Attachable should be cleared")
}
}
func TestNormalizeRequiredCommands(t *testing.T) {
got := NormalizeRequiredCommands("pi", []string{"pi", "git", "", "git", " ", "make"})
want := []string{"pi", "git", "make"}
if len(got) != len(want) {
t.Fatalf("len = %d, want %d (%v)", len(got), len(want), got)
}
for i, v := range want {
if got[i] != v {
t.Errorf("position %d: got %q, want %q", i, got[i], v)
}
}
}
func TestInspectScriptContainsAllStateFiles(t *testing.T) {
script := InspectScript("sess-abc")
for _, key := range []string{"status", "pid", "exit_code", "error", "alive"} {
if !strings.Contains(script, key) {
t.Errorf("script missing %q:\n%s", key, script)
}
}
if !strings.Contains(script, "sess-abc") {
t.Error("script missing session id")
}
}
func TestSignalScriptIncludesSignalAndDirPaths(t *testing.T) {
script := SignalScript("sess-x", "TERM")
if !strings.Contains(script, "TERM") {
t.Error("missing signal")
}
if !strings.Contains(script, "sess-x") {
t.Error("missing session id")
}
if !strings.Contains(script, "monitor_pid") || !strings.Contains(script, "stdin_keepalive") {
t.Errorf("expected both monitor + stdin_keepalive kills, got:\n%s", script)
}
}

View file

@ -1,224 +0,0 @@
package daemon
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"time"
"banger/internal/api"
sess "banger/internal/daemon/session"
"banger/internal/guest"
"banger/internal/model"
"banger/internal/sessionstream"
)
func (d *Daemon) BeginGuestSessionAttach(ctx context.Context, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) {
vm, err := d.FindVM(ctx, params.VMIDOrName)
if err != nil {
return api.GuestSessionAttachBeginResult{}, err
}
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
if err != nil {
return api.GuestSessionAttachBeginResult{}, err
}
session, _ = d.refreshGuestSession(ctx, vm, session)
if !session.Attachable {
return api.GuestSessionAttachBeginResult{}, errors.New("session is not attachable")
}
controller := &guestSessionController{}
if !d.claimGuestSessionController(session.ID, controller) {
return api.GuestSessionAttachBeginResult{}, errors.New("session already has an active attach")
}
attachID, err := model.NewID()
if err != nil {
d.clearGuestSessionController(session.ID)
return api.GuestSessionAttachBeginResult{}, err
}
socketPath := filepath.Join(d.layout.RuntimeDir, "guest-session-attach-"+attachID[:12]+".sock")
_ = os.Remove(socketPath)
listener, err := net.Listen("unix", socketPath)
if err != nil {
d.clearGuestSessionController(session.ID)
return api.GuestSessionAttachBeginResult{}, err
}
if err := os.Chmod(socketPath, 0o600); err != nil {
_ = listener.Close()
_ = os.Remove(socketPath)
d.clearGuestSessionController(session.ID)
return api.GuestSessionAttachBeginResult{}, err
}
go d.serveGuestSessionAttach(session, controller, attachID, socketPath, listener)
return api.GuestSessionAttachBeginResult{
Session: session,
AttachID: attachID,
TransportKind: sess.TransportUnixSocket,
TransportTarget: socketPath,
SocketPath: socketPath,
StreamFormat: sessionstream.FormatV1,
}, nil
}
func (d *Daemon) forwardGuestSessionOutput(_ string, controller *guestSessionController, channel byte, reader io.Reader) {
buffer := make([]byte, 32*1024)
for {
n, err := reader.Read(buffer)
if n > 0 {
controller.writeFrame(channel, buffer[:n])
}
if err != nil {
if !errors.Is(err, io.EOF) {
controller.writeControl(sessionstream.ControlMessage{Type: "error", Error: err.Error()})
}
return
}
}
}
func (d *Daemon) waitForGuestSessionExit(id string, controller *guestSessionController, session model.GuestSession) {
err := controller.stream.Wait()
updated := session
updated.Attachable = false
now := model.Now()
updated.UpdatedAt = now
updated.EndedAt = now
if exitCode, ok := sess.ExitCode(err); ok {
updated.ExitCode = &exitCode
if exitCode == 0 {
updated.Status = model.GuestSessionStatusExited
} else {
updated.Status = model.GuestSessionStatusFailed
}
}
if err != nil && updated.LastError == "" {
updated.LastError = err.Error()
}
if vm, getErr := d.store.GetVMByID(context.Background(), updated.VMID); getErr == nil {
if refreshed, refreshErr := d.refreshGuestSession(context.Background(), vm, updated); refreshErr == nil {
updated = refreshed
}
}
_ = d.store.UpsertGuestSession(context.Background(), updated)
controller.writeControl(sessionstream.ControlMessage{Type: "exit", ExitCode: updated.ExitCode})
_ = controller.close()
d.clearGuestSessionController(id)
}
func (d *Daemon) serveGuestSessionAttach(session model.GuestSession, controller *guestSessionController, _ string, socketPath string, listener net.Listener) {
defer func() {
_ = listener.Close()
_ = os.Remove(socketPath)
_ = controller.close()
d.clearGuestSessionController(session.ID)
}()
conn, err := listener.Accept()
if err != nil {
return
}
defer conn.Close()
if err := controller.setAttach(conn); err != nil {
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
return
}
defer controller.clearAttach(conn)
if err := d.attachGuestSessionBridge(session, controller); err != nil {
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
return
}
for {
channel, payload, err := sessionstream.ReadFrame(conn)
if err != nil {
return
}
switch channel {
case sessionstream.ChannelStdin:
if controller.stdin == nil {
continue
}
if _, err := controller.stdin.Write(payload); err != nil {
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
return
}
case sessionstream.ChannelControl:
message, err := sessionstream.ReadControl(payload)
if err != nil {
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
return
}
if message.Type == "eof" && controller.stdin != nil {
_ = controller.stdin.Close()
}
}
}
}
func (d *Daemon) attachGuestSessionBridge(session model.GuestSession, controller *guestSessionController) error {
vm, err := d.store.GetVMByID(context.Background(), session.VMID)
if err != nil {
return err
}
if !d.vmAlive(vm) {
return fmt.Errorf("vm %q is not running", vm.Name)
}
address := net.JoinHostPort(vm.Runtime.GuestIP, "22")
stdinStream, err := d.openGuestSessionAttachStream(address, sess.AttachInputCommand(session.ID))
if err != nil {
return fmt.Errorf("open guest session stdin stream: %w", err)
}
stdoutStream, err := d.openGuestSessionAttachStream(address, sess.AttachTailCommand(session.StdoutLogPath))
if err != nil {
_ = stdinStream.Close()
return fmt.Errorf("open guest session stdout stream: %w", err)
}
stderrStream, err := d.openGuestSessionAttachStream(address, sess.AttachTailCommand(session.StderrLogPath))
if err != nil {
_ = stdinStream.Close()
_ = stdoutStream.Close()
return fmt.Errorf("open guest session stderr stream: %w", err)
}
controller.streams = append(controller.streams, stdinStream, stdoutStream, stderrStream)
controller.stdin = stdinStream.Stdin()
go d.forwardGuestSessionOutput(session.ID, controller, sessionstream.ChannelStdout, stdoutStream.Stdout())
go d.forwardGuestSessionOutput(session.ID, controller, sessionstream.ChannelStderr, stderrStream.Stdout())
go d.watchGuestSessionAttach(session.ID, controller, session)
return nil
}
func (d *Daemon) openGuestSessionAttachStream(address, command string) (*guest.StreamSession, error) {
client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return nil, err
}
stream, err := client.StartCommand(context.Background(), command)
if err != nil {
_ = client.Close()
return nil, err
}
return stream, nil
}
func (d *Daemon) watchGuestSessionAttach(id string, controller *guestSessionController, session model.GuestSession) {
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
for range ticker.C {
vm, err := d.store.GetVMByID(context.Background(), session.VMID)
if err != nil {
controller.writeControl(sessionstream.ControlMessage{Type: "error", Error: err.Error()})
_ = controller.close()
return
}
refreshed, err := d.refreshGuestSession(context.Background(), vm, session)
if err == nil {
session = refreshed
}
if session.Status == model.GuestSessionStatusExited || session.Status == model.GuestSessionStatusFailed {
controller.writeControl(sessionstream.ControlMessage{Type: "exit", ExitCode: session.ExitCode})
_ = controller.close()
return
}
}
}

View file

@ -1,184 +0,0 @@
package daemon
import (
"errors"
"io"
"net"
"sync"
"banger/internal/guest"
"banger/internal/sessionstream"
)
type guestSessionController struct {
stream *guest.StreamSession
streams []*guest.StreamSession
stdin io.WriteCloser
attachMu sync.Mutex
attach net.Conn
writeMu sync.Mutex
closeOnce sync.Once
}
func (c *guestSessionController) setAttach(conn net.Conn) error {
c.attachMu.Lock()
defer c.attachMu.Unlock()
if c.attach != nil {
return errors.New("session already has an active attach")
}
c.attach = conn
return nil
}
func (c *guestSessionController) clearAttach(conn net.Conn) {
c.attachMu.Lock()
defer c.attachMu.Unlock()
if c.attach == conn {
c.attach = nil
}
}
func (c *guestSessionController) writeFrame(channel byte, payload []byte) {
c.attachMu.Lock()
conn := c.attach
c.attachMu.Unlock()
if conn == nil {
return
}
c.writeMu.Lock()
err := sessionstream.WriteFrame(conn, channel, payload)
c.writeMu.Unlock()
if err != nil {
_ = conn.Close()
c.clearAttach(conn)
}
}
func (c *guestSessionController) writeControl(message sessionstream.ControlMessage) {
c.attachMu.Lock()
conn := c.attach
c.attachMu.Unlock()
if conn == nil {
return
}
c.writeMu.Lock()
err := sessionstream.WriteControl(conn, message)
c.writeMu.Unlock()
if err != nil {
_ = conn.Close()
c.clearAttach(conn)
}
}
func (c *guestSessionController) close() error {
if c == nil {
return nil
}
var err error
c.closeOnce.Do(func() {
c.attachMu.Lock()
conn := c.attach
c.attach = nil
c.attachMu.Unlock()
if conn != nil {
err = errors.Join(err, conn.Close())
}
if c.stdin != nil {
err = errors.Join(err, c.stdin.Close())
}
if c.stream != nil {
err = errors.Join(err, c.stream.Close())
}
for _, stream := range c.streams {
if stream != nil {
err = errors.Join(err, stream.Close())
}
}
})
return err
}
// sessionRegistry owns the live guest-session controller map. Its lock is
// independent of Daemon.mu so guest-session lookups do not contend with
// unrelated daemon state.
type sessionRegistry struct {
mu sync.Mutex
byID map[string]*guestSessionController
closed bool
}
func newSessionRegistry() sessionRegistry {
return sessionRegistry{byID: make(map[string]*guestSessionController)}
}
func (r *sessionRegistry) set(id string, controller *guestSessionController) {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return
}
r.byID[id] = controller
}
func (r *sessionRegistry) claim(id string, controller *guestSessionController) bool {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return false
}
if r.byID[id] != nil {
return false
}
r.byID[id] = controller
return true
}
func (r *sessionRegistry) get(id string) *guestSessionController {
r.mu.Lock()
defer r.mu.Unlock()
return r.byID[id]
}
func (r *sessionRegistry) clear(id string) *guestSessionController {
r.mu.Lock()
defer r.mu.Unlock()
controller := r.byID[id]
delete(r.byID, id)
return controller
}
func (r *sessionRegistry) closeAll() error {
r.mu.Lock()
controllers := make([]*guestSessionController, 0, len(r.byID))
for _, controller := range r.byID {
controllers = append(controllers, controller)
}
r.byID = nil
r.closed = true
r.mu.Unlock()
var err error
for _, controller := range controllers {
err = errors.Join(err, controller.close())
}
return err
}
func (d *Daemon) setGuestSessionController(id string, controller *guestSessionController) {
d.sessions.set(id, controller)
}
func (d *Daemon) claimGuestSessionController(id string, controller *guestSessionController) bool {
return d.sessions.claim(id, controller)
}
func (d *Daemon) getGuestSessionController(id string) *guestSessionController {
return d.sessions.get(id)
}
func (d *Daemon) clearGuestSessionController(id string) *guestSessionController {
return d.sessions.clear(id)
}
func (d *Daemon) closeGuestSessionControllers() error {
return d.sessions.closeAll()
}

View file

@ -1,213 +0,0 @@
package daemon
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"strings"
"time"
"banger/internal/api"
sess "banger/internal/daemon/session"
"banger/internal/guest"
"banger/internal/model"
)
func (d *Daemon) StartGuestSession(ctx context.Context, params api.GuestSessionStartParams) (model.GuestSession, error) {
stdinMode := model.GuestSessionStdinMode(strings.TrimSpace(params.StdinMode))
if stdinMode == "" {
stdinMode = model.GuestSessionStdinClosed
}
if stdinMode != model.GuestSessionStdinClosed && stdinMode != model.GuestSessionStdinPipe {
return model.GuestSession{}, fmt.Errorf("unsupported stdin mode %q", params.StdinMode)
}
if strings.TrimSpace(params.Command) == "" {
return model.GuestSession{}, errors.New("session command is required")
}
var created model.GuestSession
_, err := d.withVMLockByRef(ctx, params.VMIDOrName, func(vm model.VMRecord) (model.VMRecord, error) {
if !d.vmAlive(vm) {
return model.VMRecord{}, fmt.Errorf("vm %q is not running", vm.Name)
}
session, err := d.startGuestSessionLocked(ctx, vm, params, stdinMode)
if err != nil {
return model.VMRecord{}, err
}
created = session
return vm, nil
})
return created, err
}
func (d *Daemon) startGuestSessionLocked(ctx context.Context, vm model.VMRecord, params api.GuestSessionStartParams, stdinMode model.GuestSessionStdinMode) (model.GuestSession, error) {
id, err := model.NewID()
if err != nil {
return model.GuestSession{}, err
}
now := model.Now()
session := model.GuestSession{
ID: id,
VMID: vm.ID,
Name: sess.DefaultName(id, params.Command, params.Name),
Backend: sess.BackendSSH,
Command: params.Command,
Args: append([]string(nil), params.Args...),
CWD: strings.TrimSpace(params.CWD),
Env: sess.CloneStringMap(params.Env),
StdinMode: stdinMode,
Status: model.GuestSessionStatusStarting,
GuestStateDir: sess.StateDir(id),
StdoutLogPath: sess.StdoutLogPath(id),
StderrLogPath: sess.StderrLogPath(id),
Tags: sess.CloneStringMap(params.Tags),
Attachable: stdinMode == model.GuestSessionStdinPipe,
Reattachable: stdinMode == model.GuestSessionStdinPipe,
CreatedAt: now,
UpdatedAt: now,
}
if session.Attachable {
session.AttachBackend = sess.AttachBackendSSHBridge
session.AttachMode = sess.AttachModeExclusive
} else {
session.AttachBackend = sess.AttachBackendNone
}
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
return model.GuestSession{}, err
}
fail := func(stage, message, rawLog string) (model.GuestSession, error) {
session = sess.FailLaunch(session, stage, message, rawLog)
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
return model.GuestSession{}, err
}
return session, nil
}
address := net.JoinHostPort(vm.Runtime.GuestIP, "22")
if err := d.waitForGuestSSH(ctx, address, 250*time.Millisecond); err != nil {
return fail("ssh_unavailable", fmt.Sprintf("guest ssh unavailable: %v", err), "")
}
client, err := d.dialGuest(ctx, address)
if err != nil {
return fail("dial_guest", fmt.Sprintf("dial guest ssh: %v", err), "")
}
defer client.Close()
var preflightLog bytes.Buffer
if err := client.RunScript(ctx, sess.CWDPreflightScript(session.CWD), &preflightLog); err != nil {
return fail("preflight_cwd", fmt.Sprintf("guest working directory is unavailable: %s", sess.DefaultCWD(session.CWD)), preflightLog.String())
}
preflightLog.Reset()
requiredCommands := sess.NormalizeRequiredCommands(params.Command, params.RequiredCommands)
if err := client.RunScript(ctx, sess.CommandPreflightScript(requiredCommands), &preflightLog); err != nil {
return fail("preflight_command", fmt.Sprintf("required guest command is unavailable: %s", strings.TrimSpace(preflightLog.String())), preflightLog.String())
}
var uploadLog bytes.Buffer
if err := client.UploadFile(ctx, sess.ScriptPath(id), 0o755, []byte(sess.Script(session)), &uploadLog); err != nil {
return fail("upload_script", "upload guest session script failed", uploadLog.String())
}
var launchLog bytes.Buffer
launchScript := fmt.Sprintf("set -euo pipefail\nnohup bash %s >/dev/null 2>&1 </dev/null &\ndisown || true\n", sess.ShellQuote(sess.ScriptPath(id)))
if err := client.RunScript(ctx, launchScript, &launchLog); err != nil {
return fail("launch", "launch guest session failed", launchLog.String())
}
readyCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
updated, err := d.waitForGuestSessionReadyHook(readyCtx, vm, session)
if err != nil {
return fail("ready_wait", "guest session did not report ready state", err.Error())
}
session = updated
if session.Status == model.GuestSessionStatusStarting {
session.Status = model.GuestSessionStatusRunning
session.StartedAt = model.Now()
session.UpdatedAt = model.Now()
}
session.LaunchStage = ""
session.LaunchMessage = ""
session.LaunchRawLog = ""
session.LastError = ""
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
return model.GuestSession{}, err
}
return session, nil
}
func (d *Daemon) GetGuestSession(ctx context.Context, params api.GuestSessionRefParams) (model.GuestSession, error) {
vm, err := d.FindVM(ctx, params.VMIDOrName)
if err != nil {
return model.GuestSession{}, err
}
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
if err != nil {
return model.GuestSession{}, err
}
return d.refreshGuestSession(ctx, vm, session)
}
func (d *Daemon) ListGuestSessions(ctx context.Context, params api.VMRefParams) ([]model.GuestSession, error) {
vm, err := d.FindVM(ctx, params.IDOrName)
if err != nil {
return nil, err
}
sessions, err := d.store.ListGuestSessionsByVM(ctx, vm.ID)
if err != nil {
return nil, err
}
for index := range sessions {
refreshed, refreshErr := d.refreshGuestSession(ctx, vm, sessions[index])
if refreshErr == nil {
sessions[index] = refreshed
}
}
return sessions, nil
}
func (d *Daemon) StopGuestSession(ctx context.Context, params api.GuestSessionRefParams) (model.GuestSession, error) {
return d.signalGuestSession(ctx, params, "TERM")
}
func (d *Daemon) KillGuestSession(ctx context.Context, params api.GuestSessionRefParams) (model.GuestSession, error) {
return d.signalGuestSession(ctx, params, "KILL")
}
func (d *Daemon) signalGuestSession(ctx context.Context, params api.GuestSessionRefParams, signal string) (model.GuestSession, error) {
vm, err := d.FindVM(ctx, params.VMIDOrName)
if err != nil {
return model.GuestSession{}, err
}
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
if err != nil {
return model.GuestSession{}, err
}
session, _ = d.refreshGuestSession(ctx, vm, session)
if session.Status == model.GuestSessionStatusExited || session.Status == model.GuestSessionStatusFailed {
return session, nil
}
if !d.vmAlive(vm) {
session.Status = model.GuestSessionStatusFailed
session.LastError = "vm is not running"
now := model.Now()
session.UpdatedAt = now
session.EndedAt = now
session.Attachable = false
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
return model.GuestSession{}, err
}
return session, nil
}
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return model.GuestSession{}, err
}
defer client.Close()
var log bytes.Buffer
if err := client.RunScript(ctx, sess.SignalScript(session.ID, signal), &log); err != nil {
return model.GuestSession{}, sess.FormatStepError("signal guest session", err, log.String())
}
session.Status = model.GuestSessionStatusStopping
session.UpdatedAt = model.Now()
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
return model.GuestSession{}, err
}
return session, nil
}

View file

@ -1,120 +0,0 @@
package daemon
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"path/filepath"
"strings"
"banger/internal/api"
sess "banger/internal/daemon/session"
"banger/internal/guest"
"banger/internal/model"
"banger/internal/system"
)
func (d *Daemon) GuestSessionLogs(ctx context.Context, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) {
vm, err := d.FindVM(ctx, params.VMIDOrName)
if err != nil {
return api.GuestSessionLogsResult{}, err
}
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
if err != nil {
return api.GuestSessionLogsResult{}, err
}
streamName := strings.TrimSpace(params.Stream)
if streamName == "" {
streamName = "stdout"
}
tailLines := params.TailLines
if tailLines <= 0 {
tailLines = sess.LogTailLineDefault
}
path := session.StdoutLogPath
if streamName == "stderr" {
path = session.StderrLogPath
}
content, err := d.readGuestSessionLog(ctx, vm, session, streamName, tailLines)
if err != nil {
return api.GuestSessionLogsResult{}, err
}
return api.GuestSessionLogsResult{Session: session, Stream: streamName, Path: path, Content: content}, nil
}
func (d *Daemon) SendToGuestSession(ctx context.Context, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
vm, err := d.FindVM(ctx, params.VMIDOrName)
if err != nil {
return api.GuestSessionSendResult{}, err
}
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
if err != nil {
return api.GuestSessionSendResult{}, err
}
if session.StdinMode != model.GuestSessionStdinPipe {
return api.GuestSessionSendResult{}, errors.New("session does not have a stdin pipe")
}
if session.Status != model.GuestSessionStatusRunning {
return api.GuestSessionSendResult{}, fmt.Errorf("session is not running (status=%s)", session.Status)
}
if !d.vmAlive(vm) {
return api.GuestSessionSendResult{}, fmt.Errorf("vm %q is not running", vm.Name)
}
if len(params.Payload) == 0 {
return api.GuestSessionSendResult{Session: session}, nil
}
client, err := d.dialGuest(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"))
if err != nil {
return api.GuestSessionSendResult{}, fmt.Errorf("dial guest: %w", err)
}
defer client.Close()
tmpPath := fmt.Sprintf("/tmp/banger-send-%s.bin", session.ID[:8])
var uploadLog bytes.Buffer
if err := client.UploadFile(ctx, tmpPath, 0o600, params.Payload, &uploadLog); err != nil {
return api.GuestSessionSendResult{}, fmt.Errorf("upload payload: %w", err)
}
sendScript := fmt.Sprintf(
"set -euo pipefail\ncat %s >> %s\nrm -f %s\n",
sess.ShellQuote(tmpPath),
sess.ShellQuote(sess.StdinPipePath(session.ID)),
sess.ShellQuote(tmpPath),
)
var sendLog bytes.Buffer
if err := client.RunScript(ctx, sendScript, &sendLog); err != nil {
return api.GuestSessionSendResult{}, fmt.Errorf("send to session: %w: %s", err, strings.TrimSpace(sendLog.String()))
}
return api.GuestSessionSendResult{Session: session, BytesWritten: len(params.Payload)}, nil
}
func (d *Daemon) readGuestSessionLog(ctx context.Context, vm model.VMRecord, session model.GuestSession, stream string, tailLines int) (string, error) {
if d.vmAlive(vm) {
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return "", err
}
defer client.Close()
path := session.StdoutLogPath
if stream == "stderr" {
path = session.StderrLogPath
}
var output bytes.Buffer
script := fmt.Sprintf("set -euo pipefail\nif [ -f %s ]; then tail -n %d %s; fi\n", sess.ShellQuote(path), tailLines, sess.ShellQuote(path))
if err := client.RunScript(ctx, script, &output); err != nil {
return "", sess.FormatStepError("read guest session log", err, output.String())
}
return output.String(), nil
}
runner := d.runner
if runner == nil {
runner = system.NewRunner()
}
workMount, cleanup, err := system.MountTempDir(ctx, runner, vm.Runtime.WorkDiskPath, false)
if err != nil {
return "", err
}
defer cleanup()
logPath := filepath.Join(workMount, sess.RelativeStateDir(session.ID), stream+".log")
return sess.TailFileContent(logPath, tailLines)
}

View file

@ -10,7 +10,6 @@ import (
"time"
"banger/internal/api"
sess "banger/internal/daemon/session"
ws "banger/internal/daemon/workspace"
"banger/internal/model"
)
@ -114,9 +113,9 @@ func exportScript(guestPath, diffRef, diffFlag string) string {
"git read-tree %s --index-output=\"$tmp_idx\"\n"+
"GIT_INDEX_FILE=\"$tmp_idx\" git add -A\n"+
"GIT_INDEX_FILE=\"$tmp_idx\" git diff --cached %s %s\n",
sess.ShellQuote(guestPath),
sess.ShellQuote(diffRef),
sess.ShellQuote(diffRef),
ws.ShellQuote(guestPath),
ws.ShellQuote(diffRef),
ws.ShellQuote(diffRef),
diffFlag,
)
}
@ -189,9 +188,9 @@ func (d *Daemon) prepareVMWorkspaceGuestIO(ctx context.Context, vm model.VMRecor
}
if readOnly {
var chmodLog bytes.Buffer
chmodScript := fmt.Sprintf("set -euo pipefail\nchmod -R a-w %s\n", sess.ShellQuote(guestPath))
chmodScript := fmt.Sprintf("set -euo pipefail\nchmod -R a-w %s\n", ws.ShellQuote(guestPath))
if err := client.RunScript(ctx, chmodScript, &chmodLog); err != nil {
return model.WorkspacePrepareResult{}, sess.FormatStepError("set workspace readonly", err, chmodLog.String())
return model.WorkspacePrepareResult{}, ws.FormatStepError("set workspace readonly", err, chmodLog.String())
}
}
return model.WorkspacePrepareResult{

View file

@ -0,0 +1,20 @@
package workspace
import (
"fmt"
"strings"
)
// ShellQuote returns value single-quoted for bash, escaping embedded quotes.
func ShellQuote(value string) string {
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
}
// FormatStepError wraps err with an action label and trimmed on-guest log.
func FormatStepError(action string, err error, log string) error {
log = strings.TrimSpace(log)
if log == "" {
return fmt.Errorf("%s: %w", action, err)
}
return fmt.Errorf("%s: %w: %s", action, err, log)
}

View file

@ -18,7 +18,6 @@ import (
"sort"
"strings"
sess "banger/internal/daemon/session"
"banger/internal/model"
"banger/internal/system"
)
@ -146,13 +145,13 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g
switch mode {
case model.WorkspacePrepareModeFullCopy:
var copyLog bytes.Buffer
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath), sess.ShellQuote(guestPath), sess.ShellQuote(guestPath))
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath), ShellQuote(guestPath), ShellQuote(guestPath))
if err := client.StreamTar(ctx, spec.RepoRoot, command, &copyLog); err != nil {
return sess.FormatStepError("copy full workspace", err, copyLog.String())
return FormatStepError("copy full workspace", err, copyLog.String())
}
var finalizeLog bytes.Buffer
if err := client.RunScript(ctx, FinalizeScript(spec, guestPath, mode), &finalizeLog); err != nil {
return sess.FormatStepError("finalize full workspace", err, finalizeLog.String())
return FormatStepError("finalize full workspace", err, finalizeLog.String())
}
return nil
case model.WorkspacePrepareModeMetadataOnly, model.WorkspacePrepareModeShallowOverlay:
@ -162,21 +161,21 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g
}
defer cleanup()
var copyLog bytes.Buffer
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath), sess.ShellQuote(guestPath), sess.ShellQuote(guestPath))
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath), ShellQuote(guestPath), ShellQuote(guestPath))
if err := client.StreamTar(ctx, repoCopyDir, command, &copyLog); err != nil {
return sess.FormatStepError("copy guest git metadata", err, copyLog.String())
return FormatStepError("copy guest git metadata", err, copyLog.String())
}
var scriptLog bytes.Buffer
if err := client.RunScript(ctx, FinalizeScript(spec, guestPath, mode), &scriptLog); err != nil {
return sess.FormatStepError("prepare guest checkout", err, scriptLog.String())
return FormatStepError("prepare guest checkout", err, scriptLog.String())
}
if mode == model.WorkspacePrepareModeMetadataOnly {
return nil
}
var overlayLog bytes.Buffer
command = fmt.Sprintf("tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath))
command = fmt.Sprintf("tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath))
if err := client.StreamTarEntries(ctx, spec.RepoRoot, spec.OverlayPaths, command, &overlayLog); err != nil {
return sess.FormatStepError("overlay workspace working tree", err, overlayLog.String())
return FormatStepError("overlay workspace working tree", err, overlayLog.String())
}
return nil
default:
@ -190,22 +189,22 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g
func FinalizeScript(spec RepoSpec, guestPath string, mode model.WorkspacePrepareMode) string {
var script strings.Builder
script.WriteString("set -euo pipefail\n")
fmt.Fprintf(&script, "DIR=%s\n", sess.ShellQuote(guestPath))
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(guestPath))
script.WriteString("git config --global --add safe.directory \"$DIR\"\n")
if mode != model.WorkspacePrepareModeFullCopy {
script.WriteString("find \"$DIR\" -mindepth 1 -maxdepth 1 ! -name .git -exec rm -rf {} +\n")
}
switch {
case strings.TrimSpace(spec.BranchName) != "":
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", sess.ShellQuote(spec.BranchName), sess.ShellQuote(spec.BaseCommit))
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", ShellQuote(spec.BranchName), ShellQuote(spec.BaseCommit))
case strings.TrimSpace(spec.CurrentBranch) != "":
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", sess.ShellQuote(spec.CurrentBranch), sess.ShellQuote(spec.HeadCommit))
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", ShellQuote(spec.CurrentBranch), ShellQuote(spec.HeadCommit))
default:
fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", sess.ShellQuote(spec.HeadCommit))
fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", ShellQuote(spec.HeadCommit))
}
if strings.TrimSpace(spec.GitUserName) != "" && strings.TrimSpace(spec.GitUserEmail) != "" {
fmt.Fprintf(&script, "git -C \"$DIR\" config user.name %s\n", sess.ShellQuote(spec.GitUserName))
fmt.Fprintf(&script, "git -C \"$DIR\" config user.email %s\n", sess.ShellQuote(spec.GitUserEmail))
fmt.Fprintf(&script, "git -C \"$DIR\" config user.name %s\n", ShellQuote(spec.GitUserName))
fmt.Fprintf(&script, "git -C \"$DIR\" config user.email %s\n", ShellQuote(spec.GitUserEmail))
}
return script.String()
}

View file

@ -34,23 +34,6 @@ const (
VMStateError VMState = "error"
)
type GuestSessionStatus string
const (
GuestSessionStatusStarting GuestSessionStatus = "starting"
GuestSessionStatusRunning GuestSessionStatus = "running"
GuestSessionStatusExited GuestSessionStatus = "exited"
GuestSessionStatusFailed GuestSessionStatus = "failed"
GuestSessionStatusStopping GuestSessionStatus = "stopping"
)
type GuestSessionStdinMode string
const (
GuestSessionStdinClosed GuestSessionStdinMode = "closed"
GuestSessionStdinPipe GuestSessionStdinMode = "pipe"
)
type DaemonConfig struct {
LogLevel string
FirecrackerBin string
@ -176,37 +159,6 @@ type VMSetRequest struct {
NATEnabled *bool
}
type GuestSession struct {
ID string `json:"id"`
VMID string `json:"vm_id"`
Name string `json:"name"`
Backend string `json:"backend"`
AttachBackend string `json:"attach_backend,omitempty"`
AttachMode string `json:"attach_mode,omitempty"`
Command string `json:"command"`
Args []string `json:"args,omitempty"`
CWD string `json:"cwd,omitempty"`
Env map[string]string `json:"env,omitempty"`
StdinMode GuestSessionStdinMode `json:"stdin_mode,omitempty"`
Status GuestSessionStatus `json:"status"`
ExitCode *int `json:"exit_code,omitempty"`
GuestPID int `json:"guest_pid,omitempty"`
GuestStateDir string `json:"guest_state_dir,omitempty"`
StdoutLogPath string `json:"stdout_log_path,omitempty"`
StderrLogPath string `json:"stderr_log_path,omitempty"`
Tags map[string]string `json:"tags,omitempty"`
LastError string `json:"last_error,omitempty"`
Attachable bool `json:"attachable"`
Reattachable bool `json:"reattachable"`
LaunchStage string `json:"launch_stage,omitempty"`
LaunchMessage string `json:"launch_message,omitempty"`
LaunchRawLog string `json:"launch_raw_log,omitempty"`
CreatedAt time.Time `json:"created_at"`
StartedAt time.Time `json:"started_at,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
EndedAt time.Time `json:"ended_at,omitempty"`
}
type WorkspacePrepareMode string
const (

View file

@ -1,76 +0,0 @@
package sessionstream
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
)
const (
ChannelStdin byte = 0x01
ChannelStdout byte = 0x02
ChannelStderr byte = 0x03
ChannelControl byte = 0x04
FormatV1 = "stdio_mux_v1"
)
type ControlMessage struct {
Type string `json:"type"`
ExitCode *int `json:"exit_code,omitempty"`
Error string `json:"error,omitempty"`
}
func WriteFrame(w io.Writer, channel byte, payload []byte) error {
var header [5]byte
header[0] = channel
binary.BigEndian.PutUint32(header[1:], uint32(len(payload)))
if _, err := w.Write(header[:]); err != nil {
return err
}
if len(payload) == 0 {
return nil
}
_, err := w.Write(payload)
return err
}
func ReadFrame(r io.Reader) (byte, []byte, error) {
var header [5]byte
if _, err := io.ReadFull(r, header[:]); err != nil {
return 0, nil, err
}
length := binary.BigEndian.Uint32(header[1:])
payload := make([]byte, length)
if _, err := io.ReadFull(r, payload); err != nil {
return 0, nil, err
}
return header[0], payload, nil
}
func WriteControl(w io.Writer, message ControlMessage) error {
payload, err := json.Marshal(message)
if err != nil {
return err
}
return WriteFrame(w, ChannelControl, payload)
}
func ReadControl(payload []byte) (ControlMessage, error) {
var message ControlMessage
if err := json.Unmarshal(payload, &message); err != nil {
return ControlMessage{}, err
}
return message, nil
}
func ReadNextControl(r io.Reader) (ControlMessage, error) {
channel, payload, err := ReadFrame(r)
if err != nil {
return ControlMessage{}, err
}
if channel != ChannelControl {
return ControlMessage{}, fmt.Errorf("unexpected channel %d", channel)
}
return ReadControl(payload)
}

View file

@ -1,117 +0,0 @@
package sessionstream
import (
"bytes"
"errors"
"io"
"testing"
)
func TestWriteReadFrameRoundtrip(t *testing.T) {
cases := []struct {
name string
channel byte
payload []byte
}{
{"stdout_bytes", ChannelStdout, []byte("hello world")},
{"stderr_bytes", ChannelStderr, []byte{0x00, 0xff, 0x7f}},
{"empty_payload", ChannelStdin, nil},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
var buf bytes.Buffer
if err := WriteFrame(&buf, tc.channel, tc.payload); err != nil {
t.Fatalf("WriteFrame: %v", err)
}
ch, got, err := ReadFrame(&buf)
if err != nil {
t.Fatalf("ReadFrame: %v", err)
}
if ch != tc.channel {
t.Fatalf("channel = %d, want %d", ch, tc.channel)
}
if !bytes.Equal(got, tc.payload) && !(len(got) == 0 && len(tc.payload) == 0) {
t.Fatalf("payload = %q, want %q", got, tc.payload)
}
})
}
}
type shortWriter struct {
failAfter int
written int
}
func (s *shortWriter) Write(p []byte) (int, error) {
s.written += len(p)
if s.written > s.failAfter {
return 0, io.ErrShortWrite
}
return len(p), nil
}
func TestWriteFrameWriterError(t *testing.T) {
w := &shortWriter{failAfter: 2}
err := WriteFrame(w, ChannelStdout, []byte("payload"))
if err == nil {
t.Fatal("expected error from short writer")
}
}
func TestReadFrameTruncated(t *testing.T) {
_, _, err := ReadFrame(bytes.NewReader([]byte{0x02, 0x00}))
if !errors.Is(err, io.ErrUnexpectedEOF) && err == nil {
t.Fatalf("expected EOF-ish error, got %v", err)
}
// Header OK, but payload truncated.
var buf bytes.Buffer
buf.Write([]byte{ChannelStdout, 0x00, 0x00, 0x00, 0x05})
buf.Write([]byte("ab"))
if _, _, err := ReadFrame(&buf); err == nil {
t.Fatal("expected truncated payload error")
}
}
func TestControlRoundtrip(t *testing.T) {
code := 42
msg := ControlMessage{Type: "exit", ExitCode: &code}
var buf bytes.Buffer
if err := WriteControl(&buf, msg); err != nil {
t.Fatalf("WriteControl: %v", err)
}
got, err := ReadNextControl(&buf)
if err != nil {
t.Fatalf("ReadNextControl: %v", err)
}
if got.Type != "exit" {
t.Fatalf("type = %q, want exit", got.Type)
}
if got.ExitCode == nil || *got.ExitCode != 42 {
t.Fatalf("exit_code = %v, want 42", got.ExitCode)
}
}
func TestReadControlBadJSON(t *testing.T) {
if _, err := ReadControl([]byte("{not json")); err == nil {
t.Fatal("expected JSON error")
}
}
func TestReadNextControlWrongChannel(t *testing.T) {
var buf bytes.Buffer
if err := WriteFrame(&buf, ChannelStdout, []byte("not a control frame")); err != nil {
t.Fatalf("WriteFrame: %v", err)
}
if _, err := ReadNextControl(&buf); err == nil {
t.Fatal("expected error for non-control channel")
}
}
func TestFormatConstant(t *testing.T) {
if FormatV1 != "stdio_mux_v1" {
t.Fatalf("FormatV1 = %q, want stdio_mux_v1", FormatV1)
}
}

View file

@ -1,214 +0,0 @@
package store
import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"testing"
"time"
"banger/internal/model"
)
func sampleGuestSession(id, vmID, name string) model.GuestSession {
now := fixedTime()
exit := 7
return model.GuestSession{
ID: id,
VMID: vmID,
Name: name,
Backend: "ssh",
AttachBackend: "vsock",
AttachMode: "rpc",
Command: "pi",
Args: []string{"--mode", "rpc"},
CWD: "/root/repo",
Env: map[string]string{"FOO": "bar"},
StdinMode: model.GuestSessionStdinMode("pipe"),
Status: model.GuestSessionStatus("exited"),
ExitCode: &exit,
GuestPID: 1234,
GuestStateDir: "/tmp/guest-" + id,
StdoutLogPath: "/tmp/" + id + ".stdout",
StderrLogPath: "/tmp/" + id + ".stderr",
Tags: map[string]string{"role": "planner"},
LastError: "",
Attachable: true,
Reattachable: true,
LaunchStage: "started",
LaunchMessage: "ok",
LaunchRawLog: "boot log...",
CreatedAt: now,
StartedAt: now,
UpdatedAt: now,
EndedAt: now.Add(time.Minute),
}
}
// openTestStoreWithVMs opens a fresh store seeded with the given VM IDs so
// guest_sessions FK constraints are satisfied. Each VM gets a minimal
// image it references.
func openTestStoreWithVMs(t *testing.T, vmIDs ...string) *Store {
t.Helper()
ctx := context.Background()
store := openTestStore(t)
image := sampleImage("stub-image")
if err := store.UpsertImage(ctx, image); err != nil {
t.Fatalf("UpsertImage: %v", err)
}
for i, id := range vmIDs {
vm := sampleVM(id, image.ID, fmt.Sprintf("172.16.0.%d", i+2))
vm.ID = id
if err := store.UpsertVM(ctx, vm); err != nil {
t.Fatalf("UpsertVM(%s): %v", id, err)
}
}
return store
}
func TestGuestSessionUpsertAndGetByID(t *testing.T) {
t.Parallel()
ctx := context.Background()
store := openTestStoreWithVMs(t, "vm-1")
session := sampleGuestSession("sess-1", "vm-1", "planner")
if err := store.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
got, err := store.GetGuestSessionByID(ctx, "sess-1")
if err != nil {
t.Fatalf("GetGuestSessionByID: %v", err)
}
if !reflect.DeepEqual(got, session) {
t.Fatalf("round-trip mismatch:\n got %+v\n want %+v", got, session)
}
}
func TestGuestSessionUpsertIsIdempotent(t *testing.T) {
t.Parallel()
ctx := context.Background()
store := openTestStoreWithVMs(t, "vm-1")
session := sampleGuestSession("sess-1", "vm-1", "planner")
if err := store.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession (first): %v", err)
}
// Mutate + re-upsert → existing row updated.
session.Command = "pi --other"
session.Status = model.GuestSessionStatus("running")
session.ExitCode = nil
if err := store.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession (second): %v", err)
}
got, err := store.GetGuestSessionByID(ctx, "sess-1")
if err != nil {
t.Fatalf("GetGuestSessionByID: %v", err)
}
if got.Command != "pi --other" {
t.Errorf("command = %q, want 'pi --other'", got.Command)
}
if got.Status != model.GuestSessionStatus("running") {
t.Errorf("status = %q, want running", got.Status)
}
if got.ExitCode != nil {
t.Errorf("ExitCode = %v, want nil after clearing", got.ExitCode)
}
}
func TestGetGuestSessionByIDOrName(t *testing.T) {
t.Parallel()
ctx := context.Background()
store := openTestStoreWithVMs(t, "vm-1")
session := sampleGuestSession("sess-1", "vm-1", "planner")
if err := store.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
byID, err := store.GetGuestSession(ctx, "vm-1", "sess-1")
if err != nil {
t.Fatalf("GetGuestSession by ID: %v", err)
}
if byID.ID != "sess-1" {
t.Errorf("by-ID: got %q, want sess-1", byID.ID)
}
byName, err := store.GetGuestSession(ctx, "vm-1", "planner")
if err != nil {
t.Fatalf("GetGuestSession by name: %v", err)
}
if byName.Name != "planner" {
t.Errorf("by-name: got %q, want planner", byName.Name)
}
// Scoped to the VM.
if _, err := store.GetGuestSession(ctx, "vm-unknown", "sess-1"); !errors.Is(err, sql.ErrNoRows) {
t.Errorf("wrong-vm lookup = %v, want sql.ErrNoRows", err)
}
}
func TestListGuestSessionsByVMOrdersByCreatedAt(t *testing.T) {
t.Parallel()
ctx := context.Background()
store := openTestStoreWithVMs(t, "vm-1", "vm-2")
base := fixedTime()
first := sampleGuestSession("sess-early", "vm-1", "first")
first.CreatedAt = base
second := sampleGuestSession("sess-late", "vm-1", "second")
second.CreatedAt = base.Add(time.Hour)
other := sampleGuestSession("sess-other", "vm-2", "other")
for _, s := range []model.GuestSession{second, first, other} {
if err := store.UpsertGuestSession(ctx, s); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
}
sessions, err := store.ListGuestSessionsByVM(ctx, "vm-1")
if err != nil {
t.Fatalf("ListGuestSessionsByVM: %v", err)
}
if len(sessions) != 2 {
t.Fatalf("len = %d, want 2 (vm-1 only)", len(sessions))
}
if sessions[0].ID != "sess-early" || sessions[1].ID != "sess-late" {
t.Fatalf("order: got %q, %q; want sess-early, sess-late", sessions[0].ID, sessions[1].ID)
}
empty, err := store.ListGuestSessionsByVM(ctx, "vm-unknown")
if err != nil {
t.Fatalf("ListGuestSessionsByVM (unknown vm): %v", err)
}
if len(empty) != 0 {
t.Fatalf("unknown vm sessions = %+v, want empty", empty)
}
}
func TestDeleteGuestSession(t *testing.T) {
t.Parallel()
ctx := context.Background()
store := openTestStoreWithVMs(t, "vm-1")
session := sampleGuestSession("sess-1", "vm-1", "planner")
if err := store.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
if err := store.DeleteGuestSession(ctx, "sess-1"); err != nil {
t.Fatalf("DeleteGuestSession: %v", err)
}
if _, err := store.GetGuestSessionByID(ctx, "sess-1"); !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("after delete err = %v, want sql.ErrNoRows", err)
}
// Deleting something that doesn't exist is a no-op (matches SQL DELETE semantics).
if err := store.DeleteGuestSession(ctx, "sess-nope"); err != nil {
t.Fatalf("DeleteGuestSession on missing row: %v", err)
}
}

View file

@ -99,32 +99,6 @@ func (s *Store) migrate() error {
stats_json TEXT NOT NULL DEFAULT '{}',
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE RESTRICT
);`,
`CREATE TABLE IF NOT EXISTS guest_sessions (
id TEXT PRIMARY KEY,
vm_id TEXT NOT NULL,
name TEXT NOT NULL,
backend TEXT NOT NULL,
command TEXT NOT NULL,
args_json TEXT NOT NULL DEFAULT '[]',
cwd TEXT,
env_json TEXT NOT NULL DEFAULT '{}',
stdin_mode TEXT NOT NULL,
status TEXT NOT NULL,
exit_code INTEGER,
guest_pid INTEGER NOT NULL DEFAULT 0,
guest_state_dir TEXT,
stdout_log_path TEXT,
stderr_log_path TEXT,
tags_json TEXT NOT NULL DEFAULT '{}',
last_error TEXT,
attachable INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL,
started_at TEXT,
updated_at TEXT NOT NULL,
ended_at TEXT,
UNIQUE(vm_id, name),
FOREIGN KEY(vm_id) REFERENCES vms(id) ON DELETE CASCADE
);`,
}
for _, stmt := range stmts {
if _, err := s.db.Exec(stmt); err != nil {
@ -137,18 +111,6 @@ func (s *Store) migrate() error {
if err := ensureColumnExists(s.db, "images", "seeded_ssh_public_key_fingerprint", "TEXT"); err != nil {
return err
}
for _, spec := range []struct{ table, column, typ string }{
{"guest_sessions", "attach_backend", "TEXT"},
{"guest_sessions", "attach_mode", "TEXT"},
{"guest_sessions", "reattachable", "INTEGER NOT NULL DEFAULT 0"},
{"guest_sessions", "launch_stage", "TEXT"},
{"guest_sessions", "launch_message", "TEXT"},
{"guest_sessions", "launch_raw_log", "TEXT"},
} {
if err := ensureColumnExists(s.db, spec.table, spec.column, spec.typ); err != nil {
return err
}
}
return nil
}
@ -336,122 +298,6 @@ func (s *Store) FindVMsUsingImage(ctx context.Context, imageID string) ([]model.
return vms, rows.Err()
}
func (s *Store) UpsertGuestSession(ctx context.Context, session model.GuestSession) error {
s.writeMu.Lock()
defer s.writeMu.Unlock()
argsJSON, err := json.Marshal(session.Args)
if err != nil {
return err
}
envJSON, err := json.Marshal(session.Env)
if err != nil {
return err
}
tagsJSON, err := json.Marshal(session.Tags)
if err != nil {
return err
}
const query = `
INSERT INTO guest_sessions (
id, vm_id, name, backend, attach_backend, attach_mode, command, args_json, cwd, env_json, stdin_mode, status,
exit_code, guest_pid, guest_state_dir, stdout_log_path, stderr_log_path, tags_json,
last_error, attachable, reattachable, launch_stage, launch_message, launch_raw_log,
created_at, started_at, updated_at, ended_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
vm_id=excluded.vm_id,
name=excluded.name,
backend=excluded.backend,
attach_backend=excluded.attach_backend,
attach_mode=excluded.attach_mode,
command=excluded.command,
args_json=excluded.args_json,
cwd=excluded.cwd,
env_json=excluded.env_json,
stdin_mode=excluded.stdin_mode,
status=excluded.status,
exit_code=excluded.exit_code,
guest_pid=excluded.guest_pid,
guest_state_dir=excluded.guest_state_dir,
stdout_log_path=excluded.stdout_log_path,
stderr_log_path=excluded.stderr_log_path,
tags_json=excluded.tags_json,
last_error=excluded.last_error,
attachable=excluded.attachable,
reattachable=excluded.reattachable,
launch_stage=excluded.launch_stage,
launch_message=excluded.launch_message,
launch_raw_log=excluded.launch_raw_log,
started_at=excluded.started_at,
updated_at=excluded.updated_at,
ended_at=excluded.ended_at`
_, err = s.db.ExecContext(ctx, query,
session.ID,
session.VMID,
session.Name,
session.Backend,
session.AttachBackend,
session.AttachMode,
session.Command,
string(argsJSON),
session.CWD,
string(envJSON),
string(session.StdinMode),
string(session.Status),
nullableInt(session.ExitCode),
session.GuestPID,
session.GuestStateDir,
session.StdoutLogPath,
session.StderrLogPath,
string(tagsJSON),
session.LastError,
boolToInt(session.Attachable),
boolToInt(session.Reattachable),
session.LaunchStage,
session.LaunchMessage,
session.LaunchRawLog,
session.CreatedAt.Format(time.RFC3339),
nullableTimeString(session.StartedAt),
session.UpdatedAt.Format(time.RFC3339),
nullableTimeString(session.EndedAt),
)
return err
}
func (s *Store) GetGuestSessionByID(ctx context.Context, id string) (model.GuestSession, error) {
row := s.db.QueryRowContext(ctx, guestSessionSelectSQL+" WHERE id = ?", id)
return scanGuestSessionRow(row)
}
func (s *Store) GetGuestSession(ctx context.Context, vmID, idOrName string) (model.GuestSession, error) {
row := s.db.QueryRowContext(ctx, guestSessionSelectSQL+" WHERE vm_id = ? AND (id = ? OR name = ?)", vmID, idOrName, idOrName)
return scanGuestSessionRow(row)
}
func (s *Store) ListGuestSessionsByVM(ctx context.Context, vmID string) ([]model.GuestSession, error) {
rows, err := s.db.QueryContext(ctx, guestSessionSelectSQL+" WHERE vm_id = ? ORDER BY created_at ASC", vmID)
if err != nil {
return nil, err
}
defer rows.Close()
var sessions []model.GuestSession
for rows.Next() {
session, err := scanGuestSession(rows)
if err != nil {
return nil, err
}
sessions = append(sessions, session)
}
return sessions, rows.Err()
}
func (s *Store) DeleteGuestSession(ctx context.Context, id string) error {
s.writeMu.Lock()
defer s.writeMu.Unlock()
_, err := s.db.ExecContext(ctx, "DELETE FROM guest_sessions WHERE id = ?", id)
return err
}
func (s *Store) NextGuestIP(ctx context.Context, bridgeIPPrefix string) (string, error) {
used := map[string]struct{}{}
rows, err := s.db.QueryContext(ctx, "SELECT guest_ip FROM vms")
@ -622,113 +468,6 @@ func boolToInt(value bool) int {
return 0
}
const guestSessionSelectSQL = `
SELECT id, vm_id, name, backend, attach_backend, attach_mode, command, args_json, cwd, env_json, stdin_mode, status,
exit_code, guest_pid, guest_state_dir, stdout_log_path, stderr_log_path, tags_json,
last_error, attachable, reattachable, launch_stage, launch_message, launch_raw_log,
created_at, started_at, updated_at, ended_at
FROM guest_sessions`
func scanGuestSession(rows scanner) (model.GuestSession, error) {
return scanGuestSessionRow(rows)
}
func scanGuestSessionRow(row scanner) (model.GuestSession, error) {
var session model.GuestSession
var (
argsJSON string
envJSON string
tagsJSON string
stdinMode string
status string
exitCode sql.NullInt64
startedAt sql.NullString
endedAt sql.NullString
attachable int
reattachable int
createdRaw string
updatedRaw string
)
err := row.Scan(
&session.ID,
&session.VMID,
&session.Name,
&session.Backend,
&session.AttachBackend,
&session.AttachMode,
&session.Command,
&argsJSON,
&session.CWD,
&envJSON,
&stdinMode,
&status,
&exitCode,
&session.GuestPID,
&session.GuestStateDir,
&session.StdoutLogPath,
&session.StderrLogPath,
&tagsJSON,
&session.LastError,
&attachable,
&reattachable,
&session.LaunchStage,
&session.LaunchMessage,
&session.LaunchRawLog,
&createdRaw,
&startedAt,
&updatedRaw,
&endedAt,
)
if err != nil {
return session, err
}
session.StdinMode = model.GuestSessionStdinMode(stdinMode)
session.Status = model.GuestSessionStatus(status)
session.Attachable = attachable == 1
session.Reattachable = reattachable == 1
if argsJSON != "" {
if err := json.Unmarshal([]byte(argsJSON), &session.Args); err != nil {
return session, err
}
}
if envJSON != "" {
if err := json.Unmarshal([]byte(envJSON), &session.Env); err != nil {
return session, err
}
}
if tagsJSON != "" {
if err := json.Unmarshal([]byte(tagsJSON), &session.Tags); err != nil {
return session, err
}
}
if exitCode.Valid {
value := int(exitCode.Int64)
session.ExitCode = &value
}
var parseErr error
session.CreatedAt, parseErr = time.Parse(time.RFC3339, createdRaw)
if parseErr != nil {
return session, parseErr
}
session.UpdatedAt, parseErr = time.Parse(time.RFC3339, updatedRaw)
if parseErr != nil {
return session, parseErr
}
if startedAt.Valid && startedAt.String != "" {
session.StartedAt, parseErr = time.Parse(time.RFC3339, startedAt.String)
if parseErr != nil {
return session, parseErr
}
}
if endedAt.Valid && endedAt.String != "" {
session.EndedAt, parseErr = time.Parse(time.RFC3339, endedAt.String)
if parseErr != nil {
return session, parseErr
}
}
return session, nil
}
func nullableTimeString(value time.Time) any {
if value.IsZero() {
return nil