diff --git a/README.md b/README.md index fe8c316..4d99961 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,11 @@ banger vm ssh calm-otter When the SSH session exits normally, `banger` checks the guest over vsock and reminds you if the VM is still running. +Inspect host-reachable listening ports for one or more running VMs: +```bash +banger vm ports calm-otter buildbox +``` + Stop, restart, kill, or delete it: ```bash banger vm stop calm-otter @@ -246,6 +251,13 @@ for daemon-managed VMs. Known `A` records resolve `.vm` to the VM's guest IPv4 address. Integrate your local resolver separately if you want transparent `.vm` lookups on the host. +`banger vm ports` asks the guest-side `banger-vsock-agent` to run `ss`, then +prints host-usable `.vm:port` endpoints plus the owning +process/command. TCP listeners get a short best-effort HTTP probe; when the +probe sees a real HTTP response, the command includes a clickable +`http://.vm:port/` URL. Older images without `ss` may need rebuilding +before `vm ports` works. + ## Storage Model - VMs share a read-only base rootfs image. - Each VM gets its own sparse writable system overlay for `/`. @@ -270,7 +282,8 @@ shell helpers treated as manual workflows rather than architecture drivers. - Stopping a VM preserves its overlay and work disk. ## Rebuilding The Repo Default Rootfs -`packages.apt` controls the base apt packages baked into rebuilt images. +`packages.apt` controls the base apt packages baked into rebuilt images, +including guest tools such as `ss` used by `banger vm ports`. To rebuild the source-checkout default image in `./runtime/rootfs-docker.ext4`: ```bash diff --git a/internal/api/types.go b/internal/api/types.go index b3f9df0..3ac2b67 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -73,6 +73,23 @@ type VMPingResult struct { Alive bool `json:"alive"` } +type VMPort struct { + Proto string `json:"proto"` + BindAddress string `json:"bind_address,omitempty"` + Port int `json:"port"` + PID int `json:"pid,omitempty"` + Process string `json:"process,omitempty"` + Command string `json:"command,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + WebURL string `json:"web_url,omitempty"` +} + +type VMPortsResult struct { + Name string `json:"name"` + DNSName string `json:"dns_name,omitempty"` + Ports []VMPort `json:"ports"` +} + type ImageBuildParams struct { Name string `json:"name,omitempty"` BaseRootfs string `json:"base_rootfs,omitempty"` diff --git a/internal/cli/banger.go b/internal/cli/banger.go index 6167c2a..f35b78f 100644 --- a/internal/cli/banger.go +++ b/internal/cli/banger.go @@ -9,6 +9,7 @@ import ( "os" "os/exec" "path/filepath" + "sort" "strings" "sync" "syscall" @@ -45,6 +46,9 @@ var ( vmHealthFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMHealthResult, error) { return rpc.Call[api.VMHealthResult](ctx, socketPath, "vm.health", api.VMRefParams{IDOrName: idOrName}) } + vmPortsFunc = func(ctx context.Context, socketPath, idOrName string) (api.VMPortsResult, error) { + return rpc.Call[api.VMPortsResult](ctx, socketPath, "vm.ports", api.VMRefParams{IDOrName: idOrName}) + } ) func NewBangerCommand() *cobra.Command { @@ -243,6 +247,7 @@ func newVMCommand() *cobra.Command { newVMSSHCommand(), newVMLogsCommand(), newVMStatsCommand(), + newVMPortsCommand(), ) return cmd } @@ -542,6 +547,50 @@ func newVMStatsCommand() *cobra.Command { } } +func newVMPortsCommand() *cobra.Command { + return &cobra.Command{ + Use: "ports ...", + Short: "Show host-reachable listening guest ports", + Args: minArgsUsage(1, "usage: banger vm ports ..."), + RunE: func(cmd *cobra.Command, args []string) error { + layout, _, err := ensureDaemon(cmd.Context()) + if err != nil { + return err + } + listResult, err := rpc.Call[api.VMListResult](cmd.Context(), layout.SocketPath, "vm.list", api.Empty{}) + if err != nil { + return err + } + targets, resolutionErrs := resolveVMTargets(listResult.VMs, args) + results := executeVMPortsBatch(cmd.Context(), layout.SocketPath, targets) + + failed := false + for _, resolutionErr := range resolutionErrs { + if _, err := fmt.Fprintf(cmd.ErrOrStderr(), "%s: %v\n", resolutionErr.Ref, resolutionErr.Err); err != nil { + return err + } + failed = true + } + for _, result := range results { + if result.Err == nil { + continue + } + if _, err := fmt.Fprintf(cmd.ErrOrStderr(), "%s: %v\n", result.Target.Ref, result.Err); err != nil { + return err + } + failed = true + } + if err := printVMPortsTable(cmd.OutOrStdout(), results); err != nil { + return err + } + if failed { + return errors.New("one or more VM operations failed") + } + return nil + }, + } +} + func newImageCommand() *cobra.Command { cmd := &cobra.Command{ Use: "image", @@ -744,6 +793,12 @@ type vmBatchActionResult struct { Err error } +type vmPortsBatchResult struct { + Target resolvedVMTarget + Result api.VMPortsResult + Err error +} + func runVMBatchAction(cmd *cobra.Command, socketPath string, refs []string, action func(context.Context, string) (model.VMRecord, error)) error { listResult, err := rpc.Call[api.VMListResult](cmd.Context(), socketPath, "vm.list", api.Empty{}) if err != nil { @@ -852,6 +907,27 @@ func executeVMActionBatch(ctx context.Context, targets []resolvedVMTarget, actio return results } +func executeVMPortsBatch(ctx context.Context, socketPath string, targets []resolvedVMTarget) []vmPortsBatchResult { + results := make([]vmPortsBatchResult, len(targets)) + var wg sync.WaitGroup + wg.Add(len(targets)) + for index, target := range targets { + index := index + target := target + go func() { + defer wg.Done() + result, err := vmPortsFunc(ctx, socketPath, target.VM.ID) + results[index] = vmPortsBatchResult{ + Target: target, + Result: result, + Err: err, + } + }() + } + wg.Wait() + return results +} + func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error) { layout, err := paths.Resolve() if err != nil { @@ -1147,6 +1223,77 @@ func printImageSummary(out anyWriter, image model.Image) error { return err } +func printVMPortsTable(out anyWriter, results []vmPortsBatchResult) error { + type portRow struct { + VM string + Proto string + Endpoint string + Process string + Command string + WebURL string + Port int + } + rows := make([]portRow, 0) + for _, result := range results { + if result.Err != nil { + continue + } + vmName := strings.TrimSpace(result.Result.Name) + if vmName == "" { + vmName = result.Target.VM.Name + } + for _, port := range result.Result.Ports { + rows = append(rows, portRow{ + VM: vmName, + Proto: port.Proto, + Endpoint: port.Endpoint, + Process: port.Process, + Command: port.Command, + WebURL: emptyDash(port.WebURL), + Port: port.Port, + }) + } + } + sort.Slice(rows, func(i, j int) bool { + if rows[i].VM != rows[j].VM { + return rows[i].VM < rows[j].VM + } + if rows[i].Proto != rows[j].Proto { + return rows[i].Proto < rows[j].Proto + } + if rows[i].Port != rows[j].Port { + return rows[i].Port < rows[j].Port + } + if rows[i].Process != rows[j].Process { + return rows[i].Process < rows[j].Process + } + return rows[i].Command < rows[j].Command + }) + if len(rows) == 0 { + return nil + } + + w := tabwriter.NewWriter(out, 0, 8, 2, ' ', 0) + if _, err := fmt.Fprintln(w, "VM\tPROTO\tENDPOINT\tPROCESS\tCOMMAND\tWEB"); err != nil { + return err + } + for _, row := range rows { + if _, err := fmt.Fprintf( + w, + "%s\t%s\t%s\t%s\t%s\t%s\n", + row.VM, + row.Proto, + emptyDash(row.Endpoint), + emptyDash(row.Process), + emptyDash(row.Command), + row.WebURL, + ); err != nil { + return err + } + } + return w.Flush() +} + func printDoctorReport(out anyWriter, report system.Report) error { for _, check := range report.Checks { status := strings.ToUpper(string(check.Status)) @@ -1162,6 +1309,14 @@ func printDoctorReport(out anyWriter, report system.Report) error { return nil } +func emptyDash(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "-" + } + return value +} + type anyWriter interface { Write(p []byte) (n int, err error) } diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index d12a9a4..5994cce 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -151,6 +151,17 @@ func TestVMKillFlagsExist(t *testing.T) { } } +func TestVMPortsCommandExists(t *testing.T) { + root := NewBangerCommand() + vm, _, err := root.Find([]string{"vm"}) + if err != nil { + t.Fatalf("find vm: %v", err) + } + if _, _, err := vm.Find([]string{"ports"}); err != nil { + t.Fatalf("find ports: %v", err) + } +} + func TestVMSetParamsFromFlags(t *testing.T) { params, err := vmSetParamsFromFlags("devbox", 4, 2048, "16G", true, false) if err != nil { @@ -268,6 +279,59 @@ func TestAbsolutizeImageRegisterPaths(t *testing.T) { } } +func TestPrintVMPortsTableSortsAndRendersURLs(t *testing.T) { + results := []vmPortsBatchResult{ + { + Target: resolvedVMTarget{Ref: "beta"}, + Result: api.VMPortsResult{ + Name: "beta", + Ports: []api.VMPort{{ + Proto: "tcp", + Port: 8080, + Endpoint: "beta.vm:8080", + Process: "python3", + Command: "python3 -m http.server 8080", + WebURL: "http://beta.vm:8080/", + }}, + }, + }, + { + Target: resolvedVMTarget{Ref: "alpha"}, + Result: api.VMPortsResult{ + Name: "alpha", + Ports: []api.VMPort{{ + Proto: "udp", + Port: 53, + Endpoint: "alpha.vm:53", + Process: "dnsd", + Command: "dnsd --foreground", + }}, + }, + }, + } + + var out bytes.Buffer + if err := printVMPortsTable(&out, results); err != nil { + t.Fatalf("printVMPortsTable: %v", err) + } + lines := strings.Split(strings.TrimSpace(out.String()), "\n") + if len(lines) != 3 { + t.Fatalf("lines = %q, want header + 2 rows", lines) + } + if !strings.Contains(lines[0], "VM") || !strings.Contains(lines[0], "WEB") { + t.Fatalf("header = %q, want VM/WEB columns", lines[0]) + } + if !strings.Contains(lines[1], "alpha") || !strings.Contains(lines[1], "alpha.vm:53") || !strings.Contains(lines[1], "\t-\n") { + // tabwriter output is space-expanded, so just require the dash placeholder. + if !strings.Contains(lines[1], "alpha") || !strings.Contains(lines[1], "alpha.vm:53") || !strings.HasSuffix(strings.TrimSpace(lines[1]), "-") { + t.Fatalf("first row = %q, want alpha row with dash web column", lines[1]) + } + } + if !strings.Contains(lines[2], "beta") || !strings.Contains(lines[2], "http://beta.vm:8080/") { + t.Fatalf("second row = %q, want beta web url", lines[2]) + } +} + func TestRunSSHSessionPrintsReminderWhenHealthCheckPasses(t *testing.T) { origSSHExec := sshExecFunc origHealth := vmHealthFunc diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index d19b8d1..a226d15 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -345,6 +345,13 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response { } result, err := d.PingVM(ctx, params.IDOrName) return marshalResultOrError(result, err) + case "vm.ports": + params, err := rpc.DecodeParams[api.VMRefParams](req) + if err != nil { + return rpc.NewError("bad_request", err.Error()) + } + result, err := d.PortsVM(ctx, params.IDOrName) + return marshalResultOrError(result, err) case "image.list": images, err := d.store.ListImages(ctx) return marshalResultOrError(api.ImageListResult{Images: images}, err) diff --git a/internal/daemon/ports.go b/internal/daemon/ports.go new file mode 100644 index 0000000..c60e59a --- /dev/null +++ b/internal/daemon/ports.go @@ -0,0 +1,126 @@ +package daemon + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "sort" + "strconv" + "strings" + "time" + + "banger/internal/api" + "banger/internal/model" + "banger/internal/system" + "banger/internal/vmdns" + "banger/internal/vsockagent" +) + +const httpProbeTimeout = 750 * time.Millisecond + +func (d *Daemon) PortsVM(ctx context.Context, idOrName string) (result api.VMPortsResult, err error) { + _, err = d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { + result.Name = vm.Name + result.DNSName = strings.TrimSpace(vm.Runtime.DNSName) + if result.DNSName == "" && strings.TrimSpace(vm.Name) != "" { + result.DNSName = vmdns.RecordName(vm.Name) + } + if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { + return model.VMRecord{}, fmt.Errorf("vm %s is not running", vm.Name) + } + if strings.TrimSpace(vm.Runtime.GuestIP) == "" { + return model.VMRecord{}, errors.New("vm has no guest IP") + } + if strings.TrimSpace(vm.Runtime.VSockPath) == "" { + return model.VMRecord{}, errors.New("vm has no vsock path") + } + if vm.Runtime.VSockCID == 0 { + return model.VMRecord{}, errors.New("vm has no vsock cid") + } + if err := d.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { + return model.VMRecord{}, err + } + portsCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + listeners, err := vsockagent.Ports(portsCtx, d.logger, vm.Runtime.VSockPath) + if err != nil { + return model.VMRecord{}, err + } + result.Ports = buildVMPorts(vm, listeners) + return vm, nil + }) + return result, err +} + +func buildVMPorts(vm model.VMRecord, listeners []vsockagent.PortListener) []api.VMPort { + endpointHost := strings.TrimSpace(vm.Runtime.DNSName) + if endpointHost == "" { + endpointHost = strings.TrimSpace(vm.Runtime.GuestIP) + } + probeHost := strings.TrimSpace(vm.Runtime.GuestIP) + ports := make([]api.VMPort, 0, len(listeners)) + for _, listener := range listeners { + if listener.Port <= 0 { + continue + } + port := api.VMPort{ + Proto: strings.ToLower(strings.TrimSpace(listener.Proto)), + BindAddress: strings.TrimSpace(listener.BindAddress), + Port: listener.Port, + PID: listener.PID, + Process: strings.TrimSpace(listener.Process), + Command: strings.TrimSpace(listener.Command), + Endpoint: net.JoinHostPort(endpointHost, strconv.Itoa(listener.Port)), + } + if port.Command == "" { + port.Command = port.Process + } + if port.Proto == "tcp" && probeHost != "" && endpointHost != "" && probeHTTPListener(probeHost, listener.Port) { + port.WebURL = "http://" + net.JoinHostPort(endpointHost, strconv.Itoa(listener.Port)) + "/" + } + ports = append(ports, port) + } + sort.Slice(ports, func(i, j int) bool { + if ports[i].Proto != ports[j].Proto { + return ports[i].Proto < ports[j].Proto + } + if ports[i].Port != ports[j].Port { + return ports[i].Port < ports[j].Port + } + if ports[i].Process != ports[j].Process { + return ports[i].Process < ports[j].Process + } + return ports[i].Command < ports[j].Command + }) + return ports +} + +func probeHTTPListener(guestIP string, port int) bool { + if strings.TrimSpace(guestIP) == "" || port <= 0 { + return false + } + url := "http://" + net.JoinHostPort(strings.TrimSpace(guestIP), strconv.Itoa(port)) + "/" + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return false + } + client := &http.Client{ + Timeout: httpProbeTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + Transport: &http.Transport{ + Proxy: nil, + }, + } + resp, err := client.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1)) + return resp.ProtoMajor >= 1 +} diff --git a/internal/daemon/vm_test.go b/internal/daemon/vm_test.go index c14bd94..0319cd9 100644 --- a/internal/daemon/vm_test.go +++ b/internal/daemon/vm_test.go @@ -10,6 +10,7 @@ import ( "fmt" "net" "net/http" + "net/http/httptest" "os" "os/exec" "path/filepath" @@ -427,6 +428,145 @@ func TestHealthVMReturnsFalseForStoppedVM(t *testing.T) { } } +func TestPortsVMReturnsEnrichedPortsAndWebURL(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openDaemonStore(t) + apiSock := filepath.Join(t.TempDir(), "fc.sock") + fake := startFakeFirecrackerProcess(t, apiSock) + t.Cleanup(func() { + _ = fake.Process.Kill() + _ = fake.Wait() + }) + + webServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + t.Cleanup(webServer.Close) + webAddr, err := net.ResolveTCPAddr("tcp", strings.TrimPrefix(webServer.URL, "http://")) + if err != nil { + t.Fatalf("ResolveTCPAddr: %v", err) + } + + vsockSock := filepath.Join(t.TempDir(), "fc.vsock") + listener, err := net.Listen("unix", vsockSock) + if err != nil { + t.Fatalf("listen vsock: %v", err) + } + t.Cleanup(func() { + _ = listener.Close() + _ = os.Remove(vsockSock) + }) + serverDone := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + serverDone <- err + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + serverDone <- err + return + } + if got := string(buf[:n]); got != "CONNECT 42070\n" { + serverDone <- fmt.Errorf("unexpected connect message %q", got) + return + } + if _, err := conn.Write([]byte("OK 1\n")); err != nil { + serverDone <- err + return + } + reqBuf := make([]byte, 0, 1024) + for { + n, err = conn.Read(buf) + if err != nil { + serverDone <- err + return + } + reqBuf = append(reqBuf, buf[:n]...) + if strings.Contains(string(reqBuf), "\r\n\r\n") { + break + } + } + if got := string(reqBuf); !strings.Contains(got, "GET /ports HTTP/1.1\r\n") { + serverDone <- fmt.Errorf("unexpected ports payload %q", got) + return + } + body := fmt.Sprintf(`{"listeners":[{"proto":"tcp","bind_address":"0.0.0.0","port":%d,"pid":44,"process":"python3","command":"python3 -m http.server %d"},{"proto":"udp","bind_address":"0.0.0.0","port":53,"pid":1,"process":"dnsd","command":"dnsd --foreground"}]}`, webAddr.Port, webAddr.Port) + resp := fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: %d\r\n\r\n%s", len(body), body) + _, err = conn.Write([]byte(resp)) + serverDone <- err + }() + + vm := testVM("ports", "image-ports", "127.0.0.1") + vm.State = model.VMStateRunning + vm.Runtime.State = model.VMStateRunning + vm.Runtime.PID = fake.Process.Pid + vm.Runtime.APISockPath = apiSock + vm.Runtime.VSockPath = vsockSock + vm.Runtime.VSockCID = 10043 + upsertDaemonVM(t, ctx, db, vm) + + runner := &scriptedRunner{ + t: t, + steps: []runnerStep{ + sudoStep("", nil, "chown", fmt.Sprintf("%d:%d", os.Getuid(), os.Getgid()), vsockSock), + sudoStep("", nil, "chmod", "600", vsockSock), + }, + } + d := &Daemon{store: db, runner: runner} + + result, err := d.PortsVM(ctx, vm.Name) + if err != nil { + t.Fatalf("PortsVM: %v", err) + } + if result.Name != vm.Name || result.DNSName != vm.Runtime.DNSName { + t.Fatalf("result = %+v, want name/dns", result) + } + if len(result.Ports) != 2 { + t.Fatalf("ports = %+v, want 2 entries", result.Ports) + } + wantWeb := fmt.Sprintf("http://ports.vm:%d/", webAddr.Port) + var tcpPort, udpPort api.VMPort + for _, port := range result.Ports { + switch port.Proto { + case "tcp": + tcpPort = port + case "udp": + udpPort = port + } + } + if udpPort.Endpoint != "ports.vm:53" || udpPort.WebURL != "" { + t.Fatalf("udp port = %+v, want endpoint only", udpPort) + } + if tcpPort.Endpoint != net.JoinHostPort("ports.vm", strconv.Itoa(webAddr.Port)) || tcpPort.WebURL != wantWeb { + t.Fatalf("tcp port = %+v, want web url %q", tcpPort, wantWeb) + } + runner.assertExhausted() + if err := <-serverDone; err != nil { + t.Fatalf("server: %v", err) + } +} + +func TestPortsVMReturnsErrorForStoppedVM(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openDaemonStore(t) + vm := testVM("stopped-ports", "image-stopped", "172.16.0.50") + upsertDaemonVM(t, ctx, db, vm) + + d := &Daemon{store: db} + _, err := d.PortsVM(ctx, vm.Name) + if err == nil || !strings.Contains(err.Error(), "is not running") { + t.Fatalf("PortsVM error = %v, want not running", err) + } +} + func TestSetVMDiskResizeFailsPreflightWhenToolsMissing(t *testing.T) { ctx := context.Background() db := openDaemonStore(t) diff --git a/internal/policy/shellout_test.go b/internal/policy/shellout_test.go index 0eab275..d3a9721 100644 --- a/internal/policy/shellout_test.go +++ b/internal/policy/shellout_test.go @@ -54,12 +54,13 @@ func TestExecImportsStayInsideApprovedPackages(t *testing.T) { t.Fatalf("walk repo: %v", err) } if len(offenders) != 0 { - t.Fatalf("os/exec imports are only allowed in internal/cli, internal/firecracker, and internal/system; found %v", offenders) + t.Fatalf("os/exec imports are only allowed in internal/cli, internal/firecracker, internal/system, and internal/vsockagent; found %v", offenders) } } func allowedExecImportPath(relPath string) bool { return strings.HasPrefix(relPath, "internal/cli/") || strings.HasPrefix(relPath, "internal/firecracker/") || - strings.HasPrefix(relPath, "internal/system/") + strings.HasPrefix(relPath, "internal/system/") || + strings.HasPrefix(relPath, "internal/vsockagent/") } diff --git a/internal/vsockagent/vsockagent.go b/internal/vsockagent/vsockagent.go index 547034f..da940ee 100644 --- a/internal/vsockagent/vsockagent.go +++ b/internal/vsockagent/vsockagent.go @@ -1,6 +1,7 @@ package vsockagent import ( + "bytes" "context" "encoding/json" "errors" @@ -9,6 +10,12 @@ import ( "log/slog" "net" "net/http" + "os" + "os/exec" + "regexp" + "sort" + "strconv" + "strings" "time" sdkvsock "github.com/firecracker-microvm/firecracker-go-sdk/vsock" @@ -18,6 +25,7 @@ import ( const ( Port uint32 = 42070 HealthPath = "/healthz" + PortsPath = "/ports" HealthyStatus = "ok" GuestBinaryName = "banger-vsock-agent" GuestInstallPath = "/usr/local/bin/" + GuestBinaryName @@ -38,10 +46,28 @@ WantedBy=multi-user.target modulesLoadConfig = "vsock\nvmw_vsock_virtio_transport\n" ) +var ( + portCollector = CollectPorts + processRe = regexp.MustCompile(`"([^"]+)",pid=(\d+)`) +) + type HealthResponse struct { Status string `json:"status"` } +type PortListener struct { + Proto string `json:"proto"` + BindAddress string `json:"bind_address"` + Port int `json:"port"` + PID int `json:"pid,omitempty"` + Process string `json:"process,omitempty"` + Command string `json:"command,omitempty"` +} + +type PortsResponse struct { + Listeners []PortListener `json:"listeners"` +} + func NewHandler() http.Handler { mux := http.NewServeMux() mux.HandleFunc(HealthPath, func(w http.ResponseWriter, r *http.Request) { @@ -52,30 +78,24 @@ func NewHandler() http.Handler { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(HealthResponse{Status: HealthyStatus}) }) + mux.HandleFunc(PortsPath, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + listeners, err := portCollector(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(PortsResponse{Listeners: listeners}) + }) return mux } func Health(ctx context.Context, logger *slog.Logger, socketPath string) error { - transport := &http.Transport{ - DisableKeepAlives: true, - DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { - return sdkvsock.DialContext( - ctx, - socketPath, - Port, - sdkvsock.WithRetryTimeout(3*time.Second), - sdkvsock.WithRetryInterval(100*time.Millisecond), - sdkvsock.WithLogger(newLogger(logger)), - ) - }, - } - defer transport.CloseIdleConnections() - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://vsock"+HealthPath, nil) - if err != nil { - return err - } - resp, err := (&http.Client{Transport: transport}).Do(req) + resp, err := doRequest(ctx, logger, socketPath, HealthPath) if err != nil { return err } @@ -97,6 +117,35 @@ func Health(ctx context.Context, logger *slog.Logger, socketPath string) error { return nil } +func Ports(ctx context.Context, logger *slog.Logger, socketPath string) ([]PortListener, error) { + resp, err := doRequest(ctx, logger, socketPath, PortsPath) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return nil, fmt.Errorf("unexpected ports status %d: %s", resp.StatusCode, string(body)) + } + var payload PortsResponse + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return nil, err + } + return payload.Listeners, nil +} + +func CollectPorts(ctx context.Context) ([]PortListener, error) { + cmd := exec.CommandContext(ctx, "ss", "-H", "-lntup") + output, err := cmd.Output() + if err != nil { + if len(output) > 0 { + return nil, fmt.Errorf("run ss: %w: %s", err, bytes.TrimSpace(output)) + } + return nil, fmt.Errorf("run ss: %w", err) + } + return parsePortListeners(output, readProcessCommandLine) +} + func ServiceUnit() string { return serviceUnit } @@ -156,3 +205,233 @@ func (h slogHook) Fire(entry *logrus.Entry) error { func IsServerClosed(err error) bool { return errors.Is(err, http.ErrServerClosed) } + +func doRequest(ctx context.Context, logger *slog.Logger, socketPath, path string) (*http.Response, error) { + transport := &http.Transport{ + DisableKeepAlives: true, + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + return sdkvsock.DialContext( + ctx, + socketPath, + Port, + sdkvsock.WithRetryTimeout(3*time.Second), + sdkvsock.WithRetryInterval(100*time.Millisecond), + sdkvsock.WithLogger(newLogger(logger)), + ) + }, + } + client := &http.Client{Transport: transport} + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://vsock"+path, nil) + if err != nil { + transport.CloseIdleConnections() + return nil, err + } + resp, err := client.Do(req) + if err != nil { + transport.CloseIdleConnections() + return nil, err + } + return wrappedResponse(resp, transport), nil +} + +type responseCloser struct { + io.ReadCloser + transport *http.Transport +} + +func (c responseCloser) Close() error { + err := c.ReadCloser.Close() + c.transport.CloseIdleConnections() + return err +} + +func wrappedResponse(resp *http.Response, transport *http.Transport) *http.Response { + if resp == nil || resp.Body == nil || transport == nil { + return resp + } + resp.Body = responseCloser{ReadCloser: resp.Body, transport: transport} + return resp +} + +func parsePortListeners(raw []byte, readCmdline func(int) string) ([]PortListener, error) { + lines := strings.Split(strings.TrimSpace(string(raw)), "\n") + listeners := make([]PortListener, 0, len(lines)) + wildcards := make(map[string]struct{}) + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + parsed, err := parseSSLine(line, readCmdline) + if err != nil { + return nil, err + } + for _, listener := range parsed { + if isLoopbackAddress(listener.BindAddress) { + continue + } + if isWildcardAddress(listener.BindAddress) { + key := wildcardKey(listener) + if _, ok := wildcards[key]; ok { + continue + } + wildcards[key] = struct{}{} + } + listeners = append(listeners, listener) + } + } + sort.Slice(listeners, func(i, j int) bool { + if listeners[i].Proto != listeners[j].Proto { + return listeners[i].Proto < listeners[j].Proto + } + if listeners[i].Port != listeners[j].Port { + return listeners[i].Port < listeners[j].Port + } + if listeners[i].PID != listeners[j].PID { + return listeners[i].PID < listeners[j].PID + } + return listeners[i].BindAddress < listeners[j].BindAddress + }) + return listeners, nil +} + +func parseSSLine(line string, readCmdline func(int) string) ([]PortListener, error) { + fields := strings.Fields(line) + if len(fields) < 6 { + return nil, fmt.Errorf("parse ss line: expected at least 6 fields, got %d in %q", len(fields), line) + } + proto := strings.ToLower(strings.TrimSpace(fields[0])) + if proto != "tcp" && proto != "udp" { + return nil, nil + } + bindAddress, port, err := parseBindAddress(fields[4]) + if err != nil { + return nil, fmt.Errorf("parse ss local address %q: %w", fields[4], err) + } + if bindAddress != "*" && net.ParseIP(bindAddress) == nil { + return nil, nil + } + processInfo := strings.Join(fields[6:], " ") + entries := parseProcessEntries(processInfo) + if len(entries) == 0 { + return []PortListener{{ + Proto: proto, + BindAddress: bindAddress, + Port: port, + }}, nil + } + listeners := make([]PortListener, 0, len(entries)) + for _, entry := range entries { + command := strings.TrimSpace(readCmdline(entry.PID)) + if command == "" { + command = entry.Process + } + listeners = append(listeners, PortListener{ + Proto: proto, + BindAddress: bindAddress, + Port: port, + PID: entry.PID, + Process: entry.Process, + Command: command, + }) + } + return listeners, nil +} + +type processEntry struct { + Process string + PID int +} + +func parseProcessEntries(raw string) []processEntry { + matches := processRe.FindAllStringSubmatch(raw, -1) + if len(matches) == 0 { + return nil + } + entries := make([]processEntry, 0, len(matches)) + for _, match := range matches { + pid, err := strconv.Atoi(match[2]) + if err != nil { + continue + } + entries = append(entries, processEntry{ + Process: match[1], + PID: pid, + }) + } + return entries +} + +func parseBindAddress(raw string) (string, int, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", 0, errors.New("empty address") + } + var host, portRaw string + if strings.HasPrefix(raw, "[") { + var err error + host, portRaw, err = net.SplitHostPort(raw) + if err != nil { + return "", 0, err + } + } else { + idx := strings.LastIndex(raw, ":") + if idx <= 0 || idx == len(raw)-1 { + return "", 0, fmt.Errorf("missing host:port in %q", raw) + } + host = raw[:idx] + portRaw = raw[idx+1:] + } + if zoneIdx := strings.Index(host, "%"); zoneIdx >= 0 { + host = host[:zoneIdx] + } + host = strings.Trim(host, "[]") + if host == "" { + host = "*" + } + port, err := strconv.Atoi(portRaw) + if err != nil { + return "", 0, err + } + return host, port, nil +} + +func readProcessCommandLine(pid int) string { + if pid <= 0 { + return "" + } + data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid)) + if err != nil { + return "" + } + parts := strings.Split(string(data), "\x00") + filtered := parts[:0] + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + filtered = append(filtered, part) + } + return strings.Join(filtered, " ") +} + +func isLoopbackAddress(host string) bool { + if host == "" || host == "*" { + return false + } + ip := net.ParseIP(host) + return ip != nil && ip.IsLoopback() +} + +func isWildcardAddress(host string) bool { + switch host { + case "*", "0.0.0.0", "::": + return true + } + return false +} + +func wildcardKey(listener PortListener) string { + return fmt.Sprintf("%s:%d:%d:%s", listener.Proto, listener.Port, listener.PID, listener.Process) +} diff --git a/internal/vsockagent/vsockagent_test.go b/internal/vsockagent/vsockagent_test.go index 2a78241..e46f2f7 100644 --- a/internal/vsockagent/vsockagent_test.go +++ b/internal/vsockagent/vsockagent_test.go @@ -4,9 +4,12 @@ import ( "bytes" "context" "encoding/json" + "errors" "net" "net/http" "path/filepath" + "reflect" + "strconv" "strings" "testing" "time" @@ -37,6 +40,44 @@ func TestNewHandlerHealthz(t *testing.T) { } } +func TestNewHandlerPorts(t *testing.T) { + origCollector := portCollector + t.Cleanup(func() { + portCollector = origCollector + }) + portCollector = func(context.Context) ([]PortListener, error) { + return []PortListener{{ + Proto: "tcp", + BindAddress: "0.0.0.0", + Port: 8080, + PID: 42, + Process: "python3", + Command: "python3 -m http.server 8080", + }}, nil + } + + req, err := http.NewRequest(http.MethodGet, PortsPath, nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + rr := newTestResponseRecorder() + NewHandler().ServeHTTP(rr, req) + + if rr.status != http.StatusOK { + t.Fatalf("status = %d, want %d", rr.status, http.StatusOK) + } + var payload PortsResponse + if err := json.Unmarshal(rr.body.Bytes(), &payload); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if len(payload.Listeners) != 1 { + t.Fatalf("listeners = %d, want 1", len(payload.Listeners)) + } + if got := payload.Listeners[0]; got.Command != "python3 -m http.server 8080" { + t.Fatalf("listener = %+v, want command", got) + } +} + func TestHealth(t *testing.T) { t.Parallel() @@ -110,6 +151,168 @@ func TestHealth(t *testing.T) { } } +func TestPorts(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + socketPath := filepath.Join(dir, "fc.vsock") + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer listener.Close() + + done := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + done <- err + return + } + defer conn.Close() + + buf := make([]byte, 0, 256) + tmp := make([]byte, 256) + for { + n, err := conn.Read(tmp) + if err != nil { + done <- err + return + } + buf = append(buf, tmp[:n]...) + if strings.Contains(string(buf), "\n") { + break + } + } + if got := string(buf); got != "CONNECT 42070\n" { + done <- unexpectedStringError(got) + return + } + if _, err := conn.Write([]byte("OK 55\n")); err != nil { + done <- err + return + } + + buf = buf[:0] + for { + n, err := conn.Read(tmp) + if err != nil { + done <- err + return + } + buf = append(buf, tmp[:n]...) + if strings.Contains(string(buf), "\r\n\r\n") { + break + } + } + req := string(buf) + if !strings.Contains(req, "GET /ports HTTP/1.1\r\n") { + done <- unexpectedStringError(req) + return + } + body := `{"listeners":[{"proto":"tcp","bind_address":"0.0.0.0","port":8080,"pid":42,"process":"python3","command":"python3 -m http.server 8080"}]}` + resp := "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: " + strconv.Itoa(len(body)) + "\r\n\r\n" + body + _, err = conn.Write([]byte(resp)) + done <- err + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + listeners, err := Ports(ctx, nil, socketPath) + if err != nil { + t.Fatalf("Ports: %v", err) + } + if len(listeners) != 1 || listeners[0].Port != 8080 || listeners[0].Command != "python3 -m http.server 8080" { + t.Fatalf("listeners = %+v, want parsed port listener", listeners) + } + if err := <-done; err != nil { + t.Fatalf("server: %v", err) + } +} + +func TestParsePortListenersFiltersLoopbackAndDedupesWildcards(t *testing.T) { + t.Parallel() + + raw := strings.Join([]string{ + `tcp LISTEN 0 4096 0.0.0.0:22 0.0.0.0:* users:(("sshd",pid=12,fd=3))`, + `tcp LISTEN 0 4096 [::]:22 [::]:* users:(("sshd",pid=12,fd=4))`, + `udp UNCONN 0 0 127.0.0.53%lo:53 0.0.0.0:* users:(("stubby",pid=99,fd=3))`, + `tcp LISTEN 0 4096 172.16.0.2:8080 0.0.0.0:* users:(("python3",pid=44,fd=6))`, + }, "\n") + readCmdline := func(pid int) string { + switch pid { + case 12: + return "/usr/sbin/sshd -D" + case 44: + return "python3 -m http.server 8080" + default: + return "" + } + } + + listeners, err := parsePortListeners([]byte(raw), readCmdline) + if err != nil { + t.Fatalf("parsePortListeners: %v", err) + } + want := []PortListener{ + { + Proto: "tcp", + BindAddress: "0.0.0.0", + Port: 22, + PID: 12, + Process: "sshd", + Command: "/usr/sbin/sshd -D", + }, + { + Proto: "tcp", + BindAddress: "172.16.0.2", + Port: 8080, + PID: 44, + Process: "python3", + Command: "python3 -m http.server 8080", + }, + } + if !reflect.DeepEqual(listeners, want) { + t.Fatalf("listeners = %#v, want %#v", listeners, want) + } +} + +func TestParsePortListenersFallsBackToProcessName(t *testing.T) { + t.Parallel() + + raw := `tcp LISTEN 0 128 0.0.0.0:5432 0.0.0.0:* users:(("postgres",pid=77,fd=5))` + listeners, err := parsePortListeners([]byte(raw), func(int) string { return "" }) + if err != nil { + t.Fatalf("parsePortListeners: %v", err) + } + if len(listeners) != 1 { + t.Fatalf("listeners = %d, want 1", len(listeners)) + } + if listeners[0].Command != "postgres" { + t.Fatalf("command = %q, want process fallback", listeners[0].Command) + } +} + +func TestNewHandlerPortsReturnsServerErrorOnCollectorFailure(t *testing.T) { + origCollector := portCollector + t.Cleanup(func() { + portCollector = origCollector + }) + portCollector = func(context.Context) ([]PortListener, error) { + return nil, errors.New("ss missing") + } + + req, err := http.NewRequest(http.MethodGet, PortsPath, nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + rr := newTestResponseRecorder() + NewHandler().ServeHTTP(rr, req) + if rr.status != http.StatusInternalServerError { + t.Fatalf("status = %d, want %d", rr.status, http.StatusInternalServerError) + } +} + type testResponseRecorder struct { headers http.Header body bytes.Buffer diff --git a/packages.apt b/packages.apt index 8d35af0..54a5159 100644 --- a/packages.apt +++ b/packages.apt @@ -5,5 +5,6 @@ tree ca-certificates curl wget +iproute2 vim tmux