remove vm session feature

Cuts the daemon-managed guest-session machinery (start/list/show/
logs/stop/kill/attach/send). The feature shipped aimed at agent-
orchestration workflows (programmatic stdin piping into a long-lived
guest process) that aren't driving any concrete user today, and the
~2.3K LOC of daemon surface area — attach bridge, FIFO keepalive,
controller registry, sessionstream framing, SQLite persistence — was
locking in an API we'd have to keep through v0.1.0.

Anything session-flavoured that people actually need today can be
done with `vm ssh + tmux` or `vm run -- cmd`.

Deleted:
- internal/cli/commands_vm_session.go
- internal/daemon/{guest_sessions,session_lifecycle,session_attach,session_stream,session_controller}.go
- internal/daemon/session/ (guest-session helpers package)
- internal/sessionstream/ (framing package)
- internal/daemon/guest_sessions_test.go
- internal/store/guest_session_test.go
- GuestSession* types from internal/{api,model}
- Store UpsertGuestSession/GetGuestSession/ListGuestSessionsByVM/DeleteGuestSession + scanner helpers
- guest.session.* RPC dispatch entries
- 5 CLI session tests, 2 completion tests, 2 printer tests

Extracted:
- ShellQuote + FormatStepError lifted to internal/daemon/workspace/util.go
  (only non-session consumer); workspace package now self-contained
- internal/daemon/guest_ssh.go keeps guestSSHClient + dialGuest +
  waitForGuestSSH — still used by workspace prepare/export
- internal/daemon/fake_firecracker_test.go preserves the test helper
  that used to live in guest_sessions_test.go

Store schema: CREATE TABLE guest_sessions and its column migrations
removed. Existing dev DBs keep an orphan table (harmless, pre-v0.1.0).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Thales Maciel 2026-04-20 12:47:58 -03:00
parent c42fcbe012
commit 2b6437d1b4
No known key found for this signature in database
GPG key ID: 33112E6833C34679
34 changed files with 194 additions and 4031 deletions

View file

@ -32,13 +32,11 @@ owning types:
- `createOps opstate.Registry[*vmCreateOperationState]` — in-flight VM
create operations; owns its own lock.
- `tapPool tapPool` — TAP interface pool; owns its own lock.
- `sessions sessionRegistry` — active guest session controllers; owns
its own lock.
- `listener`, `vmDNS` — networking.
- `vmCaps` — registered VM capability hooks.
- `pullAndFlatten`, `finalizePulledRootfs`, `bundleFetch`,
`requestHandler`, `guestWaitForSSH`, `guestDial`,
`waitForGuestSessionReady` — injectable seams used by tests.
`workspaceInspectRepo`, `workspaceImport` — injectable seams used by tests.
## Subpackages
@ -53,11 +51,9 @@ state beyond small test seams.
| `internal/daemon/dmsnap` | Device-mapper COW snapshot create/cleanup/remove. |
| `internal/daemon/fcproc` | Firecracker process primitives (bridge, tap, binary, PID, kill, wait). |
| `internal/daemon/imagemgr` | Image subsystem pure helpers: validators, staging, build script gen. |
| `internal/daemon/session` | Guest-session helpers: state paths, scripts, parsing, utilities. |
| `internal/daemon/workspace` | Workspace helpers: git inspection, copy prep, guest import script. |
`workspace` imports `session` for `ShellQuote` and `FormatStepError`; all
other subpackages are leaves (no other intra-daemon subpackage imports).
All subpackages are leaves — no intra-daemon subpackage imports another.
## Lock ordering
@ -73,9 +69,8 @@ time. `workspace.prepare` acquires `vmLocks[id]` just long enough to
validate VM state, releases it, then acquires `workspaceLocks[id]`
for the guest I/O phase.
Subsystem-local locks (`tapPool.mu`, `sessionRegistry.mu`,
`opstate.Registry` mu, `guestSessionController.attachMu` /
`writeMu`) are leaves. They do not contend with each other.
Subsystem-local locks (`tapPool.mu`, `opstate.Registry` mu) are leaves.
They do not contend with each other.
Notes:

View file

@ -51,24 +51,22 @@ type Daemon struct {
// See internal/daemon/vm_handles.go — persistent durable state
// lives in the store, this is rebuildable from a per-VM
// handles.json scratch file and OS inspection.
handles *handleCache
sessions sessionRegistry
tapPool tapPool
closing chan struct{}
once sync.Once
pid int
listener net.Listener
vmDNS *vmdns.Server
vmCaps []vmCapability
pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error)
finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error
bundleFetch func(ctx context.Context, destDir string, entry imagecat.CatEntry) (imagecat.Manifest, error)
requestHandler func(context.Context, rpc.Request) rpc.Response
guestWaitForSSH func(context.Context, string, string, time.Duration) error
guestDial func(context.Context, string, string) (guestSSHClient, error)
waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error)
workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error)
workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error
handles *handleCache
tapPool tapPool
closing chan struct{}
once sync.Once
pid int
listener net.Listener
vmDNS *vmdns.Server
vmCaps []vmCapability
pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error)
finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error
bundleFetch func(ctx context.Context, destDir string, entry imagecat.CatEntry) (imagecat.Manifest, error)
requestHandler func(context.Context, rpc.Request) rpc.Response
guestWaitForSSH func(context.Context, string, string, time.Duration) error
guestDial func(context.Context, string, string) (guestSSHClient, error)
workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error)
workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error
}
func Open(ctx context.Context) (d *Daemon, err error) {
@ -93,15 +91,14 @@ func Open(ctx context.Context) (d *Daemon, err error) {
return nil, err
}
d = &Daemon{
layout: layout,
config: cfg,
store: db,
runner: system.NewRunner(),
logger: logger,
closing: make(chan struct{}),
pid: os.Getpid(),
handles: newHandleCache(),
sessions: newSessionRegistry(),
layout: layout,
config: cfg,
store: db,
runner: system.NewRunner(),
logger: logger,
closing: make(chan struct{}),
pid: os.Getpid(),
handles: newHandleCache(),
}
// From here on, every failure path must run Close() so the host
// state we touched (DNS listener goroutine, resolvectl routing,
@ -144,7 +141,7 @@ func (d *Daemon) Close() error {
if d.listener != nil {
_ = d.listener.Close()
}
err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.closeGuestSessionControllers(), d.store.Close())
err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.store.Close())
})
return err
}
@ -424,62 +421,6 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
}
result, err := d.ExportVMWorkspace(ctx, params)
return marshalResultOrError(result, err)
case "guest.session.start":
params, err := rpc.DecodeParams[api.GuestSessionStartParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.StartGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.get":
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.GetGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.list":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
sessions, err := d.ListGuestSessions(ctx, params)
return marshalResultOrError(api.GuestSessionListResult{Sessions: sessions}, err)
case "guest.session.stop":
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.StopGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.kill":
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.KillGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.logs":
params, err := rpc.DecodeParams[api.GuestSessionLogsParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.GuestSessionLogs(ctx, params)
return marshalResultOrError(result, err)
case "guest.session.attach.begin":
params, err := rpc.DecodeParams[api.GuestSessionAttachBeginParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.BeginGuestSessionAttach(ctx, params)
return marshalResultOrError(result, err)
case "guest.session.send":
params, err := rpc.DecodeParams[api.GuestSessionSendParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.SendToGuestSession(ctx, params)
return marshalResultOrError(result, err)
case "image.list":
images, err := d.store.ListImages(ctx)
return marshalResultOrError(api.ImageListResult{Images: images}, err)

View file

@ -1,8 +1,8 @@
// Package daemon hosts the Banger daemon process.
//
// The daemon exposes a JSON-RPC endpoint over a Unix socket. It owns VM
// lifecycle, image management, guest sessions, host networking bootstrap,
// and state persistence via internal/store.
// lifecycle, image management, host networking bootstrap, and state
// persistence via internal/store.
//
// The package is organised into cohesive groups. Pure stateless helpers for
// each group have been lifted into subpackages; orchestrator methods
@ -18,13 +18,9 @@
// internal/daemon/imagemgr Image subsystem helpers: path validation,
// artifact staging, guest provisioning script
// generator, metadata.
// internal/daemon/session Guest-session helpers: state paths, runner
// / inspect / signal scripts, state snapshot
// parsing, launch helpers, ShellQuote,
// FormatStepError.
// internal/daemon/workspace Workspace helpers: git repo inspection,
// shallow copy prep, guest-side import,
// finalize script generation.
// finalize script generation, shell quoting.
//
// VM lifecycle (in this package):
//
@ -50,11 +46,7 @@
//
// Guest interaction (in this package):
//
// guest_sessions.go dialGuest, waitForGuestSSH, refresh/inspect
// session_lifecycle.go Start/Stop/Kill/Get/List/signal orchestrators
// session_attach.go BeginGuestSessionAttach + bridge/forward/watch
// session_stream.go GuestSessionLogs, SendToGuestSession
// session_controller.go guestSessionController, sessionRegistry
// guest_ssh.go guestSSHClient, dialGuest, waitForGuestSSH
// ssh_client_config.go daemon-managed SSH client key material
// workspace.go ExportVMWorkspace, PrepareVMWorkspace
//
@ -73,10 +65,8 @@
//
// Lock ordering:
//
// vmLocks[id] → {createVMMu, imageOpsMu} → subsystem-local locks
// vmLocks[id] → workspaceLocks[id] → {createVMMu, imageOpsMu} → subsystem-local locks
//
// Subsystem-local locks live on their owning type (tapPool.mu,
// sessionRegistry.mu, opstate.Registry mu, guestSessionController.attachMu /
// writeMu) and do not contend with each other. See ARCHITECTURE.md for
// details.
// Subsystem-local locks (tapPool.mu, opstate.Registry mu) are leaves and
// do not contend with each other. See ARCHITECTURE.md for details.
package daemon

View file

@ -0,0 +1,26 @@
package daemon
import (
"fmt"
"os/exec"
"testing"
)
// startFakeFirecracker launches a bash sleep-loop rewritten to match
// the firecracker command line a real process would expose, so
// reconcile / handle-cache paths that grep /proc/<pid>/cmdline accept
// it as a firecracker process. Killed on test cleanup.
func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd {
t.Helper()
cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
if err := cmd.Start(); err != nil {
t.Fatalf("start fake firecracker: %v", err)
}
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
}
})
return cmd
}

View file

@ -1,142 +0,0 @@
package daemon
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"time"
"banger/internal/daemon/session"
"banger/internal/guest"
"banger/internal/model"
"banger/internal/system"
)
type guestSSHClient interface {
Close() error
RunScript(context.Context, string, io.Writer) error
RunScriptOutput(context.Context, string) ([]byte, error)
UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error
StreamTar(context.Context, string, string, io.Writer) error
StreamTarEntries(context.Context, string, []string, string, io.Writer) error
}
func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval time.Duration) error {
if d != nil && d.guestWaitForSSH != nil {
return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
}
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval)
}
func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) {
if d != nil && d.guestDial != nil {
return d.guestDial(ctx, address, d.config.SSHKeyPath)
}
return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
}
func (d *Daemon) waitForGuestSessionReadyHook(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
if d != nil && d.waitForGuestSessionReady != nil {
return d.waitForGuestSessionReady(ctx, vm, s)
}
return d.waitForGuestSessionReadyDefault(ctx, vm, s)
}
func (d *Daemon) waitForGuestSessionReadyDefault(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
for {
updated, err := d.refreshGuestSession(ctx, vm, s)
if err == nil {
s = updated
if s.GuestPID != 0 || s.ExitCode != nil || s.Status == model.GuestSessionStatusRunning || s.Status == model.GuestSessionStatusFailed || s.Status == model.GuestSessionStatusExited {
return s, nil
}
}
select {
case <-ctx.Done():
return s, ctx.Err()
case <-time.After(100 * time.Millisecond):
}
}
}
func (d *Daemon) refreshGuestSession(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
if s.Status != model.GuestSessionStatusStarting && s.Status != model.GuestSessionStatusRunning && s.Status != model.GuestSessionStatusStopping {
return s, nil
}
snapshot, err := d.inspectGuestSessionState(ctx, vm, s)
if err != nil {
return s, err
}
original := s
session.ApplyStateSnapshot(&s, snapshot, d.vmAlive(vm))
if session.StateChanged(original, s) {
s.UpdatedAt = model.Now()
if err := d.store.UpsertGuestSession(ctx, s); err != nil {
return s, err
}
}
return s, nil
}
func (d *Daemon) inspectGuestSessionState(ctx context.Context, vm model.VMRecord, s model.GuestSession) (session.StateSnapshot, error) {
if d.vmAlive(vm) {
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return session.StateSnapshot{}, err
}
defer client.Close()
var output bytes.Buffer
if err := client.RunScript(ctx, session.InspectScript(s.ID), &output); err != nil {
return session.StateSnapshot{}, session.FormatStepError("inspect guest session state", err, output.String())
}
return session.ParseState(output.String())
}
return d.inspectGuestSessionStateFromWorkDisk(ctx, vm, s.ID)
}
func (d *Daemon) inspectGuestSessionStateFromWorkDisk(ctx context.Context, vm model.VMRecord, sessionID string) (session.StateSnapshot, error) {
runner := d.runner
if runner == nil {
runner = system.NewRunner()
}
workMount, cleanup, err := system.MountTempDir(ctx, runner, vm.Runtime.WorkDiskPath, false)
if err != nil {
return session.StateSnapshot{}, err
}
defer cleanup()
stateDir := filepath.Join(workMount, session.RelativeStateDir(sessionID))
return session.InspectStateFromDir(stateDir)
}
func (d *Daemon) findGuestSession(ctx context.Context, vmID, idOrName string) (model.GuestSession, error) {
if strings.TrimSpace(idOrName) == "" {
return model.GuestSession{}, errors.New("session id or name is required")
}
if s, err := d.store.GetGuestSession(ctx, vmID, idOrName); err == nil {
return s, nil
}
sessions, err := d.store.ListGuestSessionsByVM(ctx, vmID)
if err != nil {
return model.GuestSession{}, err
}
matches := make([]model.GuestSession, 0, 1)
for _, s := range sessions {
if strings.HasPrefix(s.ID, idOrName) || strings.HasPrefix(s.Name, idOrName) {
matches = append(matches, s)
}
}
switch len(matches) {
case 0:
return model.GuestSession{}, fmt.Errorf("session %q not found", idOrName)
case 1:
return matches[0], nil
default:
return model.GuestSession{}, fmt.Errorf("multiple sessions match %q", idOrName)
}
}

View file

@ -1,492 +0,0 @@
package daemon
import (
"context"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
"banger/internal/api"
sess "banger/internal/daemon/session"
"banger/internal/model"
"banger/internal/store"
)
type fakeGuestSSHClient struct {
t *testing.T
existingDirs map[string]bool
closed bool
}
func (f *fakeGuestSSHClient) Close() error {
f.closed = true
return nil
}
func (f *fakeGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error {
f.t.Helper()
switch {
case strings.Contains(script, `\n`):
return fmt.Errorf("script still contains escaped newline literals: %q", script)
case strings.Contains(script, `echo "missing cwd: $DIR"`):
if strings.Contains(script, "DIR='/root/repo'\n") && f.existingDirs["/root/repo"] {
return nil
}
return fmt.Errorf("missing cwd")
case strings.Contains(script, "check_command() {"):
return nil
case strings.Contains(script, `git config --global --add safe.directory "$DIR"`):
if strings.Contains(script, "DIR='/root/repo'\n") {
f.existingDirs["/root/repo"] = true
return nil
}
return fmt.Errorf("workspace finalize used unexpected guest path")
case strings.Contains(script, "chmod -R a-w"):
if f.existingDirs["/root/repo"] {
return nil
}
return fmt.Errorf("workspace path missing during readonly chmod")
case strings.Contains(script, "nohup bash "):
return nil
default:
return nil
}
}
func (f *fakeGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) {
return nil, nil
}
func (f *fakeGuestSSHClient) UploadFile(_ context.Context, _ string, _ os.FileMode, _ []byte, _ io.Writer) error {
return nil
}
func (f *fakeGuestSSHClient) StreamTar(_ context.Context, _ string, command string, _ io.Writer) error {
if strings.Contains(command, "/root/repo") {
f.existingDirs["/root/repo"] = true
return nil
}
return fmt.Errorf("unexpected StreamTar command: %s", command)
}
func (f *fakeGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, command string, _ io.Writer) error {
if strings.Contains(command, "/root/repo") {
f.existingDirs["/root/repo"] = true
return nil
}
return fmt.Errorf("unexpected StreamTarEntries command: %s", command)
}
func TestSendToGuestSession_HappyPath(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := openDaemonStore(t)
apiSock := filepath.Join(t.TempDir(), "fc.sock")
firecracker := startFakeFirecracker(t, apiSock)
vm := testVM("sendbox", "image-send", "172.16.0.88")
vm.State = model.VMStateRunning
vm.Runtime.State = model.VMStateRunning
vm.Runtime.APISockPath = apiSock
upsertDaemonVM(t, ctx, db, vm)
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
if err := db.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
fake := &recordingGuestSSHClient{}
d := newSendTestDaemon(t, db, fake)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
payload := []byte(`{"type":"abort"}` + "\n")
result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
VMIDOrName: vm.Name,
SessionIDOrName: session.Name,
Payload: payload,
})
if err != nil {
t.Fatalf("SendToGuestSession: %v", err)
}
if result.BytesWritten != len(payload) {
t.Fatalf("BytesWritten = %d, want %d", result.BytesWritten, len(payload))
}
if result.Session.ID != session.ID {
t.Fatalf("Session.ID = %q, want %q", result.Session.ID, session.ID)
}
if len(fake.uploadedFiles) != 1 {
t.Fatalf("UploadFile call count = %d, want 1", len(fake.uploadedFiles))
}
for path, data := range fake.uploadedFiles {
if !strings.HasPrefix(path, "/tmp/banger-send-") {
t.Fatalf("upload path = %q, want /tmp/banger-send-... prefix", path)
}
if string(data) != string(payload) {
t.Fatalf("upload data = %q, want %q", data, payload)
}
}
if len(fake.ranScripts) != 1 {
t.Fatalf("RunScript call count = %d, want 1", len(fake.ranScripts))
}
script := fake.ranScripts[0]
pipePath := sess.StdinPipePath(session.ID)
if !strings.Contains(script, "cat ") {
t.Fatalf("send script missing cat command: %q", script)
}
if !strings.Contains(script, pipePath) {
t.Fatalf("send script missing pipe path %q: %q", pipePath, script)
}
if !strings.Contains(script, "rm -f ") {
t.Fatalf("send script missing rm cleanup: %q", script)
}
}
func TestSendToGuestSession_EmptyPayload(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := openDaemonStore(t)
apiSock := filepath.Join(t.TempDir(), "fc.sock")
firecracker := startFakeFirecracker(t, apiSock)
vm := testVM("sendbox-empty", "image-send", "172.16.0.89")
vm.State = model.VMStateRunning
vm.Runtime.State = model.VMStateRunning
vm.Runtime.APISockPath = apiSock
upsertDaemonVM(t, ctx, db, vm)
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
if err := db.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
fake := &recordingGuestSSHClient{}
d := newSendTestDaemon(t, db, fake)
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
VMIDOrName: vm.Name,
SessionIDOrName: session.Name,
Payload: nil,
})
if err != nil {
t.Fatalf("SendToGuestSession(empty): %v", err)
}
if result.BytesWritten != 0 {
t.Fatalf("BytesWritten = %d, want 0", result.BytesWritten)
}
if fake.dialCount != 0 {
t.Fatalf("SSH dial count = %d, want 0 for empty payload", fake.dialCount)
}
}
func TestSendToGuestSession_NotPipeMode(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := openDaemonStore(t)
vm := testVM("sendbox-closed", "image-send", "172.16.0.90")
vm.State = model.VMStateRunning
upsertDaemonVM(t, ctx, db, vm)
session := testGuestSession(vm.ID, model.GuestSessionStdinClosed, model.GuestSessionStatusRunning)
if err := db.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
d := &Daemon{store: db}
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
VMIDOrName: vm.Name,
SessionIDOrName: session.Name,
Payload: []byte("hello\n"),
})
if err == nil || !strings.Contains(err.Error(), "stdin pipe") {
t.Fatalf("error = %v, want 'stdin pipe' error", err)
}
}
func TestSendToGuestSession_SessionNotRunning(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := openDaemonStore(t)
vm := testVM("sendbox-failed", "image-send", "172.16.0.91")
vm.State = model.VMStateRunning
upsertDaemonVM(t, ctx, db, vm)
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusFailed)
if err := db.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
d := &Daemon{store: db}
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
VMIDOrName: vm.Name,
SessionIDOrName: session.Name,
Payload: []byte("hello\n"),
})
if err == nil || !strings.Contains(err.Error(), "not running") {
t.Fatalf("error = %v, want 'not running' error", err)
}
}
func TestSendToGuestSession_VMNotRunning(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := openDaemonStore(t)
vm := testVM("sendbox-stopped", "image-send", "172.16.0.92")
vm.State = model.VMStateStopped
upsertDaemonVM(t, ctx, db, vm)
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
if err := db.UpsertGuestSession(ctx, session); err != nil {
t.Fatalf("UpsertGuestSession: %v", err)
}
d := &Daemon{store: db}
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
VMIDOrName: vm.Name,
SessionIDOrName: session.Name,
Payload: []byte("hello\n"),
})
if err == nil || !strings.Contains(err.Error(), "not running") {
t.Fatalf("error = %v, want 'not running' error", err)
}
}
// recordingGuestSSHClient captures UploadFile and RunScript calls for send tests.
type recordingGuestSSHClient struct {
dialCount int
uploadedFiles map[string][]byte
ranScripts []string
}
func (r *recordingGuestSSHClient) Close() error { return nil }
func (r *recordingGuestSSHClient) UploadFile(_ context.Context, path string, _ os.FileMode, data []byte, _ io.Writer) error {
if r.uploadedFiles == nil {
r.uploadedFiles = make(map[string][]byte)
}
copy := make([]byte, len(data))
_ = copy[:len(data):len(data)]
for i, b := range data {
copy[i] = b
}
r.uploadedFiles[path] = copy
return nil
}
func (r *recordingGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error {
r.ranScripts = append(r.ranScripts, script)
return nil
}
func (r *recordingGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) {
return nil, nil
}
func (r *recordingGuestSSHClient) StreamTar(_ context.Context, _ string, _ string, _ io.Writer) error {
return nil
}
func (r *recordingGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, _ string, _ io.Writer) error {
return nil
}
func newSendTestDaemon(t *testing.T, db *store.Store, fake *recordingGuestSSHClient) *Daemon {
t.Helper()
d := &Daemon{
store: db,
config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")},
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
}
d.guestDial = func(_ context.Context, _ string, _ string) (guestSSHClient, error) {
fake.dialCount++
return fake, nil
}
return d
}
func testGuestSession(vmID string, stdinMode model.GuestSessionStdinMode, status model.GuestSessionStatus) model.GuestSession {
now := model.Now()
id := vmID + "-sess-id"
return model.GuestSession{
ID: id,
VMID: vmID,
Name: vmID + "-sess",
Backend: sess.BackendSSH,
Command: "pi",
Args: []string{"--mode", "rpc"},
CWD: "/root/repo",
StdinMode: stdinMode,
Status: status,
GuestStateDir: sess.StateDir(id),
StdoutLogPath: sess.StdoutLogPath(id),
StderrLogPath: sess.StderrLogPath(id),
Attachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning,
Reattachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning,
CreatedAt: now,
UpdatedAt: now,
}
}
func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd {
t.Helper()
cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
if err := cmd.Start(); err != nil {
t.Fatalf("start fake firecracker: %v", err)
}
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
}
})
return cmd
}
func TestGuestSessionPreflightScriptsUseRealNewlines(t *testing.T) {
t.Parallel()
cwdScript := sess.CWDPreflightScript("/root/repo")
if strings.Contains(cwdScript, `\n`) {
t.Fatalf("cwd preflight script still contains escaped newline literals: %q", cwdScript)
}
if !strings.Contains(cwdScript, "\n") {
t.Fatalf("cwd preflight script should contain real newlines: %q", cwdScript)
}
commandScript := sess.CommandPreflightScript([]string{"git", "pi"})
if strings.Contains(commandScript, `\n`) {
t.Fatalf("command preflight script still contains escaped newline literals: %q", commandScript)
}
if !strings.Contains(commandScript, "\n") {
t.Fatalf("command preflight script should contain real newlines: %q", commandScript)
}
attachInput := sess.AttachInputCommand("session-id")
if strings.Contains(attachInput, `\n`) {
t.Fatalf("attach input command still contains escaped newline literals: %q", attachInput)
}
attachTail := sess.AttachTailCommand("/tmp/stdout.log")
if strings.Contains(attachTail, `\n`) {
t.Fatalf("attach tail command still contains escaped newline literals: %q", attachTail)
}
}
func TestPrepareWorkspaceThenStartGuestSessionPassesCWDPreflight(t *testing.T) {
ctx := context.Background()
db := openDaemonStore(t)
repoRoot := filepath.Join(t.TempDir(), "repo")
if err := os.MkdirAll(repoRoot, 0o755); err != nil {
t.Fatalf("MkdirAll(repoRoot): %v", err)
}
if err := os.WriteFile(filepath.Join(repoRoot, "README.md"), []byte("hello\n"), 0o644); err != nil {
t.Fatalf("WriteFile(README.md): %v", err)
}
runGit := func(args ...string) {
t.Helper()
cmd := exec.Command("git", append([]string{"-C", repoRoot}, args...)...)
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("git %s: %v\n%s", strings.Join(args, " "), err, output)
}
}
runGit("init", "-b", "main")
runGit("config", "user.name", "Test User")
runGit("config", "user.email", "test@example.com")
runGit("add", ".")
runGit("commit", "-m", "initial")
apiSock := filepath.Join(t.TempDir(), "fc.sock")
firecracker := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
if err := firecracker.Start(); err != nil {
t.Fatalf("start fake firecracker: %v", err)
}
t.Cleanup(func() {
if firecracker.Process != nil {
_ = firecracker.Process.Kill()
_, _ = firecracker.Process.Wait()
}
})
vm := testVM("pi-devbox", "image-pi", "172.16.0.77")
vm.State = model.VMStateRunning
vm.Runtime.State = model.VMStateRunning
vm.Runtime.APISockPath = apiSock
upsertDaemonVM(t, ctx, db, vm)
fakeClient := &fakeGuestSSHClient{t: t, existingDirs: map[string]bool{}}
d := &Daemon{
store: db,
config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")},
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
}
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
d.guestDial = func(context.Context, string, string) (guestSSHClient, error) { return fakeClient, nil }
d.waitForGuestSessionReady = func(_ context.Context, _ model.VMRecord, session model.GuestSession) (model.GuestSession, error) {
now := model.Now()
session.Status = model.GuestSessionStatusRunning
session.GuestPID = 4242
session.StartedAt = now
session.UpdatedAt = now
session.Attachable = session.StdinMode == model.GuestSessionStdinPipe
session.Reattachable = session.StdinMode == model.GuestSessionStdinPipe
return session, nil
}
workspace, err := d.PrepareVMWorkspace(ctx, api.VMWorkspacePrepareParams{
IDOrName: vm.Name,
SourcePath: repoRoot,
GuestPath: "/root/repo",
ReadOnly: true,
})
if err != nil {
t.Fatalf("PrepareVMWorkspace: %v", err)
}
if workspace.GuestPath != "/root/repo" {
t.Fatalf("PrepareVMWorkspace guest path = %q, want /root/repo", workspace.GuestPath)
}
if !fakeClient.existingDirs["/root/repo"] {
t.Fatalf("workspace prepare did not mark /root/repo as present in fake guest")
}
session, err := d.StartGuestSession(ctx, api.GuestSessionStartParams{
VMIDOrName: vm.Name,
Name: "testpi",
Command: "pi",
Args: []string{"--mode", "rpc", "--no-session"},
CWD: "/root/repo",
StdinMode: string(model.GuestSessionStdinPipe),
RequiredCommands: []string{"git"},
})
if err != nil {
t.Fatalf("StartGuestSession: %v", err)
}
if session.Status != model.GuestSessionStatusRunning {
t.Fatalf("session status = %q, want %q", session.Status, model.GuestSessionStatusRunning)
}
if session.LaunchStage != "" {
t.Fatalf("session launch stage = %q, want empty", session.LaunchStage)
}
if session.LaunchMessage != "" {
t.Fatalf("session launch message = %q, want empty", session.LaunchMessage)
}
if session.GuestPID == 0 {
t.Fatalf("session guest pid = 0, want non-zero")
}
if !session.Attachable {
t.Fatalf("session should be attachable for pipe stdin mode")
}
}

