Add vsock-backed SSH session reminders

Remind users when a VM is still running after 	hanger vm ssh exits instead of silently dropping them back to the host shell.\n\nAttach a Firecracker vsock device to each VM, persist the host vsock path/CID,\nadd a new guest-side banger-vsock-pingd responder to the runtime bundle and both\nimage-build paths, and expose a vm.ping RPC that the CLI and TUI call after SSH\nreturns. Doctor and start/build preflight now validate the helper plus\n/dev/vhost-vsock so the feature fails early and clearly.\n\nValidated with go mod tidy, bash -n customize.sh, git diff --check, make build,\nand GOCACHE=/tmp/banger-gocache go test ./... outside the sandbox because the\ndaemon tests need real Unix/UDP sockets. Rebuild the image/rootfs used for new\nVMs so the guest ping service is present.
This commit is contained in:
Thales Maciel 2026-03-18 20:14:51 -03:00
parent 4930d82cb9
commit 08ef706e3f
No known key found for this signature in database
GPG key ID: 33112E6833C34679
31 changed files with 912 additions and 75 deletions

View file

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strconv"
@ -75,6 +76,10 @@ func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm mo
if err := os.MkdirAll(vmDir, 0o755); err != nil {
return model.VMRecord{}, err
}
vsockCID, err := defaultVSockCID(guestIP)
if err != nil {
return model.VMRecord{}, err
}
systemOverlaySize := int64(model.DefaultSystemOverlaySize)
if params.SystemOverlaySize != "" {
systemOverlaySize, err = model.ParseSize(params.SystemOverlaySize)
@ -111,6 +116,8 @@ func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm mo
GuestIP: guestIP,
DNSName: vmdns.RecordName(name),
VMDir: vmDir,
VSockPath: defaultVSockPath(d.layout.RuntimeDir, id),
VSockCID: vsockCID,
SystemOverlay: filepath.Join(vmDir, "system.cow"),
WorkDiskPath: filepath.Join(vmDir, "root.ext4"),
LogPath: filepath.Join(vmDir, "firecracker.log"),
@ -183,9 +190,21 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
apiSock := filepath.Join(d.layout.RuntimeDir, "fc-"+shortID+".sock")
tap := "tap-fc-" + shortID
dmName := "fc-rootfs-" + shortID
if strings.TrimSpace(vm.Runtime.VSockPath) == "" {
vm.Runtime.VSockPath = defaultVSockPath(d.layout.RuntimeDir, vm.ID)
}
if vm.Runtime.VSockCID == 0 {
vm.Runtime.VSockCID, err = defaultVSockCID(vm.Runtime.GuestIP)
if err != nil {
return model.VMRecord{}, err
}
}
if err := os.RemoveAll(apiSock); err != nil && !os.IsNotExist(err) {
return model.VMRecord{}, err
}
if err := os.RemoveAll(vm.Runtime.VSockPath); err != nil && !os.IsNotExist(err) {
return model.VMRecord{}, err
}
op.stage("system_overlay", "overlay_path", vm.Runtime.SystemOverlay)
if err := d.ensureSystemOverlay(ctx, &vm); err != nil {
@ -260,6 +279,8 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
IsRoot: true,
}},
TapDevice: tap,
VSockPath: vm.Runtime.VSockPath,
VSockCID: vm.Runtime.VSockCID,
VCPUCount: vm.Spec.VCPUCount,
MemoryMiB: vm.Spec.MemoryMiB,
Logger: d.logger,
@ -276,7 +297,11 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
vm.Runtime.PID = d.resolveFirecrackerPID(firecrackerCtx, machine, apiSock)
op.debugStage("firecracker_started", "pid", vm.Runtime.PID)
op.stage("socket_access", "api_socket", apiSock)
if err := d.ensureSocketAccess(ctx, apiSock); err != nil {
if err := d.ensureSocketAccess(ctx, apiSock, "firecracker api socket"); err != nil {
return cleanupOnErr(err)
}
op.stage("vsock_access", "vsock_path", vm.Runtime.VSockPath, "vsock_cid", vm.Runtime.VSockCID)
if err := d.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil {
return cleanupOnErr(err)
}
op.stage("post_start_features")
@ -556,6 +581,33 @@ func (d *Daemon) GetVMStats(ctx context.Context, idOrName string) (model.VMRecor
return vm, vm.Stats, nil
}
func (d *Daemon) PingVM(ctx context.Context, idOrName string) (result api.VMPingResult, err error) {
_, err = d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
result.Name = vm.Name
if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
result.Alive = false
return vm, nil
}
if strings.TrimSpace(vm.Runtime.VSockPath) == "" {
return model.VMRecord{}, errors.New("vm has no vsock path")
}
if vm.Runtime.VSockCID == 0 {
return model.VMRecord{}, errors.New("vm has no vsock cid")
}
if err := d.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil {
return model.VMRecord{}, err
}
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
if err := firecracker.PingVSock(pingCtx, d.logger, vm.Runtime.VSockPath); err != nil {
return model.VMRecord{}, err
}
result.Alive = true
return vm, nil
})
return result, err
}
func (d *Daemon) getVMStatsLocked(ctx context.Context, vm model.VMRecord) (model.VMRecord, error) {
stats, err := d.collectStats(ctx, vm)
if err == nil {
@ -812,11 +864,14 @@ func (d *Daemon) firecrackerBinary() (string, error) {
return path, nil
}
func (d *Daemon) ensureSocketAccess(ctx context.Context, apiSock string) error {
if _, err := d.runner.RunSudo(ctx, "chown", fmt.Sprintf("%d:%d", os.Getuid(), os.Getgid()), apiSock); err != nil {
func (d *Daemon) ensureSocketAccess(ctx context.Context, socketPath, label string) error {
if err := waitForPath(ctx, socketPath, 5*time.Second, label); err != nil {
return err
}
_, err := d.runner.RunSudo(ctx, "chmod", "600", apiSock)
if _, err := d.runner.RunSudo(ctx, "chown", fmt.Sprintf("%d:%d", os.Getuid(), os.Getgid()), socketPath); err != nil {
return err
}
_, err := d.runner.RunSudo(ctx, "chmod", "600", socketPath)
return err
}
@ -841,7 +896,7 @@ func (d *Daemon) resolveFirecrackerPID(ctx context.Context, machine *firecracker
}
func (d *Daemon) sendCtrlAltDel(ctx context.Context, vm model.VMRecord) error {
if err := d.ensureSocketAccess(ctx, vm.Runtime.APISockPath); err != nil {
if err := d.ensureSocketAccess(ctx, vm.Runtime.APISockPath, "firecracker api socket"); err != nil {
return err
}
client := firecracker.New(vm.Runtime.APISockPath, d.logger)
@ -887,6 +942,9 @@ func (d *Daemon) cleanupRuntime(ctx context.Context, vm model.VMRecord, preserve
if vm.Runtime.APISockPath != "" {
_ = os.Remove(vm.Runtime.APISockPath)
}
if vm.Runtime.VSockPath != "" {
_ = os.Remove(vm.Runtime.VSockPath)
}
snapshotErr := d.cleanupDMSnapshot(ctx, dmSnapshotHandles{
BaseLoop: vm.Runtime.BaseLoop,
COWLoop: vm.Runtime.COWLoop,
@ -910,6 +968,37 @@ func clearRuntimeHandles(vm *model.VMRecord) {
vm.Runtime.DMDev = ""
}
func defaultVSockPath(runtimeDir, vmID string) string {
return filepath.Join(runtimeDir, "fc-"+system.ShortID(vmID)+".vsock")
}
func defaultVSockCID(guestIP string) (uint32, error) {
ip := net.ParseIP(strings.TrimSpace(guestIP)).To4()
if ip == nil {
return 0, fmt.Errorf("guest IP is not IPv4: %q", guestIP)
}
return 10000 + uint32(ip[3]), nil
}
func waitForPath(ctx context.Context, path string, timeout time.Duration, label string) error {
deadline := time.Now().Add(timeout)
for {
if _, err := os.Stat(path); err == nil {
return nil
} else if err != nil && !os.IsNotExist(err) {
return err
}
if time.Now().After(deadline) {
return fmt.Errorf("%s not ready: %s: %w", label, path, context.DeadlineExceeded)
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(100 * time.Millisecond):
}
}
}
func (d *Daemon) setDNS(ctx context.Context, vmName, guestIP string) error {
if d.vmDNS == nil {
return nil