Add vsock-backed VM port inspection
Let the host ask the guest vsock agent to run ss so open ports can be surfaced without SSHing in manually. Add a narrow /ports agent endpoint, a daemon vm.ports RPC that enriches listeners with <hostname>.vm endpoints and best-effort HTTP links, and a concurrent 'banger vm ports' CLI table for one or more VMs. Update the guest package contract to include ss for rebuilt Debian images, allow the guest agent package in the shell-out policy, and cover the new parsing/RPC/CLI flow in tests. Verified with GOCACHE=/tmp/banger-gocache go test ./... outside the sandbox, make build, bash -n customize.sh make-rootfs-void.sh verify.sh, and ./banger vm ports --help.
This commit is contained in:
parent
3ed78fdcfc
commit
c298ed2fc1
11 changed files with 1029 additions and 23 deletions
15
README.md
15
README.md
|
|
@ -110,6 +110,11 @@ banger vm ssh calm-otter
|
|||
When the SSH session exits normally, `banger` checks the guest over vsock and
|
||||
reminds you if the VM is still running.
|
||||
|
||||
Inspect host-reachable listening ports for one or more running VMs:
|
||||
```bash
|
||||
banger vm ports calm-otter buildbox
|
||||
```
|
||||
|
||||
Stop, restart, kill, or delete it:
|
||||
```bash
|
||||
banger vm stop calm-otter
|
||||
|
|
@ -246,6 +251,13 @@ for daemon-managed VMs. Known `A` records resolve `<vm-name>.vm` to the VM's
|
|||
guest IPv4 address. Integrate your local resolver separately if you want
|
||||
transparent `.vm` lookups on the host.
|
||||
|
||||
`banger vm ports` asks the guest-side `banger-vsock-agent` to run `ss`, then
|
||||
prints host-usable `<hostname>.vm:port` endpoints plus the owning
|
||||
process/command. TCP listeners get a short best-effort HTTP probe; when the
|
||||
probe sees a real HTTP response, the command includes a clickable
|
||||
`http://<hostname>.vm:port/` URL. Older images without `ss` may need rebuilding
|
||||
before `vm ports` works.
|
||||
|
||||
## Storage Model
|
||||
- VMs share a read-only base rootfs image.
|
||||
- Each VM gets its own sparse writable system overlay for `/`.
|
||||
|
|
@ -270,7 +282,8 @@ shell helpers treated as manual workflows rather than architecture drivers.
|
|||
- Stopping a VM preserves its overlay and work disk.
|
||||
|
||||
## Rebuilding The Repo Default Rootfs
|
||||
`packages.apt` controls the base apt packages baked into rebuilt images.
|
||||
`packages.apt` controls the base apt packages baked into rebuilt images,
|
||||
including guest tools such as `ss` used by `banger vm ports`.
|
||||
|
||||
To rebuild the source-checkout default image in `./runtime/rootfs-docker.ext4`:
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -73,6 +73,23 @@ type VMPingResult struct {
|
|||
Alive bool `json:"alive"`
|
||||
}
|
||||
|
||||
type VMPort struct {
|
||||
Proto string `json:"proto"`
|
||||
BindAddress string `json:"bind_address,omitempty"`
|
||||
Port int `json:"port"`
|
||||
PID int `json:"pid,omitempty"`
|
||||
Process string `json:"process,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
WebURL string `json:"web_url,omitempty"`
|
||||
}
|
||||
|
||||
type VMPortsResult struct {
|
||||
Name string `json:"name"`
|
||||
DNSName string `json:"dns_name,omitempty"`
|
||||
Ports []VMPort `json:"ports"`
|
||||
}
|
||||
|
||||
type ImageBuildParams struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
BaseRootfs string `json:"base_rootfs,omitempty"`
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import (
|
|||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
|
@ -45,6 +46,9 @@ var (
|
|||
vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) {
|
||||
return rpc.Call[api.VMHealthResult](ctx, socketPath, "vm.health", api.VMRefParams{IDOrName: idOrName})
|
||||
}
|
||||
vmPortsFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) {
|
||||
return rpc.Call[api.VMPortsResult](ctx, socketPath, "vm.ports", api.VMRefParams{IDOrName: idOrName})
|
||||
}
|
||||
)
|
||||
|
||||
func NewBangerCommand() *cobra.Command {
|
||||
|
|
@ -243,6 +247,7 @@ func newVMCommand() *cobra.Command {
|
|||
newVMSSHCommand(),
|
||||
newVMLogsCommand(),
|
||||
newVMStatsCommand(),
|
||||
newVMPortsCommand(),
|
||||
)
|
||||
return cmd
|
||||
}
|
||||
|
|
@ -542,6 +547,50 @@ func newVMStatsCommand() *cobra.Command {
|
|||
}
|
||||
}
|
||||
|
||||
func newVMPortsCommand() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "ports <id-or-name>...",
|
||||
Short: "Show host-reachable listening guest ports",
|
||||
Args: minArgsUsage(1, "usage: banger vm ports <id-or-name>..."),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
layout, _, err := ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
listResult, err := rpc.Call[api.VMListResult](cmd.Context(), layout.SocketPath, "vm.list", api.Empty{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
targets, resolutionErrs := resolveVMTargets(listResult.VMs, args)
|
||||
results := executeVMPortsBatch(cmd.Context(), layout.SocketPath, targets)
|
||||
|
||||
failed := false
|
||||
for _, resolutionErr := range resolutionErrs {
|
||||
if _, err := fmt.Fprintf(cmd.ErrOrStderr(), "%s: %v\n", resolutionErr.Ref, resolutionErr.Err); err != nil {
|
||||
return err
|
||||
}
|
||||
failed = true
|
||||
}
|
||||
for _, result := range results {
|
||||
if result.Err == nil {
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprintf(cmd.ErrOrStderr(), "%s: %v\n", result.Target.Ref, result.Err); err != nil {
|
||||
return err
|
||||
}
|
||||
failed = true
|
||||
}
|
||||
if err := printVMPortsTable(cmd.OutOrStdout(), results); err != nil {
|
||||
return err
|
||||
}
|
||||
if failed {
|
||||
return errors.New("one or more VM operations failed")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newImageCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "image",
|
||||
|
|
@ -744,6 +793,12 @@ type vmBatchActionResult struct {
|
|||
Err error
|
||||
}
|
||||
|
||||
type vmPortsBatchResult struct {
|
||||
Target resolvedVMTarget
|
||||
Result api.VMPortsResult
|
||||
Err error
|
||||
}
|
||||
|
||||
func runVMBatchAction(cmd *cobra.Command, socketPath string, refs []string, action func(context.Context, string) (model.VMRecord, error)) error {
|
||||
listResult, err := rpc.Call[api.VMListResult](cmd.Context(), socketPath, "vm.list", api.Empty{})
|
||||
if err != nil {
|
||||
|
|
@ -852,6 +907,27 @@ func executeVMActionBatch(ctx context.Context, targets []resolvedVMTarget, actio
|
|||
return results
|
||||
}
|
||||
|
||||
func executeVMPortsBatch(ctx context.Context, socketPath string, targets []resolvedVMTarget) []vmPortsBatchResult {
|
||||
results := make([]vmPortsBatchResult, len(targets))
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(targets))
|
||||
for index, target := range targets {
|
||||
index := index
|
||||
target := target
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result, err := vmPortsFunc(ctx, socketPath, target.VM.ID)
|
||||
results[index] = vmPortsBatchResult{
|
||||
Target: target,
|
||||
Result: result,
|
||||
Err: err,
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
return results
|
||||
}
|
||||
|
||||
func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error) {
|
||||
layout, err := paths.Resolve()
|
||||
if err != nil {
|
||||
|
|
@ -1147,6 +1223,77 @@ func printImageSummary(out anyWriter, image model.Image) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func printVMPortsTable(out anyWriter, results []vmPortsBatchResult) error {
|
||||
type portRow struct {
|
||||
VM string
|
||||
Proto string
|
||||
Endpoint string
|
||||
Process string
|
||||
Command string
|
||||
WebURL string
|
||||
Port int
|
||||
}
|
||||
rows := make([]portRow, 0)
|
||||
for _, result := range results {
|
||||
if result.Err != nil {
|
||||
continue
|
||||
}
|
||||
vmName := strings.TrimSpace(result.Result.Name)
|
||||
if vmName == "" {
|
||||
vmName = result.Target.VM.Name
|
||||
}
|
||||
for _, port := range result.Result.Ports {
|
||||
rows = append(rows, portRow{
|
||||
VM: vmName,
|
||||
Proto: port.Proto,
|
||||
Endpoint: port.Endpoint,
|
||||
Process: port.Process,
|
||||
Command: port.Command,
|
||||
WebURL: emptyDash(port.WebURL),
|
||||
Port: port.Port,
|
||||
})
|
||||
}
|
||||
}
|
||||
sort.Slice(rows, func(i, j int) bool {
|
||||
if rows[i].VM != rows[j].VM {
|
||||
return rows[i].VM < rows[j].VM
|
||||
}
|
||||
if rows[i].Proto != rows[j].Proto {
|
||||
return rows[i].Proto < rows[j].Proto
|
||||
}
|
||||
if rows[i].Port != rows[j].Port {
|
||||
return rows[i].Port < rows[j].Port
|
||||
}
|
||||
if rows[i].Process != rows[j].Process {
|
||||
return rows[i].Process < rows[j].Process
|
||||
}
|
||||
return rows[i].Command < rows[j].Command
|
||||
})
|
||||
if len(rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(out, 0, 8, 2, ' ', 0)
|
||||
if _, err := fmt.Fprintln(w, "VM\tPROTO\tENDPOINT\tPROCESS\tCOMMAND\tWEB"); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, row := range rows {
|
||||
if _, err := fmt.Fprintf(
|
||||
w,
|
||||
"%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
row.VM,
|
||||
row.Proto,
|
||||
emptyDash(row.Endpoint),
|
||||
emptyDash(row.Process),
|
||||
emptyDash(row.Command),
|
||||
row.WebURL,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
func printDoctorReport(out anyWriter, report system.Report) error {
|
||||
for _, check := range report.Checks {
|
||||
status := strings.ToUpper(string(check.Status))
|
||||
|
|
@ -1162,6 +1309,14 @@ func printDoctorReport(out anyWriter, report system.Report) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func emptyDash(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return "-"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
type anyWriter interface {
|
||||
Write(p []byte) (n int, err error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -151,6 +151,17 @@ func TestVMKillFlagsExist(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestVMPortsCommandExists(t *testing.T) {
|
||||
root := NewBangerCommand()
|
||||
vm, _, err := root.Find([]string{"vm"})
|
||||
if err != nil {
|
||||
t.Fatalf("find vm: %v", err)
|
||||
}
|
||||
if _, _, err := vm.Find([]string{"ports"}); err != nil {
|
||||
t.Fatalf("find ports: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVMSetParamsFromFlags(t *testing.T) {
|
||||
params, err := vmSetParamsFromFlags("devbox", 4, 2048, "16G", true, false)
|
||||
if err != nil {
|
||||
|
|
@ -268,6 +279,59 @@ func TestAbsolutizeImageRegisterPaths(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPrintVMPortsTableSortsAndRendersURLs(t *testing.T) {
|
||||
results := []vmPortsBatchResult{
|
||||
{
|
||||
Target: resolvedVMTarget{Ref: "beta"},
|
||||
Result: api.VMPortsResult{
|
||||
Name: "beta",
|
||||
Ports: []api.VMPort{{
|
||||
Proto: "tcp",
|
||||
Port: 8080,
|
||||
Endpoint: "beta.vm:8080",
|
||||
Process: "python3",
|
||||
Command: "python3 -m http.server 8080",
|
||||
WebURL: "http://beta.vm:8080/",
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
Target: resolvedVMTarget{Ref: "alpha"},
|
||||
Result: api.VMPortsResult{
|
||||
Name: "alpha",
|
||||
Ports: []api.VMPort{{
|
||||
Proto: "udp",
|
||||
Port: 53,
|
||||
Endpoint: "alpha.vm:53",
|
||||
Process: "dnsd",
|
||||
Command: "dnsd --foreground",
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
if err := printVMPortsTable(&out, results); err != nil {
|
||||
t.Fatalf("printVMPortsTable: %v", err)
|
||||
}
|
||||
lines := strings.Split(strings.TrimSpace(out.String()), "\n")
|
||||
if len(lines) != 3 {
|
||||
t.Fatalf("lines = %q, want header + 2 rows", lines)
|
||||
}
|
||||
if !strings.Contains(lines[0], "VM") || !strings.Contains(lines[0], "WEB") {
|
||||
t.Fatalf("header = %q, want VM/WEB columns", lines[0])
|
||||
}
|
||||
if !strings.Contains(lines[1], "alpha") || !strings.Contains(lines[1], "alpha.vm:53") || !strings.Contains(lines[1], "\t-\n") {
|
||||
// tabwriter output is space-expanded, so just require the dash placeholder.
|
||||
if !strings.Contains(lines[1], "alpha") || !strings.Contains(lines[1], "alpha.vm:53") || !strings.HasSuffix(strings.TrimSpace(lines[1]), "-") {
|
||||
t.Fatalf("first row = %q, want alpha row with dash web column", lines[1])
|
||||
}
|
||||
}
|
||||
if !strings.Contains(lines[2], "beta") || !strings.Contains(lines[2], "http://beta.vm:8080/") {
|
||||
t.Fatalf("second row = %q, want beta web url", lines[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSSHSessionPrintsReminderWhenHealthCheckPasses(t *testing.T) {
|
||||
origSSHExec := sshExecFunc
|
||||
origHealth := vmHealthFunc
|
||||
|
|
|
|||
|
|
@ -345,6 +345,13 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
|
|||
}
|
||||
result, err := d.PingVM(ctx, params.IDOrName)
|
||||
return marshalResultOrError(result, err)
|
||||
case "vm.ports":
|
||||
params, err := rpc.DecodeParams[api.VMRefParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
result, err := d.PortsVM(ctx, params.IDOrName)
|
||||
return marshalResultOrError(result, err)
|
||||
case "image.list":
|
||||
images, err := d.store.ListImages(ctx)
|
||||
return marshalResultOrError(api.ImageListResult{Images: images}, err)
|
||||
|
|
|
|||
126
internal/daemon/ports.go
Normal file
126
internal/daemon/ports.go
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"banger/internal/api"
|
||||
"banger/internal/model"
|
||||
"banger/internal/system"
|
||||
"banger/internal/vmdns"
|
||||
"banger/internal/vsockagent"
|
||||
)
|
||||
|
||||
const httpProbeTimeout = 750 * time.Millisecond
|
||||
|
||||
func (d *Daemon) PortsVM(ctx context.Context, idOrName string) (result api.VMPortsResult, err error) {
|
||||
_, err = d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
result.Name = vm.Name
|
||||
result.DNSName = strings.TrimSpace(vm.Runtime.DNSName)
|
||||
if result.DNSName == "" && strings.TrimSpace(vm.Name) != "" {
|
||||
result.DNSName = vmdns.RecordName(vm.Name)
|
||||
}
|
||||
if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
|
||||
return model.VMRecord{}, fmt.Errorf("vm %s is not running", vm.Name)
|
||||
}
|
||||
if strings.TrimSpace(vm.Runtime.GuestIP) == "" {
|
||||
return model.VMRecord{}, errors.New("vm has no guest IP")
|
||||
}
|
||||
if strings.TrimSpace(vm.Runtime.VSockPath) == "" {
|
||||
return model.VMRecord{}, errors.New("vm has no vsock path")
|
||||
}
|
||||
if vm.Runtime.VSockCID == 0 {
|
||||
return model.VMRecord{}, errors.New("vm has no vsock cid")
|
||||
}
|
||||
if err := d.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
portsCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
listeners, err := vsockagent.Ports(portsCtx, d.logger, vm.Runtime.VSockPath)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
result.Ports = buildVMPorts(vm, listeners)
|
||||
return vm, nil
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
func buildVMPorts(vm model.VMRecord, listeners []vsockagent.PortListener) []api.VMPort {
|
||||
endpointHost := strings.TrimSpace(vm.Runtime.DNSName)
|
||||
if endpointHost == "" {
|
||||
endpointHost = strings.TrimSpace(vm.Runtime.GuestIP)
|
||||
}
|
||||
probeHost := strings.TrimSpace(vm.Runtime.GuestIP)
|
||||
ports := make([]api.VMPort, 0, len(listeners))
|
||||
for _, listener := range listeners {
|
||||
if listener.Port <= 0 {
|
||||
continue
|
||||
}
|
||||
port := api.VMPort{
|
||||
Proto: strings.ToLower(strings.TrimSpace(listener.Proto)),
|
||||
BindAddress: strings.TrimSpace(listener.BindAddress),
|
||||
Port: listener.Port,
|
||||
PID: listener.PID,
|
||||
Process: strings.TrimSpace(listener.Process),
|
||||
Command: strings.TrimSpace(listener.Command),
|
||||
Endpoint: net.JoinHostPort(endpointHost, strconv.Itoa(listener.Port)),
|
||||
}
|
||||
if port.Command == "" {
|
||||
port.Command = port.Process
|
||||
}
|
||||
if port.Proto == "tcp" && probeHost != "" && endpointHost != "" && probeHTTPListener(probeHost, listener.Port) {
|
||||
port.WebURL = "http://" + net.JoinHostPort(endpointHost, strconv.Itoa(listener.Port)) + "/"
|
||||
}
|
||||
ports = append(ports, port)
|
||||
}
|
||||
sort.Slice(ports, func(i, j int) bool {
|
||||
if ports[i].Proto != ports[j].Proto {
|
||||
return ports[i].Proto < ports[j].Proto
|
||||
}
|
||||
if ports[i].Port != ports[j].Port {
|
||||
return ports[i].Port < ports[j].Port
|
||||
}
|
||||
if ports[i].Process != ports[j].Process {
|
||||
return ports[i].Process < ports[j].Process
|
||||
}
|
||||
return ports[i].Command < ports[j].Command
|
||||
})
|
||||
return ports
|
||||
}
|
||||
|
||||
func probeHTTPListener(guestIP string, port int) bool {
|
||||
if strings.TrimSpace(guestIP) == "" || port <= 0 {
|
||||
return false
|
||||
}
|
||||
url := "http://" + net.JoinHostPort(strings.TrimSpace(guestIP), strconv.Itoa(port)) + "/"
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: httpProbeTimeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
Transport: &http.Transport{
|
||||
Proxy: nil,
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1))
|
||||
return resp.ProtoMajor >= 1
|
||||
}
|
||||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
|
@ -427,6 +428,145 @@ func TestHealthVMReturnsFalseForStoppedVM(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPortsVMReturnsEnrichedPortsAndWebURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
apiSock := filepath.Join(t.TempDir(), "fc.sock")
|
||||
fake := startFakeFirecrackerProcess(t, apiSock)
|
||||
t.Cleanup(func() {
|
||||
_ = fake.Process.Kill()
|
||||
_ = fake.Wait()
|
||||
})
|
||||
|
||||
webServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
t.Cleanup(webServer.Close)
|
||||
webAddr, err := net.ResolveTCPAddr("tcp", strings.TrimPrefix(webServer.URL, "http://"))
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveTCPAddr: %v", err)
|
||||
}
|
||||
|
||||
vsockSock := filepath.Join(t.TempDir(), "fc.vsock")
|
||||
listener, err := net.Listen("unix", vsockSock)
|
||||
if err != nil {
|
||||
t.Fatalf("listen vsock: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = listener.Close()
|
||||
_ = os.Remove(vsockSock)
|
||||
})
|
||||
serverDone := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
if got := string(buf[:n]); got != "CONNECT 42070\n" {
|
||||
serverDone <- fmt.Errorf("unexpected connect message %q", got)
|
||||
return
|
||||
}
|
||||
if _, err := conn.Write([]byte("OK 1\n")); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
reqBuf := make([]byte, 0, 1024)
|
||||
for {
|
||||
n, err = conn.Read(buf)
|
||||
if err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
reqBuf = append(reqBuf, buf[:n]...)
|
||||
if strings.Contains(string(reqBuf), "\r\n\r\n") {
|
||||
break
|
||||
}
|
||||
}
|
||||
if got := string(reqBuf); !strings.Contains(got, "GET /ports HTTP/1.1\r\n") {
|
||||
serverDone <- fmt.Errorf("unexpected ports payload %q", got)
|
||||
return
|
||||
}
|
||||
body := fmt.Sprintf(`{"listeners":[{"proto":"tcp","bind_address":"0.0.0.0","port":%d,"pid":44,"process":"python3","command":"python3 -m http.server %d"},{"proto":"udp","bind_address":"0.0.0.0","port":53,"pid":1,"process":"dnsd","command":"dnsd --foreground"}]}`, webAddr.Port, webAddr.Port)
|
||||
resp := fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: %d\r\n\r\n%s", len(body), body)
|
||||
_, err = conn.Write([]byte(resp))
|
||||
serverDone <- err
|
||||
}()
|
||||
|
||||
vm := testVM("ports", "image-ports", "127.0.0.1")
|
||||
vm.State = model.VMStateRunning
|
||||
vm.Runtime.State = model.VMStateRunning
|
||||
vm.Runtime.PID = fake.Process.Pid
|
||||
vm.Runtime.APISockPath = apiSock
|
||||
vm.Runtime.VSockPath = vsockSock
|
||||
vm.Runtime.VSockCID = 10043
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
runner := &scriptedRunner{
|
||||
t: t,
|
||||
steps: []runnerStep{
|
||||
sudoStep("", nil, "chown", fmt.Sprintf("%d:%d", os.Getuid(), os.Getgid()), vsockSock),
|
||||
sudoStep("", nil, "chmod", "600", vsockSock),
|
||||
},
|
||||
}
|
||||
d := &Daemon{store: db, runner: runner}
|
||||
|
||||
result, err := d.PortsVM(ctx, vm.Name)
|
||||
if err != nil {
|
||||
t.Fatalf("PortsVM: %v", err)
|
||||
}
|
||||
if result.Name != vm.Name || result.DNSName != vm.Runtime.DNSName {
|
||||
t.Fatalf("result = %+v, want name/dns", result)
|
||||
}
|
||||
if len(result.Ports) != 2 {
|
||||
t.Fatalf("ports = %+v, want 2 entries", result.Ports)
|
||||
}
|
||||
wantWeb := fmt.Sprintf("http://ports.vm:%d/", webAddr.Port)
|
||||
var tcpPort, udpPort api.VMPort
|
||||
for _, port := range result.Ports {
|
||||
switch port.Proto {
|
||||
case "tcp":
|
||||
tcpPort = port
|
||||
case "udp":
|
||||
udpPort = port
|
||||
}
|
||||
}
|
||||
if udpPort.Endpoint != "ports.vm:53" || udpPort.WebURL != "" {
|
||||
t.Fatalf("udp port = %+v, want endpoint only", udpPort)
|
||||
}
|
||||
if tcpPort.Endpoint != net.JoinHostPort("ports.vm", strconv.Itoa(webAddr.Port)) || tcpPort.WebURL != wantWeb {
|
||||
t.Fatalf("tcp port = %+v, want web url %q", tcpPort, wantWeb)
|
||||
}
|
||||
runner.assertExhausted()
|
||||
if err := <-serverDone; err != nil {
|
||||
t.Fatalf("server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortsVMReturnsErrorForStoppedVM(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
vm := testVM("stopped-ports", "image-stopped", "172.16.0.50")
|
||||
upsertDaemonVM(t, ctx, db, vm)
|
||||
|
||||
d := &Daemon{store: db}
|
||||
_, err := d.PortsVM(ctx, vm.Name)
|
||||
if err == nil || !strings.Contains(err.Error(), "is not running") {
|
||||
t.Fatalf("PortsVM error = %v, want not running", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetVMDiskResizeFailsPreflightWhenToolsMissing(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
|
|
|
|||
|
|
@ -54,12 +54,13 @@ func TestExecImportsStayInsideApprovedPackages(t *testing.T) {
|
|||
t.Fatalf("walk repo: %v", err)
|
||||
}
|
||||
if len(offenders) != 0 {
|
||||
t.Fatalf("os/exec imports are only allowed in internal/cli, internal/firecracker, and internal/system; found %v", offenders)
|
||||
t.Fatalf("os/exec imports are only allowed in internal/cli, internal/firecracker, internal/system, and internal/vsockagent; found %v", offenders)
|
||||
}
|
||||
}
|
||||
|
||||
func allowedExecImportPath(relPath string) bool {
|
||||
return strings.HasPrefix(relPath, "internal/cli/") ||
|
||||
strings.HasPrefix(relPath, "internal/firecracker/") ||
|
||||
strings.HasPrefix(relPath, "internal/system/")
|
||||
strings.HasPrefix(relPath, "internal/system/") ||
|
||||
strings.HasPrefix(relPath, "internal/vsockagent/")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package vsockagent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
|
@ -9,6 +10,12 @@ import (
|
|||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sdkvsock "github.com/firecracker-microvm/firecracker-go-sdk/vsock"
|
||||
|
|
@ -18,6 +25,7 @@ import (
|
|||
const (
|
||||
Port uint32 = 42070
|
||||
HealthPath = "/healthz"
|
||||
PortsPath = "/ports"
|
||||
HealthyStatus = "ok"
|
||||
GuestBinaryName = "banger-vsock-agent"
|
||||
GuestInstallPath = "/usr/local/bin/" + GuestBinaryName
|
||||
|
|
@ -38,10 +46,28 @@ WantedBy=multi-user.target
|
|||
modulesLoadConfig = "vsock\nvmw_vsock_virtio_transport\n"
|
||||
)
|
||||
|
||||
var (
|
||||
portCollector = CollectPorts
|
||||
processRe = regexp.MustCompile(`"([^"]+)",pid=(\d+)`)
|
||||
)
|
||||
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type PortListener struct {
|
||||
Proto string `json:"proto"`
|
||||
BindAddress string `json:"bind_address"`
|
||||
Port int `json:"port"`
|
||||
PID int `json:"pid,omitempty"`
|
||||
Process string `json:"process,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
}
|
||||
|
||||
type PortsResponse struct {
|
||||
Listeners []PortListener `json:"listeners"`
|
||||
}
|
||||
|
||||
func NewHandler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(HealthPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -52,30 +78,24 @@ func NewHandler() http.Handler {
|
|||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(HealthResponse{Status: HealthyStatus})
|
||||
})
|
||||
mux.HandleFunc(PortsPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
listeners, err := portCollector(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(PortsResponse{Listeners: listeners})
|
||||
})
|
||||
return mux
|
||||
}
|
||||
|
||||
func Health(ctx context.Context, logger *slog.Logger, socketPath string) error {
|
||||
transport := &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return sdkvsock.DialContext(
|
||||
ctx,
|
||||
socketPath,
|
||||
Port,
|
||||
sdkvsock.WithRetryTimeout(3*time.Second),
|
||||
sdkvsock.WithRetryInterval(100*time.Millisecond),
|
||||
sdkvsock.WithLogger(newLogger(logger)),
|
||||
)
|
||||
},
|
||||
}
|
||||
defer transport.CloseIdleConnections()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://vsock"+HealthPath, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := (&http.Client{Transport: transport}).Do(req)
|
||||
resp, err := doRequest(ctx, logger, socketPath, HealthPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -97,6 +117,35 @@ func Health(ctx context.Context, logger *slog.Logger, socketPath string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func Ports(ctx context.Context, logger *slog.Logger, socketPath string) ([]PortListener, error) {
|
||||
resp, err := doRequest(ctx, logger, socketPath, PortsPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, fmt.Errorf("unexpected ports status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
var payload PortsResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload.Listeners, nil
|
||||
}
|
||||
|
||||
func CollectPorts(ctx context.Context) ([]PortListener, error) {
|
||||
cmd := exec.CommandContext(ctx, "ss", "-H", "-lntup")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
if len(output) > 0 {
|
||||
return nil, fmt.Errorf("run ss: %w: %s", err, bytes.TrimSpace(output))
|
||||
}
|
||||
return nil, fmt.Errorf("run ss: %w", err)
|
||||
}
|
||||
return parsePortListeners(output, readProcessCommandLine)
|
||||
}
|
||||
|
||||
func ServiceUnit() string {
|
||||
return serviceUnit
|
||||
}
|
||||
|
|
@ -156,3 +205,233 @@ func (h slogHook) Fire(entry *logrus.Entry) error {
|
|||
func IsServerClosed(err error) bool {
|
||||
return errors.Is(err, http.ErrServerClosed)
|
||||
}
|
||||
|
||||
func doRequest(ctx context.Context, logger *slog.Logger, socketPath, path string) (*http.Response, error) {
|
||||
transport := &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return sdkvsock.DialContext(
|
||||
ctx,
|
||||
socketPath,
|
||||
Port,
|
||||
sdkvsock.WithRetryTimeout(3*time.Second),
|
||||
sdkvsock.WithRetryInterval(100*time.Millisecond),
|
||||
sdkvsock.WithLogger(newLogger(logger)),
|
||||
)
|
||||
},
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://vsock"+path, nil)
|
||||
if err != nil {
|
||||
transport.CloseIdleConnections()
|
||||
return nil, err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
transport.CloseIdleConnections()
|
||||
return nil, err
|
||||
}
|
||||
return wrappedResponse(resp, transport), nil
|
||||
}
|
||||
|
||||
type responseCloser struct {
|
||||
io.ReadCloser
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
func (c responseCloser) Close() error {
|
||||
err := c.ReadCloser.Close()
|
||||
c.transport.CloseIdleConnections()
|
||||
return err
|
||||
}
|
||||
|
||||
func wrappedResponse(resp *http.Response, transport *http.Transport) *http.Response {
|
||||
if resp == nil || resp.Body == nil || transport == nil {
|
||||
return resp
|
||||
}
|
||||
resp.Body = responseCloser{ReadCloser: resp.Body, transport: transport}
|
||||
return resp
|
||||
}
|
||||
|
||||
func parsePortListeners(raw []byte, readCmdline func(int) string) ([]PortListener, error) {
|
||||
lines := strings.Split(strings.TrimSpace(string(raw)), "\n")
|
||||
listeners := make([]PortListener, 0, len(lines))
|
||||
wildcards := make(map[string]struct{})
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
parsed, err := parseSSLine(line, readCmdline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, listener := range parsed {
|
||||
if isLoopbackAddress(listener.BindAddress) {
|
||||
continue
|
||||
}
|
||||
if isWildcardAddress(listener.BindAddress) {
|
||||
key := wildcardKey(listener)
|
||||
if _, ok := wildcards[key]; ok {
|
||||
continue
|
||||
}
|
||||
wildcards[key] = struct{}{}
|
||||
}
|
||||
listeners = append(listeners, listener)
|
||||
}
|
||||
}
|
||||
sort.Slice(listeners, func(i, j int) bool {
|
||||
if listeners[i].Proto != listeners[j].Proto {
|
||||
return listeners[i].Proto < listeners[j].Proto
|
||||
}
|
||||
if listeners[i].Port != listeners[j].Port {
|
||||
return listeners[i].Port < listeners[j].Port
|
||||
}
|
||||
if listeners[i].PID != listeners[j].PID {
|
||||
return listeners[i].PID < listeners[j].PID
|
||||
}
|
||||
return listeners[i].BindAddress < listeners[j].BindAddress
|
||||
})
|
||||
return listeners, nil
|
||||
}
|
||||
|
||||
func parseSSLine(line string, readCmdline func(int) string) ([]PortListener, error) {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 6 {
|
||||
return nil, fmt.Errorf("parse ss line: expected at least 6 fields, got %d in %q", len(fields), line)
|
||||
}
|
||||
proto := strings.ToLower(strings.TrimSpace(fields[0]))
|
||||
if proto != "tcp" && proto != "udp" {
|
||||
return nil, nil
|
||||
}
|
||||
bindAddress, port, err := parseBindAddress(fields[4])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse ss local address %q: %w", fields[4], err)
|
||||
}
|
||||
if bindAddress != "*" && net.ParseIP(bindAddress) == nil {
|
||||
return nil, nil
|
||||
}
|
||||
processInfo := strings.Join(fields[6:], " ")
|
||||
entries := parseProcessEntries(processInfo)
|
||||
if len(entries) == 0 {
|
||||
return []PortListener{{
|
||||
Proto: proto,
|
||||
BindAddress: bindAddress,
|
||||
Port: port,
|
||||
}}, nil
|
||||
}
|
||||
listeners := make([]PortListener, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
command := strings.TrimSpace(readCmdline(entry.PID))
|
||||
if command == "" {
|
||||
command = entry.Process
|
||||
}
|
||||
listeners = append(listeners, PortListener{
|
||||
Proto: proto,
|
||||
BindAddress: bindAddress,
|
||||
Port: port,
|
||||
PID: entry.PID,
|
||||
Process: entry.Process,
|
||||
Command: command,
|
||||
})
|
||||
}
|
||||
return listeners, nil
|
||||
}
|
||||
|
||||
type processEntry struct {
|
||||
Process string
|
||||
PID int
|
||||
}
|
||||
|
||||
func parseProcessEntries(raw string) []processEntry {
|
||||
matches := processRe.FindAllStringSubmatch(raw, -1)
|
||||
if len(matches) == 0 {
|
||||
return nil
|
||||
}
|
||||
entries := make([]processEntry, 0, len(matches))
|
||||
for _, match := range matches {
|
||||
pid, err := strconv.Atoi(match[2])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, processEntry{
|
||||
Process: match[1],
|
||||
PID: pid,
|
||||
})
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func parseBindAddress(raw string) (string, int, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", 0, errors.New("empty address")
|
||||
}
|
||||
var host, portRaw string
|
||||
if strings.HasPrefix(raw, "[") {
|
||||
var err error
|
||||
host, portRaw, err = net.SplitHostPort(raw)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
} else {
|
||||
idx := strings.LastIndex(raw, ":")
|
||||
if idx <= 0 || idx == len(raw)-1 {
|
||||
return "", 0, fmt.Errorf("missing host:port in %q", raw)
|
||||
}
|
||||
host = raw[:idx]
|
||||
portRaw = raw[idx+1:]
|
||||
}
|
||||
if zoneIdx := strings.Index(host, "%"); zoneIdx >= 0 {
|
||||
host = host[:zoneIdx]
|
||||
}
|
||||
host = strings.Trim(host, "[]")
|
||||
if host == "" {
|
||||
host = "*"
|
||||
}
|
||||
port, err := strconv.Atoi(portRaw)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
return host, port, nil
|
||||
}
|
||||
|
||||
func readProcessCommandLine(pid int) string {
|
||||
if pid <= 0 {
|
||||
return ""
|
||||
}
|
||||
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(string(data), "\x00")
|
||||
filtered := parts[:0]
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, part)
|
||||
}
|
||||
return strings.Join(filtered, " ")
|
||||
}
|
||||
|
||||
func isLoopbackAddress(host string) bool {
|
||||
if host == "" || host == "*" {
|
||||
return false
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
func isWildcardAddress(host string) bool {
|
||||
switch host {
|
||||
case "*", "0.0.0.0", "::":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func wildcardKey(listener PortListener) string {
|
||||
return fmt.Sprintf("%s:%d:%d:%s", listener.Proto, listener.Port, listener.PID, listener.Process)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,9 +4,12 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -37,6 +40,44 @@ func TestNewHandlerHealthz(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNewHandlerPorts(t *testing.T) {
|
||||
origCollector := portCollector
|
||||
t.Cleanup(func() {
|
||||
portCollector = origCollector
|
||||
})
|
||||
portCollector = func(context.Context) ([]PortListener, error) {
|
||||
return []PortListener{{
|
||||
Proto: "tcp",
|
||||
BindAddress: "0.0.0.0",
|
||||
Port: 8080,
|
||||
PID: 42,
|
||||
Process: "python3",
|
||||
Command: "python3 -m http.server 8080",
|
||||
}}, nil
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, PortsPath, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
rr := newTestResponseRecorder()
|
||||
NewHandler().ServeHTTP(rr, req)
|
||||
|
||||
if rr.status != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rr.status, http.StatusOK)
|
||||
}
|
||||
var payload PortsResponse
|
||||
if err := json.Unmarshal(rr.body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("Unmarshal: %v", err)
|
||||
}
|
||||
if len(payload.Listeners) != 1 {
|
||||
t.Fatalf("listeners = %d, want 1", len(payload.Listeners))
|
||||
}
|
||||
if got := payload.Listeners[0]; got.Command != "python3 -m http.server 8080" {
|
||||
t.Fatalf("listener = %+v, want command", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
@ -110,6 +151,168 @@ func TestHealth(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPorts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
socketPath := filepath.Join(dir, "fc.vsock")
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Listen: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
buf := make([]byte, 0, 256)
|
||||
tmp := make([]byte, 256)
|
||||
for {
|
||||
n, err := conn.Read(tmp)
|
||||
if err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
buf = append(buf, tmp[:n]...)
|
||||
if strings.Contains(string(buf), "\n") {
|
||||
break
|
||||
}
|
||||
}
|
||||
if got := string(buf); got != "CONNECT 42070\n" {
|
||||
done <- unexpectedStringError(got)
|
||||
return
|
||||
}
|
||||
if _, err := conn.Write([]byte("OK 55\n")); err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
|
||||
buf = buf[:0]
|
||||
for {
|
||||
n, err := conn.Read(tmp)
|
||||
if err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
buf = append(buf, tmp[:n]...)
|
||||
if strings.Contains(string(buf), "\r\n\r\n") {
|
||||
break
|
||||
}
|
||||
}
|
||||
req := string(buf)
|
||||
if !strings.Contains(req, "GET /ports HTTP/1.1\r\n") {
|
||||
done <- unexpectedStringError(req)
|
||||
return
|
||||
}
|
||||
body := `{"listeners":[{"proto":"tcp","bind_address":"0.0.0.0","port":8080,"pid":42,"process":"python3","command":"python3 -m http.server 8080"}]}`
|
||||
resp := "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: " + strconv.Itoa(len(body)) + "\r\n\r\n" + body
|
||||
_, err = conn.Write([]byte(resp))
|
||||
done <- err
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
listeners, err := Ports(ctx, nil, socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Ports: %v", err)
|
||||
}
|
||||
if len(listeners) != 1 || listeners[0].Port != 8080 || listeners[0].Command != "python3 -m http.server 8080" {
|
||||
t.Fatalf("listeners = %+v, want parsed port listener", listeners)
|
||||
}
|
||||
if err := <-done; err != nil {
|
||||
t.Fatalf("server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePortListenersFiltersLoopbackAndDedupesWildcards(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
raw := strings.Join([]string{
|
||||
`tcp LISTEN 0 4096 0.0.0.0:22 0.0.0.0:* users:(("sshd",pid=12,fd=3))`,
|
||||
`tcp LISTEN 0 4096 [::]:22 [::]:* users:(("sshd",pid=12,fd=4))`,
|
||||
`udp UNCONN 0 0 127.0.0.53%lo:53 0.0.0.0:* users:(("stubby",pid=99,fd=3))`,
|
||||
`tcp LISTEN 0 4096 172.16.0.2:8080 0.0.0.0:* users:(("python3",pid=44,fd=6))`,
|
||||
}, "\n")
|
||||
readCmdline := func(pid int) string {
|
||||
switch pid {
|
||||
case 12:
|
||||
return "/usr/sbin/sshd -D"
|
||||
case 44:
|
||||
return "python3 -m http.server 8080"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
listeners, err := parsePortListeners([]byte(raw), readCmdline)
|
||||
if err != nil {
|
||||
t.Fatalf("parsePortListeners: %v", err)
|
||||
}
|
||||
want := []PortListener{
|
||||
{
|
||||
Proto: "tcp",
|
||||
BindAddress: "0.0.0.0",
|
||||
Port: 22,
|
||||
PID: 12,
|
||||
Process: "sshd",
|
||||
Command: "/usr/sbin/sshd -D",
|
||||
},
|
||||
{
|
||||
Proto: "tcp",
|
||||
BindAddress: "172.16.0.2",
|
||||
Port: 8080,
|
||||
PID: 44,
|
||||
Process: "python3",
|
||||
Command: "python3 -m http.server 8080",
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(listeners, want) {
|
||||
t.Fatalf("listeners = %#v, want %#v", listeners, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePortListenersFallsBackToProcessName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
raw := `tcp LISTEN 0 128 0.0.0.0:5432 0.0.0.0:* users:(("postgres",pid=77,fd=5))`
|
||||
listeners, err := parsePortListeners([]byte(raw), func(int) string { return "" })
|
||||
if err != nil {
|
||||
t.Fatalf("parsePortListeners: %v", err)
|
||||
}
|
||||
if len(listeners) != 1 {
|
||||
t.Fatalf("listeners = %d, want 1", len(listeners))
|
||||
}
|
||||
if listeners[0].Command != "postgres" {
|
||||
t.Fatalf("command = %q, want process fallback", listeners[0].Command)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHandlerPortsReturnsServerErrorOnCollectorFailure(t *testing.T) {
|
||||
origCollector := portCollector
|
||||
t.Cleanup(func() {
|
||||
portCollector = origCollector
|
||||
})
|
||||
portCollector = func(context.Context) ([]PortListener, error) {
|
||||
return nil, errors.New("ss missing")
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, PortsPath, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
rr := newTestResponseRecorder()
|
||||
NewHandler().ServeHTTP(rr, req)
|
||||
if rr.status != http.StatusInternalServerError {
|
||||
t.Fatalf("status = %d, want %d", rr.status, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
type testResponseRecorder struct {
|
||||
headers http.Header
|
||||
body bytes.Buffer
|
||||
|
|
|
|||
|
|
@ -5,5 +5,6 @@ tree
|
|||
ca-certificates
|
||||
curl
|
||||
wget
|
||||
iproute2
|
||||
vim
|
||||
tmux
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue