remove vm session feature
Cuts the daemon-managed guest-session machinery (start/list/show/
logs/stop/kill/attach/send). The feature shipped aimed at agent-
orchestration workflows (programmatic stdin piping into a long-lived
guest process) that aren't driving any concrete user today, and the
~2.3K LOC of daemon surface area — attach bridge, FIFO keepalive,
controller registry, sessionstream framing, SQLite persistence — was
locking in an API we'd have to keep through v0.1.0.
Anything session-flavoured that people actually need today can be
done with `vm ssh + tmux` or `vm run -- cmd`.
Deleted:
- internal/cli/commands_vm_session.go
- internal/daemon/{guest_sessions,session_lifecycle,session_attach,session_stream,session_controller}.go
- internal/daemon/session/ (guest-session helpers package)
- internal/sessionstream/ (framing package)
- internal/daemon/guest_sessions_test.go
- internal/store/guest_session_test.go
- GuestSession* types from internal/{api,model}
- Store UpsertGuestSession/GetGuestSession/ListGuestSessionsByVM/DeleteGuestSession + scanner helpers
- guest.session.* RPC dispatch entries
- 5 CLI session tests, 2 completion tests, 2 printer tests
Extracted:
- ShellQuote + FormatStepError lifted to internal/daemon/workspace/util.go
(only non-session consumer); workspace package now self-contained
- internal/daemon/guest_ssh.go keeps guestSSHClient + dialGuest +
waitForGuestSSH — still used by workspace prepare/export
- internal/daemon/fake_firecracker_test.go preserves the test helper
that used to live in guest_sessions_test.go
Store schema: CREATE TABLE guest_sessions and its column migrations
removed. Existing dev DBs keep an orphan table (harmless, pre-v0.1.0).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
c42fcbe012
commit
2b6437d1b4
34 changed files with 194 additions and 4031 deletions
|
|
@ -32,13 +32,11 @@ owning types:
|
|||
- `createOps opstate.Registry[*vmCreateOperationState]` — in-flight VM
|
||||
create operations; owns its own lock.
|
||||
- `tapPool tapPool` — TAP interface pool; owns its own lock.
|
||||
- `sessions sessionRegistry` — active guest session controllers; owns
|
||||
its own lock.
|
||||
- `listener`, `vmDNS` — networking.
|
||||
- `vmCaps` — registered VM capability hooks.
|
||||
- `pullAndFlatten`, `finalizePulledRootfs`, `bundleFetch`,
|
||||
`requestHandler`, `guestWaitForSSH`, `guestDial`,
|
||||
`waitForGuestSessionReady` — injectable seams used by tests.
|
||||
`workspaceInspectRepo`, `workspaceImport` — injectable seams used by tests.
|
||||
|
||||
## Subpackages
|
||||
|
||||
|
|
@ -53,11 +51,9 @@ state beyond small test seams.
|
|||
| `internal/daemon/dmsnap` | Device-mapper COW snapshot create/cleanup/remove. |
|
||||
| `internal/daemon/fcproc` | Firecracker process primitives (bridge, tap, binary, PID, kill, wait). |
|
||||
| `internal/daemon/imagemgr` | Image subsystem pure helpers: validators, staging, build script gen. |
|
||||
| `internal/daemon/session` | Guest-session helpers: state paths, scripts, parsing, utilities. |
|
||||
| `internal/daemon/workspace` | Workspace helpers: git inspection, copy prep, guest import script. |
|
||||
|
||||
`workspace` imports `session` for `ShellQuote` and `FormatStepError`; all
|
||||
other subpackages are leaves (no other intra-daemon subpackage imports).
|
||||
All subpackages are leaves — no intra-daemon subpackage imports another.
|
||||
|
||||
## Lock ordering
|
||||
|
||||
|
|
@ -73,9 +69,8 @@ time. `workspace.prepare` acquires `vmLocks[id]` just long enough to
|
|||
validate VM state, releases it, then acquires `workspaceLocks[id]`
|
||||
for the guest I/O phase.
|
||||
|
||||
Subsystem-local locks (`tapPool.mu`, `sessionRegistry.mu`,
|
||||
`opstate.Registry` mu, `guestSessionController.attachMu` /
|
||||
`writeMu`) are leaves. They do not contend with each other.
|
||||
Subsystem-local locks (`tapPool.mu`, `opstate.Registry` mu) are leaves.
|
||||
They do not contend with each other.
|
||||
|
||||
Notes:
|
||||
|
||||
|
|
|
|||
|
|
@ -51,24 +51,22 @@ type Daemon struct {
|
|||
// See internal/daemon/vm_handles.go — persistent durable state
|
||||
// lives in the store, this is rebuildable from a per-VM
|
||||
// handles.json scratch file and OS inspection.
|
||||
handles *handleCache
|
||||
sessions sessionRegistry
|
||||
tapPool tapPool
|
||||
closing chan struct{}
|
||||
once sync.Once
|
||||
pid int
|
||||
listener net.Listener
|
||||
vmDNS *vmdns.Server
|
||||
vmCaps []vmCapability
|
||||
pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error)
|
||||
finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error
|
||||
bundleFetch func(ctx context.Context, destDir string, entry imagecat.CatEntry) (imagecat.Manifest, error)
|
||||
requestHandler func(context.Context, rpc.Request) rpc.Response
|
||||
guestWaitForSSH func(context.Context, string, string, time.Duration) error
|
||||
guestDial func(context.Context, string, string) (guestSSHClient, error)
|
||||
waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error)
|
||||
workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error)
|
||||
workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error
|
||||
handles *handleCache
|
||||
tapPool tapPool
|
||||
closing chan struct{}
|
||||
once sync.Once
|
||||
pid int
|
||||
listener net.Listener
|
||||
vmDNS *vmdns.Server
|
||||
vmCaps []vmCapability
|
||||
pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error)
|
||||
finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error
|
||||
bundleFetch func(ctx context.Context, destDir string, entry imagecat.CatEntry) (imagecat.Manifest, error)
|
||||
requestHandler func(context.Context, rpc.Request) rpc.Response
|
||||
guestWaitForSSH func(context.Context, string, string, time.Duration) error
|
||||
guestDial func(context.Context, string, string) (guestSSHClient, error)
|
||||
workspaceInspectRepo func(ctx context.Context, sourcePath, branchName, fromRef string) (ws.RepoSpec, error)
|
||||
workspaceImport func(ctx context.Context, client ws.GuestClient, spec ws.RepoSpec, guestPath string, mode model.WorkspacePrepareMode) error
|
||||
}
|
||||
|
||||
func Open(ctx context.Context) (d *Daemon, err error) {
|
||||
|
|
@ -93,15 +91,14 @@ func Open(ctx context.Context) (d *Daemon, err error) {
|
|||
return nil, err
|
||||
}
|
||||
d = &Daemon{
|
||||
layout: layout,
|
||||
config: cfg,
|
||||
store: db,
|
||||
runner: system.NewRunner(),
|
||||
logger: logger,
|
||||
closing: make(chan struct{}),
|
||||
pid: os.Getpid(),
|
||||
handles: newHandleCache(),
|
||||
sessions: newSessionRegistry(),
|
||||
layout: layout,
|
||||
config: cfg,
|
||||
store: db,
|
||||
runner: system.NewRunner(),
|
||||
logger: logger,
|
||||
closing: make(chan struct{}),
|
||||
pid: os.Getpid(),
|
||||
handles: newHandleCache(),
|
||||
}
|
||||
// From here on, every failure path must run Close() so the host
|
||||
// state we touched (DNS listener goroutine, resolvectl routing,
|
||||
|
|
@ -144,7 +141,7 @@ func (d *Daemon) Close() error {
|
|||
if d.listener != nil {
|
||||
_ = d.listener.Close()
|
||||
}
|
||||
err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.closeGuestSessionControllers(), d.store.Close())
|
||||
err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.store.Close())
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
|
@ -424,62 +421,6 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
|
|||
}
|
||||
result, err := d.ExportVMWorkspace(ctx, params)
|
||||
return marshalResultOrError(result, err)
|
||||
case "guest.session.start":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionStartParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
session, err := d.StartGuestSession(ctx, params)
|
||||
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
|
||||
case "guest.session.get":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
session, err := d.GetGuestSession(ctx, params)
|
||||
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
|
||||
case "guest.session.list":
|
||||
params, err := rpc.DecodeParams[api.VMRefParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
sessions, err := d.ListGuestSessions(ctx, params)
|
||||
return marshalResultOrError(api.GuestSessionListResult{Sessions: sessions}, err)
|
||||
case "guest.session.stop":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
session, err := d.StopGuestSession(ctx, params)
|
||||
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
|
||||
case "guest.session.kill":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
session, err := d.KillGuestSession(ctx, params)
|
||||
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
|
||||
case "guest.session.logs":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionLogsParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
result, err := d.GuestSessionLogs(ctx, params)
|
||||
return marshalResultOrError(result, err)
|
||||
case "guest.session.attach.begin":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionAttachBeginParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
result, err := d.BeginGuestSessionAttach(ctx, params)
|
||||
return marshalResultOrError(result, err)
|
||||
case "guest.session.send":
|
||||
params, err := rpc.DecodeParams[api.GuestSessionSendParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
result, err := d.SendToGuestSession(ctx, params)
|
||||
return marshalResultOrError(result, err)
|
||||
case "image.list":
|
||||
images, err := d.store.ListImages(ctx)
|
||||
return marshalResultOrError(api.ImageListResult{Images: images}, err)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
// Package daemon hosts the Banger daemon process.
|
||||
//
|
||||
// The daemon exposes a JSON-RPC endpoint over a Unix socket. It owns VM
|
||||
// lifecycle, image management, guest sessions, host networking bootstrap,
|
||||
// and state persistence via internal/store.
|
||||
// lifecycle, image management, host networking bootstrap, and state
|
||||
// persistence via internal/store.
|
||||
//
|
||||
// The package is organised into cohesive groups. Pure stateless helpers for
|
||||
// each group have been lifted into subpackages; orchestrator methods
|
||||
|
|
@ -18,13 +18,9 @@
|
|||
// internal/daemon/imagemgr Image subsystem helpers: path validation,
|
||||
// artifact staging, guest provisioning script
|
||||
// generator, metadata.
|
||||
// internal/daemon/session Guest-session helpers: state paths, runner
|
||||
// / inspect / signal scripts, state snapshot
|
||||
// parsing, launch helpers, ShellQuote,
|
||||
// FormatStepError.
|
||||
// internal/daemon/workspace Workspace helpers: git repo inspection,
|
||||
// shallow copy prep, guest-side import,
|
||||
// finalize script generation.
|
||||
// finalize script generation, shell quoting.
|
||||
//
|
||||
// VM lifecycle (in this package):
|
||||
//
|
||||
|
|
@ -50,11 +46,7 @@
|
|||
//
|
||||
// Guest interaction (in this package):
|
||||
//
|
||||
// guest_sessions.go dialGuest, waitForGuestSSH, refresh/inspect
|
||||
// session_lifecycle.go Start/Stop/Kill/Get/List/signal orchestrators
|
||||
// session_attach.go BeginGuestSessionAttach + bridge/forward/watch
|
||||
// session_stream.go GuestSessionLogs, SendToGuestSession
|
||||
// session_controller.go guestSessionController, sessionRegistry
|
||||
// guest_ssh.go guestSSHClient, dialGuest, waitForGuestSSH
|
||||
// ssh_client_config.go daemon-managed SSH client key material
|
||||
// workspace.go ExportVMWorkspace, PrepareVMWorkspace
|
||||
//
|
||||
|
|
@ -73,10 +65,8 @@
|
|||
//
|
||||
// Lock ordering:
|
||||
//
|
||||
// vmLocks[id] → {createVMMu, imageOpsMu} → subsystem-local locks
|
||||
// vmLocks[id] → workspaceLocks[id] → {createVMMu, imageOpsMu} → subsystem-local locks
|
||||
//
|
||||
// Subsystem-local locks live on their owning type (tapPool.mu,
|
||||
// sessionRegistry.mu, opstate.Registry mu, guestSessionController.attachMu /
|
||||
// writeMu) and do not contend with each other. See ARCHITECTURE.md for
|
||||
// details.
|
||||
// Subsystem-local locks (tapPool.mu, opstate.Registry mu) are leaves and
|
||||
// do not contend with each other. See ARCHITECTURE.md for details.
|
||||
package daemon
|
||||
|
|
|
|||
26
internal/daemon/fake_firecracker_test.go
Normal file
26
internal/daemon/fake_firecracker_test.go
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// startFakeFirecracker launches a bash sleep-loop rewritten to match
|
||||
// the firecracker command line a real process would expose, so
|
||||
// reconcile / handle-cache paths that grep /proc/<pid>/cmdline accept
|
||||
// it as a firecracker process. Killed on test cleanup.
|
||||
func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd {
|
||||
t.Helper()
|
||||
cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("start fake firecracker: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
_, _ = cmd.Process.Wait()
|
||||
}
|
||||
})
|
||||
return cmd
|
||||
}
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"banger/internal/daemon/session"
|
||||
"banger/internal/guest"
|
||||
"banger/internal/model"
|
||||
"banger/internal/system"
|
||||
)
|
||||
|
||||
type guestSSHClient interface {
|
||||
Close() error
|
||||
RunScript(context.Context, string, io.Writer) error
|
||||
RunScriptOutput(context.Context, string) ([]byte, error)
|
||||
UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error
|
||||
StreamTar(context.Context, string, string, io.Writer) error
|
||||
StreamTarEntries(context.Context, string, []string, string, io.Writer) error
|
||||
}
|
||||
|
||||
func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval time.Duration) error {
|
||||
if d != nil && d.guestWaitForSSH != nil {
|
||||
return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
|
||||
}
|
||||
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval)
|
||||
}
|
||||
|
||||
func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) {
|
||||
if d != nil && d.guestDial != nil {
|
||||
return d.guestDial(ctx, address, d.config.SSHKeyPath)
|
||||
}
|
||||
return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
}
|
||||
|
||||
func (d *Daemon) waitForGuestSessionReadyHook(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
|
||||
if d != nil && d.waitForGuestSessionReady != nil {
|
||||
return d.waitForGuestSessionReady(ctx, vm, s)
|
||||
}
|
||||
return d.waitForGuestSessionReadyDefault(ctx, vm, s)
|
||||
}
|
||||
|
||||
func (d *Daemon) waitForGuestSessionReadyDefault(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
|
||||
for {
|
||||
updated, err := d.refreshGuestSession(ctx, vm, s)
|
||||
if err == nil {
|
||||
s = updated
|
||||
if s.GuestPID != 0 || s.ExitCode != nil || s.Status == model.GuestSessionStatusRunning || s.Status == model.GuestSessionStatusFailed || s.Status == model.GuestSessionStatusExited {
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return s, ctx.Err()
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) refreshGuestSession(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) {
|
||||
if s.Status != model.GuestSessionStatusStarting && s.Status != model.GuestSessionStatusRunning && s.Status != model.GuestSessionStatusStopping {
|
||||
return s, nil
|
||||
}
|
||||
snapshot, err := d.inspectGuestSessionState(ctx, vm, s)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
original := s
|
||||
session.ApplyStateSnapshot(&s, snapshot, d.vmAlive(vm))
|
||||
if session.StateChanged(original, s) {
|
||||
s.UpdatedAt = model.Now()
|
||||
if err := d.store.UpsertGuestSession(ctx, s); err != nil {
|
||||
return s, err
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) inspectGuestSessionState(ctx context.Context, vm model.VMRecord, s model.GuestSession) (session.StateSnapshot, error) {
|
||||
if d.vmAlive(vm) {
|
||||
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
if err != nil {
|
||||
return session.StateSnapshot{}, err
|
||||
}
|
||||
defer client.Close()
|
||||
var output bytes.Buffer
|
||||
if err := client.RunScript(ctx, session.InspectScript(s.ID), &output); err != nil {
|
||||
return session.StateSnapshot{}, session.FormatStepError("inspect guest session state", err, output.String())
|
||||
}
|
||||
return session.ParseState(output.String())
|
||||
}
|
||||
return d.inspectGuestSessionStateFromWorkDisk(ctx, vm, s.ID)
|
||||
}
|
||||
|
||||
func (d *Daemon) inspectGuestSessionStateFromWorkDisk(ctx context.Context, vm model.VMRecord, sessionID string) (session.StateSnapshot, error) {
|
||||
runner := d.runner
|
||||
if runner == nil {
|
||||
runner = system.NewRunner()
|
||||
}
|
||||
workMount, cleanup, err := system.MountTempDir(ctx, runner, vm.Runtime.WorkDiskPath, false)
|
||||
if err != nil {
|
||||
return session.StateSnapshot{}, err
|
||||
}
|
||||
defer cleanup()
|
||||
stateDir := filepath.Join(workMount, session.RelativeStateDir(sessionID))
|
||||
return session.InspectStateFromDir(stateDir)
|
||||
}
|
||||
|
||||
func (d *Daemon) findGuestSession(ctx context.Context, vmID, idOrName string) (model.GuestSession, error) {
|
||||
if strings.TrimSpace(idOrName) == "" {
|
||||
return model.GuestSession{}, errors.New("session id or name is required")
|
||||
}
|
||||
if s, err := d.store.GetGuestSession(ctx, vmID, idOrName); err == nil {
|
||||
return s, nil
|
||||
}
|
||||
sessions, err := d.store.ListGuestSessionsByVM(ctx, vmID)
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
matches := make([]model.GuestSession, 0, 1)
|
||||
for _, s := range sessions {
|
||||
if strings.HasPrefix(s.ID, idOrName) || strings.HasPrefix(s.Name, idOrName) {
|
||||
matches = append(matches, s)
|
||||
}
|
||||
}
|
||||
switch len(matches) {
|
||||
case 0:
|
||||
return model.GuestSession{}, fmt.Errorf("session %q not found", idOrName)
|
||||
case 1:
|
||||
return matches[0], nil
|
||||
default:
|
||||
return model.GuestSession{}, fmt.Errorf("multiple sessions match %q", idOrName)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,492 +0,0 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"banger/internal/api"
|
||||
sess "banger/internal/daemon/session"
|
||||
"banger/internal/model"
|
||||
"banger/internal/store"
|
||||
)
|
||||
|
||||
type fakeGuestSSHClient struct {
|
||||
t *testing.T
|
||||
existingDirs map[string]bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) Close() error {
|
||||
f.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error {
|
||||
f.t.Helper()
|
||||
switch {
|
||||
case strings.Contains(script, `\n`):
|
||||
return fmt.Errorf("script still contains escaped newline literals: %q", script)
|
||||
case strings.Contains(script, `echo "missing cwd: $DIR"`):
|
||||
if strings.Contains(script, "DIR='/root/repo'\n") && f.existingDirs["/root/repo"] {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("missing cwd")
|
||||
case strings.Contains(script, "check_command() {"):
|
||||
return nil
|
||||
case strings.Contains(script, `git config --global --add safe.directory "$DIR"`):
|
||||
if strings.Contains(script, "DIR='/root/repo'\n") {
|
||||
f.existingDirs["/root/repo"] = true
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("workspace finalize used unexpected guest path")
|
||||
case strings.Contains(script, "chmod -R a-w"):
|
||||
if f.existingDirs["/root/repo"] {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("workspace path missing during readonly chmod")
|
||||
case strings.Contains(script, "nohup bash "):
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) UploadFile(_ context.Context, _ string, _ os.FileMode, _ []byte, _ io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) StreamTar(_ context.Context, _ string, command string, _ io.Writer) error {
|
||||
if strings.Contains(command, "/root/repo") {
|
||||
f.existingDirs["/root/repo"] = true
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unexpected StreamTar command: %s", command)
|
||||
}
|
||||
|
||||
func (f *fakeGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, command string, _ io.Writer) error {
|
||||
if strings.Contains(command, "/root/repo") {
|
||||
f.existingDirs["/root/repo"] = true
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unexpected StreamTarEntries command: %s", command)
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_HappyPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
apiSock := filepath.Join(t.TempDir(), "fc.sock")
|
||||
firecracker := startFakeFirecracker(t, apiSock)
|
||||
|
||||
vm := testVM("sendbox", "image-send", "172.16.0.88")
|
||||
vm.State = model.VMStateRunning
|
||||
vm.Runtime.State = model.VMStateRunning
|
||||
vm.Runtime.APISockPath = apiSock
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
fake := &recordingGuestSSHClient{}
|
||||
d := newSendTestDaemon(t, db, fake)
|
||||
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
|
||||
|
||||
payload := []byte(`{"type":"abort"}` + "\n")
|
||||
result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: payload,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendToGuestSession: %v", err)
|
||||
}
|
||||
if result.BytesWritten != len(payload) {
|
||||
t.Fatalf("BytesWritten = %d, want %d", result.BytesWritten, len(payload))
|
||||
}
|
||||
if result.Session.ID != session.ID {
|
||||
t.Fatalf("Session.ID = %q, want %q", result.Session.ID, session.ID)
|
||||
}
|
||||
if len(fake.uploadedFiles) != 1 {
|
||||
t.Fatalf("UploadFile call count = %d, want 1", len(fake.uploadedFiles))
|
||||
}
|
||||
for path, data := range fake.uploadedFiles {
|
||||
if !strings.HasPrefix(path, "/tmp/banger-send-") {
|
||||
t.Fatalf("upload path = %q, want /tmp/banger-send-... prefix", path)
|
||||
}
|
||||
if string(data) != string(payload) {
|
||||
t.Fatalf("upload data = %q, want %q", data, payload)
|
||||
}
|
||||
}
|
||||
if len(fake.ranScripts) != 1 {
|
||||
t.Fatalf("RunScript call count = %d, want 1", len(fake.ranScripts))
|
||||
}
|
||||
script := fake.ranScripts[0]
|
||||
pipePath := sess.StdinPipePath(session.ID)
|
||||
if !strings.Contains(script, "cat ") {
|
||||
t.Fatalf("send script missing cat command: %q", script)
|
||||
}
|
||||
if !strings.Contains(script, pipePath) {
|
||||
t.Fatalf("send script missing pipe path %q: %q", pipePath, script)
|
||||
}
|
||||
if !strings.Contains(script, "rm -f ") {
|
||||
t.Fatalf("send script missing rm cleanup: %q", script)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_EmptyPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
apiSock := filepath.Join(t.TempDir(), "fc.sock")
|
||||
firecracker := startFakeFirecracker(t, apiSock)
|
||||
|
||||
vm := testVM("sendbox-empty", "image-send", "172.16.0.89")
|
||||
vm.State = model.VMStateRunning
|
||||
vm.Runtime.State = model.VMStateRunning
|
||||
vm.Runtime.APISockPath = apiSock
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
fake := &recordingGuestSSHClient{}
|
||||
d := newSendTestDaemon(t, db, fake)
|
||||
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
|
||||
|
||||
result, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: nil,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendToGuestSession(empty): %v", err)
|
||||
}
|
||||
if result.BytesWritten != 0 {
|
||||
t.Fatalf("BytesWritten = %d, want 0", result.BytesWritten)
|
||||
}
|
||||
if fake.dialCount != 0 {
|
||||
t.Fatalf("SSH dial count = %d, want 0 for empty payload", fake.dialCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_NotPipeMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
vm := testVM("sendbox-closed", "image-send", "172.16.0.90")
|
||||
vm.State = model.VMStateRunning
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinClosed, model.GuestSessionStatusRunning)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
d := &Daemon{store: db}
|
||||
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: []byte("hello\n"),
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "stdin pipe") {
|
||||
t.Fatalf("error = %v, want 'stdin pipe' error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_SessionNotRunning(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
vm := testVM("sendbox-failed", "image-send", "172.16.0.91")
|
||||
vm.State = model.VMStateRunning
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusFailed)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
d := &Daemon{store: db}
|
||||
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: []byte("hello\n"),
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "not running") {
|
||||
t.Fatalf("error = %v, want 'not running' error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToGuestSession_VMNotRunning(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
vm := testVM("sendbox-stopped", "image-send", "172.16.0.92")
|
||||
vm.State = model.VMStateStopped
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
session := testGuestSession(vm.ID, model.GuestSessionStdinPipe, model.GuestSessionStatusRunning)
|
||||
if err := db.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
d := &Daemon{store: db}
|
||||
_, err := d.SendToGuestSession(ctx, api.GuestSessionSendParams{
|
||||
VMIDOrName: vm.Name,
|
||||
SessionIDOrName: session.Name,
|
||||
Payload: []byte("hello\n"),
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "not running") {
|
||||
t.Fatalf("error = %v, want 'not running' error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// recordingGuestSSHClient captures UploadFile and RunScript calls for send tests.
|
||||
type recordingGuestSSHClient struct {
|
||||
dialCount int
|
||||
uploadedFiles map[string][]byte
|
||||
ranScripts []string
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) Close() error { return nil }
|
||||
|
||||
func (r *recordingGuestSSHClient) UploadFile(_ context.Context, path string, _ os.FileMode, data []byte, _ io.Writer) error {
|
||||
if r.uploadedFiles == nil {
|
||||
r.uploadedFiles = make(map[string][]byte)
|
||||
}
|
||||
copy := make([]byte, len(data))
|
||||
_ = copy[:len(data):len(data)]
|
||||
for i, b := range data {
|
||||
copy[i] = b
|
||||
}
|
||||
r.uploadedFiles[path] = copy
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) RunScript(_ context.Context, script string, _ io.Writer) error {
|
||||
r.ranScripts = append(r.ranScripts, script)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) RunScriptOutput(_ context.Context, _ string) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) StreamTar(_ context.Context, _ string, _ string, _ io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingGuestSSHClient) StreamTarEntries(_ context.Context, _ string, _ []string, _ string, _ io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newSendTestDaemon(t *testing.T, db *store.Store, fake *recordingGuestSSHClient) *Daemon {
|
||||
t.Helper()
|
||||
d := &Daemon{
|
||||
store: db,
|
||||
config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")},
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
}
|
||||
d.guestDial = func(_ context.Context, _ string, _ string) (guestSSHClient, error) {
|
||||
fake.dialCount++
|
||||
return fake, nil
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func testGuestSession(vmID string, stdinMode model.GuestSessionStdinMode, status model.GuestSessionStatus) model.GuestSession {
|
||||
now := model.Now()
|
||||
id := vmID + "-sess-id"
|
||||
return model.GuestSession{
|
||||
ID: id,
|
||||
VMID: vmID,
|
||||
Name: vmID + "-sess",
|
||||
Backend: sess.BackendSSH,
|
||||
Command: "pi",
|
||||
Args: []string{"--mode", "rpc"},
|
||||
CWD: "/root/repo",
|
||||
StdinMode: stdinMode,
|
||||
Status: status,
|
||||
GuestStateDir: sess.StateDir(id),
|
||||
StdoutLogPath: sess.StdoutLogPath(id),
|
||||
StderrLogPath: sess.StderrLogPath(id),
|
||||
Attachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning,
|
||||
Reattachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd {
|
||||
t.Helper()
|
||||
cmd := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("start fake firecracker: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
_, _ = cmd.Process.Wait()
|
||||
}
|
||||
})
|
||||
return cmd
|
||||
}
|
||||
|
||||
func TestGuestSessionPreflightScriptsUseRealNewlines(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cwdScript := sess.CWDPreflightScript("/root/repo")
|
||||
if strings.Contains(cwdScript, `\n`) {
|
||||
t.Fatalf("cwd preflight script still contains escaped newline literals: %q", cwdScript)
|
||||
}
|
||||
if !strings.Contains(cwdScript, "\n") {
|
||||
t.Fatalf("cwd preflight script should contain real newlines: %q", cwdScript)
|
||||
}
|
||||
|
||||
commandScript := sess.CommandPreflightScript([]string{"git", "pi"})
|
||||
if strings.Contains(commandScript, `\n`) {
|
||||
t.Fatalf("command preflight script still contains escaped newline literals: %q", commandScript)
|
||||
}
|
||||
if !strings.Contains(commandScript, "\n") {
|
||||
t.Fatalf("command preflight script should contain real newlines: %q", commandScript)
|
||||
}
|
||||
|
||||
attachInput := sess.AttachInputCommand("session-id")
|
||||
if strings.Contains(attachInput, `\n`) {
|
||||
t.Fatalf("attach input command still contains escaped newline literals: %q", attachInput)
|
||||
}
|
||||
|
||||
attachTail := sess.AttachTailCommand("/tmp/stdout.log")
|
||||
if strings.Contains(attachTail, `\n`) {
|
||||
t.Fatalf("attach tail command still contains escaped newline literals: %q", attachTail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareWorkspaceThenStartGuestSessionPassesCWDPreflight(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
||||
repoRoot := filepath.Join(t.TempDir(), "repo")
|
||||
if err := os.MkdirAll(repoRoot, 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(repoRoot): %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(repoRoot, "README.md"), []byte("hello\n"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(README.md): %v", err)
|
||||
}
|
||||
runGit := func(args ...string) {
|
||||
t.Helper()
|
||||
cmd := exec.Command("git", append([]string{"-C", repoRoot}, args...)...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("git %s: %v\n%s", strings.Join(args, " "), err, output)
|
||||
}
|
||||
}
|
||||
runGit("init", "-b", "main")
|
||||
runGit("config", "user.name", "Test User")
|
||||
runGit("config", "user.email", "test@example.com")
|
||||
runGit("add", ".")
|
||||
runGit("commit", "-m", "initial")
|
||||
|
||||
apiSock := filepath.Join(t.TempDir(), "fc.sock")
|
||||
firecracker := exec.Command("bash", "-lc", fmt.Sprintf("exec -a %q sleep 60", "firecracker --api-sock "+apiSock))
|
||||
if err := firecracker.Start(); err != nil {
|
||||
t.Fatalf("start fake firecracker: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if firecracker.Process != nil {
|
||||
_ = firecracker.Process.Kill()
|
||||
_, _ = firecracker.Process.Wait()
|
||||
}
|
||||
})
|
||||
|
||||
vm := testVM("pi-devbox", "image-pi", "172.16.0.77")
|
||||
vm.State = model.VMStateRunning
|
||||
vm.Runtime.State = model.VMStateRunning
|
||||
vm.Runtime.APISockPath = apiSock
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
fakeClient := &fakeGuestSSHClient{t: t, existingDirs: map[string]bool{}}
|
||||
d := &Daemon{
|
||||
store: db,
|
||||
config: model.DaemonConfig{SSHKeyPath: filepath.Join(t.TempDir(), "id_ed25519")},
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
}
|
||||
d.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: firecracker.Process.Pid})
|
||||
d.guestWaitForSSH = func(context.Context, string, string, time.Duration) error { return nil }
|
||||
d.guestDial = func(context.Context, string, string) (guestSSHClient, error) { return fakeClient, nil }
|
||||
d.waitForGuestSessionReady = func(_ context.Context, _ model.VMRecord, session model.GuestSession) (model.GuestSession, error) {
|
||||
now := model.Now()
|
||||
session.Status = model.GuestSessionStatusRunning
|
||||
session.GuestPID = 4242
|
||||
session.StartedAt = now
|
||||
session.UpdatedAt = now
|
||||
session.Attachable = session.StdinMode == model.GuestSessionStdinPipe
|
||||
session.Reattachable = session.StdinMode == model.GuestSessionStdinPipe
|
||||
return session, nil
|
||||
}
|
||||
|
||||
workspace, err := d.PrepareVMWorkspace(ctx, api.VMWorkspacePrepareParams{
|
||||
IDOrName: vm.Name,
|
||||
SourcePath: repoRoot,
|
||||
GuestPath: "/root/repo",
|
||||
ReadOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PrepareVMWorkspace: %v", err)
|
||||
}
|
||||
if workspace.GuestPath != "/root/repo" {
|
||||
t.Fatalf("PrepareVMWorkspace guest path = %q, want /root/repo", workspace.GuestPath)
|
||||
}
|
||||
if !fakeClient.existingDirs["/root/repo"] {
|
||||
t.Fatalf("workspace prepare did not mark /root/repo as present in fake guest")
|
||||
}
|
||||
|
||||
session, err := d.StartGuestSession(ctx, api.GuestSessionStartParams{
|
||||
VMIDOrName: vm.Name,
|
||||
Name: "testpi",
|
||||
Command: "pi",
|
||||
Args: []string{"--mode", "rpc", "--no-session"},
|
||||
CWD: "/root/repo",
|
||||
StdinMode: string(model.GuestSessionStdinPipe),
|
||||
RequiredCommands: []string{"git"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartGuestSession: %v", err)
|
||||
}
|
||||
if session.Status != model.GuestSessionStatusRunning {
|
||||
t.Fatalf("session status = %q, want %q", session.Status, model.GuestSessionStatusRunning)
|
||||
}
|
||||
if session.LaunchStage != "" {
|
||||
t.Fatalf("session launch stage = %q, want empty", session.LaunchStage)
|
||||
}
|
||||
if session.LaunchMessage != "" {
|
||||
t.Fatalf("session launch message = %q, want empty", session.LaunchMessage)
|
||||
}
|
||||
if session.GuestPID == 0 {
|
||||
t.Fatalf("session guest pid = 0, want non-zero")
|
||||
}
|
||||
if !session.Attachable {
|
||||
t.Fatalf("session should be attachable for pipe stdin mode")
|
||||
}
|
||||
}
|
||||
35
internal/daemon/guest_ssh.go
Normal file
35
internal/daemon/guest_ssh.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"banger/internal/guest"
|
||||
)
|
||||
|
||||
// guestSSHClient is the narrow guest-SSH surface the daemon uses for
|
||||
// workspace prepare / export and ad-hoc guest interactions.
|
||||
type guestSSHClient interface {
|
||||
Close() error
|
||||
RunScript(context.Context, string, io.Writer) error
|
||||
RunScriptOutput(context.Context, string) ([]byte, error)
|
||||
UploadFile(context.Context, string, os.FileMode, []byte, io.Writer) error
|
||||
StreamTar(context.Context, string, string, io.Writer) error
|
||||
StreamTarEntries(context.Context, string, []string, string, io.Writer) error
|
||||
}
|
||||
|
||||
func (d *Daemon) waitForGuestSSH(ctx context.Context, address string, interval time.Duration) error {
|
||||
if d != nil && d.guestWaitForSSH != nil {
|
||||
return d.guestWaitForSSH(ctx, address, d.config.SSHKeyPath, interval)
|
||||
}
|
||||
return guest.WaitForSSH(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath, interval)
|
||||
}
|
||||
|
||||
func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, error) {
|
||||
if d != nil && d.guestDial != nil {
|
||||
return d.guestDial(ctx, address, d.config.SSHKeyPath)
|
||||
}
|
||||
return guest.Dial(ctx, address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
}
|
||||
|
|
@ -26,10 +26,9 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) {
|
|||
name: "only store + closing channel (early failure)",
|
||||
build: func(t *testing.T) *Daemon {
|
||||
return &Daemon{
|
||||
store: openDaemonStore(t),
|
||||
closing: make(chan struct{}),
|
||||
sessions: newSessionRegistry(),
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
store: openDaemonStore(t),
|
||||
closing: make(chan struct{}),
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
}
|
||||
},
|
||||
verify: func(t *testing.T, d *Daemon) {
|
||||
|
|
@ -49,11 +48,10 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) {
|
|||
t.Fatalf("vmdns.New: %v", err)
|
||||
}
|
||||
return &Daemon{
|
||||
store: openDaemonStore(t),
|
||||
closing: make(chan struct{}),
|
||||
sessions: newSessionRegistry(),
|
||||
vmDNS: server,
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
store: openDaemonStore(t),
|
||||
closing: make(chan struct{}),
|
||||
vmDNS: server,
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
}
|
||||
},
|
||||
verify: func(t *testing.T, d *Daemon) {
|
||||
|
|
@ -86,11 +84,10 @@ func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) {
|
|||
// returns and also calls Close afterwards, both paths must survive.
|
||||
func TestCloseIdempotentUnderConcurrency(t *testing.T) {
|
||||
d := &Daemon{
|
||||
store: openDaemonStore(t),
|
||||
closing: make(chan struct{}),
|
||||
sessions: newSessionRegistry(),
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
config: model.DaemonConfig{BridgeName: ""},
|
||||
store: openDaemonStore(t),
|
||||
closing: make(chan struct{}),
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
config: model.DaemonConfig{BridgeName: ""},
|
||||
}
|
||||
|
||||
var count atomic.Int32
|
||||
|
|
|
|||
|
|
@ -1,521 +0,0 @@
|
|||
// Package session contains the pure helpers of the guest-session subsystem:
|
||||
// bash script generators, on-guest state path helpers, state snapshot
|
||||
// parsing, and small utilities like ShellQuote and FormatStepError.
|
||||
//
|
||||
// The orchestrator methods (StartGuestSession, BeginGuestSessionAttach,
|
||||
// etc.) stay on *daemon.Daemon and compose these helpers.
|
||||
package session
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"banger/internal/model"
|
||||
"banger/internal/system"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Constants shared between orchestration and helpers.
|
||||
const (
|
||||
BackendSSH = "ssh"
|
||||
AttachBackendNone = "none"
|
||||
AttachBackendSSHBridge = "ssh_rehydratable"
|
||||
AttachModeExclusive = "exclusive"
|
||||
TransportUnixSocket = "unix_socket"
|
||||
StateRoot = "/root/.local/state/banger/sessions"
|
||||
LogTailLineDefault = 200
|
||||
)
|
||||
|
||||
// StateSnapshot is the decoded per-session state as read from the guest.
|
||||
type StateSnapshot struct {
|
||||
Status string
|
||||
GuestPID int
|
||||
ExitCode *int
|
||||
Alive bool
|
||||
LastError string
|
||||
}
|
||||
|
||||
// -- Guest filesystem paths -------------------------------------------------
|
||||
|
||||
func StateDir(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateRoot, id))
|
||||
}
|
||||
|
||||
func RelativeStateDir(id string) string {
|
||||
return strings.TrimPrefix(StateDir(id), "/root/")
|
||||
}
|
||||
|
||||
func ScriptPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "run.sh")) }
|
||||
func PIDPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "pid")) }
|
||||
func MonitorPIDPath(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateDir(id), "monitor_pid"))
|
||||
}
|
||||
func ExitCodePath(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateDir(id), "exit_code"))
|
||||
}
|
||||
func StdinPipePath(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateDir(id), "stdin.pipe"))
|
||||
}
|
||||
func StdinKeepalivePIDPath(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateDir(id), "stdin_keepalive.pid"))
|
||||
}
|
||||
func StatusPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "status")) }
|
||||
func ErrorPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "error")) }
|
||||
func StdoutLogPath(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateDir(id), "stdout.log"))
|
||||
}
|
||||
func StderrLogPath(id string) string {
|
||||
return filepath.ToSlash(filepath.Join(StateDir(id), "stderr.log"))
|
||||
}
|
||||
|
||||
// -- Script generators ------------------------------------------------------
|
||||
|
||||
// Script returns the bash runner installed into the guest for session. It
|
||||
// sets up state/log paths, optional stdin fifo, and wait-loop around the
|
||||
// user command.
|
||||
func Script(sess model.GuestSession) string {
|
||||
var script strings.Builder
|
||||
script.WriteString("set -euo pipefail\n")
|
||||
fmt.Fprintf(&script, "STATE_DIR=%s\n", ShellQuote(sess.GuestStateDir))
|
||||
fmt.Fprintf(&script, "STDOUT_LOG=%s\n", ShellQuote(sess.StdoutLogPath))
|
||||
fmt.Fprintf(&script, "STDERR_LOG=%s\n", ShellQuote(sess.StderrLogPath))
|
||||
fmt.Fprintf(&script, "PID_FILE=%s\n", ShellQuote(PIDPath(sess.ID)))
|
||||
fmt.Fprintf(&script, "MONITOR_PID_FILE=%s\n", ShellQuote(MonitorPIDPath(sess.ID)))
|
||||
fmt.Fprintf(&script, "EXIT_FILE=%s\n", ShellQuote(ExitCodePath(sess.ID)))
|
||||
fmt.Fprintf(&script, "STATUS_FILE=%s\n", ShellQuote(StatusPath(sess.ID)))
|
||||
fmt.Fprintf(&script, "ERROR_FILE=%s\n", ShellQuote(ErrorPath(sess.ID)))
|
||||
fmt.Fprintf(&script, "STDIN_PIPE=%s\n", ShellQuote(StdinPipePath(sess.ID)))
|
||||
fmt.Fprintf(&script, "STDIN_KEEPALIVE_PID_FILE=%s\n", ShellQuote(StdinKeepalivePIDPath(sess.ID)))
|
||||
fmt.Fprintf(&script, "SESSION_CWD=%s\n", ShellQuote(DefaultCWD(sess.CWD)))
|
||||
script.WriteString("mkdir -p \"$STATE_DIR\"\n")
|
||||
script.WriteString(": >\"$STDOUT_LOG\"\n")
|
||||
script.WriteString(": >\"$STDERR_LOG\"\n")
|
||||
script.WriteString("rm -f \"$EXIT_FILE\" \"$ERROR_FILE\" \"$STDIN_KEEPALIVE_PID_FILE\"\n")
|
||||
if sess.StdinMode == model.GuestSessionStdinPipe {
|
||||
script.WriteString("rm -f \"$STDIN_PIPE\"\n")
|
||||
script.WriteString("mkfifo -m 600 \"$STDIN_PIPE\"\n")
|
||||
}
|
||||
script.WriteString("printf '%s\\n' \"${BASHPID:-$$}\" >\"$MONITOR_PID_FILE\"\n")
|
||||
script.WriteString("printf 'starting\\n' >\"$STATUS_FILE\"\n")
|
||||
script.WriteString("cd \"$SESSION_CWD\"\n")
|
||||
script.WriteString("exec > >(tee -a \"$STDOUT_LOG\") 2> >(tee -a \"$STDERR_LOG\" >&2)\n")
|
||||
for _, line := range EnvLines(sess.Env) {
|
||||
script.WriteString(line)
|
||||
script.WriteByte('\n')
|
||||
}
|
||||
script.WriteString("COMMAND=(")
|
||||
for _, value := range append([]string{sess.Command}, sess.Args...) {
|
||||
script.WriteByte(' ')
|
||||
script.WriteString(ShellQuote(value))
|
||||
}
|
||||
script.WriteString(" )\n")
|
||||
if sess.StdinMode == model.GuestSessionStdinPipe {
|
||||
script.WriteString("( while :; do sleep 3600; done ) >\"$STDIN_PIPE\" &\n")
|
||||
script.WriteString("keepalive=$!\n")
|
||||
script.WriteString("printf '%s\\n' \"$keepalive\" >\"$STDIN_KEEPALIVE_PID_FILE\"\n")
|
||||
script.WriteString("\"${COMMAND[@]}\" <\"$STDIN_PIPE\" &\n")
|
||||
} else {
|
||||
script.WriteString("\"${COMMAND[@]}\" &\n")
|
||||
}
|
||||
script.WriteString("child=$!\n")
|
||||
script.WriteString("printf '%s\\n' \"$child\" >\"$PID_FILE\"\n")
|
||||
script.WriteString("printf 'running\\n' >\"$STATUS_FILE\"\n")
|
||||
script.WriteString("wait \"$child\"\n")
|
||||
script.WriteString("rc=$?\n")
|
||||
if sess.StdinMode == model.GuestSessionStdinPipe {
|
||||
script.WriteString("if [ -f \"$STDIN_KEEPALIVE_PID_FILE\" ]; then kill \"$(cat \"$STDIN_KEEPALIVE_PID_FILE\")\" 2>/dev/null || true; fi\n")
|
||||
}
|
||||
script.WriteString("printf '%s\\n' \"$rc\" >\"$EXIT_FILE\"\n")
|
||||
script.WriteString("if [ \"$rc\" -eq 0 ]; then printf 'exited\\n' >\"$STATUS_FILE\"; else printf 'failed\\n' >\"$STATUS_FILE\"; fi\n")
|
||||
script.WriteString("exit \"$rc\"\n")
|
||||
return script.String()
|
||||
}
|
||||
|
||||
// InspectScript reads the on-guest state files for sessionID and prints a
|
||||
// key=value block parseable by ParseState.
|
||||
func InspectScript(sessionID string) string {
|
||||
var script strings.Builder
|
||||
script.WriteString("set -euo pipefail\n")
|
||||
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(StateDir(sessionID)))
|
||||
script.WriteString("status=''\n")
|
||||
script.WriteString("pid=''\n")
|
||||
script.WriteString("exit_code=''\n")
|
||||
script.WriteString("last_error=''\n")
|
||||
script.WriteString("alive=false\n")
|
||||
script.WriteString("[ -f \"$DIR/status\" ] && status=\"$(cat \"$DIR/status\")\"\n")
|
||||
script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n")
|
||||
script.WriteString("[ -f \"$DIR/exit_code\" ] && exit_code=\"$(cat \"$DIR/exit_code\")\"\n")
|
||||
script.WriteString("[ -f \"$DIR/error\" ] && last_error=\"$(cat \"$DIR/error\")\"\n")
|
||||
script.WriteString("if [ -n \"$pid\" ] && kill -0 \"$pid\" 2>/dev/null; then alive=true; fi\n")
|
||||
script.WriteString("printf 'status=%s\\n' \"$status\"\n")
|
||||
script.WriteString("printf 'pid=%s\\n' \"$pid\"\n")
|
||||
script.WriteString("printf 'exit=%s\\n' \"$exit_code\"\n")
|
||||
script.WriteString("printf 'alive=%s\\n' \"$alive\"\n")
|
||||
script.WriteString("printf 'error=%s\\n' \"$last_error\"\n")
|
||||
return script.String()
|
||||
}
|
||||
|
||||
// SignalScript sends signal to sessionID's runner and monitor processes.
|
||||
func SignalScript(sessionID, signal string) string {
|
||||
var script strings.Builder
|
||||
script.WriteString("set -euo pipefail\n")
|
||||
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(StateDir(sessionID)))
|
||||
fmt.Fprintf(&script, "SIGNAL=%s\n", ShellQuote(signal))
|
||||
script.WriteString("pid=''\n")
|
||||
script.WriteString("monitor=''\n")
|
||||
script.WriteString("keepalive=''\n")
|
||||
script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n")
|
||||
script.WriteString("[ -f \"$DIR/monitor_pid\" ] && monitor=\"$(cat \"$DIR/monitor_pid\")\"\n")
|
||||
script.WriteString("[ -f \"$DIR/stdin_keepalive.pid\" ] && keepalive=\"$(cat \"$DIR/stdin_keepalive.pid\")\"\n")
|
||||
script.WriteString("printf 'stopping\\n' >\"$DIR/status\"\n")
|
||||
script.WriteString("if [ -n \"$pid\" ]; then kill -${SIGNAL} \"$pid\" 2>/dev/null || true; fi\n")
|
||||
script.WriteString("if [ -n \"$monitor\" ]; then kill -${SIGNAL} \"$monitor\" 2>/dev/null || true; fi\n")
|
||||
script.WriteString("if [ -n \"$keepalive\" ]; then kill -${SIGNAL} \"$keepalive\" 2>/dev/null || true; fi\n")
|
||||
return script.String()
|
||||
}
|
||||
|
||||
// CWDPreflightScript verifies cwd exists on the guest.
|
||||
func CWDPreflightScript(cwd string) string {
|
||||
var script strings.Builder
|
||||
script.WriteString("set -euo pipefail\n")
|
||||
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(DefaultCWD(cwd)))
|
||||
script.WriteString("if [ ! -d \"$DIR\" ]; then echo \"missing cwd: $DIR\"; exit 1; fi\n")
|
||||
return script.String()
|
||||
}
|
||||
|
||||
// CommandPreflightScript verifies each command is resolvable on the guest.
|
||||
func CommandPreflightScript(commands []string) string {
|
||||
var script strings.Builder
|
||||
script.WriteString("set -euo pipefail\n")
|
||||
script.WriteString("check_command() {\n")
|
||||
script.WriteString(" cmd=\"$1\"\n")
|
||||
script.WriteString(" case \"$cmd\" in\n")
|
||||
script.WriteString(" */*) [ -x \"$cmd\" ] || { echo \"missing command: $cmd\"; exit 1; } ;;\n")
|
||||
script.WriteString(" *) command -v \"$cmd\" >/dev/null 2>&1 || { echo \"missing command: $cmd\"; exit 1; } ;;\n")
|
||||
script.WriteString(" esac\n")
|
||||
script.WriteString("}\n")
|
||||
for _, command := range commands {
|
||||
fmt.Fprintf(&script, "check_command %s\n", ShellQuote(command))
|
||||
}
|
||||
return script.String()
|
||||
}
|
||||
|
||||
// AttachInputCommand returns the guest command that creates/opens the stdin
|
||||
// fifo for sessionID and cats attach-side bytes into it.
|
||||
func AttachInputCommand(sessionID string) string {
|
||||
path := StdinPipePath(sessionID)
|
||||
return "bash -lc " + ShellQuote(fmt.Sprintf("set -euo pipefail\n[ -p %s ] || mkfifo -m 600 %s\nexec cat > %s\n", ShellQuote(path), ShellQuote(path), ShellQuote(path)))
|
||||
}
|
||||
|
||||
// AttachTailCommand returns the guest command that tails a log file and
|
||||
// streams new content back to the attach bridge.
|
||||
func AttachTailCommand(path string) string {
|
||||
return "bash -lc " + ShellQuote(fmt.Sprintf("set -euo pipefail\ntouch %s\nexec tail -n 0 -F %s 2>/dev/null\n", ShellQuote(path), ShellQuote(path)))
|
||||
}
|
||||
|
||||
// EnvLines returns deterministic `export KEY=value` lines for the session
|
||||
// launcher, ordered by key.
|
||||
func EnvLines(values map[string]string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
lines := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
lines = append(lines, "export "+key+"="+ShellQuote(values[key]))
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
// -- State snapshot helpers -------------------------------------------------
|
||||
|
||||
// ParseState decodes the key=value output produced by InspectScript.
|
||||
func ParseState(raw string) (StateSnapshot, error) {
|
||||
var snapshot StateSnapshot
|
||||
scanner := bufio.NewScanner(strings.NewReader(raw))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
key, value, ok := strings.Cut(line, "=")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch strings.TrimSpace(key) {
|
||||
case "status":
|
||||
snapshot.Status = strings.TrimSpace(value)
|
||||
case "pid":
|
||||
if pid, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
|
||||
snapshot.GuestPID = pid
|
||||
}
|
||||
case "exit":
|
||||
if exitCode, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
|
||||
snapshot.ExitCode = &exitCode
|
||||
}
|
||||
case "alive":
|
||||
snapshot.Alive = strings.TrimSpace(value) == "true"
|
||||
case "error":
|
||||
snapshot.LastError = strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return snapshot, scanner.Err()
|
||||
}
|
||||
|
||||
// InspectStateFromDir reads the state files directly from stateDir (used
|
||||
// when the guest is offline and we can mount the work disk from the host).
|
||||
func InspectStateFromDir(stateDir string) (StateSnapshot, error) {
|
||||
var snapshot StateSnapshot
|
||||
statusData, _ := os.ReadFile(filepath.Join(stateDir, "status"))
|
||||
snapshot.Status = strings.TrimSpace(string(statusData))
|
||||
pidData, _ := os.ReadFile(filepath.Join(stateDir, "pid"))
|
||||
if pidValue, err := strconv.Atoi(strings.TrimSpace(string(pidData))); err == nil {
|
||||
snapshot.GuestPID = pidValue
|
||||
}
|
||||
exitData, _ := os.ReadFile(filepath.Join(stateDir, "exit_code"))
|
||||
if exitValue, err := strconv.Atoi(strings.TrimSpace(string(exitData))); err == nil {
|
||||
snapshot.ExitCode = &exitValue
|
||||
}
|
||||
errorData, _ := os.ReadFile(filepath.Join(stateDir, "error"))
|
||||
snapshot.LastError = strings.TrimSpace(string(errorData))
|
||||
if snapshot.GuestPID != 0 {
|
||||
snapshot.Alive = ProcessAlive(snapshot.GuestPID)
|
||||
}
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
// ApplyStateSnapshot mutates sess in place to reflect snapshot. vmRunning
|
||||
// captures whether the VM is currently up so stale in-flight sessions can be
|
||||
// failed when the VM is gone.
|
||||
func ApplyStateSnapshot(sess *model.GuestSession, snapshot StateSnapshot, vmRunning bool) {
|
||||
if sess == nil {
|
||||
return
|
||||
}
|
||||
if snapshot.GuestPID != 0 {
|
||||
sess.GuestPID = snapshot.GuestPID
|
||||
}
|
||||
if snapshot.LastError != "" {
|
||||
sess.LastError = snapshot.LastError
|
||||
}
|
||||
if snapshot.ExitCode != nil {
|
||||
sess.ExitCode = snapshot.ExitCode
|
||||
sess.Attachable = false
|
||||
sess.Reattachable = false
|
||||
if sess.StartedAt.IsZero() {
|
||||
sess.StartedAt = model.Now()
|
||||
}
|
||||
if sess.EndedAt.IsZero() {
|
||||
sess.EndedAt = model.Now()
|
||||
}
|
||||
if *snapshot.ExitCode == 0 {
|
||||
sess.Status = model.GuestSessionStatusExited
|
||||
} else {
|
||||
sess.Status = model.GuestSessionStatusFailed
|
||||
}
|
||||
return
|
||||
}
|
||||
if snapshot.Alive {
|
||||
if sess.StartedAt.IsZero() {
|
||||
sess.StartedAt = model.Now()
|
||||
}
|
||||
sess.Status = model.GuestSessionStatusRunning
|
||||
return
|
||||
}
|
||||
if !vmRunning && (sess.Status == model.GuestSessionStatusStarting || sess.Status == model.GuestSessionStatusRunning || sess.Status == model.GuestSessionStatusStopping) {
|
||||
sess.Status = model.GuestSessionStatusFailed
|
||||
sess.Attachable = false
|
||||
sess.Reattachable = false
|
||||
if sess.LastError == "" {
|
||||
sess.LastError = "vm is not running"
|
||||
}
|
||||
if sess.EndedAt.IsZero() {
|
||||
sess.EndedAt = model.Now()
|
||||
}
|
||||
return
|
||||
}
|
||||
if snapshot.Status == string(model.GuestSessionStatusRunning) {
|
||||
if sess.StartedAt.IsZero() {
|
||||
sess.StartedAt = model.Now()
|
||||
}
|
||||
sess.Status = model.GuestSessionStatusRunning
|
||||
}
|
||||
if sess.Status == model.GuestSessionStatusRunning && sess.StdinMode == model.GuestSessionStdinPipe {
|
||||
sess.Attachable = true
|
||||
sess.Reattachable = true
|
||||
if sess.AttachBackend == "" {
|
||||
sess.AttachBackend = AttachBackendSSHBridge
|
||||
}
|
||||
if sess.AttachMode == "" {
|
||||
sess.AttachMode = AttachModeExclusive
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StateChanged reports whether any materially observable field differs
|
||||
// between before and after, guiding whether to persist an update.
|
||||
func StateChanged(before, after model.GuestSession) bool {
|
||||
if before.Status != after.Status || before.GuestPID != after.GuestPID || before.LastError != after.LastError || before.Attachable != after.Attachable || before.Reattachable != after.Reattachable || before.AttachBackend != after.AttachBackend || before.AttachMode != after.AttachMode || before.LaunchStage != after.LaunchStage || before.LaunchMessage != after.LaunchMessage || before.LaunchRawLog != after.LaunchRawLog {
|
||||
return true
|
||||
}
|
||||
if before.StartedAt != after.StartedAt || before.EndedAt != after.EndedAt {
|
||||
return true
|
||||
}
|
||||
switch {
|
||||
case before.ExitCode == nil && after.ExitCode == nil:
|
||||
return false
|
||||
case before.ExitCode == nil || after.ExitCode == nil:
|
||||
return true
|
||||
default:
|
||||
return *before.ExitCode != *after.ExitCode
|
||||
}
|
||||
}
|
||||
|
||||
// -- Launch helpers ---------------------------------------------------------
|
||||
|
||||
// DefaultName returns a friendly session name: caller-provided if non-empty,
|
||||
// otherwise `<command-base>-<short-id>`.
|
||||
func DefaultName(id, command, explicit string) string {
|
||||
if trimmed := strings.TrimSpace(explicit); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
base := filepath.Base(strings.TrimSpace(command))
|
||||
if base == "." || base == string(filepath.Separator) || base == "" {
|
||||
base = "session"
|
||||
}
|
||||
return base + "-" + system.ShortID(id)
|
||||
}
|
||||
|
||||
// DefaultCWD returns value if non-empty, else /root.
|
||||
func DefaultCWD(value string) string {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
return "/root"
|
||||
}
|
||||
|
||||
// FailLaunch annotates sess as launch-failed with stage/message/raw log and
|
||||
// returns it for persistence.
|
||||
func FailLaunch(sess model.GuestSession, stage, message, rawLog string) model.GuestSession {
|
||||
now := model.Now()
|
||||
sess.Status = model.GuestSessionStatusFailed
|
||||
sess.LastError = strings.TrimSpace(message)
|
||||
sess.Attachable = false
|
||||
sess.Reattachable = false
|
||||
sess.LaunchStage = strings.TrimSpace(stage)
|
||||
sess.LaunchMessage = strings.TrimSpace(message)
|
||||
sess.LaunchRawLog = strings.TrimSpace(rawLog)
|
||||
sess.UpdatedAt = now
|
||||
sess.EndedAt = now
|
||||
return sess
|
||||
}
|
||||
|
||||
// NormalizeRequiredCommands returns a de-duplicated, order-preserving list
|
||||
// of required commands, with the session command first.
|
||||
func NormalizeRequiredCommands(command string, extras []string) []string {
|
||||
ordered := make([]string, 0, len(extras)+1)
|
||||
seen := map[string]struct{}{}
|
||||
appendValue := func(value string) {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[trimmed]; ok {
|
||||
return
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
ordered = append(ordered, trimmed)
|
||||
}
|
||||
appendValue(command)
|
||||
for _, extra := range extras {
|
||||
appendValue(extra)
|
||||
}
|
||||
return ordered
|
||||
}
|
||||
|
||||
// -- Small utilities --------------------------------------------------------
|
||||
|
||||
// ShellQuote returns value single-quoted for bash, escaping embedded quotes.
|
||||
func ShellQuote(value string) string {
|
||||
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
|
||||
}
|
||||
|
||||
// ExitCode extracts the exit status from an ssh.ExitError, returning
|
||||
// (0, true) for nil errors.
|
||||
func ExitCode(err error) (int, bool) {
|
||||
if err == nil {
|
||||
return 0, true
|
||||
}
|
||||
var exitErr *ssh.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
return exitErr.ExitStatus(), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// CloneStringMap returns a shallow copy of values, or nil if empty.
|
||||
func CloneStringMap(values map[string]string) map[string]string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make(map[string]string, len(values))
|
||||
for key, value := range values {
|
||||
cloned[key] = value
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
// TailFileContent returns the last N lines of a file, or "" if the file is
|
||||
// missing.
|
||||
func TailFileContent(path string, lines int) (string, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if lines <= 0 {
|
||||
return string(data), nil
|
||||
}
|
||||
parts := strings.Split(string(data), "\n")
|
||||
if len(parts) <= lines {
|
||||
return string(data), nil
|
||||
}
|
||||
return strings.Join(parts[len(parts)-lines-1:], "\n"), nil
|
||||
}
|
||||
|
||||
// ProcessAlive returns true if the process with pid exists. The syscallKill
|
||||
// override is exposed for tests that need to simulate alive/dead processes.
|
||||
func ProcessAlive(pid int) bool {
|
||||
if pid <= 0 {
|
||||
return false
|
||||
}
|
||||
return syscallKill(pid, syscall.Signal(0)) == nil
|
||||
}
|
||||
|
||||
// syscallKill is a test seam: tests replace it to stub process-alive checks.
|
||||
var syscallKill = func(pid int, signal os.Signal) error {
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return proc.Signal(signal)
|
||||
}
|
||||
|
||||
// FormatStepError wraps err with an action label and trimmed on-guest log.
|
||||
func FormatStepError(action string, err error, log string) error {
|
||||
log = strings.TrimSpace(log)
|
||||
if log == "" {
|
||||
return fmt.Errorf("%s: %w", action, err)
|
||||
}
|
||||
return fmt.Errorf("%s: %w: %s", action, err, log)
|
||||
}
|
||||
|
|
@ -1,440 +0,0 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"banger/internal/model"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestRelativeStateDir(t *testing.T) {
|
||||
got := RelativeStateDir("abc")
|
||||
if strings.HasPrefix(got, "/root/") {
|
||||
t.Fatalf("RelativeStateDir(%q) = %q, should strip /root/ prefix", "abc", got)
|
||||
}
|
||||
if !strings.Contains(got, "abc") {
|
||||
t.Fatalf("missing session id in %q", got)
|
||||
}
|
||||
absolute := StateDir("abc")
|
||||
if got != strings.TrimPrefix(absolute, "/root/") {
|
||||
t.Fatalf("relative = %q, want %q", got, strings.TrimPrefix(absolute, "/root/"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCWD(t *testing.T) {
|
||||
if DefaultCWD("") != "/root" {
|
||||
t.Error("empty should return /root")
|
||||
}
|
||||
if DefaultCWD(" ") != "/root" {
|
||||
t.Error("whitespace should return /root")
|
||||
}
|
||||
if DefaultCWD("/work") != "/work" {
|
||||
t.Error("explicit should pass through")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellQuote(t *testing.T) {
|
||||
if got := ShellQuote(""); got != "''" {
|
||||
t.Errorf("empty: got %q, want ''", got)
|
||||
}
|
||||
if got := ShellQuote("x"); got != "'x'" {
|
||||
t.Errorf("plain: got %q", got)
|
||||
}
|
||||
if got := ShellQuote("it's"); got != `'it'"'"'s'` {
|
||||
t.Errorf("apostrophe: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitCode(t *testing.T) {
|
||||
if code, ok := ExitCode(nil); !ok || code != 0 {
|
||||
t.Errorf("nil err: got (%d, %v), want (0, true)", code, ok)
|
||||
}
|
||||
// Build an ssh.ExitError using its real type — can't hand-construct,
|
||||
// so wrap via errors.As check with a stub.
|
||||
raw := &ssh.ExitError{}
|
||||
if _, ok := ExitCode(raw); !ok {
|
||||
t.Error("ssh.ExitError: ok should be true")
|
||||
}
|
||||
if _, ok := ExitCode(errors.New("bare error")); ok {
|
||||
t.Error("bare error: ok should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneStringMap(t *testing.T) {
|
||||
if CloneStringMap(nil) != nil {
|
||||
t.Error("nil in → nil out")
|
||||
}
|
||||
if CloneStringMap(map[string]string{}) != nil {
|
||||
t.Error("empty in → nil out")
|
||||
}
|
||||
src := map[string]string{"a": "1", "b": "2"}
|
||||
cloned := CloneStringMap(src)
|
||||
if len(cloned) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(cloned))
|
||||
}
|
||||
cloned["a"] = "changed"
|
||||
if src["a"] != "1" {
|
||||
t.Error("mutating clone leaked back to source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTailFileContent(t *testing.T) {
|
||||
// Missing file → empty, no error.
|
||||
got, err := TailFileContent(filepath.Join(t.TempDir(), "missing"), 10)
|
||||
if err != nil || got != "" {
|
||||
t.Errorf("missing: got (%q, %v), want ('', nil)", got, err)
|
||||
}
|
||||
|
||||
path := filepath.Join(t.TempDir(), "log")
|
||||
lines := "one\ntwo\nthree\nfour\nfive"
|
||||
if err := os.WriteFile(path, []byte(lines), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
|
||||
full, err := TailFileContent(path, 0)
|
||||
if err != nil || full != lines {
|
||||
t.Errorf("0 lines: got (%q, %v), want (%q, nil)", full, err, lines)
|
||||
}
|
||||
|
||||
// Request more lines than exist → full content.
|
||||
all, err := TailFileContent(path, 999)
|
||||
if err != nil || all != lines {
|
||||
t.Errorf("999 lines: got %q", all)
|
||||
}
|
||||
|
||||
last2, err := TailFileContent(path, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("2 lines: %v", err)
|
||||
}
|
||||
if !strings.Contains(last2, "five") {
|
||||
t.Errorf("2 lines missing last line: %q", last2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessAlive(t *testing.T) {
|
||||
if ProcessAlive(0) {
|
||||
t.Error("pid 0 should not be alive")
|
||||
}
|
||||
if ProcessAlive(-1) {
|
||||
t.Error("negative pid should not be alive")
|
||||
}
|
||||
// Swap the syscall seam.
|
||||
original := syscallKill
|
||||
t.Cleanup(func() { syscallKill = original })
|
||||
|
||||
syscallKill = func(pid int, signal os.Signal) error { return nil }
|
||||
if !ProcessAlive(42) {
|
||||
t.Error("syscallKill=nil should report alive")
|
||||
}
|
||||
|
||||
syscallKill = func(pid int, signal os.Signal) error { return fmt.Errorf("no such process") }
|
||||
if ProcessAlive(42) {
|
||||
t.Error("syscallKill error should report dead")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatStepError(t *testing.T) {
|
||||
base := errors.New("boom")
|
||||
err := FormatStepError("prepare", base, "")
|
||||
if !errors.Is(err, base) {
|
||||
t.Error("FormatStepError should wrap the base error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "prepare") {
|
||||
t.Errorf("missing action: %v", err)
|
||||
}
|
||||
|
||||
errWithLog := FormatStepError("prepare", base, " log line\n")
|
||||
if !strings.Contains(errWithLog.Error(), "log line") {
|
||||
t.Errorf("missing log: %v", errWithLog)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStateHappyPath(t *testing.T) {
|
||||
raw := `status=running
|
||||
pid=123
|
||||
exit=
|
||||
alive=true
|
||||
error=
|
||||
`
|
||||
snap, err := ParseState(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseState: %v", err)
|
||||
}
|
||||
if snap.Status != "running" {
|
||||
t.Errorf("Status = %q", snap.Status)
|
||||
}
|
||||
if snap.GuestPID != 123 {
|
||||
t.Errorf("GuestPID = %d", snap.GuestPID)
|
||||
}
|
||||
if snap.ExitCode != nil {
|
||||
t.Errorf("ExitCode should be nil when empty, got %v", snap.ExitCode)
|
||||
}
|
||||
if !snap.Alive {
|
||||
t.Error("Alive should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStateWithExit(t *testing.T) {
|
||||
raw := `status=exited
|
||||
pid=123
|
||||
exit=7
|
||||
alive=false
|
||||
error=something bad
|
||||
`
|
||||
snap, err := ParseState(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseState: %v", err)
|
||||
}
|
||||
if snap.ExitCode == nil || *snap.ExitCode != 7 {
|
||||
t.Errorf("ExitCode = %v, want 7", snap.ExitCode)
|
||||
}
|
||||
if snap.LastError != "something bad" {
|
||||
t.Errorf("LastError = %q", snap.LastError)
|
||||
}
|
||||
if snap.Alive {
|
||||
t.Error("Alive should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStateIgnoresMalformedLines(t *testing.T) {
|
||||
raw := "no-equals-here\nstatus=ok\n"
|
||||
snap, err := ParseState(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseState: %v", err)
|
||||
}
|
||||
if snap.Status != "ok" {
|
||||
t.Errorf("Status = %q, want ok", snap.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectStateFromDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeFile := func(name, content string) {
|
||||
if err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile(%s): %v", name, err)
|
||||
}
|
||||
}
|
||||
writeFile("status", "running\n")
|
||||
writeFile("pid", "42\n")
|
||||
writeFile("exit_code", "0\n")
|
||||
writeFile("error", "\n")
|
||||
|
||||
original := syscallKill
|
||||
t.Cleanup(func() { syscallKill = original })
|
||||
syscallKill = func(pid int, signal os.Signal) error { return nil }
|
||||
|
||||
snap, err := InspectStateFromDir(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("InspectStateFromDir: %v", err)
|
||||
}
|
||||
if snap.Status != "running" {
|
||||
t.Errorf("Status = %q", snap.Status)
|
||||
}
|
||||
if snap.GuestPID != 42 {
|
||||
t.Errorf("GuestPID = %d", snap.GuestPID)
|
||||
}
|
||||
if snap.ExitCode == nil || *snap.ExitCode != 0 {
|
||||
t.Errorf("ExitCode = %v, want 0", snap.ExitCode)
|
||||
}
|
||||
if !snap.Alive {
|
||||
t.Error("Alive should reflect syscallKill result (true)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectStateFromDirMissingFiles(t *testing.T) {
|
||||
snap, err := InspectStateFromDir(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("InspectStateFromDir (empty): %v", err)
|
||||
}
|
||||
if snap.Status != "" || snap.GuestPID != 0 || snap.ExitCode != nil {
|
||||
t.Errorf("empty dir: snap = %+v", snap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotNilReceiver(t *testing.T) {
|
||||
ApplyStateSnapshot(nil, StateSnapshot{}, true) // should not panic
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotExitedSuccess(t *testing.T) {
|
||||
exit := 0
|
||||
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning, Attachable: true, Reattachable: true}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true)
|
||||
if sess.Status != model.GuestSessionStatusExited {
|
||||
t.Errorf("Status = %q, want exited", sess.Status)
|
||||
}
|
||||
if sess.Attachable || sess.Reattachable {
|
||||
t.Error("attach flags should be cleared on exit")
|
||||
}
|
||||
if sess.EndedAt.IsZero() {
|
||||
t.Error("EndedAt should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotExitedFailure(t *testing.T) {
|
||||
exit := 2
|
||||
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true)
|
||||
if sess.Status != model.GuestSessionStatusFailed {
|
||||
t.Errorf("Status = %q, want failed", sess.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotVMGone(t *testing.T) {
|
||||
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{Alive: false}, false)
|
||||
if sess.Status != model.GuestSessionStatusFailed {
|
||||
t.Errorf("Status = %q, want failed", sess.Status)
|
||||
}
|
||||
if sess.LastError == "" {
|
||||
t.Error("LastError should be populated when VM is gone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotRunningStatusSetsAttachableForPipe(t *testing.T) {
|
||||
// When the guest-side status file reports "running" (Alive=false from
|
||||
// kill -0 may still fail transiently), ApplyStateSnapshot transitions
|
||||
// the session to running and sets attach flags for pipe-mode.
|
||||
sess := &model.GuestSession{
|
||||
Status: model.GuestSessionStatusStarting,
|
||||
StdinMode: model.GuestSessionStdinPipe,
|
||||
}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{Status: string(model.GuestSessionStatusRunning), GuestPID: 11}, true)
|
||||
if sess.Status != model.GuestSessionStatusRunning {
|
||||
t.Errorf("Status = %q, want running", sess.Status)
|
||||
}
|
||||
if !sess.Attachable || !sess.Reattachable {
|
||||
t.Error("pipe-mode running session should be attachable + reattachable")
|
||||
}
|
||||
if sess.AttachBackend != AttachBackendSSHBridge {
|
||||
t.Errorf("AttachBackend = %q, want %q", sess.AttachBackend, AttachBackendSSHBridge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotAliveEarlyReturn(t *testing.T) {
|
||||
// Alive-true returns immediately after setting status; no attach
|
||||
// flags set on this path (by design — attach metadata only attaches
|
||||
// to status-driven transitions).
|
||||
sess := &model.GuestSession{
|
||||
Status: model.GuestSessionStatusStarting,
|
||||
StdinMode: model.GuestSessionStdinPipe,
|
||||
}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{Alive: true, GuestPID: 11}, true)
|
||||
if sess.Status != model.GuestSessionStatusRunning {
|
||||
t.Errorf("Status = %q, want running", sess.Status)
|
||||
}
|
||||
if sess.StartedAt.IsZero() {
|
||||
t.Error("StartedAt should have been set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateChanged(t *testing.T) {
|
||||
base := model.GuestSession{Status: model.GuestSessionStatusRunning, GuestPID: 10}
|
||||
|
||||
// Identical → no change.
|
||||
if StateChanged(base, base) {
|
||||
t.Error("identical states should not be considered changed")
|
||||
}
|
||||
|
||||
// Status change.
|
||||
changed := base
|
||||
changed.Status = model.GuestSessionStatusExited
|
||||
if !StateChanged(base, changed) {
|
||||
t.Error("status change should be detected")
|
||||
}
|
||||
|
||||
// ExitCode change from nil → value.
|
||||
exit := 3
|
||||
changed = base
|
||||
changed.ExitCode = &exit
|
||||
if !StateChanged(base, changed) {
|
||||
t.Error("exit-code appearing should be detected")
|
||||
}
|
||||
|
||||
// Both have the same exit code → no change.
|
||||
a := base
|
||||
a.ExitCode = &exit
|
||||
b := base
|
||||
b.ExitCode = &exit
|
||||
if StateChanged(a, b) {
|
||||
t.Error("matching exit codes should not trigger change")
|
||||
}
|
||||
|
||||
// Different exit codes.
|
||||
other := 5
|
||||
b.ExitCode = &other
|
||||
if !StateChanged(a, b) {
|
||||
t.Error("differing exit codes should be detected")
|
||||
}
|
||||
|
||||
// Timestamp change.
|
||||
changed = base
|
||||
changed.StartedAt = time.Now()
|
||||
if !StateChanged(base, changed) {
|
||||
t.Error("StartedAt change should be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailLaunch(t *testing.T) {
|
||||
in := model.GuestSession{Status: model.GuestSessionStatusStarting, Attachable: true}
|
||||
out := FailLaunch(in, "provision", " ssh did not come up ", " raw output\n")
|
||||
if out.Status != model.GuestSessionStatusFailed {
|
||||
t.Errorf("Status = %q, want failed", out.Status)
|
||||
}
|
||||
if out.LastError != "ssh did not come up" {
|
||||
t.Errorf("LastError = %q (not trimmed?)", out.LastError)
|
||||
}
|
||||
if out.LaunchStage != "provision" || out.LaunchMessage != "ssh did not come up" {
|
||||
t.Errorf("launch fields not set: %+v", out)
|
||||
}
|
||||
if out.LaunchRawLog != "raw output" {
|
||||
t.Errorf("rawLog = %q (not trimmed?)", out.LaunchRawLog)
|
||||
}
|
||||
if out.Attachable {
|
||||
t.Error("Attachable should be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeRequiredCommands(t *testing.T) {
|
||||
got := NormalizeRequiredCommands("pi", []string{"pi", "git", "", "git", " ", "make"})
|
||||
want := []string{"pi", "git", "make"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("len = %d, want %d (%v)", len(got), len(want), got)
|
||||
}
|
||||
for i, v := range want {
|
||||
if got[i] != v {
|
||||
t.Errorf("position %d: got %q, want %q", i, got[i], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectScriptContainsAllStateFiles(t *testing.T) {
|
||||
script := InspectScript("sess-abc")
|
||||
for _, key := range []string{"status", "pid", "exit_code", "error", "alive"} {
|
||||
if !strings.Contains(script, key) {
|
||||
t.Errorf("script missing %q:\n%s", key, script)
|
||||
}
|
||||
}
|
||||
if !strings.Contains(script, "sess-abc") {
|
||||
t.Error("script missing session id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalScriptIncludesSignalAndDirPaths(t *testing.T) {
|
||||
script := SignalScript("sess-x", "TERM")
|
||||
if !strings.Contains(script, "TERM") {
|
||||
t.Error("missing signal")
|
||||
}
|
||||
if !strings.Contains(script, "sess-x") {
|
||||
t.Error("missing session id")
|
||||
}
|
||||
if !strings.Contains(script, "monitor_pid") || !strings.Contains(script, "stdin_keepalive") {
|
||||
t.Errorf("expected both monitor + stdin_keepalive kills, got:\n%s", script)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,224 +0,0 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"banger/internal/api"
|
||||
sess "banger/internal/daemon/session"
|
||||
"banger/internal/guest"
|
||||
"banger/internal/model"
|
||||
"banger/internal/sessionstream"
|
||||
)
|
||||
|
||||
func (d *Daemon) BeginGuestSessionAttach(ctx context.Context, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) {
|
||||
vm, err := d.FindVM(ctx, params.VMIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionAttachBeginResult{}, err
|
||||
}
|
||||
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionAttachBeginResult{}, err
|
||||
}
|
||||
session, _ = d.refreshGuestSession(ctx, vm, session)
|
||||
if !session.Attachable {
|
||||
return api.GuestSessionAttachBeginResult{}, errors.New("session is not attachable")
|
||||
}
|
||||
controller := &guestSessionController{}
|
||||
if !d.claimGuestSessionController(session.ID, controller) {
|
||||
return api.GuestSessionAttachBeginResult{}, errors.New("session already has an active attach")
|
||||
}
|
||||
attachID, err := model.NewID()
|
||||
if err != nil {
|
||||
d.clearGuestSessionController(session.ID)
|
||||
return api.GuestSessionAttachBeginResult{}, err
|
||||
}
|
||||
socketPath := filepath.Join(d.layout.RuntimeDir, "guest-session-attach-"+attachID[:12]+".sock")
|
||||
_ = os.Remove(socketPath)
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
d.clearGuestSessionController(session.ID)
|
||||
return api.GuestSessionAttachBeginResult{}, err
|
||||
}
|
||||
if err := os.Chmod(socketPath, 0o600); err != nil {
|
||||
_ = listener.Close()
|
||||
_ = os.Remove(socketPath)
|
||||
d.clearGuestSessionController(session.ID)
|
||||
return api.GuestSessionAttachBeginResult{}, err
|
||||
}
|
||||
go d.serveGuestSessionAttach(session, controller, attachID, socketPath, listener)
|
||||
return api.GuestSessionAttachBeginResult{
|
||||
Session: session,
|
||||
AttachID: attachID,
|
||||
TransportKind: sess.TransportUnixSocket,
|
||||
TransportTarget: socketPath,
|
||||
SocketPath: socketPath,
|
||||
StreamFormat: sessionstream.FormatV1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) forwardGuestSessionOutput(_ string, controller *guestSessionController, channel byte, reader io.Reader) {
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := reader.Read(buffer)
|
||||
if n > 0 {
|
||||
controller.writeFrame(channel, buffer[:n])
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
controller.writeControl(sessionstream.ControlMessage{Type: "error", Error: err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) waitForGuestSessionExit(id string, controller *guestSessionController, session model.GuestSession) {
|
||||
err := controller.stream.Wait()
|
||||
updated := session
|
||||
updated.Attachable = false
|
||||
now := model.Now()
|
||||
updated.UpdatedAt = now
|
||||
updated.EndedAt = now
|
||||
if exitCode, ok := sess.ExitCode(err); ok {
|
||||
updated.ExitCode = &exitCode
|
||||
if exitCode == 0 {
|
||||
updated.Status = model.GuestSessionStatusExited
|
||||
} else {
|
||||
updated.Status = model.GuestSessionStatusFailed
|
||||
}
|
||||
}
|
||||
if err != nil && updated.LastError == "" {
|
||||
updated.LastError = err.Error()
|
||||
}
|
||||
if vm, getErr := d.store.GetVMByID(context.Background(), updated.VMID); getErr == nil {
|
||||
if refreshed, refreshErr := d.refreshGuestSession(context.Background(), vm, updated); refreshErr == nil {
|
||||
updated = refreshed
|
||||
}
|
||||
}
|
||||
_ = d.store.UpsertGuestSession(context.Background(), updated)
|
||||
controller.writeControl(sessionstream.ControlMessage{Type: "exit", ExitCode: updated.ExitCode})
|
||||
_ = controller.close()
|
||||
d.clearGuestSessionController(id)
|
||||
}
|
||||
|
||||
func (d *Daemon) serveGuestSessionAttach(session model.GuestSession, controller *guestSessionController, _ string, socketPath string, listener net.Listener) {
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
_ = os.Remove(socketPath)
|
||||
_ = controller.close()
|
||||
d.clearGuestSessionController(session.ID)
|
||||
}()
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
if err := controller.setAttach(conn); err != nil {
|
||||
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
|
||||
return
|
||||
}
|
||||
defer controller.clearAttach(conn)
|
||||
if err := d.attachGuestSessionBridge(session, controller); err != nil {
|
||||
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
|
||||
return
|
||||
}
|
||||
for {
|
||||
channel, payload, err := sessionstream.ReadFrame(conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch channel {
|
||||
case sessionstream.ChannelStdin:
|
||||
if controller.stdin == nil {
|
||||
continue
|
||||
}
|
||||
if _, err := controller.stdin.Write(payload); err != nil {
|
||||
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
|
||||
return
|
||||
}
|
||||
case sessionstream.ChannelControl:
|
||||
message, err := sessionstream.ReadControl(payload)
|
||||
if err != nil {
|
||||
_ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()})
|
||||
return
|
||||
}
|
||||
if message.Type == "eof" && controller.stdin != nil {
|
||||
_ = controller.stdin.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) attachGuestSessionBridge(session model.GuestSession, controller *guestSessionController) error {
|
||||
vm, err := d.store.GetVMByID(context.Background(), session.VMID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !d.vmAlive(vm) {
|
||||
return fmt.Errorf("vm %q is not running", vm.Name)
|
||||
}
|
||||
address := net.JoinHostPort(vm.Runtime.GuestIP, "22")
|
||||
stdinStream, err := d.openGuestSessionAttachStream(address, sess.AttachInputCommand(session.ID))
|
||||
if err != nil {
|
||||
return fmt.Errorf("open guest session stdin stream: %w", err)
|
||||
}
|
||||
stdoutStream, err := d.openGuestSessionAttachStream(address, sess.AttachTailCommand(session.StdoutLogPath))
|
||||
if err != nil {
|
||||
_ = stdinStream.Close()
|
||||
return fmt.Errorf("open guest session stdout stream: %w", err)
|
||||
}
|
||||
stderrStream, err := d.openGuestSessionAttachStream(address, sess.AttachTailCommand(session.StderrLogPath))
|
||||
if err != nil {
|
||||
_ = stdinStream.Close()
|
||||
_ = stdoutStream.Close()
|
||||
return fmt.Errorf("open guest session stderr stream: %w", err)
|
||||
}
|
||||
controller.streams = append(controller.streams, stdinStream, stdoutStream, stderrStream)
|
||||
controller.stdin = stdinStream.Stdin()
|
||||
go d.forwardGuestSessionOutput(session.ID, controller, sessionstream.ChannelStdout, stdoutStream.Stdout())
|
||||
go d.forwardGuestSessionOutput(session.ID, controller, sessionstream.ChannelStderr, stderrStream.Stdout())
|
||||
go d.watchGuestSessionAttach(session.ID, controller, session)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Daemon) openGuestSessionAttachStream(address, command string) (*guest.StreamSession, error) {
|
||||
client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream, err := client.StartCommand(context.Background(), command)
|
||||
if err != nil {
|
||||
_ = client.Close()
|
||||
return nil, err
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) watchGuestSessionAttach(id string, controller *guestSessionController, session model.GuestSession) {
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
vm, err := d.store.GetVMByID(context.Background(), session.VMID)
|
||||
if err != nil {
|
||||
controller.writeControl(sessionstream.ControlMessage{Type: "error", Error: err.Error()})
|
||||
_ = controller.close()
|
||||
return
|
||||
}
|
||||
refreshed, err := d.refreshGuestSession(context.Background(), vm, session)
|
||||
if err == nil {
|
||||
session = refreshed
|
||||
}
|
||||
if session.Status == model.GuestSessionStatusExited || session.Status == model.GuestSessionStatusFailed {
|
||||
controller.writeControl(sessionstream.ControlMessage{Type: "exit", ExitCode: session.ExitCode})
|
||||
_ = controller.close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"banger/internal/guest"
|
||||
"banger/internal/sessionstream"
|
||||
)
|
||||
|
||||
type guestSessionController struct {
|
||||
stream *guest.StreamSession
|
||||
streams []*guest.StreamSession
|
||||
stdin io.WriteCloser
|
||||
attachMu sync.Mutex
|
||||
attach net.Conn
|
||||
writeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (c *guestSessionController) setAttach(conn net.Conn) error {
|
||||
c.attachMu.Lock()
|
||||
defer c.attachMu.Unlock()
|
||||
if c.attach != nil {
|
||||
return errors.New("session already has an active attach")
|
||||
}
|
||||
c.attach = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *guestSessionController) clearAttach(conn net.Conn) {
|
||||
c.attachMu.Lock()
|
||||
defer c.attachMu.Unlock()
|
||||
if c.attach == conn {
|
||||
c.attach = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *guestSessionController) writeFrame(channel byte, payload []byte) {
|
||||
c.attachMu.Lock()
|
||||
conn := c.attach
|
||||
c.attachMu.Unlock()
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
c.writeMu.Lock()
|
||||
err := sessionstream.WriteFrame(conn, channel, payload)
|
||||
c.writeMu.Unlock()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
c.clearAttach(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *guestSessionController) writeControl(message sessionstream.ControlMessage) {
|
||||
c.attachMu.Lock()
|
||||
conn := c.attach
|
||||
c.attachMu.Unlock()
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
c.writeMu.Lock()
|
||||
err := sessionstream.WriteControl(conn, message)
|
||||
c.writeMu.Unlock()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
c.clearAttach(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *guestSessionController) close() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
c.closeOnce.Do(func() {
|
||||
c.attachMu.Lock()
|
||||
conn := c.attach
|
||||
c.attach = nil
|
||||
c.attachMu.Unlock()
|
||||
if conn != nil {
|
||||
err = errors.Join(err, conn.Close())
|
||||
}
|
||||
if c.stdin != nil {
|
||||
err = errors.Join(err, c.stdin.Close())
|
||||
}
|
||||
if c.stream != nil {
|
||||
err = errors.Join(err, c.stream.Close())
|
||||
}
|
||||
for _, stream := range c.streams {
|
||||
if stream != nil {
|
||||
err = errors.Join(err, stream.Close())
|
||||
}
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// sessionRegistry owns the live guest-session controller map. Its lock is
|
||||
// independent of Daemon.mu so guest-session lookups do not contend with
|
||||
// unrelated daemon state.
|
||||
type sessionRegistry struct {
|
||||
mu sync.Mutex
|
||||
byID map[string]*guestSessionController
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newSessionRegistry() sessionRegistry {
|
||||
return sessionRegistry{byID: make(map[string]*guestSessionController)}
|
||||
}
|
||||
|
||||
func (r *sessionRegistry) set(id string, controller *guestSessionController) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.closed {
|
||||
return
|
||||
}
|
||||
r.byID[id] = controller
|
||||
}
|
||||
|
||||
func (r *sessionRegistry) claim(id string, controller *guestSessionController) bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.closed {
|
||||
return false
|
||||
}
|
||||
if r.byID[id] != nil {
|
||||
return false
|
||||
}
|
||||
r.byID[id] = controller
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *sessionRegistry) get(id string) *guestSessionController {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.byID[id]
|
||||
}
|
||||
|
||||
func (r *sessionRegistry) clear(id string) *guestSessionController {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
controller := r.byID[id]
|
||||
delete(r.byID, id)
|
||||
return controller
|
||||
}
|
||||
|
||||
func (r *sessionRegistry) closeAll() error {
|
||||
r.mu.Lock()
|
||||
controllers := make([]*guestSessionController, 0, len(r.byID))
|
||||
for _, controller := range r.byID {
|
||||
controllers = append(controllers, controller)
|
||||
}
|
||||
r.byID = nil
|
||||
r.closed = true
|
||||
r.mu.Unlock()
|
||||
var err error
|
||||
for _, controller := range controllers {
|
||||
err = errors.Join(err, controller.close())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Daemon) setGuestSessionController(id string, controller *guestSessionController) {
|
||||
d.sessions.set(id, controller)
|
||||
}
|
||||
|
||||
func (d *Daemon) claimGuestSessionController(id string, controller *guestSessionController) bool {
|
||||
return d.sessions.claim(id, controller)
|
||||
}
|
||||
|
||||
func (d *Daemon) getGuestSessionController(id string) *guestSessionController {
|
||||
return d.sessions.get(id)
|
||||
}
|
||||
|
||||
func (d *Daemon) clearGuestSessionController(id string) *guestSessionController {
|
||||
return d.sessions.clear(id)
|
||||
}
|
||||
|
||||
func (d *Daemon) closeGuestSessionControllers() error {
|
||||
return d.sessions.closeAll()
|
||||
}
|
||||
|
|
@ -1,213 +0,0 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"banger/internal/api"
|
||||
sess "banger/internal/daemon/session"
|
||||
"banger/internal/guest"
|
||||
"banger/internal/model"
|
||||
)
|
||||
|
||||
func (d *Daemon) StartGuestSession(ctx context.Context, params api.GuestSessionStartParams) (model.GuestSession, error) {
|
||||
stdinMode := model.GuestSessionStdinMode(strings.TrimSpace(params.StdinMode))
|
||||
if stdinMode == "" {
|
||||
stdinMode = model.GuestSessionStdinClosed
|
||||
}
|
||||
if stdinMode != model.GuestSessionStdinClosed && stdinMode != model.GuestSessionStdinPipe {
|
||||
return model.GuestSession{}, fmt.Errorf("unsupported stdin mode %q", params.StdinMode)
|
||||
}
|
||||
if strings.TrimSpace(params.Command) == "" {
|
||||
return model.GuestSession{}, errors.New("session command is required")
|
||||
}
|
||||
var created model.GuestSession
|
||||
_, err := d.withVMLockByRef(ctx, params.VMIDOrName, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
if !d.vmAlive(vm) {
|
||||
return model.VMRecord{}, fmt.Errorf("vm %q is not running", vm.Name)
|
||||
}
|
||||
session, err := d.startGuestSessionLocked(ctx, vm, params, stdinMode)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
created = session
|
||||
return vm, nil
|
||||
})
|
||||
return created, err
|
||||
}
|
||||
|
||||
func (d *Daemon) startGuestSessionLocked(ctx context.Context, vm model.VMRecord, params api.GuestSessionStartParams, stdinMode model.GuestSessionStdinMode) (model.GuestSession, error) {
|
||||
id, err := model.NewID()
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
now := model.Now()
|
||||
session := model.GuestSession{
|
||||
ID: id,
|
||||
VMID: vm.ID,
|
||||
Name: sess.DefaultName(id, params.Command, params.Name),
|
||||
Backend: sess.BackendSSH,
|
||||
Command: params.Command,
|
||||
Args: append([]string(nil), params.Args...),
|
||||
CWD: strings.TrimSpace(params.CWD),
|
||||
Env: sess.CloneStringMap(params.Env),
|
||||
StdinMode: stdinMode,
|
||||
Status: model.GuestSessionStatusStarting,
|
||||
GuestStateDir: sess.StateDir(id),
|
||||
StdoutLogPath: sess.StdoutLogPath(id),
|
||||
StderrLogPath: sess.StderrLogPath(id),
|
||||
Tags: sess.CloneStringMap(params.Tags),
|
||||
Attachable: stdinMode == model.GuestSessionStdinPipe,
|
||||
Reattachable: stdinMode == model.GuestSessionStdinPipe,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if session.Attachable {
|
||||
session.AttachBackend = sess.AttachBackendSSHBridge
|
||||
session.AttachMode = sess.AttachModeExclusive
|
||||
} else {
|
||||
session.AttachBackend = sess.AttachBackendNone
|
||||
}
|
||||
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
fail := func(stage, message, rawLog string) (model.GuestSession, error) {
|
||||
session = sess.FailLaunch(session, stage, message, rawLog)
|
||||
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
address := net.JoinHostPort(vm.Runtime.GuestIP, "22")
|
||||
if err := d.waitForGuestSSH(ctx, address, 250*time.Millisecond); err != nil {
|
||||
return fail("ssh_unavailable", fmt.Sprintf("guest ssh unavailable: %v", err), "")
|
||||
}
|
||||
client, err := d.dialGuest(ctx, address)
|
||||
if err != nil {
|
||||
return fail("dial_guest", fmt.Sprintf("dial guest ssh: %v", err), "")
|
||||
}
|
||||
defer client.Close()
|
||||
var preflightLog bytes.Buffer
|
||||
if err := client.RunScript(ctx, sess.CWDPreflightScript(session.CWD), &preflightLog); err != nil {
|
||||
return fail("preflight_cwd", fmt.Sprintf("guest working directory is unavailable: %s", sess.DefaultCWD(session.CWD)), preflightLog.String())
|
||||
}
|
||||
preflightLog.Reset()
|
||||
requiredCommands := sess.NormalizeRequiredCommands(params.Command, params.RequiredCommands)
|
||||
if err := client.RunScript(ctx, sess.CommandPreflightScript(requiredCommands), &preflightLog); err != nil {
|
||||
return fail("preflight_command", fmt.Sprintf("required guest command is unavailable: %s", strings.TrimSpace(preflightLog.String())), preflightLog.String())
|
||||
}
|
||||
var uploadLog bytes.Buffer
|
||||
if err := client.UploadFile(ctx, sess.ScriptPath(id), 0o755, []byte(sess.Script(session)), &uploadLog); err != nil {
|
||||
return fail("upload_script", "upload guest session script failed", uploadLog.String())
|
||||
}
|
||||
var launchLog bytes.Buffer
|
||||
launchScript := fmt.Sprintf("set -euo pipefail\nnohup bash %s >/dev/null 2>&1 </dev/null &\ndisown || true\n", sess.ShellQuote(sess.ScriptPath(id)))
|
||||
if err := client.RunScript(ctx, launchScript, &launchLog); err != nil {
|
||||
return fail("launch", "launch guest session failed", launchLog.String())
|
||||
}
|
||||
readyCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
updated, err := d.waitForGuestSessionReadyHook(readyCtx, vm, session)
|
||||
if err != nil {
|
||||
return fail("ready_wait", "guest session did not report ready state", err.Error())
|
||||
}
|
||||
session = updated
|
||||
if session.Status == model.GuestSessionStatusStarting {
|
||||
session.Status = model.GuestSessionStatusRunning
|
||||
session.StartedAt = model.Now()
|
||||
session.UpdatedAt = model.Now()
|
||||
}
|
||||
session.LaunchStage = ""
|
||||
session.LaunchMessage = ""
|
||||
session.LaunchRawLog = ""
|
||||
session.LastError = ""
|
||||
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) GetGuestSession(ctx context.Context, params api.GuestSessionRefParams) (model.GuestSession, error) {
|
||||
vm, err := d.FindVM(ctx, params.VMIDOrName)
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
return d.refreshGuestSession(ctx, vm, session)
|
||||
}
|
||||
|
||||
func (d *Daemon) ListGuestSessions(ctx context.Context, params api.VMRefParams) ([]model.GuestSession, error) {
|
||||
vm, err := d.FindVM(ctx, params.IDOrName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions, err := d.store.ListGuestSessionsByVM(ctx, vm.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for index := range sessions {
|
||||
refreshed, refreshErr := d.refreshGuestSession(ctx, vm, sessions[index])
|
||||
if refreshErr == nil {
|
||||
sessions[index] = refreshed
|
||||
}
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) StopGuestSession(ctx context.Context, params api.GuestSessionRefParams) (model.GuestSession, error) {
|
||||
return d.signalGuestSession(ctx, params, "TERM")
|
||||
}
|
||||
|
||||
func (d *Daemon) KillGuestSession(ctx context.Context, params api.GuestSessionRefParams) (model.GuestSession, error) {
|
||||
return d.signalGuestSession(ctx, params, "KILL")
|
||||
}
|
||||
|
||||
func (d *Daemon) signalGuestSession(ctx context.Context, params api.GuestSessionRefParams, signal string) (model.GuestSession, error) {
|
||||
vm, err := d.FindVM(ctx, params.VMIDOrName)
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
session, _ = d.refreshGuestSession(ctx, vm, session)
|
||||
if session.Status == model.GuestSessionStatusExited || session.Status == model.GuestSessionStatusFailed {
|
||||
return session, nil
|
||||
}
|
||||
if !d.vmAlive(vm) {
|
||||
session.Status = model.GuestSessionStatusFailed
|
||||
session.LastError = "vm is not running"
|
||||
now := model.Now()
|
||||
session.UpdatedAt = now
|
||||
session.EndedAt = now
|
||||
session.Attachable = false
|
||||
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
if err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
defer client.Close()
|
||||
var log bytes.Buffer
|
||||
if err := client.RunScript(ctx, sess.SignalScript(session.ID, signal), &log); err != nil {
|
||||
return model.GuestSession{}, sess.FormatStepError("signal guest session", err, log.String())
|
||||
}
|
||||
session.Status = model.GuestSessionStatusStopping
|
||||
session.UpdatedAt = model.Now()
|
||||
if err := d.store.UpsertGuestSession(ctx, session); err != nil {
|
||||
return model.GuestSession{}, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
|
@ -1,120 +0,0 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"banger/internal/api"
|
||||
sess "banger/internal/daemon/session"
|
||||
"banger/internal/guest"
|
||||
"banger/internal/model"
|
||||
"banger/internal/system"
|
||||
)
|
||||
|
||||
func (d *Daemon) GuestSessionLogs(ctx context.Context, params api.GuestSessionLogsParams) (api.GuestSessionLogsResult, error) {
|
||||
vm, err := d.FindVM(ctx, params.VMIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionLogsResult{}, err
|
||||
}
|
||||
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionLogsResult{}, err
|
||||
}
|
||||
streamName := strings.TrimSpace(params.Stream)
|
||||
if streamName == "" {
|
||||
streamName = "stdout"
|
||||
}
|
||||
tailLines := params.TailLines
|
||||
if tailLines <= 0 {
|
||||
tailLines = sess.LogTailLineDefault
|
||||
}
|
||||
path := session.StdoutLogPath
|
||||
if streamName == "stderr" {
|
||||
path = session.StderrLogPath
|
||||
}
|
||||
content, err := d.readGuestSessionLog(ctx, vm, session, streamName, tailLines)
|
||||
if err != nil {
|
||||
return api.GuestSessionLogsResult{}, err
|
||||
}
|
||||
return api.GuestSessionLogsResult{Session: session, Stream: streamName, Path: path, Content: content}, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) SendToGuestSession(ctx context.Context, params api.GuestSessionSendParams) (api.GuestSessionSendResult, error) {
|
||||
vm, err := d.FindVM(ctx, params.VMIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionSendResult{}, err
|
||||
}
|
||||
session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName)
|
||||
if err != nil {
|
||||
return api.GuestSessionSendResult{}, err
|
||||
}
|
||||
if session.StdinMode != model.GuestSessionStdinPipe {
|
||||
return api.GuestSessionSendResult{}, errors.New("session does not have a stdin pipe")
|
||||
}
|
||||
if session.Status != model.GuestSessionStatusRunning {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("session is not running (status=%s)", session.Status)
|
||||
}
|
||||
if !d.vmAlive(vm) {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("vm %q is not running", vm.Name)
|
||||
}
|
||||
if len(params.Payload) == 0 {
|
||||
return api.GuestSessionSendResult{Session: session}, nil
|
||||
}
|
||||
client, err := d.dialGuest(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"))
|
||||
if err != nil {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("dial guest: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
tmpPath := fmt.Sprintf("/tmp/banger-send-%s.bin", session.ID[:8])
|
||||
var uploadLog bytes.Buffer
|
||||
if err := client.UploadFile(ctx, tmpPath, 0o600, params.Payload, &uploadLog); err != nil {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("upload payload: %w", err)
|
||||
}
|
||||
sendScript := fmt.Sprintf(
|
||||
"set -euo pipefail\ncat %s >> %s\nrm -f %s\n",
|
||||
sess.ShellQuote(tmpPath),
|
||||
sess.ShellQuote(sess.StdinPipePath(session.ID)),
|
||||
sess.ShellQuote(tmpPath),
|
||||
)
|
||||
var sendLog bytes.Buffer
|
||||
if err := client.RunScript(ctx, sendScript, &sendLog); err != nil {
|
||||
return api.GuestSessionSendResult{}, fmt.Errorf("send to session: %w: %s", err, strings.TrimSpace(sendLog.String()))
|
||||
}
|
||||
return api.GuestSessionSendResult{Session: session, BytesWritten: len(params.Payload)}, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) readGuestSessionLog(ctx context.Context, vm model.VMRecord, session model.GuestSession, stream string, tailLines int) (string, error) {
|
||||
if d.vmAlive(vm) {
|
||||
client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath, d.layout.KnownHostsPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer client.Close()
|
||||
path := session.StdoutLogPath
|
||||
if stream == "stderr" {
|
||||
path = session.StderrLogPath
|
||||
}
|
||||
var output bytes.Buffer
|
||||
script := fmt.Sprintf("set -euo pipefail\nif [ -f %s ]; then tail -n %d %s; fi\n", sess.ShellQuote(path), tailLines, sess.ShellQuote(path))
|
||||
if err := client.RunScript(ctx, script, &output); err != nil {
|
||||
return "", sess.FormatStepError("read guest session log", err, output.String())
|
||||
}
|
||||
return output.String(), nil
|
||||
}
|
||||
runner := d.runner
|
||||
if runner == nil {
|
||||
runner = system.NewRunner()
|
||||
}
|
||||
workMount, cleanup, err := system.MountTempDir(ctx, runner, vm.Runtime.WorkDiskPath, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer cleanup()
|
||||
logPath := filepath.Join(workMount, sess.RelativeStateDir(session.ID), stream+".log")
|
||||
return sess.TailFileContent(logPath, tailLines)
|
||||
}
|
||||
|
|
@ -10,7 +10,6 @@ import (
|
|||
"time"
|
||||
|
||||
"banger/internal/api"
|
||||
sess "banger/internal/daemon/session"
|
||||
ws "banger/internal/daemon/workspace"
|
||||
"banger/internal/model"
|
||||
)
|
||||
|
|
@ -114,9 +113,9 @@ func exportScript(guestPath, diffRef, diffFlag string) string {
|
|||
"git read-tree %s --index-output=\"$tmp_idx\"\n"+
|
||||
"GIT_INDEX_FILE=\"$tmp_idx\" git add -A\n"+
|
||||
"GIT_INDEX_FILE=\"$tmp_idx\" git diff --cached %s %s\n",
|
||||
sess.ShellQuote(guestPath),
|
||||
sess.ShellQuote(diffRef),
|
||||
sess.ShellQuote(diffRef),
|
||||
ws.ShellQuote(guestPath),
|
||||
ws.ShellQuote(diffRef),
|
||||
ws.ShellQuote(diffRef),
|
||||
diffFlag,
|
||||
)
|
||||
}
|
||||
|
|
@ -189,9 +188,9 @@ func (d *Daemon) prepareVMWorkspaceGuestIO(ctx context.Context, vm model.VMRecor
|
|||
}
|
||||
if readOnly {
|
||||
var chmodLog bytes.Buffer
|
||||
chmodScript := fmt.Sprintf("set -euo pipefail\nchmod -R a-w %s\n", sess.ShellQuote(guestPath))
|
||||
chmodScript := fmt.Sprintf("set -euo pipefail\nchmod -R a-w %s\n", ws.ShellQuote(guestPath))
|
||||
if err := client.RunScript(ctx, chmodScript, &chmodLog); err != nil {
|
||||
return model.WorkspacePrepareResult{}, sess.FormatStepError("set workspace readonly", err, chmodLog.String())
|
||||
return model.WorkspacePrepareResult{}, ws.FormatStepError("set workspace readonly", err, chmodLog.String())
|
||||
}
|
||||
}
|
||||
return model.WorkspacePrepareResult{
|
||||
|
|
|
|||
20
internal/daemon/workspace/util.go
Normal file
20
internal/daemon/workspace/util.go
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
package workspace
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ShellQuote returns value single-quoted for bash, escaping embedded quotes.
|
||||
func ShellQuote(value string) string {
|
||||
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
|
||||
}
|
||||
|
||||
// FormatStepError wraps err with an action label and trimmed on-guest log.
|
||||
func FormatStepError(action string, err error, log string) error {
|
||||
log = strings.TrimSpace(log)
|
||||
if log == "" {
|
||||
return fmt.Errorf("%s: %w", action, err)
|
||||
}
|
||||
return fmt.Errorf("%s: %w: %s", action, err, log)
|
||||
}
|
||||
|
|
@ -18,7 +18,6 @@ import (
|
|||
"sort"
|
||||
"strings"
|
||||
|
||||
sess "banger/internal/daemon/session"
|
||||
"banger/internal/model"
|
||||
"banger/internal/system"
|
||||
)
|
||||
|
|
@ -146,13 +145,13 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g
|
|||
switch mode {
|
||||
case model.WorkspacePrepareModeFullCopy:
|
||||
var copyLog bytes.Buffer
|
||||
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath), sess.ShellQuote(guestPath), sess.ShellQuote(guestPath))
|
||||
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath), ShellQuote(guestPath), ShellQuote(guestPath))
|
||||
if err := client.StreamTar(ctx, spec.RepoRoot, command, ©Log); err != nil {
|
||||
return sess.FormatStepError("copy full workspace", err, copyLog.String())
|
||||
return FormatStepError("copy full workspace", err, copyLog.String())
|
||||
}
|
||||
var finalizeLog bytes.Buffer
|
||||
if err := client.RunScript(ctx, FinalizeScript(spec, guestPath, mode), &finalizeLog); err != nil {
|
||||
return sess.FormatStepError("finalize full workspace", err, finalizeLog.String())
|
||||
return FormatStepError("finalize full workspace", err, finalizeLog.String())
|
||||
}
|
||||
return nil
|
||||
case model.WorkspacePrepareModeMetadataOnly, model.WorkspacePrepareModeShallowOverlay:
|
||||
|
|
@ -162,21 +161,21 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g
|
|||
}
|
||||
defer cleanup()
|
||||
var copyLog bytes.Buffer
|
||||
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath), sess.ShellQuote(guestPath), sess.ShellQuote(guestPath))
|
||||
command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath), ShellQuote(guestPath), ShellQuote(guestPath))
|
||||
if err := client.StreamTar(ctx, repoCopyDir, command, ©Log); err != nil {
|
||||
return sess.FormatStepError("copy guest git metadata", err, copyLog.String())
|
||||
return FormatStepError("copy guest git metadata", err, copyLog.String())
|
||||
}
|
||||
var scriptLog bytes.Buffer
|
||||
if err := client.RunScript(ctx, FinalizeScript(spec, guestPath, mode), &scriptLog); err != nil {
|
||||
return sess.FormatStepError("prepare guest checkout", err, scriptLog.String())
|
||||
return FormatStepError("prepare guest checkout", err, scriptLog.String())
|
||||
}
|
||||
if mode == model.WorkspacePrepareModeMetadataOnly {
|
||||
return nil
|
||||
}
|
||||
var overlayLog bytes.Buffer
|
||||
command = fmt.Sprintf("tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath))
|
||||
command = fmt.Sprintf("tar -o -C %s --strip-components=1 -xf -", ShellQuote(guestPath))
|
||||
if err := client.StreamTarEntries(ctx, spec.RepoRoot, spec.OverlayPaths, command, &overlayLog); err != nil {
|
||||
return sess.FormatStepError("overlay workspace working tree", err, overlayLog.String())
|
||||
return FormatStepError("overlay workspace working tree", err, overlayLog.String())
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
|
|
@ -190,22 +189,22 @@ func ImportRepoToGuest(ctx context.Context, client GuestClient, spec RepoSpec, g
|
|||
func FinalizeScript(spec RepoSpec, guestPath string, mode model.WorkspacePrepareMode) string {
|
||||
var script strings.Builder
|
||||
script.WriteString("set -euo pipefail\n")
|
||||
fmt.Fprintf(&script, "DIR=%s\n", sess.ShellQuote(guestPath))
|
||||
fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(guestPath))
|
||||
script.WriteString("git config --global --add safe.directory \"$DIR\"\n")
|
||||
if mode != model.WorkspacePrepareModeFullCopy {
|
||||
script.WriteString("find \"$DIR\" -mindepth 1 -maxdepth 1 ! -name .git -exec rm -rf {} +\n")
|
||||
}
|
||||
switch {
|
||||
case strings.TrimSpace(spec.BranchName) != "":
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", sess.ShellQuote(spec.BranchName), sess.ShellQuote(spec.BaseCommit))
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", ShellQuote(spec.BranchName), ShellQuote(spec.BaseCommit))
|
||||
case strings.TrimSpace(spec.CurrentBranch) != "":
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", sess.ShellQuote(spec.CurrentBranch), sess.ShellQuote(spec.HeadCommit))
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", ShellQuote(spec.CurrentBranch), ShellQuote(spec.HeadCommit))
|
||||
default:
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", sess.ShellQuote(spec.HeadCommit))
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", ShellQuote(spec.HeadCommit))
|
||||
}
|
||||
if strings.TrimSpace(spec.GitUserName) != "" && strings.TrimSpace(spec.GitUserEmail) != "" {
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" config user.name %s\n", sess.ShellQuote(spec.GitUserName))
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" config user.email %s\n", sess.ShellQuote(spec.GitUserEmail))
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" config user.name %s\n", ShellQuote(spec.GitUserName))
|
||||
fmt.Fprintf(&script, "git -C \"$DIR\" config user.email %s\n", ShellQuote(spec.GitUserEmail))
|
||||
}
|
||||
return script.String()
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue