package vsockagent import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "log/slog" "net" "net/http" "os" "os/exec" "regexp" "sort" "strconv" "strings" "time" sdkvsock "github.com/firecracker-microvm/firecracker-go-sdk/vsock" "github.com/sirupsen/logrus" ) const ( Port uint32 = 42070 HealthPath = "/healthz" PortsPath = "/ports" HealthyStatus = "ok" GuestBinaryName = "banger-vsock-agent" GuestInstallPath = "/usr/local/bin/" + GuestBinaryName ServiceName = "banger-vsock-agent.service" serviceUnit = `[Unit] Description=Banger vsock agent After=network.target [Service] Type=simple ExecStart=/usr/local/bin/banger-vsock-agent Restart=on-failure RestartSec=1 [Install] WantedBy=multi-user.target ` modulesLoadConfig = "vsock\nvmw_vsock_virtio_transport\n" ) var ( portCollector = CollectPorts processRe = regexp.MustCompile(`"([^"]+)",pid=(\d+)`) ) type HealthResponse struct { Status string `json:"status"` } type PortListener struct { Proto string `json:"proto"` BindAddress string `json:"bind_address"` Port int `json:"port"` PID int `json:"pid,omitempty"` Process string `json:"process,omitempty"` Command string `json:"command,omitempty"` } type PortsResponse struct { Listeners []PortListener `json:"listeners"` } func NewHandler() http.Handler { mux := http.NewServeMux() mux.HandleFunc(HealthPath, func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(HealthResponse{Status: HealthyStatus}) }) mux.HandleFunc(PortsPath, func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } listeners, err := portCollector(r.Context()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(PortsResponse{Listeners: listeners}) }) return mux } func Health(ctx context.Context, logger *slog.Logger, socketPath string) error { resp, err := doRequest(ctx, logger, socketPath, HealthPath) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) return fmt.Errorf("unexpected health status %d: %s", resp.StatusCode, string(body)) } var payload HealthResponse if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { return err } if payload.Status != HealthyStatus { return fmt.Errorf("unexpected health response status %q", payload.Status) } if logger != nil { logger.Debug("vsock health ok", "vsock_path", socketPath, "vsock_port", Port) } return nil } func Ports(ctx context.Context, logger *slog.Logger, socketPath string) ([]PortListener, error) { resp, err := doRequest(ctx, logger, socketPath, PortsPath) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) return nil, fmt.Errorf("unexpected ports status %d: %s", resp.StatusCode, string(body)) } var payload PortsResponse if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { return nil, err } return payload.Listeners, nil } func CollectPorts(ctx context.Context) ([]PortListener, error) { cmd := exec.CommandContext(ctx, "ss", "-H", "-lntup") output, err := cmd.Output() if err != nil { if len(output) > 0 { return nil, fmt.Errorf("run ss: %w: %s", err, bytes.TrimSpace(output)) } return nil, fmt.Errorf("run ss: %w", err) } return parsePortListeners(output, readProcessCommandLine) } func ServiceUnit() string { return serviceUnit } func ModulesLoadConfig() string { return modulesLoadConfig } func ReminderMessage(name string) string { return fmt.Sprintf("session ended; %s is still running (stop it with 'banger vm stop %s')", name, name) } func WarningMessage(name string, err error) string { if err == nil { return "" } return fmt.Sprintf("warning: failed to check whether %s is still running: %v", name, err) } func newLogger(base *slog.Logger) *logrus.Entry { logger := logrus.New() logger.SetOutput(io.Discard) logger.SetLevel(logrus.DebugLevel) logger.AddHook(slogHook{logger: base}) return logrus.NewEntry(logger) } type slogHook struct { logger *slog.Logger } func (h slogHook) Levels() []logrus.Level { return logrus.AllLevels } func (h slogHook) Fire(entry *logrus.Entry) error { if h.logger == nil { return nil } level := slog.LevelDebug switch entry.Level { case logrus.ErrorLevel, logrus.FatalLevel, logrus.PanicLevel: level = slog.LevelError case logrus.WarnLevel: level = slog.LevelWarn case logrus.InfoLevel: level = slog.LevelInfo } attrs := make([]any, 0, len(entry.Data)*2) for key, value := range entry.Data { attrs = append(attrs, key, value) } h.logger.Log(context.Background(), level, entry.Message, attrs...) return nil } func IsServerClosed(err error) bool { return errors.Is(err, http.ErrServerClosed) } func doRequest(ctx context.Context, logger *slog.Logger, socketPath, path string) (*http.Response, error) { transport := &http.Transport{ DisableKeepAlives: true, DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { return sdkvsock.DialContext( ctx, socketPath, Port, sdkvsock.WithRetryTimeout(3*time.Second), sdkvsock.WithRetryInterval(100*time.Millisecond), sdkvsock.WithLogger(newLogger(logger)), ) }, } client := &http.Client{Transport: transport} req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://vsock"+path, nil) if err != nil { transport.CloseIdleConnections() return nil, err } resp, err := client.Do(req) if err != nil { transport.CloseIdleConnections() return nil, err } return wrappedResponse(resp, transport), nil } type responseCloser struct { io.ReadCloser transport *http.Transport } func (c responseCloser) Close() error { err := c.ReadCloser.Close() c.transport.CloseIdleConnections() return err } func wrappedResponse(resp *http.Response, transport *http.Transport) *http.Response { if resp == nil || resp.Body == nil || transport == nil { return resp } resp.Body = responseCloser{ReadCloser: resp.Body, transport: transport} return resp } func parsePortListeners(raw []byte, readCmdline func(int) string) ([]PortListener, error) { lines := strings.Split(strings.TrimSpace(string(raw)), "\n") listeners := make([]PortListener, 0, len(lines)) wildcards := make(map[string]struct{}) for _, line := range lines { line = strings.TrimSpace(line) if line == "" { continue } parsed, err := parseSSLine(line, readCmdline) if err != nil { return nil, err } for _, listener := range parsed { if isLoopbackAddress(listener.BindAddress) { continue } if isWildcardAddress(listener.BindAddress) { key := wildcardKey(listener) if _, ok := wildcards[key]; ok { continue } wildcards[key] = struct{}{} } listeners = append(listeners, listener) } } sort.Slice(listeners, func(i, j int) bool { if listeners[i].Proto != listeners[j].Proto { return listeners[i].Proto < listeners[j].Proto } if listeners[i].Port != listeners[j].Port { return listeners[i].Port < listeners[j].Port } if listeners[i].PID != listeners[j].PID { return listeners[i].PID < listeners[j].PID } return listeners[i].BindAddress < listeners[j].BindAddress }) return listeners, nil } func parseSSLine(line string, readCmdline func(int) string) ([]PortListener, error) { fields := strings.Fields(line) if len(fields) < 6 { return nil, fmt.Errorf("parse ss line: expected at least 6 fields, got %d in %q", len(fields), line) } proto := strings.ToLower(strings.TrimSpace(fields[0])) if proto != "tcp" && proto != "udp" { return nil, nil } bindAddress, port, err := parseBindAddress(fields[4]) if err != nil { return nil, fmt.Errorf("parse ss local address %q: %w", fields[4], err) } if bindAddress != "*" && net.ParseIP(bindAddress) == nil { return nil, nil } processInfo := strings.Join(fields[6:], " ") entries := parseProcessEntries(processInfo) if len(entries) == 0 { return []PortListener{{ Proto: proto, BindAddress: bindAddress, Port: port, }}, nil } listeners := make([]PortListener, 0, len(entries)) for _, entry := range entries { command := strings.TrimSpace(readCmdline(entry.PID)) if command == "" { command = entry.Process } listeners = append(listeners, PortListener{ Proto: proto, BindAddress: bindAddress, Port: port, PID: entry.PID, Process: entry.Process, Command: command, }) } return listeners, nil } type processEntry struct { Process string PID int } func parseProcessEntries(raw string) []processEntry { matches := processRe.FindAllStringSubmatch(raw, -1) if len(matches) == 0 { return nil } entries := make([]processEntry, 0, len(matches)) for _, match := range matches { pid, err := strconv.Atoi(match[2]) if err != nil { continue } entries = append(entries, processEntry{ Process: match[1], PID: pid, }) } return entries } func parseBindAddress(raw string) (string, int, error) { raw = strings.TrimSpace(raw) if raw == "" { return "", 0, errors.New("empty address") } var host, portRaw string if strings.HasPrefix(raw, "[") { var err error host, portRaw, err = net.SplitHostPort(raw) if err != nil { return "", 0, err } } else { idx := strings.LastIndex(raw, ":") if idx <= 0 || idx == len(raw)-1 { return "", 0, fmt.Errorf("missing host:port in %q", raw) } host = raw[:idx] portRaw = raw[idx+1:] } if zoneIdx := strings.Index(host, "%"); zoneIdx >= 0 { host = host[:zoneIdx] } host = strings.Trim(host, "[]") if host == "" { host = "*" } port, err := strconv.Atoi(portRaw) if err != nil { return "", 0, err } return host, port, nil } func readProcessCommandLine(pid int) string { if pid <= 0 { return "" } data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid)) if err != nil { return "" } parts := strings.Split(string(data), "\x00") filtered := parts[:0] for _, part := range parts { part = strings.TrimSpace(part) if part == "" { continue } filtered = append(filtered, part) } return strings.Join(filtered, " ") } func isLoopbackAddress(host string) bool { if host == "" || host == "*" { return false } ip := net.ParseIP(host) return ip != nil && ip.IsLoopback() } func isWildcardAddress(host string) bool { switch host { case "*", "0.0.0.0", "::": return true } return false } func wildcardKey(listener PortListener) string { return fmt.Sprintf("%s:%d:%d:%s", listener.Proto, listener.Port, listener.PID, listener.Process) }