Add vsock-backed SSH session reminders
Remind users when a VM is still running after hanger vm ssh exits instead of silently dropping them back to the host shell.\n\nAttach a Firecracker vsock device to each VM, persist the host vsock path/CID,\nadd a new guest-side banger-vsock-pingd responder to the runtime bundle and both\nimage-build paths, and expose a vm.ping RPC that the CLI and TUI call after SSH\nreturns. Doctor and start/build preflight now validate the helper plus\n/dev/vhost-vsock so the feature fails early and clearly.\n\nValidated with go mod tidy, bash -n customize.sh, git diff --check, make build,\nand GOCACHE=/tmp/banger-gocache go test ./... outside the sandbox because the\ndaemon tests need real Unix/UDP sockets. Rebuild the image/rootfs used for new\nVMs so the guest ping service is present.
This commit is contained in:
parent
4930d82cb9
commit
08ef706e3f
31 changed files with 912 additions and 75 deletions
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
|
@ -23,6 +24,7 @@ import (
|
|||
"banger/internal/rpc"
|
||||
"banger/internal/system"
|
||||
"banger/internal/vmdns"
|
||||
"banger/internal/vsockping"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
|
@ -32,7 +34,17 @@ var (
|
|||
daemonExePath = func(pid int) string {
|
||||
return filepath.Join("/proc", fmt.Sprintf("%d", pid), "exe")
|
||||
}
|
||||
doctorFunc = daemon.Doctor
|
||||
doctorFunc = daemon.Doctor
|
||||
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
|
||||
sshCmd := exec.CommandContext(ctx, "ssh", args...)
|
||||
sshCmd.Stdout = stdout
|
||||
sshCmd.Stderr = stderr
|
||||
sshCmd.Stdin = stdin
|
||||
return sshCmd.Run()
|
||||
}
|
||||
vmPingFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPingResult, error) {
|
||||
return rpc.Call[api.VMPingResult](ctx, socketPath, "vm.ping", api.VMRefParams{IDOrName: idOrName})
|
||||
}
|
||||
)
|
||||
|
||||
func NewBangerCommand() *cobra.Command {
|
||||
|
|
@ -454,11 +466,7 @@ func newVMSSHCommand() *cobra.Command {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sshCmd := exec.CommandContext(cmd.Context(), "ssh", sshArgs...)
|
||||
sshCmd.Stdout = cmd.OutOrStdout()
|
||||
sshCmd.Stderr = cmd.ErrOrStderr()
|
||||
sshCmd.Stdin = cmd.InOrStdin()
|
||||
return sshCmd.Run()
|
||||
return runSSHSession(cmd.Context(), layout.SocketPath, result.Name, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), sshArgs)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
@ -953,6 +961,36 @@ func validatePositiveSetting(label string, value int) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func runSSHSession(ctx context.Context, socketPath, vmRef string, stdin io.Reader, stdout, stderr io.Writer, sshArgs []string) error {
|
||||
sshErr := sshExecFunc(ctx, stdin, stdout, stderr, sshArgs)
|
||||
if !shouldCheckSSHReminder(sshErr) || ctx.Err() != nil {
|
||||
return sshErr
|
||||
}
|
||||
pingCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
ping, err := vmPingFunc(pingCtx, socketPath, vmRef)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(stderr, vsockping.WarningMessage(vmRef, err))
|
||||
return sshErr
|
||||
}
|
||||
if ping.Alive {
|
||||
name := ping.Name
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name = vmRef
|
||||
}
|
||||
_, _ = fmt.Fprintln(stderr, vsockping.ReminderMessage(name))
|
||||
}
|
||||
return sshErr
|
||||
}
|
||||
|
||||
func shouldCheckSSHReminder(err error) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
var exitErr *exec.ExitError
|
||||
return errors.As(err, &exitErr)
|
||||
}
|
||||
|
||||
func sshCommandArgs(cfg model.DaemonConfig, guestIP string, extra []string) ([]string, error) {
|
||||
if guestIP == "" {
|
||||
return nil, errors.New("vm has no guest IP")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
|
@ -209,6 +211,56 @@ func TestVMSetParamsFromFlagsRejectsNonPositiveCPUAndMemory(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRunSSHSessionPrintsReminderWhenPingAlive(t *testing.T) {
|
||||
origSSHExec := sshExecFunc
|
||||
origPing := vmPingFunc
|
||||
t.Cleanup(func() {
|
||||
sshExecFunc = origSSHExec
|
||||
vmPingFunc = origPing
|
||||
})
|
||||
|
||||
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
|
||||
return nil
|
||||
}
|
||||
vmPingFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPingResult, error) {
|
||||
return api.VMPingResult{Name: "devbox", Alive: true}, nil
|
||||
}
|
||||
|
||||
var stderr bytes.Buffer
|
||||
if err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}); err != nil {
|
||||
t.Fatalf("runSSHSession: %v", err)
|
||||
}
|
||||
if !strings.Contains(stderr.String(), "devbox is still running") {
|
||||
t.Fatalf("stderr = %q, want reminder", stderr.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSSHSessionPreservesSSHExitStatusOnPingWarning(t *testing.T) {
|
||||
origSSHExec := sshExecFunc
|
||||
origPing := vmPingFunc
|
||||
t.Cleanup(func() {
|
||||
sshExecFunc = origSSHExec
|
||||
vmPingFunc = origPing
|
||||
})
|
||||
|
||||
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
|
||||
return &exec.ExitError{}
|
||||
}
|
||||
vmPingFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPingResult, error) {
|
||||
return api.VMPingResult{}, errors.New("dial failed")
|
||||
}
|
||||
|
||||
var stderr bytes.Buffer
|
||||
err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"})
|
||||
var exitErr *exec.ExitError
|
||||
if !errors.As(err, &exitErr) {
|
||||
t.Fatalf("runSSHSession error = %v, want exit error", err)
|
||||
}
|
||||
if !strings.Contains(stderr.String(), "failed to check whether devbox is still running") {
|
||||
t.Fatalf("stderr = %q, want warning", stderr.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveVMTargetsDeduplicatesAndReportsErrors(t *testing.T) {
|
||||
vms := []model.VMRecord{
|
||||
testCLIResolvedVM("alpha-id", "alpha"),
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import (
|
|||
"banger/internal/paths"
|
||||
"banger/internal/rpc"
|
||||
"banger/internal/system"
|
||||
"banger/internal/vsockping"
|
||||
|
||||
"github.com/charmbracelet/bubbles/help"
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
|
|
@ -104,6 +105,7 @@ type externalPreparedMsg struct {
|
|||
action actionRequest
|
||||
command *exec.Cmd
|
||||
doneStatus string
|
||||
done func(error) tea.Msg
|
||||
refresh bool
|
||||
err error
|
||||
}
|
||||
|
|
@ -716,10 +718,14 @@ func (m tuiModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
break
|
||||
}
|
||||
cmds = append(cmds, tea.ExecProcess(msg.command, func(err error) tea.Msg {
|
||||
err = normalizeExecError(err)
|
||||
if msg.done != nil {
|
||||
return msg.done(err)
|
||||
}
|
||||
return actionResultMsg{
|
||||
action: msg.action,
|
||||
status: msg.doneStatus,
|
||||
err: normalizeExecError(err),
|
||||
err: err,
|
||||
refresh: msg.refresh,
|
||||
focusID: m.selectedID,
|
||||
}
|
||||
|
|
@ -1439,14 +1445,55 @@ func prepareSSHCmd(layout paths.Layout, cfg model.DaemonConfig, action actionReq
|
|||
return externalPreparedMsg{action: action, err: err}
|
||||
}
|
||||
return externalPreparedMsg{
|
||||
action: action,
|
||||
command: exec.Command("ssh", args...),
|
||||
doneStatus: fmt.Sprintf("ssh session ended for %s", result.Name),
|
||||
refresh: true,
|
||||
action: action,
|
||||
command: exec.Command("ssh", args...),
|
||||
done: func(execErr error) tea.Msg {
|
||||
return sshDoneMsg(layout, action, result.Name, execErr)
|
||||
},
|
||||
refresh: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sshDoneMsg(layout paths.Layout, action actionRequest, name string, execErr error) tea.Msg {
|
||||
if execErr != nil {
|
||||
return actionResultMsg{
|
||||
action: action,
|
||||
err: execErr,
|
||||
refresh: true,
|
||||
focusID: action.id,
|
||||
}
|
||||
}
|
||||
pingCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
ping, err := vmPingFunc(pingCtx, layout.SocketPath, name)
|
||||
if err != nil {
|
||||
return actionResultMsg{
|
||||
action: action,
|
||||
status: vsockping.WarningMessage(name, err),
|
||||
refresh: true,
|
||||
focusID: action.id,
|
||||
}
|
||||
}
|
||||
if ping.Alive {
|
||||
if strings.TrimSpace(ping.Name) != "" {
|
||||
name = ping.Name
|
||||
}
|
||||
return actionResultMsg{
|
||||
action: action,
|
||||
status: vsockping.ReminderMessage(name),
|
||||
refresh: true,
|
||||
focusID: action.id,
|
||||
}
|
||||
}
|
||||
return actionResultMsg{
|
||||
action: action,
|
||||
status: fmt.Sprintf("ssh session ended for %s", name),
|
||||
refresh: true,
|
||||
focusID: action.id,
|
||||
}
|
||||
}
|
||||
|
||||
func prepareLogsCmd(layout paths.Layout, action actionRequest) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
result, err := rpc.Call[api.VMLogsResult](context.Background(), layout.SocketPath, "vm.logs", api.VMRefParams{IDOrName: action.id})
|
||||
|
|
|
|||
|
|
@ -2,12 +2,14 @@ package cli
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"banger/internal/api"
|
||||
"banger/internal/model"
|
||||
"banger/internal/paths"
|
||||
|
||||
|
|
@ -236,6 +238,41 @@ func TestTUIStatusIncludesStageDurationsAfterInitialLoad(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSSHDoneMsgShowsReminderWhenPingAlive(t *testing.T) {
|
||||
origPing := vmPingFunc
|
||||
t.Cleanup(func() {
|
||||
vmPingFunc = origPing
|
||||
})
|
||||
vmPingFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPingResult, error) {
|
||||
return api.VMPingResult{Name: "devbox", Alive: true}, nil
|
||||
}
|
||||
|
||||
msg := sshDoneMsg(paths.Layout{SocketPath: "/tmp/bangerd.sock"}, actionRequest{id: "devbox", name: "devbox"}, "devbox", nil)
|
||||
result, ok := msg.(actionResultMsg)
|
||||
if !ok {
|
||||
t.Fatalf("msg = %T, want actionResultMsg", msg)
|
||||
}
|
||||
if !strings.Contains(result.status, "devbox is still running") {
|
||||
t.Fatalf("status = %q, want reminder", result.status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHDoneMsgShowsWarningWhenPingFails(t *testing.T) {
|
||||
origPing := vmPingFunc
|
||||
t.Cleanup(func() {
|
||||
vmPingFunc = origPing
|
||||
})
|
||||
vmPingFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPingResult, error) {
|
||||
return api.VMPingResult{}, errors.New("dial failed")
|
||||
}
|
||||
|
||||
msg := sshDoneMsg(paths.Layout{SocketPath: "/tmp/bangerd.sock"}, actionRequest{id: "devbox", name: "devbox"}, "devbox", nil)
|
||||
result := msg.(actionResultMsg)
|
||||
if !strings.Contains(result.status, "failed to check whether devbox is still running") {
|
||||
t.Fatalf("status = %q, want warning", result.status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAggregateRunningVMResources(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue