Add vsock-backed SSH session reminders

Remind users when a VM is still running after 	hanger vm ssh exits instead of silently dropping them back to the host shell.\n\nAttach a Firecracker vsock device to each VM, persist the host vsock path/CID,\nadd a new guest-side banger-vsock-pingd responder to the runtime bundle and both\nimage-build paths, and expose a vm.ping RPC that the CLI and TUI call after SSH\nreturns. Doctor and start/build preflight now validate the helper plus\n/dev/vhost-vsock so the feature fails early and clearly.\n\nValidated with go mod tidy, bash -n customize.sh, git diff --check, make build,\nand GOCACHE=/tmp/banger-gocache go test ./... outside the sandbox because the\ndaemon tests need real Unix/UDP sockets. Rebuild the image/rootfs used for new\nVMs so the guest ping service is present.
This commit is contained in:
Thales Maciel 2026-03-18 20:14:51 -03:00
parent 4930d82cb9
commit 08ef706e3f
No known key found for this signature in database
GPG key ID: 33112E6833C34679
31 changed files with 912 additions and 75 deletions

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
@ -23,6 +24,7 @@ import (
"banger/internal/rpc"
"banger/internal/system"
"banger/internal/vmdns"
"banger/internal/vsockping"
"github.com/spf13/cobra"
)
@ -32,7 +34,17 @@ var (
daemonExePath = func(pid int) string {
return filepath.Join("/proc", fmt.Sprintf("%d", pid), "exe")
}
doctorFunc = daemon.Doctor
doctorFunc = daemon.Doctor
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
sshCmd := exec.CommandContext(ctx, "ssh", args...)
sshCmd.Stdout = stdout
sshCmd.Stderr = stderr
sshCmd.Stdin = stdin
return sshCmd.Run()
}
vmPingFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPingResult, error) {
return rpc.Call[api.VMPingResult](ctx, socketPath, "vm.ping", api.VMRefParams{IDOrName: idOrName})
}
)
func NewBangerCommand() *cobra.Command {
@ -454,11 +466,7 @@ func newVMSSHCommand() *cobra.Command {
if err != nil {
return err
}
sshCmd := exec.CommandContext(cmd.Context(), "ssh", sshArgs...)
sshCmd.Stdout = cmd.OutOrStdout()
sshCmd.Stderr = cmd.ErrOrStderr()
sshCmd.Stdin = cmd.InOrStdin()
return sshCmd.Run()
return runSSHSession(cmd.Context(), layout.SocketPath, result.Name, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), sshArgs)
},
}
}
@ -953,6 +961,36 @@ func validatePositiveSetting(label string, value int) error {
return nil
}
func runSSHSession(ctx context.Context, socketPath, vmRef string, stdin io.Reader, stdout, stderr io.Writer, sshArgs []string) error {
sshErr := sshExecFunc(ctx, stdin, stdout, stderr, sshArgs)
if !shouldCheckSSHReminder(sshErr) || ctx.Err() != nil {
return sshErr
}
pingCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
ping, err := vmPingFunc(pingCtx, socketPath, vmRef)
if err != nil {
_, _ = fmt.Fprintln(stderr, vsockping.WarningMessage(vmRef, err))
return sshErr
}
if ping.Alive {
name := ping.Name
if strings.TrimSpace(name) == "" {
name = vmRef
}
_, _ = fmt.Fprintln(stderr, vsockping.ReminderMessage(name))
}
return sshErr
}
func shouldCheckSSHReminder(err error) bool {
if err == nil {
return true
}
var exitErr *exec.ExitError
return errors.As(err, &exitErr)
}
func sshCommandArgs(cfg model.DaemonConfig, guestIP string, extra []string) ([]string, error) {
if guestIP == "" {
return nil, errors.New("vm has no guest IP")

View file

@ -4,7 +4,9 @@ import (
"bytes"
"context"
"errors"
"io"
"os"
"os/exec"
"path/filepath"
"reflect"
"strings"
@ -209,6 +211,56 @@ func TestVMSetParamsFromFlagsRejectsNonPositiveCPUAndMemory(t *testing.T) {
}
}
func TestRunSSHSessionPrintsReminderWhenPingAlive(t *testing.T) {
origSSHExec := sshExecFunc
origPing := vmPingFunc
t.Cleanup(func() {
sshExecFunc = origSSHExec
vmPingFunc = origPing
})
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
return nil
}
vmPingFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPingResult, error) {
return api.VMPingResult{Name: "devbox", Alive: true}, nil
}
var stderr bytes.Buffer
if err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"}); err != nil {
t.Fatalf("runSSHSession: %v", err)
}
if !strings.Contains(stderr.String(), "devbox is still running") {
t.Fatalf("stderr = %q, want reminder", stderr.String())
}
}
func TestRunSSHSessionPreservesSSHExitStatusOnPingWarning(t *testing.T) {
origSSHExec := sshExecFunc
origPing := vmPingFunc
t.Cleanup(func() {
sshExecFunc = origSSHExec
vmPingFunc = origPing
})
sshExecFunc = func(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
return &exec.ExitError{}
}
vmPingFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPingResult, error) {
return api.VMPingResult{}, errors.New("dial failed")
}
var stderr bytes.Buffer
err := runSSHSession(context.Background(), "/tmp/bangerd.sock", "devbox", strings.NewReader(""), &bytes.Buffer{}, &stderr, []string{"root@127.0.0.1"})
var exitErr *exec.ExitError
if !errors.As(err, &exitErr) {
t.Fatalf("runSSHSession error = %v, want exit error", err)
}
if !strings.Contains(stderr.String(), "failed to check whether devbox is still running") {
t.Fatalf("stderr = %q, want warning", stderr.String())
}
}
func TestResolveVMTargetsDeduplicatesAndReportsErrors(t *testing.T) {
vms := []model.VMRecord{
testCLIResolvedVM("alpha-id", "alpha"),

View file

@ -16,6 +16,7 @@ import (
"banger/internal/paths"
"banger/internal/rpc"
"banger/internal/system"
"banger/internal/vsockping"
"github.com/charmbracelet/bubbles/help"
"github.com/charmbracelet/bubbles/key"
@ -104,6 +105,7 @@ type externalPreparedMsg struct {
action actionRequest
command *exec.Cmd
doneStatus string
done func(error) tea.Msg
refresh bool
err error
}
@ -716,10 +718,14 @@ func (m tuiModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
break
}
cmds = append(cmds, tea.ExecProcess(msg.command, func(err error) tea.Msg {
err = normalizeExecError(err)
if msg.done != nil {
return msg.done(err)
}
return actionResultMsg{
action: msg.action,
status: msg.doneStatus,
err: normalizeExecError(err),
err: err,
refresh: msg.refresh,
focusID: m.selectedID,
}
@ -1439,14 +1445,55 @@ func prepareSSHCmd(layout paths.Layout, cfg model.DaemonConfig, action actionReq
return externalPreparedMsg{action: action, err: err}
}
return externalPreparedMsg{
action: action,
command: exec.Command("ssh", args...),
doneStatus: fmt.Sprintf("ssh session ended for %s", result.Name),
refresh: true,
action: action,
command: exec.Command("ssh", args...),
done: func(execErr error) tea.Msg {
return sshDoneMsg(layout, action, result.Name, execErr)
},
refresh: true,
}
}
}
func sshDoneMsg(layout paths.Layout, action actionRequest, name string, execErr error) tea.Msg {
if execErr != nil {
return actionResultMsg{
action: action,
err: execErr,
refresh: true,
focusID: action.id,
}
}
pingCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
ping, err := vmPingFunc(pingCtx, layout.SocketPath, name)
if err != nil {
return actionResultMsg{
action: action,
status: vsockping.WarningMessage(name, err),
refresh: true,
focusID: action.id,
}
}
if ping.Alive {
if strings.TrimSpace(ping.Name) != "" {
name = ping.Name
}
return actionResultMsg{
action: action,
status: vsockping.ReminderMessage(name),
refresh: true,
focusID: action.id,
}
}
return actionResultMsg{
action: action,
status: fmt.Sprintf("ssh session ended for %s", name),
refresh: true,
focusID: action.id,
}
}
func prepareLogsCmd(layout paths.Layout, action actionRequest) tea.Cmd {
return func() tea.Msg {
result, err := rpc.Call[api.VMLogsResult](context.Background(), layout.SocketPath, "vm.logs", api.VMRefParams{IDOrName: action.id})

View file

@ -2,12 +2,14 @@ package cli
import (
"context"
"errors"
"os"
"path/filepath"
"strings"
"testing"
"time"
"banger/internal/api"
"banger/internal/model"
"banger/internal/paths"
@ -236,6 +238,41 @@ func TestTUIStatusIncludesStageDurationsAfterInitialLoad(t *testing.T) {
}
}
func TestSSHDoneMsgShowsReminderWhenPingAlive(t *testing.T) {
origPing := vmPingFunc
t.Cleanup(func() {
vmPingFunc = origPing
})
vmPingFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPingResult, error) {
return api.VMPingResult{Name: "devbox", Alive: true}, nil
}
msg := sshDoneMsg(paths.Layout{SocketPath: "/tmp/bangerd.sock"}, actionRequest{id: "devbox", name: "devbox"}, "devbox", nil)
result, ok := msg.(actionResultMsg)
if !ok {
t.Fatalf("msg = %T, want actionResultMsg", msg)
}
if !strings.Contains(result.status, "devbox is still running") {
t.Fatalf("status = %q, want reminder", result.status)
}
}
func TestSSHDoneMsgShowsWarningWhenPingFails(t *testing.T) {
origPing := vmPingFunc
t.Cleanup(func() {
vmPingFunc = origPing
})
vmPingFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPingResult, error) {
return api.VMPingResult{}, errors.New("dial failed")
}
msg := sshDoneMsg(paths.Layout{SocketPath: "/tmp/bangerd.sock"}, actionRequest{id: "devbox", name: "devbox"}, "devbox", nil)
result := msg.(actionResultMsg)
if !strings.Contains(result.status, "failed to check whether devbox is still running") {
t.Fatalf("status = %q, want warning", result.status)
}
}
func TestAggregateRunningVMResources(t *testing.T) {
t.Parallel()