View file

@ -0,0 +1,35 @@
package daemon
import (
"context"
"io"
"os"
"time"
"banger/internal/guest"
)
// guestSSHClient is the narrow guest-SSH surface the daemon uses for
// workspace prepare / export and ad-hoc guest interactions.
type guestSSHClient interface {
Close() error
RunScript(context.Context, string, io.Writer) error
RunScriptOutput(context.Context, string) ([]byte, error)
UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error
StreamTar(context.Context, string, string, io.Writer) error
StreamTarEntries(context.Context, string, []string, string, io.Writer) error
}
func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval time.Duration) error {
if d != nil && d.guestWaitForSSH != nil {
return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
}
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval)
}
func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) {
if d != nil && d.guestDial != nil {
return d.guestDial(ctx, address, d.config.SSHKeyPath)
}
return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
}

View file

@ -26,10 +26,9 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) {
name: "only store + closing channel (early failure)",
build: func(t *testing.T) *Daemon {
return &Daemon{
store: openDaemonStore(t),
closing: make(chan struct{}),
sessions: newSessionRegistry(),
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
store: openDaemonStore(t),
closing: make(chan struct{}),
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
}
},
verify: func(t *testing.T, d *Daemon) {
@ -49,11 +48,10 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) {
t.Fatalf("vmdns.New: %v", err)
}
return &Daemon{
store: openDaemonStore(t),
closing: make(chan struct{}),
sessions: newSessionRegistry(),
vmDNS: server,
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
store: openDaemonStore(t),
closing: make(chan struct{}),
vmDNS: server,
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
}
},
verify: func(t *testing.T, d *Daemon) {
@ -86,11 +84,10 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) {
// returns and also calls Close afterwards, both paths must survive.
func TestCloseIdempotentUnderConcurrency(t *testing.T) {
d := &Daemon{
store: openDaemonStore(t),
closing: make(chan struct{}),
sessions: newSessionRegistry(),
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
config: model.DaemonConfig{BridgeName: ""},
store: openDaemonStore(t),
closing: make(chan struct{}),
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
config: model.DaemonConfig{BridgeName: ""},
}
var count atomic.Int32

View file

@ -1,521 +0,0 @@
// Package session contains the pure helpers of the guest-session subsystem:
// bash script generators, on-guest state path helpers, state snapshot
// parsing, and small utilities like ShellQuote and FormatStepError.
//
// The orchestrator methods (StartGuestSession, BeginGuestSessionAttach,
// etc.) stay on *daemon.Daemon and compose these helpers.
package session
import (
"bufio"
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"syscall"
"banger/internal/model"
"banger/internal/system"
"golang.org/x/crypto/ssh"
)
// Constants shared between orchestration and helpers.
const (
BackendSSH = "ssh"
AttachBackendNone = "none"
AttachBackendSSHBridge = "ssh_rehydratable"
AttachModeExclusive = "exclusive"
TransportUnixSocket = "unix_socket"
StateRoot = "/root/.local/state/banger/sessions"
LogTailLineDefault = 200
)
// StateSnapshot is the decoded per-session state as read from the guest.
type StateSnapshot struct {
Status string
GuestPID int
ExitCode *int
Alive bool
LastError string
}
// -- Guest filesystem paths -------------------------------------------------
func StateDir(id string) string {
return filepath.ToSlash(filepath.Join(StateRoot, id))
}
func RelativeStateDir(id string) string {
return strings.TrimPrefix(StateDir(id), "/root/")
}
func ScriptPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "run.sh")) }
func PIDPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "pid")) }
func MonitorPIDPath(id string) string {
return filepath.ToSlash(filepath.Join(StateDir(id), "monitor_pid"))
}
func ExitCodePath(id string) string {
return filepath.ToSlash(filepath.Join(StateDir(id), "exit_code"))
}
func StdinPipePath(id string) string {
return filepath.ToSlash(filepath.Join(StateDir(id), "stdin.pipe"))
}
func StdinKeepalivePIDPath(id string) string {
return filepath.ToSlash(filepath.Join(StateDir(id), "stdin_keepalive.pid"))
}
func StatusPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "status")) }
func ErrorPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "error")) }
func StdoutLogPath(id string) string {
return filepath.ToSlash(filepath.Join(StateDir(id), "stdout.log"))
}
func StderrLogPath(id string) string {
return filepath.ToSlash(filepath.Join(StateDir(id), "stderr.log"))
}
// -- Script generators ------------------------------------------------------
// Script returns the bash runner installed into the guest for session. It
// sets up state/log paths, optional stdin fifo, and wait-loop around the
// user command.
func Script(sess model.GuestSession) string {
var script strings.Builder
script.WriteString("set -euo pipefail\n")
fmt.Fprintf(&script, "STATE_DIR=%s\n", ShellQuote(sess.GuestStateDir))
fmt.Fprintf(&script, "STDOUT_LOG=%s\n", ShellQuote(sess.StdoutLogPath))
fmt.Fprintf(&script, "STDERR_LOG=%s\n", ShellQuote(sess.StderrLogPath))
fmt.Fprintf(&script, "PID_FILE=%s\n", ShellQuote(PIDPath(sess.ID)))
fmt.Fprintf(&script, "MONITOR_PID_FILE=%s\n", ShellQuote(MonitorPIDPath(sess.ID)))
fmt.Fprintf(&script, "EXIT_FILE=%s\n", ShellQuote(ExitCodePath(sess.ID)))
fmt.Fprintf(&script, "STATUS_FILE=%s\n", ShellQuote(StatusPath(sess.ID)))
fmt.Fprintf(&script, "ERROR_FILE=%s\n", ShellQuote(ErrorPath(sess.ID)))
fmt.Fprintf(&script, "STDIN_PIPE=%s\n", ShellQuote(StdinPipePath(sess.ID)))
fmt.Fprintf(&script, "STDIN_KEEPALIVE_PID_FILE=%s\n", ShellQuote(StdinKeepalivePIDPath(sess.ID)))
fmt.Fprintf(&script, "SESSION_CWD=%s\n", ShellQuote(DefaultCWD(sess.CWD)))
script.WriteString("mkdir -p \"$STATE_DIR\"\n")
script.WriteString(": >\"$STDOUT_LOG\"\n")
script.WriteString(": >\"$STDERR_LOG\"\n")
script.WriteString("rm -f \"$EXIT_FILE\" \"$ERROR_FILE\" \"$STDIN_KEEPALIVE_PID_FILE\"\n")
if sess.StdinMode == model.GuestSessionStdinPipe {
script.WriteString("rm -f \"$STDIN_PIPE\"\n")
script.WriteString("mkfifo -m 600 \"$STDIN_PIPE\"\n")
}
script.WriteString("printf '%s\\n' \"${BASHPID:-$$}\" >\"$MONITOR_PID_FILE\"\n")
script.WriteString("printf 'starting\\n' >\"$STATUS_FILE\"\n")
script.WriteString("cd \"$SESSION_CWD\"\n")
script.WriteString("exec > >(tee -a \"$STDOUT_LOG\") 2> >(tee -a \"$STDERR_LOG\" >&2)\n")
for _, line := range EnvLines(sess.Env) {
script.WriteString(line)
script.WriteByte('\n')
}
script.WriteString("COMMAND=(")
for _, value := range append([]string{sess.Command}, sess.Args...) {
script.WriteByte(' ')
script.WriteString(ShellQuote(value))
}
script.WriteString(" )\n")
if sess.StdinMode == model.GuestSessionStdinPipe {
script.WriteString("( while :; do sleep 3600; done ) >\"$STDIN_PIPE\" &\n")
script.WriteString("keepalive=$!\n")
script.WriteString("printf '%s\\n' \"$keepalive\" >\"$STDIN_KEEPALIVE_PID_FILE\"\n")
script.WriteString("\"${COMMAND[@]}\" <\"$STDIN_PIPE\" &\n")
} else {
script.WriteString("\"${COMMAND[@]}\" &\n")
}
script.WriteString("child=$!\n")
script.WriteString("printf '%s\\n' \"$child\" >\"$PID_FILE\"\n")
script.WriteString("printf 'running\\n' >\"$STATUS_FILE\"\n")
script.WriteString("wait \"$child\"\n")
script.WriteString("rc=$?\n")
if sess.StdinMode == model.GuestSessionStdinPipe {
script.WriteString("if [ -f \"$STDIN_KEEPALIVE_PID_FILE\" ]; then kill \"$(cat \"$STDIN_KEEPALIVE_PID_FILE\")\" 2>/dev/null || true; fi\n")
}
script.WriteString("printf '%s\\n' \"$rc\" >\"$EXIT_FILE\"\n")
script.WriteString("if [ \"$rc\" -eq 0 ]; then printf 'exited\\n' >\"$STATUS_FILE\"; else printf 'failed\\n' >\"$STATUS_FILE\"; fi\n")
script.WriteString("exit \"$rc\"\n")
return script.String()
}
// InspectScript reads the on-guest state files for sessionID and prints a
// key=value block parseable by ParseState.
func InspectScript(sessionID string) string {
var script strings.Builder
script.WriteString("set -euo pipefail\n")
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(StateDir(sessionID)))
script.WriteString("status=''\n")
script.WriteString("pid=''\n")
script.WriteString("exit_code=''\n")
script.WriteString("last_error=''\n")
script.WriteString("alive=false\n")
script.WriteString("[ -f \"$DIR/status\" ] && status=\"$(cat \"$DIR/status\")\"\n")
script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n")
script.WriteString("[ -f \"$DIR/exit_code\" ] && exit_code=\"$(cat \"$DIR/exit_code\")\"\n")
script.WriteString("[ -f \"$DIR/error\" ] && last_error=\"$(cat \"$DIR/error\")\"\n")
script.WriteString("if [ -n \"$pid\" ] && kill -0 \"$pid\" 2>/dev/null; then alive=true; fi\n")
script.WriteString("printf 'status=%s\\n' \"$status\"\n")
script.WriteString("printf 'pid=%s\\n' \"$pid\"\n")
script.WriteString("printf 'exit=%s\\n' \"$exit_code\"\n")
script.WriteString("printf 'alive=%s\\n' \"$alive\"\n")
script.WriteString("printf 'error=%s\\n' \"$last_error\"\n")
return script.String()
}
// SignalScript sends signal to sessionID's runner and monitor processes.
func SignalScript(sessionID, signal string) string {
var script strings.Builder
script.WriteString("set -euo pipefail\n")
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(StateDir(sessionID)))
fmt.Fprintf(&script, "SIGNAL=%s\n", ShellQuote(signal))
script.WriteString("pid=''\n")
script.WriteString("monitor=''\n")
script.WriteString("keepalive=''\n")
script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n")
script.WriteString("[ -f \"$DIR/monitor_pid\" ] && monitor=\"$(cat \"$DIR/monitor_pid\")\"\n")
script.WriteString("[ -f \"$DIR/stdin_keepalive.pid\" ] && keepalive=\"$(cat \"$DIR/stdin_keepalive.pid\")\"\n")
script.WriteString("printf 'stopping\\n' >\"$DIR/status\"\n")
script.WriteString("if [ -n \"$pid\" ]; then kill -${SIGNAL} \"$pid\" 2>/dev/null || true; fi\n")
script.WriteString("if [ -n \"$monitor\" ]; then kill -${SIGNAL} \"$monitor\" 2>/dev/null || true; fi\n")
script.WriteString("if [ -n \"$keepalive\" ]; then kill -${SIGNAL} \"$keepalive\" 2>/dev/null || true; fi\n")
return script.String()
}
// CWDPreflightScript verifies cwd exists on the guest.
func CWDPreflightScript(cwd string) string {
var script strings.Builder
script.WriteString("set -euo pipefail\n")
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(DefaultCWD(cwd)))
script.WriteString("if [ ! -d \"$DIR\" ]; then echo \"missing cwd: $DIR\"; exit 1; fi\n")
return script.String()
}
// CommandPreflightScript verifies each command is resolvable on the guest.
func CommandPreflightScript(commands []string) string {
var script strings.Builder
script.WriteString("set -euo pipefail\n")
script.WriteString("check_command() {\n")
script.WriteString(" cmd=\"$1\"\n")
script.WriteString(" case \"$cmd\" in\n")
script.WriteString(" */*) [ -x \"$cmd\" ] || { echo \"missing command: $cmd\"; exit 1; } ;;\n")
script.WriteString(" *) command -v \"$cmd\" >/dev/null 2>&1 || { echo \"missing command: $cmd\"; exit 1; } ;;\n")
script.WriteString(" esac\n")
script.WriteString("}\n")
for _, command := range commands {
fmt.Fprintf(&script, "check_command %s\n", ShellQuote(command))
}
return script.String()
}
// AttachInputCommand returns the guest command that creates/opens the stdin
// fifo for sessionID and cats attach-side bytes into it.
func AttachInputCommand(sessionID string) string {
path := StdinPipePath(sessionID)
return "bash -lc " + ShellQuote(fmt.Sprintf("set -euo pipefail\n[ -p %s ] || mkfifo -m 600 %s\nexec cat > %s\n", ShellQuote(path), ShellQuote(path), ShellQuote(path)))
}
// AttachTailCommand returns the guest command that tails a log file and
// streams new content back to the attach bridge.
func AttachTailCommand(path string) string {
return "bash -lc " + ShellQuote(fmt.Sprintf("set -euo pipefail\ntouch %s\nexec tail -n 0 -F %s 2>/dev/null\n", ShellQuote(path), ShellQuote(path)))
}
// EnvLines returns deterministic `export KEY=value` lines for the session
// launcher, ordered by key.
func EnvLines(values map[string]string) []string {
if len(values) == 0 {
return nil
}
keys := make([]string, 0, len(values))
for key := range values {
keys = append(keys, key)
}
sort.Strings(keys)
lines := make([]string, 0, len(keys))
for _, key := range keys {
lines = append(lines, "export "+key+"="+ShellQuote(values[key]))
}
return lines
}
// -- State snapshot helpers -------------------------------------------------
// ParseState decodes the key=value output produced by InspectScript.
func ParseState(raw string) (StateSnapshot, error) {
var snapshot StateSnapshot
scanner := bufio.NewScanner(strings.NewReader(raw))
for scanner.Scan() {
line := scanner.Text()
key, value, ok := strings.Cut(line, "=")
if !ok {
continue
}
switch strings.TrimSpace(key) {
case "status":
snapshot.Status = strings.TrimSpace(value)
case "pid":
if pid, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
snapshot.GuestPID = pid
}
case "exit":
if exitCode, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
snapshot.ExitCode = &exitCode
}
case "alive":
snapshot.Alive = strings.TrimSpace(value) == "true"
case "error":
snapshot.LastError = strings.TrimSpace(value)
}
}
return snapshot, scanner.Err()
}
// InspectStateFromDir reads the state files directly from stateDir (used
// when the guest is offline and we can mount the work disk from the host).
func InspectStateFromDir(stateDir string) (StateSnapshot, error) {
var snapshot StateSnapshot
statusData, _ := os.ReadFile(filepath.Join(stateDir, "status"))
snapshot.Status = strings.TrimSpace(string(statusData))
pidData, _ := os.ReadFile(filepath.Join(stateDir, "pid"))
if pidValue, err := strconv.Atoi(strings.TrimSpace(string(pidData))); err == nil {
snapshot.GuestPID = pidValue
}
exitData, _ := os.ReadFile(filepath.Join(stateDir, "exit_code"))
if exitValue, err := strconv.Atoi(strings.TrimSpace(string(exitData))); err == nil {
snapshot.ExitCode = &exitValue
}
errorData, _ := os.ReadFile(filepath.Join(stateDir, "error"))
snapshot.LastError = strings.TrimSpace(string(errorData))
if snapshot.GuestPID != 0 {
snapshot.Alive = ProcessAlive(snapshot.GuestPID)
}
return snapshot, nil
}
// ApplyStateSnapshot mutates sess in place to reflect snapshot. vmRunning
// captures whether the VM is currently up so stale in-flight sessions can be
// failed when the VM is gone.
func ApplyStateSnapshot(sess *model.GuestSession, snapshot StateSnapshot, vmRunning bool) {
if sess == nil {
return
}
if snapshot.GuestPID != 0 {
sess.GuestPID = snapshot.GuestPID
}
if snapshot.LastError != "" {
sess.LastError = snapshot.LastError
}
if snapshot.ExitCode != nil {
sess.ExitCode = snapshot.ExitCode
sess.Attachable = false
sess.Reattachable = false
if sess.StartedAt.IsZero() {
sess.StartedAt = model.Now()
}
if sess.EndedAt.IsZero() {
sess.EndedAt = model.Now()
}
if *snapshot.ExitCode == 0 {
sess.Status = model.GuestSessionStatusExited
} else {
sess.Status = model.GuestSessionStatusFailed
}
return
}
if snapshot.Alive {
if sess.StartedAt.IsZero() {
sess.StartedAt = model.Now()
}
sess.Status = model.GuestSessionStatusRunning
return
}
if !vmRunning && (sess.Status == model.GuestSessionStatusStarting || sess.Status == model.GuestSessionStatusRunning || sess.Status == model.GuestSessionStatusStopping) {
sess.Status = model.GuestSessionStatusFailed
sess.Attachable = false
sess.Reattachable = false
if sess.LastError == "" {
sess.LastError = "vm is not running"
}
if sess.EndedAt.IsZero() {
sess.EndedAt = model.Now()
}
return
}
if snapshot.Status == string(model.GuestSessionStatusRunning) {
if sess.StartedAt.IsZero() {
sess.StartedAt = model.Now()
}
sess.Status = model.GuestSessionStatusRunning
}
if sess.Status == model.GuestSessionStatusRunning && sess.StdinMode == model.GuestSessionStdinPipe {
sess.Attachable = true
sess.Reattachable = true
if sess.AttachBackend == "" {
sess.AttachBackend = AttachBackendSSHBridge
}
if sess.AttachMode == "" {
sess.AttachMode = AttachModeExclusive
}
}
}
// StateChanged reports whether any materially observable field differs
// between before and after, guiding whether to persist an update.
func StateChanged(before, after model.GuestSession) bool {
if before.Status != after.Status || before.GuestPID != after.GuestPID || before.LastError != after.LastError || before.Attachable != after.Attachable || before.Reattachable != after.Reattachable || before.AttachBackend != after.AttachBackend || before.AttachMode != after.AttachMode || before.LaunchStage != after.LaunchStage || before.LaunchMessage != after.LaunchMessage || before.LaunchRawLog != after.LaunchRawLog {
return true
}
if before.StartedAt != after.StartedAt || before.EndedAt != after.EndedAt {
return true
}
switch {
case before.ExitCode == nil && after.ExitCode == nil:
return false
case before.ExitCode == nil || after.ExitCode == nil:
return true
default:
return *before.ExitCode != *after.ExitCode
}
}
// -- Launch helpers ---------------------------------------------------------
// DefaultName returns a friendly session name: caller-provided if non-empty,
// otherwise `<command-base>-<short-id>`.
func DefaultName(id, command, explicit string) string {
if trimmed := strings.TrimSpace(explicit); trimmed != "" {
return trimmed
}
base := filepath.Base(strings.TrimSpace(command))
if base == "." || base == string(filepath.Separator) || base == "" {
base = "session"
}
return base + "-" + system.ShortID(id)
}
// DefaultCWD returns value if non-empty, else /root.
func DefaultCWD(value string) string {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
return "/root"
}
// FailLaunch annotates sess as launch-failed with stage/message/raw log and
// returns it for persistence.
func FailLaunch(sess model.GuestSession, stage, message, rawLog string) model.GuestSession {
now := model.Now()
sess.Status = model.GuestSessionStatusFailed
sess.LastError = strings.TrimSpace(message)
sess.Attachable = false
sess.Reattachable = false
sess.LaunchStage = strings.TrimSpace(stage)
sess.LaunchMessage = strings.TrimSpace(message)
sess.LaunchRawLog = strings.TrimSpace(rawLog)
sess.UpdatedAt = now
sess.EndedAt = now
return sess
}
// NormalizeRequiredCommands returns a de-duplicated, order-preserving list
// of required commands, with the session command first.
func NormalizeRequiredCommands(command string, extras []string) []string {
ordered := make([]string, 0, len(extras)+1)
seen := map[string]struct{}{}
appendValue := func(value string) {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return
}
if _, ok := seen[trimmed]; ok {
return
}
seen[trimmed] = struct{}{}
ordered = append(ordered, trimmed)
}
appendValue(command)
for _, extra := range extras {
appendValue(extra)
}
return ordered
}
// -- Small utilities --------------------------------------------------------
// ShellQuote returns value single-quoted for bash, escaping embedded quotes.
func ShellQuote(value string) string {
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
}
// ExitCode extracts the exit status from an ssh.ExitError, returning
// (0, true) for nil errors.
func ExitCode(err error) (int, bool) {
if err == nil {
return 0, true
}
var exitErr *ssh.ExitError
if errors.As(err, &exitErr) {
return exitErr.ExitStatus(), true
}
return 0, false
}
// CloneStringMap returns a shallow copy of values, or nil if empty.
func CloneStringMap(values map[string]string) map[string]string {
if len(values) == 0 {
return nil
}
cloned := make(map[string]string, len(values))
for key, value := range values {
cloned[key] = value
}
return cloned
}
// TailFileContent returns the last N lines of a file, or "" if the file is
// missing.
func TailFileContent(path string, lines int) (string, error) {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return "", nil
}
return "", err
}
if lines <= 0 {
return string(data), nil
}
parts := strings.Split(string(data), "\n")
if len(parts) <= lines {
return string(data), nil
}
return strings.Join(parts[len(parts)-lines-1:], "\n"), nil
}
// ProcessAlive returns true if the process with pid exists. The syscallKill
// override is exposed for tests that need to simulate alive/dead processes.
func ProcessAlive(pid int) bool {
if pid <= 0 {
return false
}
return syscallKill(pid, syscall.Signal(0)) == nil
}
// syscallKill is a test seam: tests replace it to stub process-alive checks.
var syscallKill = func(pid int, signal os.Signal) error {
proc, err := os.FindProcess(pid)
if err != nil {
return err
}
return proc.Signal(signal)
}
// FormatStepError wraps err with an action label and trimmed on-guest log.
func FormatStepError(action string, err error, log string) error {
log = strings.TrimSpace(log)
if log == "" {
return fmt.Errorf("%s: %w", action, err)
}
return fmt.Errorf("%s: %w: %s", action, err, log)
}

View file

@ -1,440 +0,0 @@
package session
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
"banger/internal/model"
"golang.org/x/crypto/ssh"
)
func TestRelativeStateDir(t *testing.T) {
got := RelativeStateDir("abc")
if strings.HasPrefix(got, "/root/") {
t.Fatalf("RelativeStateDir(%q) = %q, should strip /root/ prefix", "abc", got)
}
if !strings.Contains(got, "abc") {
t.Fatalf("missing session id in %q", got)
}
absolute := StateDir("abc")
if got != strings.TrimPrefix(absolute, "/root/") {
t.Fatalf("relative = %q, want %q", got, strings.TrimPrefix(absolute, "/root/"))
}
}
func TestDefaultCWD(t *testing.T) {
if DefaultCWD("") != "/root" {
t.Error("empty should return /root")
}
if DefaultCWD(" ") != "/root" {
t.Error("whitespace should return /root")
}
if DefaultCWD("/work") != "/work" {
t.Error("explicit should pass through")
}
}
func TestShellQuote(t *testing.T) {
if got := ShellQuote(""); got != "''" {
t.Errorf("empty: got %q, want ''", got)
}
if got := ShellQuote("x"); got != "'x'" {
t.Errorf("plain: got %q", got)
}
if got := ShellQuote("it's"); got != `'it'"'"'s'` {
t.Errorf("apostrophe: got %q", got)
}
}
func TestExitCode(t *testing.T) {
if code, ok := ExitCode(nil); !ok || code != 0 {
t.Errorf("nil err: got (%d, %v), want (0, true)", code, ok)
}
// Build an ssh.ExitError using its real type — can't hand-construct,
// so wrap via errors.As check with a stub.
raw := &ssh.ExitError{}
if _, ok := ExitCode(raw); !ok {
t.Error("ssh.ExitError: ok should be true")
}
if _, ok := ExitCode(errors.New("bare error")); ok {
t.Error("bare error: ok should be false")
}
}
func TestCloneStringMap(t *testing.T) {
if CloneStringMap(nil) != nil {
t.Error("nil in → nil out")
}
if CloneStringMap(map[string]string{}) != nil {
t.Error("empty in → nil out")
}
src := map[string]string{"a": "1", "b": "2"}
cloned := CloneStringMap(src)
if len(cloned) != 2 {
t.Fatalf("len = %d, want 2", len(cloned))
}
cloned["a"] = "changed"
if src["a"] != "1" {
t.Error("mutating clone leaked back to source")
}
}
func TestTailFileContent(t *testing.T) {
// Missing file → empty, no error.
got, err := TailFileContent(filepath.Join(t.TempDir(), "missing"), 10)
if err != nil || got != "" {
t.Errorf("missing: got (%q, %v), want ('', nil)", got, err)
}
path := filepath.Join(t.TempDir(), "log")
lines := "one\ntwo\nthree\nfour\nfive"
if err := os.WriteFile(path, []byte(lines), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
full, err := TailFileContent(path, 0)
if err != nil || full != lines {
t.Errorf("0 lines: got (%q, %v), want (%q, nil)", full, err, lines)
}
// Request more lines than exist → full content.
all, err := TailFileContent(path, 999)
if err != nil || all != lines {
t.Errorf("999 lines: got %q", all)
}
last2, err := TailFileContent(path, 2)
if err != nil {
t.Fatalf("2 lines: %v", err)
}
if !strings.Contains(last2, "five") {
t.Errorf("2 lines missing last line: %q", last2)
}
}
func TestProcessAlive(t *testing.T) {
if ProcessAlive(0) {
t.Error("pid 0 should not be alive")
}
if ProcessAlive(-1) {
t.Error("negative pid should not be alive")
}
// Swap the syscall seam.
original := syscallKill
t.Cleanup(func() { syscallKill = original })
syscallKill = func(pid int, signal os.Signal) error { return nil }
if !ProcessAlive(42) {
t.Error("syscallKill=nil should report alive")
}
syscallKill = func(pid int, signal os.Signal) error { return fmt.Errorf("no such process") }
if ProcessAlive(42) {
t.Error("syscallKill error should report dead")
}
}
func TestFormatStepError(t *testing.T) {
base := errors.New("boom")
err := FormatStepError("prepare", base, "")
if !errors.Is(err, base) {
t.Error("FormatStepError should wrap the base error")
}
if !strings.Contains(err.Error(), "prepare") {
t.Errorf("missing action: %v", err)
}
errWithLog := FormatStepError("prepare", base, " log line\n")
if !strings.Contains(errWithLog.Error(), "log line") {
t.Errorf("missing log: %v", errWithLog)
}
}
func TestParseStateHappyPath(t *testing.T) {
raw := `status=running
pid=123
exit=
alive=true
error=
`
snap, err := ParseState(raw)
if err != nil {
t.Fatalf("ParseState: %v", err)
}
if snap.Status != "running" {
t.Errorf("Status = %q", snap.Status)
}
if snap.GuestPID != 123 {
t.Errorf("GuestPID = %d", snap.GuestPID)
}
if snap.ExitCode != nil {
t.Errorf("ExitCode should be nil when empty, got %v", snap.ExitCode)
}
if !snap.Alive {
t.Error("Alive should be true")
}
}
func TestParseStateWithExit(t *testing.T) {
raw := `status=exited
pid=123
exit=7
alive=false
error=something bad
`
snap, err := ParseState(raw)
if err != nil {
t.Fatalf("ParseState: %v", err)
}
if snap.ExitCode == nil || *snap.ExitCode != 7 {
t.Errorf("ExitCode = %v, want 7", snap.ExitCode)
}
if snap.LastError != "something bad" {
t.Errorf("LastError = %q", snap.LastError)
}
if snap.Alive {
t.Error("Alive should be false")
}
}
func TestParseStateIgnoresMalformedLines(t *testing.T) {
raw := "no-equals-here\nstatus=ok\n"
snap, err := ParseState(raw)
if err != nil {
t.Fatalf("ParseState: %v", err)
}
if snap.Status != "ok" {
t.Errorf("Status = %q, want ok", snap.Status)
}
}
func TestInspectStateFromDir(t *testing.T) {
dir := t.TempDir()
writeFile := func(name, content string) {
if err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0o600); err != nil {
t.Fatalf("WriteFile(%s): %v", name, err)
}
}
writeFile("status", "running\n")
writeFile("pid", "42\n")
writeFile("exit_code", "0\n")
writeFile("error", "\n")
original := syscallKill
t.Cleanup(func() { syscallKill = original })
syscallKill = func(pid int, signal os.Signal) error { return nil }
snap, err := InspectStateFromDir(dir)
if err != nil {
t.Fatalf("InspectStateFromDir: %v", err)
}
if snap.Status != "running" {
t.Errorf("Status = %q", snap.Status)
}
if snap.GuestPID != 42 {
t.Errorf("GuestPID = %d", snap.GuestPID)
}
if snap.ExitCode == nil || *snap.ExitCode != 0 {
t.Errorf("ExitCode = %v, want 0", snap.ExitCode)
}
if !snap.Alive {
t.Error("Alive should reflect syscallKill result (true)")
}
}
func TestInspectStateFromDirMissingFiles(t *testing.T) {
snap, err := InspectStateFromDir(t.TempDir())
if err != nil {
t.Fatalf("InspectStateFromDir (empty): %v", err)
}
if snap.Status != "" || snap.GuestPID != 0 || snap.ExitCode != nil {
t.Errorf("empty dir: snap = %+v", snap)
}
}
func TestApplyStateSnapshotNilReceiver(t *testing.T) {
ApplyStateSnapshot(nil, StateSnapshot{}, true) // should not panic
}
func TestApplyStateSnapshotExitedSuccess(t *testing.T) {
exit := 0
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning, Attachable: true, Reattachable: true}
ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true)
if sess.Status != model.GuestSessionStatusExited {
t.Errorf("Status = %q, want exited", sess.Status)
}
if sess.Attachable || sess.Reattachable {
t.Error("attach flags should be cleared on exit")
}
if sess.EndedAt.IsZero() {
t.Error("EndedAt should be set")
}
}
func TestApplyStateSnapshotExitedFailure(t *testing.T) {
exit := 2
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning}
ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true)
if sess.Status != model.GuestSessionStatusFailed {
t.Errorf("Status = %q, want failed", sess.Status)
}
}
func TestApplyStateSnapshotVMGone(t *testing.T) {
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning}
ApplyStateSnapshot(sess, StateSnapshot{Alive: false}, false)
if sess.Status != model.GuestSessionStatusFailed {
t.Errorf("Status = %q, want failed", sess.Status)
}
if sess.LastError == "" {
t.Error("LastError should be populated when VM is gone")
}
}
func TestApplyStateSnapshotRunningStatusSetsAttachableForPipe(t *testing.T) {
// When the guest-side status file reports "running" (Alive=false from
// kill -0 may still fail transiently), ApplyStateSnapshot transitions
// the session to running and sets attach flags for pipe-mode.
sess := &model.GuestSession{
Status: model.GuestSessionStatusStarting,
StdinMode: model.GuestSessionStdinPipe,
}
ApplyStateSnapshot(sess, StateSnapshot{Status: string(model.GuestSessionStatusRunning), GuestPID: 11}, true)
if sess.Status != model.GuestSessionStatusRunning {
t.Errorf("Status = %q, want running", sess.Status)
}
if !sess.Attachable || !sess.Reattachable {
t.Error("pipe-mode running session should be attachable + reattachable")
}
if sess.AttachBackend != AttachBackendSSHBridge {
t.Errorf("AttachBackend = %q, want %q", sess.AttachBackend, AttachBackendSSHBridge)
}
}
func TestApplyStateSnapshotAliveEarlyReturn(t *testing.T) {
// Alive-true returns immediately after setting status; no attach
// flags set on this path (by design — attach metadata only attaches
// to status-driven transitions).
sess := &model.GuestSession{
Status: model.GuestSessionStatusStarting,
StdinMode: model.GuestSessionStdinPipe,
}
ApplyStateSnapshot(sess, StateSnapshot{Alive: true, GuestPID: 11}, true)
if sess.Status != model.GuestSessionStatusRunning {
t.Errorf("Status = %q, want running", sess.Status)
}
if sess.StartedAt.IsZero() {
t.Error("StartedAt should have been set")
}
}
func TestStateChanged(t *testing.T) {
base := model.GuestSession{Status: model.GuestSessionStatusRunning, GuestPID: 10}
// Identical → no change.
if StateChanged(base, base) {
t.Error("identical states should not be considered changed")
}
// Status change.
changed := base
changed.Status = model.GuestSessionStatusExited
if !StateChanged(base, changed) {
t.Error("status change should be detected")
}
// ExitCode change from nil → value.
exit := 3
changed = base
changed.ExitCode = &exit
if !StateChanged(base, changed) {
t.Error("exit-code appearing should be detected")
}
// Both have the same exit code → no change.
a := base
a.ExitCode = &exit
b := base
b.ExitCode = &exit
if StateChanged(a, b) {
t.Error("matching exit codes should not trigger change")
}
// Different exit codes.
other := 5
b.ExitCode = &other
if !StateChanged(a, b) {
t.Error("differing exit codes should be detected")
}
// Timestamp change.
changed = base
changed.StartedAt = time.Now()
if !StateChanged(base, changed) {
t.Error("StartedAt change should be detected")
}
}
func TestFailLaunch(t *testing.T) {
in := model.GuestSession{Status: model.GuestSessionStatusStarting, Attachable: true}
out := FailLaunch(in, "provision", " ssh did not come up ", " raw output\n")
if out.Status != model.GuestSessionStatusFailed {
t.Errorf("Status = %q, want failed", out.Status)
}
if out.LastError != "ssh did not come up" {
t.Errorf("LastError = %q (not trimmed?)", out.LastError)
}
if out.LaunchStage != "provision" || out.LaunchMessage != "ssh did not come up" {
t.Errorf("launch fields not set: %+v", out)
}
if out.LaunchRawLog != "raw output" {
t.Errorf("rawLog = %q (not trimmed?)", out.LaunchRawLog)
}
if out.Attachable {
t.Error("Attachable should be cleared")
}
}
func TestNormalizeRequiredCommands(t *testing.T) {
got := NormalizeRequiredCommands("pi", []string{"pi", "git", "", "git", " ", "make"})
want := []string{"pi", "git", "make"}
if len(got) != len(want) {
t.Fatalf("len = %d, want %d (%v)", len(got), len(want), got)
}
for i, v := range want {
if got[i] != v {
t.Errorf("position %d: got %q, want %q", i, got[i], v)
}
}
}
func TestInspectScriptContainsAllStateFiles(t *testing.T) {
script := InspectScript("sess-abc")
for _, key := range []string{"status", "pid", "exit_code", "error", "alive"} {
if !strings.Contains(script, key) {
t.Errorf("script missing %q:\n%s", key, script)
}
}
if !strings.Contains(script, "sess-abc") {
t.Error("script missing session id")
}
}
func TestSignalScriptIncludesSignalAndDirPaths(t *testing.T) {
script := SignalScript("sess-x", "TERM")
if !strings.Contains(script, "TERM") {
t.Error("missing signal")
}
if !strings.Contains(script, "sess-x") {
t.Error("missing session id")
}
if !strings.Contains(script, "monitor_pid") || !strings.Contains(script, "stdin_keepalive") {
t.Errorf("expected both monitor + stdin_keepalive kills, got:\n%s", script)
}
}

View file

@ -1,224 +0,0 @@
package daemon
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"time"
"banger/internal/api"
sess "banger/internal/daemon/session"
"banger/internal/guest"
"banger/internal/model"
"banger/internal/sessionstream"
)
func (d *Daemon) BeginGuestSessionAttach(ctx context.Context, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) {
vm, err := d.FindVM(ctx, params.VMIDOrName)
if err != nil {
return api.GuestSessionAttachBeginResult{}, err
}
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
if err != nil {
return api.GuestSessionAttachBeginResult{}, err
}
session, _ = d.refreshGuestSession(ctx, vm, session)
if !session.Attachable {
return api.GuestSessionAttachBeginResult{}, errors.New("session is not attachable")
}
controller := &guestSessionController{}
if !d.claimGuestSessionController(session.ID, controller) {
return api.GuestSessionAttachBeginResult{}, errors.New("session already has an active attach")
}
attachID, err := model.NewID()
if err != nil {
d.clearGuestSessionController(session.ID)
return api.GuestSessionAttachBeginResult{}, err
}
socketPath := filepath.Join(d.layout.RuntimeDir, "guest-session-attach-"+attachID[:12]+".sock")
_ = os.Remove(socketPath)
listener, err := net.Listen("unix", socketPath)
if err != nil {
d.clearGuestSessionController(session.ID)
return api.GuestSessionAttachBeginResult{}, err
}
if err := os.Chmod(socketPath, 0o600); err != nil {
_ = listener.Close()
_ = os.Remove(socketPath)
d.clearGuestSessionController(session.ID)
return api.GuestSessionAttachBeginResult{}, err
}
go d.serveGuestSessionAttach(session, controller, attachID, socketPath, listener)
return api.GuestSessionAttachBeginResult{
Session: session,
AttachID: attachID,
TransportKind: sess.TransportUnixSocket,
TransportTarget: socketPath,
SocketPath: socketPath,
StreamFormat: sessionstream.FormatV1,
}, nil
}
func (d *Daemon) forwardGuestSessionOutput(_ string, controller *guestSessionController, channel byte, reader io.Reader) {
buffer := make([]byte, 32*1024)
for {
n, err := reader.Read(buffer)
if n > 0 {
controller.writeFrame(channel, buffer[:n])
}
if err != nil {
if !errors.Is(err, io.EOF) {
controller.writeControl(sessionstream.ControlMessage{Type: "error", Error: err.Error()})
}
return
}
}
}
func (d *Daemon) waitForGuestSessionExit(id string, controller *guestSessionController, session model.GuestSession) {
err := controller.stream.Wait()
updated := session
updated.Attachable = false
now := model.Now()
updated.UpdatedAt = now
updated.EndedAt = now
if exitCode, ok := sess.ExitCode(err); ok {
updated.ExitCode = &exitCode
if exitCode == 0 {
updated.Status = model.GuestSessionStatusExited
} else {
updated.Status = model.GuestSessionStatusFailed
}
}
if err != nil && updated.LastError == "" {
updated.LastError = err.Error()
}
if vm, getErr := d.store.GetVMByID(context.Background(), updated.VMID); getErr == nil {
if refreshed, refreshErr := d.refreshGuestSession(context.Background(), vm, updated); refreshErr == nil {
updated = refreshed
}
}
_ = d.store.UpsertGuestSession(context.Background(), updated)
controller.writeControl(sessionstream.ControlMessage{Type: "exit", ExitCode: updated.ExitCode})
_ = controller.close()
d.clearGuestSessionController(id)
}
func (d *Daemon) serveGuestSessionAttach(session model.GuestSession, controller *guestSessionController, _ string, socketPath string, listener net.Listener) {
defer func() {
_ = listener.Close()
_ = os.Remove(socketPath)
_ = controller.close()
d.clearGuestSessionController(session.ID)
}()
conn, err := listener.Accept()
if err != nil {
return
}
defer conn.Close()
if err := controller.setAttach(conn); err != nil {
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
return
}
defer controller.clearAttach(conn)
if err := d.attachGuestSessionBridge(session, controller); err != nil {
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
return
}
for {
channel, payload, err := sessionstream.ReadFrame(conn)
if err != nil {
return
}
switch channel {
case sessionstream.ChannelStdin:
if controller.stdin == nil {
continue
}
if _, err := controller.stdin.Write(payload); err != nil {
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
return
}
case sessionstream.ChannelControl:
message, err := sessionstream.ReadControl(payload)
if err != nil {
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
return
}
if message.Type == "eof" && controller.stdin != nil {
_ = controller.stdin.Close()
}
}
}
}
func (d *Daemon) attachGuestSessionBridge(session model.GuestSession, controller *guestSessionController) error {
vm, err := d.store.GetVMByID(context.Background(), session.VMID)
if err != nil {
return err
}
if !d.vmAlive(vm) {
return fmt.Errorf("vm %q is not running", vm.Name)
}
address := net.JoinHostPort(vm.Runtime.GuestIP, "22")
stdinStream, err := d.openGuestSessionAttachStream(address, sess.AttachInputCommand(session.ID))
if err != nil {
return fmt.Errorf("open guest session stdin stream: %w", err)
}
stdoutStream, err := d.openGuestSessionAttachStream(address, sess.AttachTailCommand(session.StdoutLogPath))
if err != nil {
_ = stdinStream.Close()
return fmt.Errorf("open guest session stdout stream: %w", err)
}
stderrStream, err := d.openGuestSessionAttachStream(address, sess.AttachTailCommand(session.StderrLogPath))
if err != nil {
_ = stdinStream.Close()
_ = stdoutStream.Close()
return fmt.Errorf("open guest session stderr stream: %w", err)
}
controller.streams = append(controller.streams, stdinStream, stdoutStream, stderrStream)
controller.stdin = stdinStream.Stdin()
go d.forwardGuestSessionOutput(session.ID, controller, sessionstream.ChannelStdout, stdoutStream.Stdout())
go d.forwardGuestSessionOutput(session.ID, controller, sessionstream.ChannelStderr, stderrStream.Stdout())
go d.watchGuestSessionAttach(session.ID, controller, session)
return nil
}
func (d *Daemon) openGuestSessionAttachStream(address, command string) (*guest.StreamSession, error) {
client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return nil, err
}
stream, err := client.StartCommand(context.Background(), command)
if err != nil {
_ = client.Close()
return nil, err
}
return stream, nil
}
func (d *Daemon) watchGuestSessionAttach(id string, controller *guestSessionController, session model.GuestSession) {
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
for range ticker.C {
vm, err := d.store.GetVMByID(context.Background(), session.VMID)
if err != nil {
controller.writeControl(sessionstream.ControlMessage{Type: "error", Error: err.Error()})
_ = controller.close()
return
}
refreshed, err := d.refreshGuestSession(context.Background(), vm, session)
if err == nil {
session = refreshed
}
if session.Status == model.GuestSessionStatusExited || session.Status == model.GuestSessionStatusFailed {
controller.writeControl(sessionstream.ControlMessage{Type: "exit", ExitCode: session.ExitCode})
_ = controller.close()
return
}
}
}

