Add vsock-backed VM port inspection

Let the host ask the guest vsock agent to run ss so open ports can be surfaced without SSHing in manually.

Add a narrow /ports agent endpoint, a daemon vm.ports RPC that enriches listeners with <hostname>.vm endpoints and best-effort HTTP links, and a concurrent 'banger vm ports' CLI table for one or more VMs.

Update the guest package contract to include ss for rebuilt Debian images, allow the guest agent package in the shell-out policy, and cover the new parsing/RPC/CLI flow in tests.

Verified with GOCACHE=/tmp/banger-gocache go test ./... outside the sandbox, make build, bash -n customize.sh make-rootfs-void.sh verify.sh, and ./banger vm ports --help.
This commit is contained in:
Thales Maciel 2026-03-19 15:52:11 -03:00
parent 3ed78fdcfc
commit c298ed2fc1
No known key found for this signature in database
GPG key ID: 33112E6833C34679
11 changed files with 1029 additions and 23 deletions

View file

@ -110,6 +110,11 @@ banger vm ssh calm-otter
When the SSH session exits normally, `banger` checks the guest over vsock and
reminds you if the VM is still running.
Inspect host-reachable listening ports for one or more running VMs:
```bash
banger vm ports calm-otter buildbox
```
Stop, restart, kill, or delete it:
```bash
banger vm stop calm-otter
@ -246,6 +251,13 @@ for daemon-managed VMs. Known `A` records resolve `<vm-name>.vm` to the VM's
guest IPv4 address. Integrate your local resolver separately if you want
transparent `.vm` lookups on the host.
`banger vm ports` asks the guest-side `banger-vsock-agent` to run `ss`, then
prints host-usable `<hostname>.vm:port` endpoints plus the owning
process/command. TCP listeners get a short best-effort HTTP probe; when the
probe sees a real HTTP response, the command includes a clickable
`http://<hostname>.vm:port/` URL. Older images without `ss` may need rebuilding
before `vm ports` works.
## Storage Model
- VMs share a read-only base rootfs image.
- Each VM gets its own sparse writable system overlay for `/`.
@ -270,7 +282,8 @@ shell helpers treated as manual workflows rather than architecture drivers.
- Stopping a VM preserves its overlay and work disk.
## Rebuilding The Repo Default Rootfs
`packages.apt` controls the base apt packages baked into rebuilt images.
`packages.apt` controls the base apt packages baked into rebuilt images,
including guest tools such as `ss` used by `banger vm ports`.
To rebuild the source-checkout default image in `./runtime/rootfs-docker.ext4`:
```bash

View file

@ -73,6 +73,23 @@ type VMPingResult struct {
Alive bool `json:"alive"`
}
type VMPort struct {
Proto string `json:"proto"`
BindAddress string `json:"bind_address,omitempty"`
Port int `json:"port"`
PID int `json:"pid,omitempty"`
Process string `json:"process,omitempty"`
Command string `json:"command,omitempty"`
Endpoint string `json:"endpoint,omitempty"`
WebURL string `json:"web_url,omitempty"`
}
type VMPortsResult struct {
Name string `json:"name"`
DNSName string `json:"dns_name,omitempty"`
Ports []VMPort `json:"ports"`
}
type ImageBuildParams struct {
Name string `json:"name,omitempty"`
BaseRootfs string `json:"base_rootfs,omitempty"`

View file

@ -9,6 +9,7 @@ import (
"os"
"os/exec"
"path/filepath"
"sort"
"strings"
"sync"
"syscall"
@ -45,6 +46,9 @@ var (
vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
return rpc.Call[api.VMHealthResult](ctx, socketPath, "vm.health", api.VMRefParams{IDOrName: idOrName})
}
vmPortsFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) {
return rpc.Call[api.VMPortsResult](ctx, socketPath, "vm.ports", api.VMRefParams{IDOrName: idOrName})
}
)
func NewBangerCommand() *cobra.Command {
@ -243,6 +247,7 @@ func newVMCommand() *cobra.Command {
newVMSSHCommand(),
newVMLogsCommand(),
newVMStatsCommand(),
newVMPortsCommand(),
)
return cmd
}
@ -542,6 +547,50 @@ func newVMStatsCommand() *cobra.Command {
}
}
func newVMPortsCommand() *cobra.Command {
return &cobra.Command{
Use: "ports <id-or-name>...",
Short: "Show host-reachable listening guest ports",
Args: minArgsUsage(1, "usage: banger vm ports <id-or-name>..."),
RunE: func(cmd *cobra.Command, args []string) error {
layout, _, err := ensureDaemon(cmd.Context())
if err != nil {
return err
}
listResult, err := rpc.Call[api.VMListResult](cmd.Context(), layout.SocketPath, "vm.list", api.Empty{})
if err != nil {
return err
}
targets, resolutionErrs := resolveVMTargets(listResult.VMs, args)
results := executeVMPortsBatch(cmd.Context(), layout.SocketPath, targets)
failed := false
for _, resolutionErr := range resolutionErrs {
if _, err := fmt.Fprintf(cmd.ErrOrStderr(), "%s: %v\n", resolutionErr.Ref, resolutionErr.Err); err != nil {
return err
}
failed = true
}
for _, result := range results {
if result.Err == nil {
continue
}
if _, err := fmt.Fprintf(cmd.ErrOrStderr(), "%s: %v\n", result.Target.Ref, result.Err); err != nil {
return err
}
failed = true
}
if err := printVMPortsTable(cmd.OutOrStdout(), results); err != nil {
return err
}
if failed {
return errors.New("one or more VM operations failed")
}
return nil
},
}
}
func newImageCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "image",
@ -744,6 +793,12 @@ type vmBatchActionResult struct {
Err error
}
type vmPortsBatchResult struct {
Target resolvedVMTarget
Result api.VMPortsResult
Err error
}
func runVMBatchAction(cmd *cobra.Command, socketPath string, refs []string, action func(context.Context, string) (model.VMRecord, error)) error {
listResult, err := rpc.Call[api.VMListResult](cmd.Context(), socketPath, "vm.list", api.Empty{})
if err != nil {
@ -852,6 +907,27 @@ func executeVMActionBatch(ctx context.Context, targets []resolvedVMTarget, actio
return results
}
func executeVMPortsBatch(ctx context.Context, socketPath string, targets []resolvedVMTarget) []vmPortsBatchResult {
results := make([]vmPortsBatchResult, len(targets))
var wg sync.WaitGroup
wg.Add(len(targets))
for index, target := range targets {
index := index
target := target
go func() {
defer wg.Done()
result, err := vmPortsFunc(ctx, socketPath, target.VM.ID)
results[index] = vmPortsBatchResult{
Target: target,
Result: result,
Err: err,
}
}()
}
wg.Wait()
return results
}
func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error) {
layout, err := paths.Resolve()
if err != nil {
@ -1147,6 +1223,77 @@ func printImageSummary(out anyWriter, image model.Image) error {
return err
}
func printVMPortsTable(out anyWriter, results []vmPortsBatchResult) error {
type portRow struct {
VM string
Proto string
Endpoint string
Process string
Command string
WebURL string
Port int
}
rows := make([]portRow, 0)
for _, result := range results {
if result.Err != nil {
continue
}
vmName := strings.TrimSpace(result.Result.Name)
if vmName == "" {
vmName = result.Target.VM.Name
}
for _, port := range result.Result.Ports {
rows = append(rows, portRow{
VM: vmName,
Proto: port.Proto,
Endpoint: port.Endpoint,
Process: port.Process,
Command: port.Command,
WebURL: emptyDash(port.WebURL),
Port: port.Port,
})
}
}
sort.Slice(rows, func(i, j int) bool {
if rows[i].VM != rows[j].VM {
return rows[i].VM < rows[j].VM
}
if rows[i].Proto != rows[j].Proto {
return rows[i].Proto < rows[j].Proto
}
if rows[i].Port != rows[j].Port {
return rows[i].Port < rows[j].Port
}
if rows[i].Process != rows[j].Process {
return rows[i].Process < rows[j].Process
}
return rows[i].Command < rows[j].Command
})
if len(rows) == 0 {
return nil
}
w := tabwriter.NewWriter(out, 0, 8, 2, ' ', 0)
if _, err := fmt.Fprintln(w, "VM\tPROTO\tENDPOINT\tPROCESS\tCOMMAND\tWEB"); err != nil {
return err
}
for _, row := range rows {
if _, err := fmt.Fprintf(
w,
"%s\t%s\t%s\t%s\t%s\t%s\n",
row.VM,
row.Proto,
emptyDash(row.Endpoint),
emptyDash(row.Process),
emptyDash(row.Command),
row.WebURL,
); err != nil {
return err
}
}
return w.Flush()
}
func printDoctorReport(out anyWriter, report system.Report) error {
for _, check := range report.Checks {
status := strings.ToUpper(string(check.Status))
@ -1162,6 +1309,14 @@ func printDoctorReport(out anyWriter, report system.Report) error {
return nil
}
func emptyDash(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return "-"
}
return value
}
type anyWriter interface {
Write(p []byte) (n int, err error)
}

View file

@ -151,6 +151,17 @@ func TestVMKillFlagsExist(t *testing.T) {
}
}
func TestVMPortsCommandExists(t *testing.T) {
root := NewBangerCommand()
vm, _, err := root.Find([]string{"vm"})
if err != nil {
t.Fatalf("find vm: %v", err)
}
if _, _, err := vm.Find([]string{"ports"}); err != nil {
t.Fatalf("find ports: %v", err)
}
}
func TestVMSetParamsFromFlags(t *testing.T) {
params, err := vmSetParamsFromFlags("devbox", 4, 2048, "16G", true, false)
if err != nil {
@ -268,6 +279,59 @@ func TestAbsolutizeImageRegisterPaths(t *testing.T) {
}
}
func TestPrintVMPortsTableSortsAndRendersURLs(t *testing.T) {
results := []vmPortsBatchResult{
{
Target: resolvedVMTarget{Ref: "beta"},
Result: api.VMPortsResult{
Name: "beta",
Ports: []api.VMPort{{
Proto: "tcp",
Port: 8080,
Endpoint: "beta.vm:8080",
Process: "python3",
Command: "python3 -m http.server 8080",
WebURL: "http://beta.vm:8080/",
}},
},
},
{
Target: resolvedVMTarget{Ref: "alpha"},
Result: api.VMPortsResult{
Name: "alpha",
Ports: []api.VMPort{{
Proto: "udp",
Port: 53,
Endpoint: "alpha.vm:53",
Process: "dnsd",
Command: "dnsd --foreground",
}},
},
},
}
var out bytes.Buffer
if err := printVMPortsTable(&out, results); err != nil {
t.Fatalf("printVMPortsTable: %v", err)
}
lines := strings.Split(strings.TrimSpace(out.String()), "\n")
if len(lines) != 3 {
t.Fatalf("lines = %q, want header + 2 rows", lines)
}
if !strings.Contains(lines[0], "VM") || !strings.Contains(lines[0], "WEB") {
t.Fatalf("header = %q, want VM/WEB columns", lines[0])
}
if !strings.Contains(lines[1], "alpha") || !strings.Contains(lines[1], "alpha.vm:53") || !strings.Contains(lines[1], "\t-\n") {
// tabwriter output is space-expanded, so just require the dash placeholder.
if !strings.Contains(lines[1], "alpha") || !strings.Contains(lines[1], "alpha.vm:53") || !strings.HasSuffix(strings.TrimSpace(lines[1]), "-") {
t.Fatalf("first row = %q, want alpha row with dash web column", lines[1])
}
}
if !strings.Contains(lines[2], "beta") || !strings.Contains(lines[2], "http://beta.vm:8080/") {
t.Fatalf("second row = %q, want beta web url", lines[2])
}
}
func TestRunSSHSessionPrintsReminderWhenHealthCheckPasses(t *testing.T) {
origSSHExec := sshExecFunc
origHealth := vmHealthFunc

View file

@ -345,6 +345,13 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
}
result, err := d.PingVM(ctx, params.IDOrName)
return marshalResultOrError(result, err)
case "vm.ports":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.PortsVM(ctx, params.IDOrName)
return marshalResultOrError(result, err)
case "image.list":
images, err := d.store.ListImages(ctx)
return marshalResultOrError(api.ImageListResult{Images: images}, err)

126
internal/daemon/ports.go Normal file
View file

@ -0,0 +1,126 @@
package daemon
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"sort"
"strconv"
"strings"
"time"
"banger/internal/api"
"banger/internal/model"
"banger/internal/system"
"banger/internal/vmdns"
"banger/internal/vsockagent"
)
const httpProbeTimeout = 750 * time.Millisecond
func (d *Daemon) PortsVM(ctx context.Context, idOrName string) (result api.VMPortsResult, err error) {
_, err = d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
result.Name = vm.Name
result.DNSName = strings.TrimSpace(vm.Runtime.DNSName)
if result.DNSName == "" && strings.TrimSpace(vm.Name) != "" {
result.DNSName = vmdns.RecordName(vm.Name)
}
if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
return model.VMRecord{}, fmt.Errorf("vm %s is not running", vm.Name)
}
if strings.TrimSpace(vm.Runtime.GuestIP) == "" {
return model.VMRecord{}, errors.New("vm has no guest IP")
}
if strings.TrimSpace(vm.Runtime.VSockPath) == "" {
return model.VMRecord{}, errors.New("vm has no vsock path")
}
if vm.Runtime.VSockCID == 0 {
return model.VMRecord{}, errors.New("vm has no vsock cid")
}
if err := d.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil {
return model.VMRecord{}, err
}
portsCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
listeners, err := vsockagent.Ports(portsCtx, d.logger, vm.Runtime.VSockPath)
if err != nil {
return model.VMRecord{}, err
}
result.Ports = buildVMPorts(vm, listeners)
return vm, nil
})
return result, err
}
func buildVMPorts(vm model.VMRecord, listeners []vsockagent.PortListener) []api.VMPort {
endpointHost := strings.TrimSpace(vm.Runtime.DNSName)
if endpointHost == "" {
endpointHost = strings.TrimSpace(vm.Runtime.GuestIP)
}
probeHost := strings.TrimSpace(vm.Runtime.GuestIP)
ports := make([]api.VMPort, 0, len(listeners))
for _, listener := range listeners {
if listener.Port <= 0 {
continue
}
port := api.VMPort{
Proto: strings.ToLower(strings.TrimSpace(listener.Proto)),
BindAddress: strings.TrimSpace(listener.BindAddress),
Port: listener.Port,
PID: listener.PID,
Process: strings.TrimSpace(listener.Process),
Command: strings.TrimSpace(listener.Command),
Endpoint: net.JoinHostPort(endpointHost, strconv.Itoa(listener.Port)),
}
if port.Command == "" {
port.Command = port.Process
}
if port.Proto == "tcp" && probeHost != "" && endpointHost != "" && probeHTTPListener(probeHost, listener.Port) {
port.WebURL = "http://" + net.JoinHostPort(endpointHost, strconv.Itoa(listener.Port)) + "/"
}
ports = append(ports, port)
}
sort.Slice(ports, func(i, j int) bool {
if ports[i].Proto != ports[j].Proto {
return ports[i].Proto < ports[j].Proto
}
if ports[i].Port != ports[j].Port {
return ports[i].Port < ports[j].Port
}
if ports[i].Process != ports[j].Process {
return ports[i].Process < ports[j].Process
}
return ports[i].Command < ports[j].Command
})
return ports
}
func probeHTTPListener(guestIP string, port int) bool {
if strings.TrimSpace(guestIP) == "" || port <= 0 {
return false
}
url := "http://" + net.JoinHostPort(strings.TrimSpace(guestIP), strconv.Itoa(port)) + "/"
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return false
}
client := &http.Client{
Timeout: httpProbeTimeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Transport: &http.Transport{
Proxy: nil,
},
}
resp, err := client.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1))
return resp.ProtoMajor >= 1
}

View file

@ -10,6 +10,7 @@ import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path/filepath"
@ -427,6 +428,145 @@ func TestHealthVMReturnsFalseForStoppedVM(t *testing.T) {
}
}
func TestPortsVMReturnsEnrichedPortsAndWebURL(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := openDaemonStore(t)
apiSock := filepath.Join(t.TempDir(), "fc.sock")
fake := startFakeFirecrackerProcess(t, apiSock)
t.Cleanup(func() {
_ = fake.Process.Kill()
_ = fake.Wait()
})
webServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
t.Cleanup(webServer.Close)
webAddr, err := net.ResolveTCPAddr("tcp", strings.TrimPrefix(webServer.URL, "http://"))
if err != nil {
t.Fatalf("ResolveTCPAddr: %v", err)
}
vsockSock := filepath.Join(t.TempDir(), "fc.vsock")
listener, err := net.Listen("unix", vsockSock)
if err != nil {
t.Fatalf("listen vsock: %v", err)
}
t.Cleanup(func() {
_ = listener.Close()
_ = os.Remove(vsockSock)
})
serverDone := make(chan error, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
serverDone <- err
return
}
defer conn.Close()
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
serverDone <- err
return
}
if got := string(buf[:n]); got != "CONNECT 42070\n" {
serverDone <- fmt.Errorf("unexpected connect message %q", got)
return
}
if _, err := conn.Write([]byte("OK 1\n")); err != nil {
serverDone <- err
return
}
reqBuf := make([]byte, 0, 1024)
for {
n, err = conn.Read(buf)
if err != nil {
serverDone <- err
return
}
reqBuf = append(reqBuf, buf[:n]...)
if strings.Contains(string(reqBuf), "\r\n\r\n") {
break
}
}
if got := string(reqBuf); !strings.Contains(got, "GET /ports HTTP/1.1\r\n") {
serverDone <- fmt.Errorf("unexpected ports payload %q", got)
return
}
body := fmt.Sprintf(`{"listeners":[{"proto":"tcp","bind_address":"0.0.0.0","port":%d,"pid":44,"process":"python3","command":"python3 -m http.server %d"},{"proto":"udp","bind_address":"0.0.0.0","port":53,"pid":1,"process":"dnsd","command":"dnsd --foreground"}]}`, webAddr.Port, webAddr.Port)
resp := fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: %d\r\n\r\n%s", len(body), body)
_, err = conn.Write([]byte(resp))
serverDone <- err
}()
vm := testVM("ports", "image-ports", "127.0.0.1")
vm.State = model.VMStateRunning
vm.Runtime.State = model.VMStateRunning
vm.Runtime.PID = fake.Process.Pid
vm.Runtime.APISockPath = apiSock
vm.Runtime.VSockPath = vsockSock
vm.Runtime.VSockCID = 10043
upsertDaemonVM(t, ctx, db, vm)
runner := &scriptedRunner{
t: t,
steps: []runnerStep{
sudoStep("", nil, "chown", fmt.Sprintf("%d:%d", os.Getuid(), os.Getgid()), vsockSock),
sudoStep("", nil, "chmod", "600", vsockSock),
},
}
d := &Daemon{store: db, runner: runner}
result, err := d.PortsVM(ctx, vm.Name)
if err != nil {
t.Fatalf("PortsVM: %v", err)
}
if result.Name != vm.Name || result.DNSName != vm.Runtime.DNSName {
t.Fatalf("result = %+v, want name/dns", result)
}
if len(result.Ports) != 2 {
t.Fatalf("ports = %+v, want 2 entries", result.Ports)
}
wantWeb := fmt.Sprintf("http://ports.vm:%d/", webAddr.Port)
var tcpPort, udpPort api.VMPort
for _, port := range result.Ports {
switch port.Proto {
case "tcp":
tcpPort = port
case "udp":
udpPort = port
}
}
if udpPort.Endpoint != "ports.vm:53" || udpPort.WebURL != "" {
t.Fatalf("udp port = %+v, want endpoint only", udpPort)
}
if tcpPort.Endpoint != net.JoinHostPort("ports.vm", strconv.Itoa(webAddr.Port)) || tcpPort.WebURL != wantWeb {
t.Fatalf("tcp port = %+v, want web url %q", tcpPort, wantWeb)
}
runner.assertExhausted()
if err := <-serverDone; err != nil {
t.Fatalf("server: %v", err)
}
}
func TestPortsVMReturnsErrorForStoppedVM(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := openDaemonStore(t)
vm := testVM("stopped-ports", "image-stopped", "172.16.0.50")
upsertDaemonVM(t, ctx, db, vm)
d := &Daemon{store: db}
_, err := d.PortsVM(ctx, vm.Name)
if err == nil || !strings.Contains(err.Error(), "is not running") {
t.Fatalf("PortsVM error = %v, want not running", err)
}
}
func TestSetVMDiskResizeFailsPreflightWhenToolsMissing(t *testing.T) {
ctx := context.Background()
db := openDaemonStore(t)

View file

@ -54,12 +54,13 @@ func TestExecImportsStayInsideApprovedPackages(t *testing.T) {
t.Fatalf("walk repo: %v", err)
}
if len(offenders) != 0 {
t.Fatalf("os/exec imports are only allowed in internal/cli, internal/firecracker, and internal/system; found %v", offenders)
t.Fatalf("os/exec imports are only allowed in internal/cli, internal/firecracker, internal/system, and internal/vsockagent; found %v", offenders)
}
}
func allowedExecImportPath(relPath string) bool {
return strings.HasPrefix(relPath, "internal/cli/") ||
strings.HasPrefix(relPath, "internal/firecracker/") ||
strings.HasPrefix(relPath, "internal/system/")
strings.HasPrefix(relPath, "internal/system/") ||
strings.HasPrefix(relPath, "internal/vsockagent/")
}

View file

@ -1,6 +1,7 @@
package vsockagent
import (
"bytes"
"context"
"encoding/json"
"errors"
@ -9,6 +10,12 @@ import (
"log/slog"
"net"
"net/http"
"os"
"os/exec"
"regexp"
"sort"
"strconv"
"strings"
"time"
sdkvsock "github.com/firecracker-microvm/firecracker-go-sdk/vsock"
@ -18,6 +25,7 @@ import (
const (
Port uint32 = 42070
HealthPath = "/healthz"
PortsPath = "/ports"
HealthyStatus = "ok"
GuestBinaryName = "banger-vsock-agent"
GuestInstallPath = "/usr/local/bin/" + GuestBinaryName
@ -38,10 +46,28 @@ WantedBy=multi-user.target
modulesLoadConfig = "vsock\nvmw_vsock_virtio_transport\n"
)
var (
portCollector = CollectPorts
processRe = regexp.MustCompile(`"([^"]+)",pid=(\d+)`)
)
type HealthResponse struct {
Status string `json:"status"`
}
type PortListener struct {
Proto string `json:"proto"`
BindAddress string `json:"bind_address"`
Port int `json:"port"`
PID int `json:"pid,omitempty"`
Process string `json:"process,omitempty"`
Command string `json:"command,omitempty"`
}
type PortsResponse struct {
Listeners []PortListener `json:"listeners"`
}
func NewHandler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc(HealthPath, func(w http.ResponseWriter, r *http.Request) {
@ -52,30 +78,24 @@ func NewHandler() http.Handler {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(HealthResponse{Status: HealthyStatus})
})
mux.HandleFunc(PortsPath, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
listeners, err := portCollector(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(PortsResponse{Listeners: listeners})
})
return mux
}
func Health(ctx context.Context, logger *slog.Logger, socketPath string) error {
transport := &http.Transport{
DisableKeepAlives: true,
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
return sdkvsock.DialContext(
ctx,
socketPath,
Port,
sdkvsock.WithRetryTimeout(3*time.Second),
sdkvsock.WithRetryInterval(100*time.Millisecond),
sdkvsock.WithLogger(newLogger(logger)),
)
},
}
defer transport.CloseIdleConnections()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://vsock"+HealthPath, nil)
if err != nil {
return err
}
resp, err := (&http.Client{Transport: transport}).Do(req)
resp, err := doRequest(ctx, logger, socketPath, HealthPath)
if err != nil {
return err
}
@ -97,6 +117,35 @@ func Health(ctx context.Context, logger *slog.Logger, socketPath string) error {
return nil
}
func Ports(ctx context.Context, logger *slog.Logger, socketPath string) ([]PortListener, error) {
resp, err := doRequest(ctx, logger, socketPath, PortsPath)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, fmt.Errorf("unexpected ports status %d: %s", resp.StatusCode, string(body))
}
var payload PortsResponse
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return nil, err
}
return payload.Listeners, nil
}
func CollectPorts(ctx context.Context) ([]PortListener, error) {
cmd := exec.CommandContext(ctx, "ss", "-H", "-lntup")
output, err := cmd.Output()
if err != nil {
if len(output) > 0 {
return nil, fmt.Errorf("run ss: %w: %s", err, bytes.TrimSpace(output))
}
return nil, fmt.Errorf("run ss: %w", err)
}
return parsePortListeners(output, readProcessCommandLine)
}
func ServiceUnit() string {
return serviceUnit
}
@ -156,3 +205,233 @@ func (h slogHook) Fire(entry *logrus.Entry) error {
func IsServerClosed(err error) bool {
return errors.Is(err, http.ErrServerClosed)
}
func doRequest(ctx context.Context, logger *slog.Logger, socketPath, path string) (*http.Response, error) {
transport := &http.Transport{
DisableKeepAlives: true,
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
return sdkvsock.DialContext(
ctx,
socketPath,
Port,
sdkvsock.WithRetryTimeout(3*time.Second),
sdkvsock.WithRetryInterval(100*time.Millisecond),
sdkvsock.WithLogger(newLogger(logger)),
)
},
}
client := &http.Client{Transport: transport}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://vsock"+path, nil)
if err != nil {
transport.CloseIdleConnections()
return nil, err
}
resp, err := client.Do(req)
if err != nil {
transport.CloseIdleConnections()
return nil, err
}
return wrappedResponse(resp, transport), nil
}
type responseCloser struct {
io.ReadCloser
transport *http.Transport
}
func (c responseCloser) Close() error {
err := c.ReadCloser.Close()
c.transport.CloseIdleConnections()
return err
}
func wrappedResponse(resp *http.Response, transport *http.Transport) *http.Response {
if resp == nil || resp.Body == nil || transport == nil {
return resp
}
resp.Body = responseCloser{ReadCloser: resp.Body, transport: transport}
return resp
}
func parsePortListeners(raw []byte, readCmdline func(int) string) ([]PortListener, error) {
lines := strings.Split(strings.TrimSpace(string(raw)), "\n")
listeners := make([]PortListener, 0, len(lines))
wildcards := make(map[string]struct{})
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parsed, err := parseSSLine(line, readCmdline)
if err != nil {
return nil, err
}
for _, listener := range parsed {
if isLoopbackAddress(listener.BindAddress) {
continue
}
if isWildcardAddress(listener.BindAddress) {
key := wildcardKey(listener)
if _, ok := wildcards[key]; ok {
continue
}
wildcards[key] = struct{}{}
}
listeners = append(listeners, listener)
}
}
sort.Slice(listeners, func(i, j int) bool {
if listeners[i].Proto != listeners[j].Proto {
return listeners[i].Proto < listeners[j].Proto
}
if listeners[i].Port != listeners[j].Port {
return listeners[i].Port < listeners[j].Port
}
if listeners[i].PID != listeners[j].PID {
return listeners[i].PID < listeners[j].PID
}
return listeners[i].BindAddress < listeners[j].BindAddress
})
return listeners, nil
}
func parseSSLine(line string, readCmdline func(int) string) ([]PortListener, error) {
fields := strings.Fields(line)
if len(fields) < 6 {
return nil, fmt.Errorf("parse ss line: expected at least 6 fields, got %d in %q", len(fields), line)
}
proto := strings.ToLower(strings.TrimSpace(fields[0]))
if proto != "tcp" && proto != "udp" {
return nil, nil
}
bindAddress, port, err := parseBindAddress(fields[4])
if err != nil {
return nil, fmt.Errorf("parse ss local address %q: %w", fields[4], err)
}
if bindAddress != "*" && net.ParseIP(bindAddress) == nil {
return nil, nil
}
processInfo := strings.Join(fields[6:], " ")
entries := parseProcessEntries(processInfo)
if len(entries) == 0 {
return []PortListener{{
Proto: proto,
BindAddress: bindAddress,
Port: port,
}}, nil
}
listeners := make([]PortListener, 0, len(entries))
for _, entry := range entries {
command := strings.TrimSpace(readCmdline(entry.PID))
if command == "" {
command = entry.Process
}
listeners = append(listeners, PortListener{
Proto: proto,
BindAddress: bindAddress,
Port: port,
PID: entry.PID,
Process: entry.Process,
Command: command,
})
}
return listeners, nil
}
type processEntry struct {
Process string
PID int
}
func parseProcessEntries(raw string) []processEntry {
matches := processRe.FindAllStringSubmatch(raw, -1)
if len(matches) == 0 {
return nil
}
entries := make([]processEntry, 0, len(matches))
for _, match := range matches {
pid, err := strconv.Atoi(match[2])
if err != nil {
continue
}
entries = append(entries, processEntry{
Process: match[1],
PID: pid,
})
}
return entries
}
func parseBindAddress(raw string) (string, int, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", 0, errors.New("empty address")
}
var host, portRaw string
if strings.HasPrefix(raw, "[") {
var err error
host, portRaw, err = net.SplitHostPort(raw)
if err != nil {
return "", 0, err
}
} else {
idx := strings.LastIndex(raw, ":")
if idx <= 0 || idx == len(raw)-1 {
return "", 0, fmt.Errorf("missing host:port in %q", raw)
}
host = raw[:idx]
portRaw = raw[idx+1:]
}
if zoneIdx := strings.Index(host, "%"); zoneIdx >= 0 {
host = host[:zoneIdx]
}
host = strings.Trim(host, "[]")
if host == "" {
host = "*"
}
port, err := strconv.Atoi(portRaw)
if err != nil {
return "", 0, err
}
return host, port, nil
}
func readProcessCommandLine(pid int) string {
if pid <= 0 {
return ""
}
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
if err != nil {
return ""
}
parts := strings.Split(string(data), "\x00")
filtered := parts[:0]
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
filtered = append(filtered, part)
}
return strings.Join(filtered, " ")
}
func isLoopbackAddress(host string) bool {
if host == "" || host == "*" {
return false
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
func isWildcardAddress(host string) bool {
switch host {
case "*", "0.0.0.0", "::":
return true
}
return false
}
func wildcardKey(listener PortListener) string {
return fmt.Sprintf("%s:%d:%d:%s", listener.Proto, listener.Port, listener.PID, listener.Process)
}

View file

@ -4,9 +4,12 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"net"
"net/http"
"path/filepath"
"reflect"
"strconv"
"strings"
"testing"
"time"
@ -37,6 +40,44 @@ func TestNewHandlerHealthz(t *testing.T) {
}
}
func TestNewHandlerPorts(t *testing.T) {
origCollector := portCollector
t.Cleanup(func() {
portCollector = origCollector
})
portCollector = func(context.Context) ([]PortListener, error) {
return []PortListener{{
Proto: "tcp",
BindAddress: "0.0.0.0",
Port: 8080,
PID: 42,
Process: "python3",
Command: "python3 -m http.server 8080",
}}, nil
}
req, err := http.NewRequest(http.MethodGet, PortsPath, nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
rr := newTestResponseRecorder()
NewHandler().ServeHTTP(rr, req)
if rr.status != http.StatusOK {
t.Fatalf("status = %d, want %d", rr.status, http.StatusOK)
}
var payload PortsResponse
if err := json.Unmarshal(rr.body.Bytes(), &payload); err != nil {
t.Fatalf("Unmarshal: %v", err)
}
if len(payload.Listeners) != 1 {
t.Fatalf("listeners = %d, want 1", len(payload.Listeners))
}
if got := payload.Listeners[0]; got.Command != "python3 -m http.server 8080" {
t.Fatalf("listener = %+v, want command", got)
}
}
func TestHealth(t *testing.T) {
t.Parallel()
@ -110,6 +151,168 @@ func TestHealth(t *testing.T) {
}
}
func TestPorts(t *testing.T) {
t.Parallel()
dir := t.TempDir()
socketPath := filepath.Join(dir, "fc.vsock")
listener, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("Listen: %v", err)
}
defer listener.Close()
done := make(chan error, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
done <- err
return
}
defer conn.Close()
buf := make([]byte, 0, 256)
tmp := make([]byte, 256)
for {
n, err := conn.Read(tmp)
if err != nil {
done <- err
return
}
buf = append(buf, tmp[:n]...)
if strings.Contains(string(buf), "\n") {
break
}
}
if got := string(buf); got != "CONNECT 42070\n" {
done <- unexpectedStringError(got)
return
}
if _, err := conn.Write([]byte("OK 55\n")); err != nil {
done <- err
return
}
buf = buf[:0]
for {
n, err := conn.Read(tmp)
if err != nil {
done <- err
return
}
buf = append(buf, tmp[:n]...)
if strings.Contains(string(buf), "\r\n\r\n") {
break
}
}
req := string(buf)
if !strings.Contains(req, "GET /ports HTTP/1.1\r\n") {
done <- unexpectedStringError(req)
return
}
body := `{"listeners":[{"proto":"tcp","bind_address":"0.0.0.0","port":8080,"pid":42,"process":"python3","command":"python3 -m http.server 8080"}]}`
resp := "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: " + strconv.Itoa(len(body)) + "\r\n\r\n" + body
_, err = conn.Write([]byte(resp))
done <- err
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
listeners, err := Ports(ctx, nil, socketPath)
if err != nil {
t.Fatalf("Ports: %v", err)
}
if len(listeners) != 1 || listeners[0].Port != 8080 || listeners[0].Command != "python3 -m http.server 8080" {
t.Fatalf("listeners = %+v, want parsed port listener", listeners)
}
if err := <-done; err != nil {
t.Fatalf("server: %v", err)
}
}
func TestParsePortListenersFiltersLoopbackAndDedupesWildcards(t *testing.T) {
t.Parallel()
raw := strings.Join([]string{
`tcp LISTEN 0 4096 0.0.0.0:22 0.0.0.0:* users:(("sshd",pid=12,fd=3))`,
`tcp LISTEN 0 4096 [::]:22 [::]:* users:(("sshd",pid=12,fd=4))`,
`udp UNCONN 0 0 127.0.0.53%lo:53 0.0.0.0:* users:(("stubby",pid=99,fd=3))`,
`tcp LISTEN 0 4096 172.16.0.2:8080 0.0.0.0:* users:(("python3",pid=44,fd=6))`,
}, "\n")
readCmdline := func(pid int) string {
switch pid {
case 12:
return "/usr/sbin/sshd -D"
case 44:
return "python3 -m http.server 8080"
default:
return ""
}
}
listeners, err := parsePortListeners([]byte(raw), readCmdline)
if err != nil {
t.Fatalf("parsePortListeners: %v", err)
}
want := []PortListener{
{
Proto: "tcp",
BindAddress: "0.0.0.0",
Port: 22,
PID: 12,
Process: "sshd",
Command: "/usr/sbin/sshd -D",
},
{
Proto: "tcp",
BindAddress: "172.16.0.2",
Port: 8080,
PID: 44,
Process: "python3",
Command: "python3 -m http.server 8080",
},
}
if !reflect.DeepEqual(listeners, want) {
t.Fatalf("listeners = %#v, want %#v", listeners, want)
}
}
func TestParsePortListenersFallsBackToProcessName(t *testing.T) {
t.Parallel()
raw := `tcp LISTEN 0 128 0.0.0.0:5432 0.0.0.0:* users:(("postgres",pid=77,fd=5))`
listeners, err := parsePortListeners([]byte(raw), func(int) string { return "" })
if err != nil {
t.Fatalf("parsePortListeners: %v", err)
}
if len(listeners) != 1 {
t.Fatalf("listeners = %d, want 1", len(listeners))
}
if listeners[0].Command != "postgres" {
t.Fatalf("command = %q, want process fallback", listeners[0].Command)
}
}
func TestNewHandlerPortsReturnsServerErrorOnCollectorFailure(t *testing.T) {
origCollector := portCollector
t.Cleanup(func() {
portCollector = origCollector
})
portCollector = func(context.Context) ([]PortListener, error) {
return nil, errors.New("ss missing")
}
req, err := http.NewRequest(http.MethodGet, PortsPath, nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
rr := newTestResponseRecorder()
NewHandler().ServeHTTP(rr, req)
if rr.status != http.StatusInternalServerError {
t.Fatalf("status = %d, want %d", rr.status, http.StatusInternalServerError)
}
}
type testResponseRecorder struct {
headers http.Header
body bytes.Buffer

View file

@ -5,5 +5,6 @@ tree
ca-certificates
curl
wget
iproute2
vim
tmux