package vsockagent import ( "context" "encoding/json" "errors" "fmt" "io" "log/slog" "net" "net/http" "time" sdkvsock "github.com/firecracker-microvm/firecracker-go-sdk/vsock" "github.com/sirupsen/logrus" ) const ( Port uint32 = 42070 HealthPath = "/healthz" HealthyStatus = "ok" GuestBinaryName = "banger-vsock-agent" GuestInstallPath = "/usr/local/bin/" + GuestBinaryName ServiceName = "banger-vsock-agent.service" serviceUnit = `[Unit] Description=Banger vsock agent After=network.target [Service] Type=simple ExecStart=/usr/local/bin/banger-vsock-agent Restart=on-failure RestartSec=1 [Install] WantedBy=multi-user.target ` modulesLoadConfig = "vsock\nvmw_vsock_virtio_transport\n" ) type HealthResponse struct { Status string `json:"status"` } func NewHandler() http.Handler { mux := http.NewServeMux() mux.HandleFunc(HealthPath, func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(HealthResponse{Status: HealthyStatus}) }) return mux } func Health(ctx context.Context, logger *slog.Logger, socketPath string) error { transport := &http.Transport{ DisableKeepAlives: true, DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { return sdkvsock.DialContext( ctx, socketPath, Port, sdkvsock.WithRetryTimeout(3*time.Second), sdkvsock.WithRetryInterval(100*time.Millisecond), sdkvsock.WithLogger(newLogger(logger)), ) }, } defer transport.CloseIdleConnections() req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://vsock"+HealthPath, nil) if err != nil { return err } resp, err := (&http.Client{Transport: transport}).Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) return fmt.Errorf("unexpected health status %d: %s", resp.StatusCode, string(body)) } var payload HealthResponse if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { return err } if payload.Status != HealthyStatus { return fmt.Errorf("unexpected health response status %q", payload.Status) } if logger != nil { logger.Debug("vsock health ok", "vsock_path", socketPath, "vsock_port", Port) } return nil } func ServiceUnit() string { return serviceUnit } func ModulesLoadConfig() string { return modulesLoadConfig } func ReminderMessage(name string) string { return fmt.Sprintf("session ended; %s is still running (stop it with 'banger vm stop %s')", name, name) } func WarningMessage(name string, err error) string { if err == nil { return "" } return fmt.Sprintf("warning: failed to check whether %s is still running: %v", name, err) } func newLogger(base *slog.Logger) *logrus.Entry { logger := logrus.New() logger.SetOutput(io.Discard) logger.SetLevel(logrus.DebugLevel) logger.AddHook(slogHook{logger: base}) return logrus.NewEntry(logger) } type slogHook struct { logger *slog.Logger } func (h slogHook) Levels() []logrus.Level { return logrus.AllLevels } func (h slogHook) Fire(entry *logrus.Entry) error { if h.logger == nil { return nil } level := slog.LevelDebug switch entry.Level { case logrus.ErrorLevel, logrus.FatalLevel, logrus.PanicLevel: level = slog.LevelError case logrus.WarnLevel: level = slog.LevelWarn case logrus.InfoLevel: level = slog.LevelInfo } attrs := make([]any, 0, len(entry.Data)*2) for key, value := range entry.Data { attrs = append(attrs, key, value) } h.logger.Log(context.Background(), level, entry.Message, attrs...) return nil } func IsServerClosed(err error) bool { return errors.Is(err, http.ErrServerClosed) }