View file

@ -1,184 +0,0 @@
package daemon
import (
"errors"
"io"
"net"
"sync"
"banger/internal/guest"
"banger/internal/sessionstream"
)
type guestSessionController struct {
stream *guest.StreamSession
streams []*guest.StreamSession
stdin io.WriteCloser
attachMu sync.Mutex
attach net.Conn
writeMu sync.Mutex
closeOnce sync.Once
}
func (c *guestSessionController) setAttach(conn net.Conn) error {
c.attachMu.Lock()
defer c.attachMu.Unlock()
if c.attach != nil {
return errors.New("session already has an active attach")
}
c.attach = conn
return nil
}
func (c *guestSessionController) clearAttach(conn net.Conn) {
c.attachMu.Lock()
defer c.attachMu.Unlock()
if c.attach == conn {
c.attach = nil
}
}
func (c *guestSessionController) writeFrame(channel byte, payload []byte) {
c.attachMu.Lock()
conn := c.attach
c.attachMu.Unlock()
if conn == nil {
return
}
c.writeMu.Lock()
err := sessionstream.WriteFrame(conn, channel, payload)
c.writeMu.Unlock()
if err != nil {
_ = conn.Close()
c.clearAttach(conn)
}
}
func (c *guestSessionController) writeControl(message sessionstream.ControlMessage) {
c.attachMu.Lock()
conn := c.attach
c.attachMu.Unlock()
if conn == nil {
return
}
c.writeMu.Lock()
err := sessionstream.WriteControl(conn, message)
c.writeMu.Unlock()
if err != nil {
_ = conn.Close()
c.clearAttach(conn)
}
}
func (c *guestSessionController) close() error {
if c == nil {
return nil
}
var err error
c.closeOnce.Do(func() {
c.attachMu.Lock()
conn := c.attach
c.attach = nil
c.attachMu.Unlock()
if conn != nil {
err = errors.Join(err, conn.Close())
}
if c.stdin != nil {
err = errors.Join(err, c.stdin.Close())
}
if c.stream != nil {
err = errors.Join(err, c.stream.Close())
}
for _, stream := range c.streams {
if stream != nil {
err = errors.Join(err, stream.Close())
}
}
})
return err
}
// sessionRegistry owns the live guest-session controller map. Its lock is
// independent of Daemon.mu so guest-session lookups do not contend with
// unrelated daemon state.
type sessionRegistry struct {
mu sync.Mutex
byID map[string]*guestSessionController
closed bool
}
func newSessionRegistry() sessionRegistry {
return sessionRegistry{byID: make(map[string]*guestSessionController)}
}
func (r *sessionRegistry) set(id string, controller *guestSessionController) {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return
}
r.byID[id] = controller
}
func (r *sessionRegistry) claim(id string, controller *guestSessionController) bool {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return false
}
if r.byID[id] != nil {
return false
}
r.byID[id] = controller
return true
}
func (r *sessionRegistry) get(id string) *guestSessionController {
r.mu.Lock()
defer r.mu.Unlock()
return r.byID[id]
}
func (r *sessionRegistry) clear(id string) *guestSessionController {
r.mu.Lock()
defer r.mu.Unlock()
controller := r.byID[id]
delete(r.byID, id)
return controller
}
func (r *sessionRegistry) closeAll() error {
r.mu.Lock()
controllers := make([]*guestSessionController, 0, len(r.byID))
for _, controller := range r.byID {
controllers = append(controllers, controller)
}
r.byID = nil
r.closed = true
r.mu.Unlock()
var err error
for _, controller := range controllers {
err = errors.Join(err, controller.close())
}
return err
}
func (d *Daemon) setGuestSessionController(id string, controller *guestSessionController) {
d.sessions.set(id, controller)
}
func (d *Daemon) claimGuestSessionController(id string, controller *guestSessionController) bool {
return d.sessions.claim(id, controller)
}
func (d *Daemon) getGuestSessionController(id string) *guestSessionController {
return d.sessions.get(id)
}
func (d *Daemon) clearGuestSessionController(id string) *guestSessionController {
return d.sessions.clear(id)
}
func (d *Daemon) closeGuestSessionControllers() error {
return d.sessions.closeAll()
}

