Let the host ask the guest vsock agent to run ss so open ports can be surfaced without SSHing in manually. Add a narrow /ports agent endpoint, a daemon vm.ports RPC that enriches listeners with <hostname>.vm endpoints and best-effort HTTP links, and a concurrent 'banger vm ports' CLI table for one or more VMs. Update the guest package contract to include ss for rebuilt Debian images, allow the guest agent package in the shell-out policy, and cover the new parsing/RPC/CLI flow in tests. Verified with GOCACHE=/tmp/banger-gocache go test ./... outside the sandbox, make build, bash -n customize.sh make-rootfs-void.sh verify.sh, and ./banger vm ports --help.
437 lines
11 KiB
Go
437 lines
11 KiB
Go
package vsockagent
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"regexp"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
sdkvsock "github.com/firecracker-microvm/firecracker-go-sdk/vsock"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
Port uint32 = 42070
|
|
HealthPath = "/healthz"
|
|
PortsPath = "/ports"
|
|
HealthyStatus = "ok"
|
|
GuestBinaryName = "banger-vsock-agent"
|
|
GuestInstallPath = "/usr/local/bin/" + GuestBinaryName
|
|
ServiceName = "banger-vsock-agent.service"
|
|
serviceUnit = `[Unit]
|
|
Description=Banger vsock agent
|
|
After=network.target
|
|
|
|
[Service]
|
|
Type=simple
|
|
ExecStart=/usr/local/bin/banger-vsock-agent
|
|
Restart=on-failure
|
|
RestartSec=1
|
|
|
|
[Install]
|
|
WantedBy=multi-user.target
|
|
`
|
|
modulesLoadConfig = "vsock\nvmw_vsock_virtio_transport\n"
|
|
)
|
|
|
|
var (
|
|
portCollector = CollectPorts
|
|
processRe = regexp.MustCompile(`"([^"]+)",pid=(\d+)`)
|
|
)
|
|
|
|
type HealthResponse struct {
|
|
Status string `json:"status"`
|
|
}
|
|
|
|
type PortListener struct {
|
|
Proto string `json:"proto"`
|
|
BindAddress string `json:"bind_address"`
|
|
Port int `json:"port"`
|
|
PID int `json:"pid,omitempty"`
|
|
Process string `json:"process,omitempty"`
|
|
Command string `json:"command,omitempty"`
|
|
}
|
|
|
|
type PortsResponse struct {
|
|
Listeners []PortListener `json:"listeners"`
|
|
}
|
|
|
|
func NewHandler() http.Handler {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(HealthPath, func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_ = json.NewEncoder(w).Encode(HealthResponse{Status: HealthyStatus})
|
|
})
|
|
mux.HandleFunc(PortsPath, func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
listeners, err := portCollector(r.Context())
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_ = json.NewEncoder(w).Encode(PortsResponse{Listeners: listeners})
|
|
})
|
|
return mux
|
|
}
|
|
|
|
func Health(ctx context.Context, logger *slog.Logger, socketPath string) error {
|
|
resp, err := doRequest(ctx, logger, socketPath, HealthPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
|
return fmt.Errorf("unexpected health status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
var payload HealthResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
|
return err
|
|
}
|
|
if payload.Status != HealthyStatus {
|
|
return fmt.Errorf("unexpected health response status %q", payload.Status)
|
|
}
|
|
if logger != nil {
|
|
logger.Debug("vsock health ok", "vsock_path", socketPath, "vsock_port", Port)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func Ports(ctx context.Context, logger *slog.Logger, socketPath string) ([]PortListener, error) {
|
|
resp, err := doRequest(ctx, logger, socketPath, PortsPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
|
return nil, fmt.Errorf("unexpected ports status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
var payload PortsResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
|
return nil, err
|
|
}
|
|
return payload.Listeners, nil
|
|
}
|
|
|
|
func CollectPorts(ctx context.Context) ([]PortListener, error) {
|
|
cmd := exec.CommandContext(ctx, "ss", "-H", "-lntup")
|
|
output, err := cmd.Output()
|
|
if err != nil {
|
|
if len(output) > 0 {
|
|
return nil, fmt.Errorf("run ss: %w: %s", err, bytes.TrimSpace(output))
|
|
}
|
|
return nil, fmt.Errorf("run ss: %w", err)
|
|
}
|
|
return parsePortListeners(output, readProcessCommandLine)
|
|
}
|
|
|
|
func ServiceUnit() string {
|
|
return serviceUnit
|
|
}
|
|
|
|
func ModulesLoadConfig() string {
|
|
return modulesLoadConfig
|
|
}
|
|
|
|
func ReminderMessage(name string) string {
|
|
return fmt.Sprintf("session ended; %s is still running (stop it with 'banger vm stop %s')", name, name)
|
|
}
|
|
|
|
func WarningMessage(name string, err error) string {
|
|
if err == nil {
|
|
return ""
|
|
}
|
|
return fmt.Sprintf("warning: failed to check whether %s is still running: %v", name, err)
|
|
}
|
|
|
|
func newLogger(base *slog.Logger) *logrus.Entry {
|
|
logger := logrus.New()
|
|
logger.SetOutput(io.Discard)
|
|
logger.SetLevel(logrus.DebugLevel)
|
|
logger.AddHook(slogHook{logger: base})
|
|
return logrus.NewEntry(logger)
|
|
}
|
|
|
|
type slogHook struct {
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func (h slogHook) Levels() []logrus.Level {
|
|
return logrus.AllLevels
|
|
}
|
|
|
|
func (h slogHook) Fire(entry *logrus.Entry) error {
|
|
if h.logger == nil {
|
|
return nil
|
|
}
|
|
level := slog.LevelDebug
|
|
switch entry.Level {
|
|
case logrus.ErrorLevel, logrus.FatalLevel, logrus.PanicLevel:
|
|
level = slog.LevelError
|
|
case logrus.WarnLevel:
|
|
level = slog.LevelWarn
|
|
case logrus.InfoLevel:
|
|
level = slog.LevelInfo
|
|
}
|
|
attrs := make([]any, 0, len(entry.Data)*2)
|
|
for key, value := range entry.Data {
|
|
attrs = append(attrs, key, value)
|
|
}
|
|
h.logger.Log(context.Background(), level, entry.Message, attrs...)
|
|
return nil
|
|
}
|
|
|
|
func IsServerClosed(err error) bool {
|
|
return errors.Is(err, http.ErrServerClosed)
|
|
}
|
|
|
|
func doRequest(ctx context.Context, logger *slog.Logger, socketPath, path string) (*http.Response, error) {
|
|
transport := &http.Transport{
|
|
DisableKeepAlives: true,
|
|
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
|
return sdkvsock.DialContext(
|
|
ctx,
|
|
socketPath,
|
|
Port,
|
|
sdkvsock.WithRetryTimeout(3*time.Second),
|
|
sdkvsock.WithRetryInterval(100*time.Millisecond),
|
|
sdkvsock.WithLogger(newLogger(logger)),
|
|
)
|
|
},
|
|
}
|
|
client := &http.Client{Transport: transport}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://vsock"+path, nil)
|
|
if err != nil {
|
|
transport.CloseIdleConnections()
|
|
return nil, err
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
transport.CloseIdleConnections()
|
|
return nil, err
|
|
}
|
|
return wrappedResponse(resp, transport), nil
|
|
}
|
|
|
|
type responseCloser struct {
|
|
io.ReadCloser
|
|
transport *http.Transport
|
|
}
|
|
|
|
func (c responseCloser) Close() error {
|
|
err := c.ReadCloser.Close()
|
|
c.transport.CloseIdleConnections()
|
|
return err
|
|
}
|
|
|
|
func wrappedResponse(resp *http.Response, transport *http.Transport) *http.Response {
|
|
if resp == nil || resp.Body == nil || transport == nil {
|
|
return resp
|
|
}
|
|
resp.Body = responseCloser{ReadCloser: resp.Body, transport: transport}
|
|
return resp
|
|
}
|
|
|
|
func parsePortListeners(raw []byte, readCmdline func(int) string) ([]PortListener, error) {
|
|
lines := strings.Split(strings.TrimSpace(string(raw)), "\n")
|
|
listeners := make([]PortListener, 0, len(lines))
|
|
wildcards := make(map[string]struct{})
|
|
for _, line := range lines {
|
|
line = strings.TrimSpace(line)
|
|
if line == "" {
|
|
continue
|
|
}
|
|
parsed, err := parseSSLine(line, readCmdline)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, listener := range parsed {
|
|
if isLoopbackAddress(listener.BindAddress) {
|
|
continue
|
|
}
|
|
if isWildcardAddress(listener.BindAddress) {
|
|
key := wildcardKey(listener)
|
|
if _, ok := wildcards[key]; ok {
|
|
continue
|
|
}
|
|
wildcards[key] = struct{}{}
|
|
}
|
|
listeners = append(listeners, listener)
|
|
}
|
|
}
|
|
sort.Slice(listeners, func(i, j int) bool {
|
|
if listeners[i].Proto != listeners[j].Proto {
|
|
return listeners[i].Proto < listeners[j].Proto
|
|
}
|
|
if listeners[i].Port != listeners[j].Port {
|
|
return listeners[i].Port < listeners[j].Port
|
|
}
|
|
if listeners[i].PID != listeners[j].PID {
|
|
return listeners[i].PID < listeners[j].PID
|
|
}
|
|
return listeners[i].BindAddress < listeners[j].BindAddress
|
|
})
|
|
return listeners, nil
|
|
}
|
|
|
|
func parseSSLine(line string, readCmdline func(int) string) ([]PortListener, error) {
|
|
fields := strings.Fields(line)
|
|
if len(fields) < 6 {
|
|
return nil, fmt.Errorf("parse ss line: expected at least 6 fields, got %d in %q", len(fields), line)
|
|
}
|
|
proto := strings.ToLower(strings.TrimSpace(fields[0]))
|
|
if proto != "tcp" && proto != "udp" {
|
|
return nil, nil
|
|
}
|
|
bindAddress, port, err := parseBindAddress(fields[4])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse ss local address %q: %w", fields[4], err)
|
|
}
|
|
if bindAddress != "*" && net.ParseIP(bindAddress) == nil {
|
|
return nil, nil
|
|
}
|
|
processInfo := strings.Join(fields[6:], " ")
|
|
entries := parseProcessEntries(processInfo)
|
|
if len(entries) == 0 {
|
|
return []PortListener{{
|
|
Proto: proto,
|
|
BindAddress: bindAddress,
|
|
Port: port,
|
|
}}, nil
|
|
}
|
|
listeners := make([]PortListener, 0, len(entries))
|
|
for _, entry := range entries {
|
|
command := strings.TrimSpace(readCmdline(entry.PID))
|
|
if command == "" {
|
|
command = entry.Process
|
|
}
|
|
listeners = append(listeners, PortListener{
|
|
Proto: proto,
|
|
BindAddress: bindAddress,
|
|
Port: port,
|
|
PID: entry.PID,
|
|
Process: entry.Process,
|
|
Command: command,
|
|
})
|
|
}
|
|
return listeners, nil
|
|
}
|
|
|
|
type processEntry struct {
|
|
Process string
|
|
PID int
|
|
}
|
|
|
|
func parseProcessEntries(raw string) []processEntry {
|
|
matches := processRe.FindAllStringSubmatch(raw, -1)
|
|
if len(matches) == 0 {
|
|
return nil
|
|
}
|
|
entries := make([]processEntry, 0, len(matches))
|
|
for _, match := range matches {
|
|
pid, err := strconv.Atoi(match[2])
|
|
if err != nil {
|
|
continue
|
|
}
|
|
entries = append(entries, processEntry{
|
|
Process: match[1],
|
|
PID: pid,
|
|
})
|
|
}
|
|
return entries
|
|
}
|
|
|
|
func parseBindAddress(raw string) (string, int, error) {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return "", 0, errors.New("empty address")
|
|
}
|
|
var host, portRaw string
|
|
if strings.HasPrefix(raw, "[") {
|
|
var err error
|
|
host, portRaw, err = net.SplitHostPort(raw)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
} else {
|
|
idx := strings.LastIndex(raw, ":")
|
|
if idx <= 0 || idx == len(raw)-1 {
|
|
return "", 0, fmt.Errorf("missing host:port in %q", raw)
|
|
}
|
|
host = raw[:idx]
|
|
portRaw = raw[idx+1:]
|
|
}
|
|
if zoneIdx := strings.Index(host, "%"); zoneIdx >= 0 {
|
|
host = host[:zoneIdx]
|
|
}
|
|
host = strings.Trim(host, "[]")
|
|
if host == "" {
|
|
host = "*"
|
|
}
|
|
port, err := strconv.Atoi(portRaw)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
return host, port, nil
|
|
}
|
|
|
|
func readProcessCommandLine(pid int) string {
|
|
if pid <= 0 {
|
|
return ""
|
|
}
|
|
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
parts := strings.Split(string(data), "\x00")
|
|
filtered := parts[:0]
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if part == "" {
|
|
continue
|
|
}
|
|
filtered = append(filtered, part)
|
|
}
|
|
return strings.Join(filtered, " ")
|
|
}
|
|
|
|
func isLoopbackAddress(host string) bool {
|
|
if host == "" || host == "*" {
|
|
return false
|
|
}
|
|
ip := net.ParseIP(host)
|
|
return ip != nil && ip.IsLoopback()
|
|
}
|
|
|
|
func isWildcardAddress(host string) bool {
|
|
switch host {
|
|
case "*", "0.0.0.0", "::":
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func wildcardKey(listener PortListener) string {
|
|
return fmt.Sprintf("%s:%d:%d:%s", listener.Proto, listener.Port, listener.PID, listener.Process)
|
|
}
|