banger/internal/daemon/stats_service.go
Thales Maciel d743a8ba4b
daemon: persist teardown fallbacks and reject unsafe import paths
Preserve cleanup after daemon restarts and harden OCI and tar imports
against filenames that debugfs cannot encode safely.

Mirror tap, loop, and dm teardown identity onto VM.Runtime, teach
cleanup and reconcile to fall back to those persisted fields when
handles.json is missing or corrupt, and clear the recovery state on
stop, error, and delete paths.

Reject debugfs-hostile entry names during flattening and in
ApplyOwnership itself, then add regression coverage for corrupt
handles.json recovery and unsafe import paths.

Verified with targeted go tests, make lint-go, make lint-shell, and
make build.
2026-04-23 16:21:59 -03:00

387 lines
12 KiB
Go

package daemon
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"sort"
"strconv"
"strings"
"time"
"banger/internal/api"
"banger/internal/model"
"banger/internal/store"
"banger/internal/system"
"banger/internal/vmdns"
"banger/internal/vsockagent"
)
// StatsService owns the "observe a VM" surface: stats collection
// (CPU / memory / disk), listening-port enumeration, vsock-agent
// health probes, the background poller that refreshes stats for every
// live VM, and the auto-stop-when-idle sweep.
//
// Split out from VMService (commit 3 of the god-service decomposition):
// nothing here orchestrates lifecycle. The three VMService touch
// points stats genuinely needs — vmAlive, vmHandles, the per-VM lock
// helpers, plus cleanupRuntime for the stale-VM sweep — come in as
// function-typed closures so StatsService has no back-reference to
// its sibling. Same pattern WorkspaceService already uses.
type StatsService struct {
runner system.CommandRunner
logger *slog.Logger
config model.DaemonConfig
store *store.Store
net *HostNetwork
beginOperation func(name string, attrs ...any) *operationLog
// vmAlive / vmHandles are the minimum pair needed to answer "is
// this VM actually running right now?" + "what PID is it?".
// Closures over VMService so we re-read d.vm at call time — wire
// order in wireServices puts d.vm before d.stats, so these are
// safe by the time anything on StatsService fires.
vmAlive func(vm model.VMRecord) bool
vmHandles func(vmID string) model.VMHandles
// Lock helpers: stats collection and the stale-sweep both mutate
// VM records (persist new stats, flip State to Stopped on auto-
// stop) and so need the same per-VM mutex lifecycle ops hold.
withVMLockByRef func(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error)
withVMLockByIDErr func(ctx context.Context, id string, fn func(model.VMRecord) error) error
// cleanupRuntime is the auto-stop-sweep's only call into the
// lifecycle side — forcibly tears down a VM that's been idle past
// AutoStopStaleAfter. Keeping it as a closure means StatsService
// never directly dereferences VMService.
cleanupRuntime func(ctx context.Context, vm model.VMRecord, preserveDisks bool) error
}
type statsServiceDeps struct {
runner system.CommandRunner
logger *slog.Logger
config model.DaemonConfig
store *store.Store
net *HostNetwork
beginOperation func(name string, attrs ...any) *operationLog
vmAlive func(vm model.VMRecord) bool
vmHandles func(vmID string) model.VMHandles
withVMLockByRef func(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error)
withVMLockByIDErr func(ctx context.Context, id string, fn func(model.VMRecord) error) error
cleanupRuntime func(ctx context.Context, vm model.VMRecord, preserveDisks bool) error
}
func newStatsService(deps statsServiceDeps) *StatsService {
return &StatsService{
runner: deps.runner,
logger: deps.logger,
config: deps.config,
store: deps.store,
net: deps.net,
beginOperation: deps.beginOperation,
vmAlive: deps.vmAlive,
vmHandles: deps.vmHandles,
withVMLockByRef: deps.withVMLockByRef,
withVMLockByIDErr: deps.withVMLockByIDErr,
cleanupRuntime: deps.cleanupRuntime,
}
}
// ---- stats ----
func (s *StatsService) GetVMStats(ctx context.Context, idOrName string) (model.VMRecord, model.VMStats, error) {
vm, err := s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
return s.getVMStatsLocked(ctx, vm)
})
if err != nil {
return model.VMRecord{}, model.VMStats{}, err
}
return vm, vm.Stats, nil
}
func (s *StatsService) HealthVM(ctx context.Context, idOrName string) (result api.VMHealthResult, err error) {
_, err = s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
result.Name = vm.Name
if !s.vmAlive(vm) {
result.Healthy = false
return vm, nil
}
if strings.TrimSpace(vm.Runtime.VSockPath) == "" {
return model.VMRecord{}, errors.New("vm has no vsock path")
}
if vm.Runtime.VSockCID == 0 {
return model.VMRecord{}, errors.New("vm has no vsock cid")
}
if err := s.net.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil {
return model.VMRecord{}, err
}
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
if err := vsockagent.Health(pingCtx, s.logger, vm.Runtime.VSockPath); err != nil {
return model.VMRecord{}, err
}
result.Healthy = true
return vm, nil
})
return result, err
}
func (s *StatsService) PingVM(ctx context.Context, idOrName string) (result api.VMPingResult, err error) {
health, err := s.HealthVM(ctx, idOrName)
if err != nil {
return api.VMPingResult{}, err
}
return api.VMPingResult{Name: health.Name, Alive: health.Healthy}, nil
}
func (s *StatsService) getVMStatsLocked(ctx context.Context, vm model.VMRecord) (model.VMRecord, error) {
stats, err := s.collectStats(ctx, vm)
if err == nil {
vm.Stats = stats
vm.UpdatedAt = model.Now()
_ = s.store.UpsertVM(ctx, vm)
if s.logger != nil {
s.logger.Debug("vm stats collected", append(vmLogAttrs(vm), "rss_bytes", stats.RSSBytes, "vsz_bytes", stats.VSZBytes, "cpu_percent", stats.CPUPercent)...)
}
}
return vm, nil
}
// pollStats runs on the daemon's background ticker; refreshes stats
// for every VM the store knows about, skipping ones that aren't alive.
func (s *StatsService) pollStats(ctx context.Context) error {
vms, err := s.store.ListVMs(ctx)
if err != nil {
return err
}
for _, vm := range vms {
if err := s.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error {
if !s.vmAlive(vm) {
return nil
}
stats, err := s.collectStats(ctx, vm)
if err != nil {
if s.logger != nil {
s.logger.Debug("vm stats collection failed", append(vmLogAttrs(vm), "error", err.Error())...)
}
return nil
}
vm.Stats = stats
vm.UpdatedAt = model.Now()
return s.store.UpsertVM(ctx, vm)
}); err != nil {
return err
}
}
return nil
}
// stopStaleVMs auto-stops any running VM whose LastTouchedAt is older
// than config.AutoStopStaleAfter. This is the only path through
// StatsService that actually mutates VM lifecycle state — it needs
// cleanupRuntime to tear down the kernel + process side.
func (s *StatsService) stopStaleVMs(ctx context.Context) (err error) {
if s.config.AutoStopStaleAfter <= 0 {
return nil
}
op := s.beginOperation("vm.stop_stale")
defer func() {
if err != nil {
op.fail(err)
return
}
op.done()
}()
vms, err := s.store.ListVMs(ctx)
if err != nil {
return err
}
now := model.Now()
for _, vm := range vms {
if err := s.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error {
if !s.vmAlive(vm) {
return nil
}
if now.Sub(vm.LastTouchedAt) < s.config.AutoStopStaleAfter {
return nil
}
op.stage("stopping_vm", vmLogAttrs(vm)...)
_ = s.net.sendCtrlAltDel(ctx, vm.Runtime.APISockPath)
_ = s.net.waitForExit(ctx, s.vmHandles(vm.ID).PID, vm.Runtime.APISockPath, 10*time.Second)
_ = s.cleanupRuntime(ctx, vm, true)
vm.State = model.VMStateStopped
vm.Runtime.State = model.VMStateStopped
clearRuntimeTeardownState(&vm)
vm.UpdatedAt = model.Now()
return s.store.UpsertVM(ctx, vm)
}); err != nil {
return err
}
}
return nil
}
func (s *StatsService) collectStats(ctx context.Context, vm model.VMRecord) (model.VMStats, error) {
stats := model.VMStats{
CollectedAt: model.Now(),
SystemOverlayBytes: system.AllocatedBytes(vm.Runtime.SystemOverlay),
WorkDiskBytes: system.AllocatedBytes(vm.Runtime.WorkDiskPath),
MetricsRaw: system.ParseMetricsFile(vm.Runtime.MetricsPath),
}
if s.vmAlive(vm) {
if ps, err := system.ReadProcessStats(ctx, s.vmHandles(vm.ID).PID); err == nil {
stats.CPUPercent = ps.CPUPercent
stats.RSSBytes = ps.RSSBytes
stats.VSZBytes = ps.VSZBytes
}
}
return stats, nil
}
// ---- ports ----
const httpProbeTimeout = 750 * time.Millisecond
func (s *StatsService) PortsVM(ctx context.Context, idOrName string) (result api.VMPortsResult, err error) {
_, err = s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
result.Name = vm.Name
result.DNSName = strings.TrimSpace(vm.Runtime.DNSName)
if result.DNSName == "" && strings.TrimSpace(vm.Name) != "" {
result.DNSName = vmdns.RecordName(vm.Name)
}
if !s.vmAlive(vm) {
return model.VMRecord{}, fmt.Errorf("vm %s is not running", vm.Name)
}
if strings.TrimSpace(vm.Runtime.GuestIP) == "" {
return model.VMRecord{}, errors.New("vm has no guest IP")
}
if strings.TrimSpace(vm.Runtime.VSockPath) == "" {
return model.VMRecord{}, errors.New("vm has no vsock path")
}
if vm.Runtime.VSockCID == 0 {
return model.VMRecord{}, errors.New("vm has no vsock cid")
}
if err := s.net.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil {
return model.VMRecord{}, err
}
portsCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
listeners, err := vsockagent.Ports(portsCtx, s.logger, vm.Runtime.VSockPath)
if err != nil {
return model.VMRecord{}, err
}
result.Ports = buildVMPorts(vm, listeners)
return vm, nil
})
return result, err
}
func buildVMPorts(vm model.VMRecord, listeners []vsockagent.PortListener) []api.VMPort {
endpointHost := strings.TrimSpace(vm.Runtime.DNSName)
if endpointHost == "" {
endpointHost = strings.TrimSpace(vm.Runtime.GuestIP)
}
probeHost := strings.TrimSpace(vm.Runtime.GuestIP)
ports := make([]api.VMPort, 0, len(listeners))
for _, listener := range listeners {
if listener.Port <= 0 {
continue
}
port := api.VMPort{
Proto: strings.ToLower(strings.TrimSpace(listener.Proto)),
BindAddress: strings.TrimSpace(listener.BindAddress),
Port: listener.Port,
PID: listener.PID,
Process: strings.TrimSpace(listener.Process),
Command: strings.TrimSpace(listener.Command),
Endpoint: net.JoinHostPort(endpointHost, strconv.Itoa(listener.Port)),
}
if port.Command == "" {
port.Command = port.Process
}
if port.Proto == "tcp" && probeHost != "" && endpointHost != "" {
if scheme, ok := probeWebListener(probeHost, listener.Port); ok {
port.Proto = scheme
port.Endpoint = scheme + "://" + net.JoinHostPort(endpointHost, strconv.Itoa(listener.Port)) + "/"
}
}
ports = append(ports, port)
}
sort.Slice(ports, func(i, j int) bool {
if ports[i].Proto != ports[j].Proto {
return ports[i].Proto < ports[j].Proto
}
if ports[i].Port != ports[j].Port {
return ports[i].Port < ports[j].Port
}
if ports[i].PID != ports[j].PID {
return ports[i].PID < ports[j].PID
}
if ports[i].Process != ports[j].Process {
return ports[i].Process < ports[j].Process
}
return ports[i].BindAddress < ports[j].BindAddress
})
return dedupeVMPorts(ports)
}
func probeWebListener(guestIP string, port int) (string, bool) {
if probeHTTPScheme("https", guestIP, port) {
return "https", true
}
if probeHTTPScheme("http", guestIP, port) {
return "http", true
}
return "", false
}
func probeHTTPScheme(scheme, guestIP string, port int) bool {
if strings.TrimSpace(guestIP) == "" || port <= 0 {
return false
}
url := scheme + "://" + net.JoinHostPort(strings.TrimSpace(guestIP), strconv.Itoa(port)) + "/"
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return false
}
transport := &http.Transport{Proxy: nil}
if scheme == "https" {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
client := &http.Client{
Timeout: httpProbeTimeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Transport: transport,
}
resp, err := client.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1))
return resp.ProtoMajor >= 1
}
func dedupeVMPorts(ports []api.VMPort) []api.VMPort {
if len(ports) < 2 {
return ports
}
deduped := make([]api.VMPort, 0, len(ports))
seen := make(map[string]struct{}, len(ports))
for _, port := range ports {
key := port.Proto + "\x00" + port.Endpoint
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
deduped = append(deduped, port)
}
return deduped
}