View file

@ -1,213 +0,0 @@
package daemon
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"strings"
"time"
"banger/internal/api"
sess "banger/internal/daemon/session"
"banger/internal/guest"
"banger/internal/model"
)
func (d *Daemon) StartGuestSession(ctx context.Context, params api.GuestSessionStartParams) (model.GuestSession, error) {
stdinMode := model.GuestSessionStdinMode(strings.TrimSpace(params.StdinMode))
if stdinMode == "" {
stdinMode = model.GuestSessionStdinClosed
}
if stdinMode != model.GuestSessionStdinClosed && stdinMode != model.GuestSessionStdinPipe {
return model.GuestSession{}, fmt.Errorf("unsupported stdin mode %q", params.StdinMode)
}
if strings.TrimSpace(params.Command) == "" {
return model.GuestSession{}, errors.New("session command is required")
}
var created model.GuestSession
_, err := d.withVMLockByRef(ctx, params.VMIDOrName, func(vm model.VMRecord) (model.VMRecord, error) {
if !d.vmAlive(vm) {
return model.VMRecord{}, fmt.Errorf("vm %q is not running", vm.Name)
}
session, err := d.startGuestSessionLocked(ctx, vm, params, stdinMode)
if err != nil {
return model.VMRecord{}, err
}
created = session
return vm, nil
})
return created, err
}
func (d *Daemon) startGuestSessionLocked(ctx context.Context, vm model.VMRecord, params api.GuestSessionStartParams, stdinMode model.GuestSessionStdinMode) (model.GuestSession, error) {
id, err := model.NewID()
if err != nil {
return model.GuestSession{}, err
}
now := model.Now()
session := model.GuestSession{
ID: id,
VMID: vm.ID,
Name: sess.DefaultName(id, params.Command, params.Name),
Backend: sess.BackendSSH,
Command: params.Command,
Args: append([]string(nil), params.Args...),
CWD: strings.TrimSpace(params.CWD),
Env: sess.CloneStringMap(params.Env),
StdinMode: stdinMode,
Status: model.GuestSessionStatusStarting,
GuestStateDir: sess.StateDir(id),
StdoutLogPath: sess.StdoutLogPath(id),
StderrLogPath: sess.StderrLogPath(id),
Tags: sess.CloneStringMap(params.Tags),
Attachable: stdinMode == model.GuestSessionStdinPipe,
Reattachable: stdinMode == model.GuestSessionStdinPipe,
CreatedAt: now,
UpdatedAt: now,
}
if session.Attachable {
session.AttachBackend = sess.AttachBackendSSHBridge
session.AttachMode = sess.AttachModeExclusive
} else {
session.AttachBackend = sess.AttachBackendNone
}
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
return model.GuestSession{}, err
}
fail := func(stage, message, rawLog string) (model.GuestSession, error) {
session = sess.FailLaunch(session, stage, message, rawLog)
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
return model.GuestSession{}, err
}
return session, nil
}
address := net.JoinHostPort(vm.Runtime.GuestIP, "22")
if err := d.waitForGuestSSH(ctx, address, 250*time.Millisecond); err != nil {
return fail("ssh_unavailable", fmt.Sprintf("guest ssh unavailable: %v", err), "")
}
client, err := d.dialGuest(ctx, address)
if err != nil {
return fail("dial_guest", fmt.Sprintf("dial guest ssh: %v", err), "")
}
defer client.Close()
var preflightLog bytes.Buffer
if err := client.RunScript(ctx, sess.CWDPreflightScript(session.CWD), &preflightLog); err != nil {
return fail("preflight_cwd", fmt.Sprintf("guest working directory is unavailable: %s", sess.DefaultCWD(session.CWD)), preflightLog.String())
}
preflightLog.Reset()
requiredCommands := sess.NormalizeRequiredCommands(params.Command, params.RequiredCommands)
if err := client.RunScript(ctx, sess.CommandPreflightScript(requiredCommands), &preflightLog); err != nil {
return fail("preflight_command", fmt.Sprintf("required guest command is unavailable: %s", strings.TrimSpace(preflightLog.String())), preflightLog.String())
}
var uploadLog bytes.Buffer
if err := client.UploadFile(ctx, sess.ScriptPath(id), 0o755, []byte(sess.Script(session)), &uploadLog); err != nil {
return fail("upload_script", "upload guest session script failed", uploadLog.String())
}
var launchLog bytes.Buffer
launchScript := fmt.Sprintf("set -euo pipefail\nnohup bash %s >/dev/null 2>&1 </dev/null &\ndisown || true\n", sess.ShellQuote(sess.ScriptPath(id)))
if err := client.RunScript(ctx, launchScript, &launchLog); err != nil {
return fail("launch", "launch guest session failed", launchLog.String())
}
readyCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
updated, err := d.waitForGuestSessionReadyHook(readyCtx, vm, session)
if err != nil {
return fail("ready_wait", "guest session did not report ready state", err.Error())
}
session = updated
if session.Status == model.GuestSessionStatusStarting {
session.Status = model.GuestSessionStatusRunning
session.StartedAt = model.Now()
session.UpdatedAt = model.Now()
}
session.LaunchStage = ""
session.LaunchMessage = ""
session.LaunchRawLog = ""
session.LastError = ""
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
return model.GuestSession{}, err
}
return session, nil
}
func (d *Daemon) GetGuestSession(ctx context.Context, params api.GuestSessionRefParams) (model.GuestSession, error) {
vm, err := d.FindVM(ctx, params.VMIDOrName)
if err != nil {
return model.GuestSession{}, err
}
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
if err != nil {
return model.GuestSession{}, err
}
return d.refreshGuestSession(ctx, vm, session)
}
func (d *Daemon) ListGuestSessions(ctx context.Context, params api.VMRefParams) ([]model.GuestSession, error) {
vm, err := d.FindVM(ctx, params.IDOrName)
if err != nil {
return nil, err
}
sessions, err := d.store.ListGuestSessionsByVM(ctx, vm.ID)
if err != nil {
return nil, err
}
for index := range sessions {
refreshed, refreshErr := d.refreshGuestSession(ctx, vm, sessions[index])
if refreshErr == nil {
sessions[index] = refreshed
}
}
return sessions, nil
}
func (d *Daemon) StopGuestSession(ctx context.Context, params api.GuestSessionRefParams) (model.GuestSession, error) {
return d.signalGuestSession(ctx, params, "TERM")
}
func (d *Daemon) KillGuestSession(ctx context.Context, params api.GuestSessionRefParams) (model.GuestSession, error) {
return d.signalGuestSession(ctx, params, "KILL")
}
func (d *Daemon) signalGuestSession(ctx context.Context, params api.GuestSessionRefParams, signal string) (model.GuestSession, error) {
vm, err := d.FindVM(ctx, params.VMIDOrName)
if err != nil {
return model.GuestSession{}, err
}
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
if err != nil {
return model.GuestSession{}, err
}
session, _ = d.refreshGuestSession(ctx, vm, session)
if session.Status == model.GuestSessionStatusExited || session.Status == model.GuestSessionStatusFailed {
return session, nil
}
if !d.vmAlive(vm) {
session.Status = model.GuestSessionStatusFailed
session.LastError = "vm is not running"
now := model.Now()
session.UpdatedAt = now
session.EndedAt = now
session.Attachable = false
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
return model.GuestSession{}, err
}
return session, nil
}
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return model.GuestSession{}, err
}
defer client.Close()
var log bytes.Buffer
if err := client.RunScript(ctx, sess.SignalScript(session.ID, signal), &log); err != nil {
return model.GuestSession{}, sess.FormatStepError("signal guest session", err, log.String())
}
session.Status = model.GuestSessionStatusStopping
session.UpdatedAt = model.Now()
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
return model.GuestSession{}, err
}
return session, nil
}

View file

@ -1,120 +0,0 @@
package daemon
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"path/filepath"
"strings"
"banger/internal/api"
sess "banger/internal/daemon/session"
"banger/internal/guest"
"banger/internal/model"
"banger/internal/system"
)
func (d *Daemon) GuestSessionLogs(ctx context.Context, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) {
vm, err := d.FindVM(ctx, params.VMIDOrName)
if err != nil {
return api.GuestSessionLogsResult{}, err
}
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
if err != nil {
return api.GuestSessionLogsResult{}, err
}
streamName := strings.TrimSpace(params.Stream)
if streamName == "" {
streamName = "stdout"
}
tailLines := params.TailLines
if tailLines <= 0 {
tailLines = sess.LogTailLineDefault
}
path := session.StdoutLogPath
if streamName == "stderr" {
path = session.StderrLogPath
}
content, err := d.readGuestSessionLog(ctx, vm, session, streamName, tailLines)
if err != nil {
return api.GuestSessionLogsResult{}, err
}
return api.GuestSessionLogsResult{Session: session, Stream: streamName, Path: path, Content: content}, nil
}
func (d *Daemon) SendToGuestSession(ctx context.Context, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
vm, err := d.FindVM(ctx, params.VMIDOrName)
if err != nil {
return api.GuestSessionSendResult{}, err
}
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
if err != nil {
return api.GuestSessionSendResult{}, err
}
if session.StdinMode != model.GuestSessionStdinPipe {
return api.GuestSessionSendResult{}, errors.New("session does not have a stdin pipe")
}
if session.Status != model.GuestSessionStatusRunning {
return api.GuestSessionSendResult{}, fmt.Errorf("session is not running (status=%s)", session.Status)
}
if !d.vmAlive(vm) {
return api.GuestSessionSendResult{}, fmt.Errorf("vm %q is not running", vm.Name)
}
if len(params.Payload) == 0 {
return api.GuestSessionSendResult{Session: session}, nil
}
client, err := d.dialGuest(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"))
if err != nil {
return api.GuestSessionSendResult{}, fmt.Errorf("dial guest: %w", err)
}
defer client.Close()
tmpPath := fmt.Sprintf("/tmp/banger-send-%s.bin", session.ID[:8])
var uploadLog bytes.Buffer
if err := client.UploadFile(ctx, tmpPath, 0o600, params.Payload, &uploadLog); err != nil {
return api.GuestSessionSendResult{}, fmt.Errorf("upload payload: %w", err)
}
sendScript := fmt.Sprintf(
"set -euo pipefail\ncat %s >> %s\nrm -f %s\n",
sess.ShellQuote(tmpPath),
sess.ShellQuote(sess.StdinPipePath(session.ID)),
sess.ShellQuote(tmpPath),
)
var sendLog bytes.Buffer
if err := client.RunScript(ctx, sendScript, &sendLog); err != nil {
return api.GuestSessionSendResult{}, fmt.Errorf("send to session: %w: %s", err, strings.TrimSpace(sendLog.String()))
}
return api.GuestSessionSendResult{Session: session, BytesWritten: len(params.Payload)}, nil
}
func (d *Daemon) readGuestSessionLog(ctx context.Context, vm model.VMRecord, session model.GuestSession, stream string, tailLines int) (string, error) {
if d.vmAlive(vm) {
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
if err != nil {
return "", err
}
defer client.Close()
path := session.StdoutLogPath
if stream == "stderr" {
path = session.StderrLogPath
}
var output bytes.Buffer
script := fmt.Sprintf("set -euo pipefail\nif [ -f %s ]; then tail -n %d %s; fi\n", sess.ShellQuote(path), tailLines, sess.ShellQuote(path))
if err := client.RunScript(ctx, script, &output); err != nil {
return "", sess.FormatStepError("read guest session log", err, output.String())
}
return output.String(), nil
}
runner := d.runner
if runner == nil {
runner = system.NewRunner()
}
workMount, cleanup, err := system.MountTempDir(ctx, runner, vm.Runtime.WorkDiskPath, false)
if err != nil {
return "", err
}
defer cleanup()
logPath := filepath.Join(workMount, sess.RelativeStateDir(session.ID), stream+".log")
return sess.TailFileContent(logPath, tailLines)
}

View file

@ -10,7 +10,6 @@ import (
"time"
"banger/internal/api"
sess "banger/internal/daemon/session"
ws "banger/internal/daemon/workspace"
"banger/internal/model"
)
@ -114,9 +113,9 @@ func exportScript(guestPath, diffRef, diffFlag string) string {
"git read-tree %s --index-output=\"$tmp_idx\"\n"+
"GIT_INDEX_FILE=\"$tmp_idx\" git add -A\n"+
"GIT_INDEX_FILE=\"$tmp_idx\" git diff --cached %s %s\n",
sess.ShellQuote(guestPath),
sess.ShellQuote(diffRef),
sess.ShellQuote(diffRef),
ws.ShellQuote(guestPath),
ws.ShellQuote(diffRef),
ws.ShellQuote(diffRef),
diffFlag,
)
}
@ -189,9 +188,9 @@ func (d *Daemon) prepareVMWorkspaceGuestIO(ctx context.Context, vm model.VMRecor
}
if readOnly {
var chmodLog bytes.Buffer
chmodScript := fmt.Sprintf("set -euo pipefail\nchmod -R a-w %s\n", sess.ShellQuote(guestPath))
chmodScript := fmt.Sprintf("set -euo pipefail\nchmod -R a-w %s\n", ws.ShellQuote(guestPath))
if err := client.RunScript(ctx, chmodScript, &chmodLog); err != nil {
return model.WorkspacePrepareResult{}, sess.FormatStepError("set workspace readonly", err, chmodLog.String())
return model.WorkspacePrepareResult{}, ws.FormatStepError("set workspace readonly", err, chmodLog.String())
}
}
return model.WorkspacePrepareResult{

View file

@ -0,0 +1,20 @@
package workspace
import (
"fmt"
"strings"
)
// ShellQuote returns value single-quoted for bash, escaping embedded quotes.
func ShellQuote(value string) string {
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
}
// FormatStepError wraps err with an action label and trimmed on-guest log.
func FormatStepError(action string, err error, log string) error {
log = strings.TrimSpace(log)
if log == "" {
return fmt.Errorf("%s: %w", action, err)
}
return fmt.Errorf("%s: %w: %s", action, err, log)
}

View file

@ -18,7 +18,6 @@ import (
"sort"
"strings"
sess "banger/internal/daemon/session"
"banger/internal/model"
"banger/internal/system"
)
@ -146,13 +145,13 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g
switch mode {
case model.WorkspacePrepareModeFullCopy:
var copyLog bytes.Buffer
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath), sess.ShellQuote(guestPath), sess.ShellQuote(guestPath))
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath), ShellQuote(guestPath), ShellQuote(guestPath))
if err := client.StreamTar(ctx, spec.RepoRoot, command, &copyLog); err != nil {
return sess.FormatStepError("copy full workspace", err, copyLog.String())
return FormatStepError("copy full workspace", err, copyLog.String())
}
var finalizeLog bytes.Buffer
if err := client.RunScript(ctx, FinalizeScript(spec, guestPath, mode), &finalizeLog); err != nil {
return sess.FormatStepError("finalize full workspace", err, finalizeLog.String())
return FormatStepError("finalize full workspace", err, finalizeLog.String())
}
return nil
case model.WorkspacePrepareModeMetadataOnly, model.WorkspacePrepareModeShallowOverlay:
@ -162,21 +161,21 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g
}
defer cleanup()
var copyLog bytes.Buffer
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath), sess.ShellQuote(guestPath), sess.ShellQuote(guestPath))
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath), ShellQuote(guestPath), ShellQuote(guestPath))
if err := client.StreamTar(ctx, repoCopyDir, command, &copyLog); err != nil {
return sess.FormatStepError("copy guest git metadata", err, copyLog.String())
return FormatStepError("copy guest git metadata", err, copyLog.String())
}
var scriptLog bytes.Buffer
if err := client.RunScript(ctx, FinalizeScript(spec, guestPath, mode), &scriptLog); err != nil {
return sess.FormatStepError("prepare guest checkout", err, scriptLog.String())
return FormatStepError("prepare guest checkout", err, scriptLog.String())
}
if mode == model.WorkspacePrepareModeMetadataOnly {
return nil
}
var overlayLog bytes.Buffer
command = fmt.Sprintf("tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath))
command = fmt.Sprintf("tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath))
if err := client.StreamTarEntries(ctx, spec.RepoRoot, spec.OverlayPaths, command, &overlayLog); err != nil {
return sess.FormatStepError("overlay workspace working tree", err, overlayLog.String())
return FormatStepError("overlay workspace working tree", err, overlayLog.String())
}
return nil
default:
@ -190,22 +189,22 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g
func FinalizeScript(spec RepoSpec, guestPath string, mode model.WorkspacePrepareMode) string {
var script strings.Builder
script.WriteString("set -euo pipefail\n")
fmt.Fprintf(&script, "DIR=%s\n", sess.ShellQuote(guestPath))
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(guestPath))
script.WriteString("git config --global --add safe.directory \"$DIR\"\n")
if mode != model.WorkspacePrepareModeFullCopy {
script.WriteString("find \"$DIR\" -mindepth 1 -maxdepth 1 ! -name .git -exec rm -rf {} +\n")
}
switch {
case strings.TrimSpace(spec.BranchName) != "":
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", sess.ShellQuote(spec.BranchName), sess.ShellQuote(spec.BaseCommit))
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", ShellQuote(spec.BranchName), ShellQuote(spec.BaseCommit))
case strings.TrimSpace(spec.CurrentBranch) != "":
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", sess.ShellQuote(spec.CurrentBranch), sess.ShellQuote(spec.HeadCommit))
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", ShellQuote(spec.CurrentBranch), ShellQuote(spec.HeadCommit))
default:
fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", sess.ShellQuote(spec.HeadCommit))
fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", ShellQuote(spec.HeadCommit))
}
if strings.TrimSpace(spec.GitUserName) != "" && strings.TrimSpace(spec.GitUserEmail) != "" {
fmt.Fprintf(&script, "git -C \"$DIR\" config user.name %s\n", sess.ShellQuote(spec.GitUserName))
fmt.Fprintf(&script, "git -C \"$DIR\" config user.email %s\n", sess.ShellQuote(spec.GitUserEmail))
fmt.Fprintf(&script, "git -C \"$DIR\" config user.name %s\n", ShellQuote(spec.GitUserName))
fmt.Fprintf(&script, "git -C \"$DIR\" config user.email %s\n", ShellQuote(spec.GitUserEmail))
}
return script.String()
}