From 86a56fedb379fb89662e27a432a60be90b92a400 Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Thu, 23 Apr 2026 15:46:59 -0300 Subject: [PATCH] daemon: extract StatsService sibling; shrink VMService's surface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes commit 3 of the god-service decomposition. VMService still owned 45+ methods after the startVMLocked extraction and RPC table landed in commits 1 and 2. Stats / ports / health / vsock-ping sit in a corner of that surface that doesn't share any state with lifecycle orchestration — nothing about "what's this VM's CPU doing" belongs in the same service as Create/Start/Stop/Delete/Set. New StatsService owns: - GetVMStats / getVMStatsLocked / collectStats (stats collection) - HealthVM / PingVM (vsock-agent health probe) - PortsVM + buildVMPorts + probeWebListener + probeHTTPScheme + dedupeVMPorts (listening-port enumeration) - pollStats (background ticker refresh) - stopStaleVMs (auto-stop sweep past config.AutoStopStaleAfter) The three VMService touch-points stats genuinely needs — vmAlive, vmHandles, the per-VM lock helpers, plus cleanupRuntime for the stale-sweep tear-down — come in as function-typed closures, not a *VMService pointer. StatsService has no back-reference to its sibling. Mirrors the dependency-struct pattern WorkspaceService already uses. Wiring: d.stats is populated in wireServices AFTER d.vm (closures must see a non-nil d.vm at call time). Dispatch table's four entries (vm.stats / vm.health / vm.ping / vm.ports) now resolve through d.stats. Background loop's pollStats / stopStaleVMs tickers do the same. Dispatch surface from the RPC client's perspective is byte-identical. After this commit: - vm_stats.go and ports.go are deleted; their content (plus the stats-specific fields) lives in stats_service.go. - VMService loses 12 methods. It's still the biggest service (~30 methods, all lifecycle-supporting: handle cache, disk provisioning, preflight, create-ops registry, lock helpers, the lifecycle verbs themselves) but it's finally one coherent concern instead of five. Tests: - TestWireServicesInstantiatesStatsService — pins that the wiring order puts d.stats non-nil + its five closures all populated. Prevents a silent background-loop regression. - All existing tests that called d.vm.HealthVM / d.vm.PingVM / d.vm.PortsVM / d.vm.collectStats were re-pointed at d.stats. Smoke: all 21 scenarios green, including vm ports (exercises the new PortsVM entry end-to-end) and the long-running workspace scenarios (exercise the background stats poller implicitly). Co-Authored-By: Claude Opus 4.7 (1M context) --- internal/daemon/daemon.go | 38 ++- internal/daemon/dispatch.go | 8 +- internal/daemon/ports.go | 164 ----------- internal/daemon/stats_service.go | 387 ++++++++++++++++++++++++++ internal/daemon/stats_service_test.go | 51 ++++ internal/daemon/vm_stats.go | 157 ----------- internal/daemon/vm_test.go | 12 +- 7 files changed, 480 insertions(+), 337 deletions(-) delete mode 100644 internal/daemon/ports.go create mode 100644 internal/daemon/stats_service.go create mode 100644 internal/daemon/stats_service_test.go delete mode 100644 internal/daemon/vm_stats.go diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index daf0d70..67e78ac 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -34,10 +34,11 @@ type Daemon struct { runner system.CommandRunner logger *slog.Logger - net *HostNetwork - img *ImageService - ws *WorkspaceService - vm *VMService + net *HostNetwork + img *ImageService + ws *WorkspaceService + vm *VMService + stats *StatsService closing chan struct{} once sync.Once @@ -276,11 +277,11 @@ func (d *Daemon) backgroundLoop() { case <-d.closing: return case <-statsTicker.C: - if err := d.vm.pollStats(context.Background()); err != nil && d.logger != nil { + if err := d.stats.pollStats(context.Background()); err != nil && d.logger != nil { d.logger.Error("background stats poll failed", "error", err.Error()) } case <-staleTicker.C: - if err := d.vm.stopStaleVMs(context.Background()); err != nil && d.logger != nil { + if err := d.stats.stopStaleVMs(context.Background()); err != nil && d.logger != nil { d.logger.Error("background stale sweep failed", "error", err.Error()) } d.vm.pruneVMCreateOperations(time.Now().Add(-10 * time.Minute)) @@ -429,6 +430,31 @@ func wireServices(d *Daemon) { vsockHostDevice: defaultVsockHostDevice, }) } + if d.stats == nil { + // Closures capture d rather than d.vm directly, so they re-read + // d.vm at call time. Wire order (d.vm constructed above) makes + // the closures safe, but this pattern also protects against a + // future test that swaps d.vm after initial wire. + d.stats = newStatsService(statsServiceDeps{ + runner: d.runner, + logger: d.logger, + config: d.config, + store: d.store, + net: d.net, + beginOperation: d.beginOperation, + vmAlive: func(vm model.VMRecord) bool { return d.vm.vmAlive(vm) }, + vmHandles: func(id string) model.VMHandles { return d.vm.vmHandles(id) }, + withVMLockByRef: func(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) { + return d.vm.withVMLockByRef(ctx, idOrName, fn) + }, + withVMLockByIDErr: func(ctx context.Context, id string, fn func(model.VMRecord) error) error { + return d.vm.withVMLockByIDErr(ctx, id, fn) + }, + cleanupRuntime: func(ctx context.Context, vm model.VMRecord, preserve bool) error { + return d.vm.cleanupRuntime(ctx, vm, preserve) + }, + }) + } if len(d.vmCaps) == 0 { d.vmCaps = d.defaultCapabilities() } diff --git a/internal/daemon/dispatch.go b/internal/daemon/dispatch.go index 5fd5c3d..a47647d 100644 --- a/internal/daemon/dispatch.go +++ b/internal/daemon/dispatch.go @@ -154,20 +154,20 @@ func vmSetDispatch(ctx context.Context, d *Daemon, p api.VMSetParams) (api.VMSho } func vmStatsDispatch(ctx context.Context, d *Daemon, p api.VMRefParams) (api.VMStatsResult, error) { - vm, stats, err := d.vm.GetVMStats(ctx, p.IDOrName) + vm, stats, err := d.stats.GetVMStats(ctx, p.IDOrName) return api.VMStatsResult{VM: vm, Stats: stats}, err } func vmHealthDispatch(ctx context.Context, d *Daemon, p api.VMRefParams) (api.VMHealthResult, error) { - return d.vm.HealthVM(ctx, p.IDOrName) + return d.stats.HealthVM(ctx, p.IDOrName) } func vmPingDispatch(ctx context.Context, d *Daemon, p api.VMRefParams) (api.VMPingResult, error) { - return d.vm.PingVM(ctx, p.IDOrName) + return d.stats.PingVM(ctx, p.IDOrName) } func vmPortsDispatch(ctx context.Context, d *Daemon, p api.VMRefParams) (api.VMPortsResult, error) { - return d.vm.PortsVM(ctx, p.IDOrName) + return d.stats.PortsVM(ctx, p.IDOrName) } func workspacePrepareDispatch(ctx context.Context, d *Daemon, p api.VMWorkspacePrepareParams) (api.VMWorkspacePrepareResult, error) { diff --git a/internal/daemon/ports.go b/internal/daemon/ports.go deleted file mode 100644 index e765c20..0000000 --- a/internal/daemon/ports.go +++ /dev/null @@ -1,164 +0,0 @@ -package daemon - -import ( - "context" - "crypto/tls" - "errors" - "fmt" - "io" - "net" - "net/http" - "sort" - "strconv" - "strings" - "time" - - "banger/internal/api" - "banger/internal/model" - "banger/internal/vmdns" - "banger/internal/vsockagent" -) - -const httpProbeTimeout = 750 * time.Millisecond - -func (s *VMService) PortsVM(ctx context.Context, idOrName string) (result api.VMPortsResult, err error) { - _, err = s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { - result.Name = vm.Name - result.DNSName = strings.TrimSpace(vm.Runtime.DNSName) - if result.DNSName == "" && strings.TrimSpace(vm.Name) != "" { - result.DNSName = vmdns.RecordName(vm.Name) - } - if !s.vmAlive(vm) { - return model.VMRecord{}, fmt.Errorf("vm %s is not running", vm.Name) - } - if strings.TrimSpace(vm.Runtime.GuestIP) == "" { - return model.VMRecord{}, errors.New("vm has no guest IP") - } - if strings.TrimSpace(vm.Runtime.VSockPath) == "" { - return model.VMRecord{}, errors.New("vm has no vsock path") - } - if vm.Runtime.VSockCID == 0 { - return model.VMRecord{}, errors.New("vm has no vsock cid") - } - if err := s.net.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { - return model.VMRecord{}, err - } - portsCtx, cancel := context.WithTimeout(ctx, 3*time.Second) - defer cancel() - listeners, err := vsockagent.Ports(portsCtx, s.logger, vm.Runtime.VSockPath) - if err != nil { - return model.VMRecord{}, err - } - result.Ports = buildVMPorts(vm, listeners) - return vm, nil - }) - return result, err -} - -func buildVMPorts(vm model.VMRecord, listeners []vsockagent.PortListener) []api.VMPort { - endpointHost := strings.TrimSpace(vm.Runtime.DNSName) - if endpointHost == "" { - endpointHost = strings.TrimSpace(vm.Runtime.GuestIP) - } - probeHost := strings.TrimSpace(vm.Runtime.GuestIP) - ports := make([]api.VMPort, 0, len(listeners)) - for _, listener := range listeners { - if listener.Port <= 0 { - continue - } - port := api.VMPort{ - Proto: strings.ToLower(strings.TrimSpace(listener.Proto)), - BindAddress: strings.TrimSpace(listener.BindAddress), - Port: listener.Port, - PID: listener.PID, - Process: strings.TrimSpace(listener.Process), - Command: strings.TrimSpace(listener.Command), - Endpoint: net.JoinHostPort(endpointHost, strconv.Itoa(listener.Port)), - } - if port.Command == "" { - port.Command = port.Process - } - if port.Proto == "tcp" && probeHost != "" && endpointHost != "" { - if scheme, ok := probeWebListener(probeHost, listener.Port); ok { - port.Proto = scheme - port.Endpoint = scheme + "://" + net.JoinHostPort(endpointHost, strconv.Itoa(listener.Port)) + "/" - } - } - ports = append(ports, port) - } - sort.Slice(ports, func(i, j int) bool { - if ports[i].Proto != ports[j].Proto { - return ports[i].Proto < ports[j].Proto - } - if ports[i].Port != ports[j].Port { - return ports[i].Port < ports[j].Port - } - if ports[i].PID != ports[j].PID { - return ports[i].PID < ports[j].PID - } - if ports[i].Process != ports[j].Process { - return ports[i].Process < ports[j].Process - } - if ports[i].Command != ports[j].Command { - return ports[i].Command < ports[j].Command - } - return ports[i].BindAddress < ports[j].BindAddress - }) - return dedupeVMPorts(ports) -} - -func probeWebListener(guestIP string, port int) (string, bool) { - if probeHTTPScheme("https", guestIP, port) { - return "https", true - } - if probeHTTPScheme("http", guestIP, port) { - return "http", true - } - return "", false -} - -func probeHTTPScheme(scheme, guestIP string, port int) bool { - if strings.TrimSpace(guestIP) == "" || port <= 0 { - return false - } - url := scheme + "://" + net.JoinHostPort(strings.TrimSpace(guestIP), strconv.Itoa(port)) + "/" - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return false - } - transport := &http.Transport{Proxy: nil} - if scheme == "https" { - transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - client := &http.Client{ - Timeout: httpProbeTimeout, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - Transport: transport, - } - resp, err := client.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1)) - return resp.ProtoMajor >= 1 -} - -func dedupeVMPorts(ports []api.VMPort) []api.VMPort { - if len(ports) < 2 { - return ports - } - deduped := make([]api.VMPort, 0, len(ports)) - seen := make(map[string]struct{}, len(ports)) - for _, port := range ports { - key := port.Proto + "\x00" + port.Endpoint - if _, ok := seen[key]; ok { - continue - } - seen[key] = struct{}{} - deduped = append(deduped, port) - } - return deduped -} diff --git a/internal/daemon/stats_service.go b/internal/daemon/stats_service.go new file mode 100644 index 0000000..71ecb8e --- /dev/null +++ b/internal/daemon/stats_service.go @@ -0,0 +1,387 @@ +package daemon + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "sort" + "strconv" + "strings" + "time" + + "banger/internal/api" + "banger/internal/model" + "banger/internal/store" + "banger/internal/system" + "banger/internal/vmdns" + "banger/internal/vsockagent" +) + +// StatsService owns the "observe a VM" surface: stats collection +// (CPU / memory / disk), listening-port enumeration, vsock-agent +// health probes, the background poller that refreshes stats for every +// live VM, and the auto-stop-when-idle sweep. +// +// Split out from VMService (commit 3 of the god-service decomposition): +// nothing here orchestrates lifecycle. The three VMService touch +// points stats genuinely needs — vmAlive, vmHandles, the per-VM lock +// helpers, plus cleanupRuntime for the stale-VM sweep — come in as +// function-typed closures so StatsService has no back-reference to +// its sibling. Same pattern WorkspaceService already uses. +type StatsService struct { + runner system.CommandRunner + logger *slog.Logger + config model.DaemonConfig + store *store.Store + net *HostNetwork + beginOperation func(name string, attrs ...any) *operationLog + + // vmAlive / vmHandles are the minimum pair needed to answer "is + // this VM actually running right now?" + "what PID is it?". + // Closures over VMService so we re-read d.vm at call time — wire + // order in wireServices puts d.vm before d.stats, so these are + // safe by the time anything on StatsService fires. + vmAlive func(vm model.VMRecord) bool + vmHandles func(vmID string) model.VMHandles + + // Lock helpers: stats collection and the stale-sweep both mutate + // VM records (persist new stats, flip State to Stopped on auto- + // stop) and so need the same per-VM mutex lifecycle ops hold. + withVMLockByRef func(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) + withVMLockByIDErr func(ctx context.Context, id string, fn func(model.VMRecord) error) error + + // cleanupRuntime is the auto-stop-sweep's only call into the + // lifecycle side — forcibly tears down a VM that's been idle past + // AutoStopStaleAfter. Keeping it as a closure means StatsService + // never directly dereferences VMService. + cleanupRuntime func(ctx context.Context, vm model.VMRecord, preserveDisks bool) error +} + +type statsServiceDeps struct { + runner system.CommandRunner + logger *slog.Logger + config model.DaemonConfig + store *store.Store + net *HostNetwork + beginOperation func(name string, attrs ...any) *operationLog + vmAlive func(vm model.VMRecord) bool + vmHandles func(vmID string) model.VMHandles + withVMLockByRef func(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) + withVMLockByIDErr func(ctx context.Context, id string, fn func(model.VMRecord) error) error + cleanupRuntime func(ctx context.Context, vm model.VMRecord, preserveDisks bool) error +} + +func newStatsService(deps statsServiceDeps) *StatsService { + return &StatsService{ + runner: deps.runner, + logger: deps.logger, + config: deps.config, + store: deps.store, + net: deps.net, + beginOperation: deps.beginOperation, + vmAlive: deps.vmAlive, + vmHandles: deps.vmHandles, + withVMLockByRef: deps.withVMLockByRef, + withVMLockByIDErr: deps.withVMLockByIDErr, + cleanupRuntime: deps.cleanupRuntime, + } +} + +// ---- stats ---- + +func (s *StatsService) GetVMStats(ctx context.Context, idOrName string) (model.VMRecord, model.VMStats, error) { + vm, err := s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { + return s.getVMStatsLocked(ctx, vm) + }) + if err != nil { + return model.VMRecord{}, model.VMStats{}, err + } + return vm, vm.Stats, nil +} + +func (s *StatsService) HealthVM(ctx context.Context, idOrName string) (result api.VMHealthResult, err error) { + _, err = s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { + result.Name = vm.Name + if !s.vmAlive(vm) { + result.Healthy = false + return vm, nil + } + if strings.TrimSpace(vm.Runtime.VSockPath) == "" { + return model.VMRecord{}, errors.New("vm has no vsock path") + } + if vm.Runtime.VSockCID == 0 { + return model.VMRecord{}, errors.New("vm has no vsock cid") + } + if err := s.net.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { + return model.VMRecord{}, err + } + pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + if err := vsockagent.Health(pingCtx, s.logger, vm.Runtime.VSockPath); err != nil { + return model.VMRecord{}, err + } + result.Healthy = true + return vm, nil + }) + return result, err +} + +func (s *StatsService) PingVM(ctx context.Context, idOrName string) (result api.VMPingResult, err error) { + health, err := s.HealthVM(ctx, idOrName) + if err != nil { + return api.VMPingResult{}, err + } + return api.VMPingResult{Name: health.Name, Alive: health.Healthy}, nil +} + +func (s *StatsService) getVMStatsLocked(ctx context.Context, vm model.VMRecord) (model.VMRecord, error) { + stats, err := s.collectStats(ctx, vm) + if err == nil { + vm.Stats = stats + vm.UpdatedAt = model.Now() + _ = s.store.UpsertVM(ctx, vm) + if s.logger != nil { + s.logger.Debug("vm stats collected", append(vmLogAttrs(vm), "rss_bytes", stats.RSSBytes, "vsz_bytes", stats.VSZBytes, "cpu_percent", stats.CPUPercent)...) + } + } + return vm, nil +} + +// pollStats runs on the daemon's background ticker; refreshes stats +// for every VM the store knows about, skipping ones that aren't alive. +func (s *StatsService) pollStats(ctx context.Context) error { + vms, err := s.store.ListVMs(ctx) + if err != nil { + return err + } + for _, vm := range vms { + if err := s.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error { + if !s.vmAlive(vm) { + return nil + } + stats, err := s.collectStats(ctx, vm) + if err != nil { + if s.logger != nil { + s.logger.Debug("vm stats collection failed", append(vmLogAttrs(vm), "error", err.Error())...) + } + return nil + } + vm.Stats = stats + vm.UpdatedAt = model.Now() + return s.store.UpsertVM(ctx, vm) + }); err != nil { + return err + } + } + return nil +} + +// stopStaleVMs auto-stops any running VM whose LastTouchedAt is older +// than config.AutoStopStaleAfter. This is the only path through +// StatsService that actually mutates VM lifecycle state — it needs +// cleanupRuntime to tear down the kernel + process side. +func (s *StatsService) stopStaleVMs(ctx context.Context) (err error) { + if s.config.AutoStopStaleAfter <= 0 { + return nil + } + op := s.beginOperation("vm.stop_stale") + defer func() { + if err != nil { + op.fail(err) + return + } + op.done() + }() + vms, err := s.store.ListVMs(ctx) + if err != nil { + return err + } + now := model.Now() + for _, vm := range vms { + if err := s.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error { + if !s.vmAlive(vm) { + return nil + } + if now.Sub(vm.LastTouchedAt) < s.config.AutoStopStaleAfter { + return nil + } + op.stage("stopping_vm", vmLogAttrs(vm)...) + _ = s.net.sendCtrlAltDel(ctx, vm.Runtime.APISockPath) + _ = s.net.waitForExit(ctx, s.vmHandles(vm.ID).PID, vm.Runtime.APISockPath, 10*time.Second) + _ = s.cleanupRuntime(ctx, vm, true) + vm.State = model.VMStateStopped + vm.Runtime.State = model.VMStateStopped + vm.Runtime.TapDevice = "" + vm.UpdatedAt = model.Now() + return s.store.UpsertVM(ctx, vm) + }); err != nil { + return err + } + } + return nil +} + +func (s *StatsService) collectStats(ctx context.Context, vm model.VMRecord) (model.VMStats, error) { + stats := model.VMStats{ + CollectedAt: model.Now(), + SystemOverlayBytes: system.AllocatedBytes(vm.Runtime.SystemOverlay), + WorkDiskBytes: system.AllocatedBytes(vm.Runtime.WorkDiskPath), + MetricsRaw: system.ParseMetricsFile(vm.Runtime.MetricsPath), + } + if s.vmAlive(vm) { + if ps, err := system.ReadProcessStats(ctx, s.vmHandles(vm.ID).PID); err == nil { + stats.CPUPercent = ps.CPUPercent + stats.RSSBytes = ps.RSSBytes + stats.VSZBytes = ps.VSZBytes + } + } + return stats, nil +} + +// ---- ports ---- + +const httpProbeTimeout = 750 * time.Millisecond + +func (s *StatsService) PortsVM(ctx context.Context, idOrName string) (result api.VMPortsResult, err error) { + _, err = s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { + result.Name = vm.Name + result.DNSName = strings.TrimSpace(vm.Runtime.DNSName) + if result.DNSName == "" && strings.TrimSpace(vm.Name) != "" { + result.DNSName = vmdns.RecordName(vm.Name) + } + if !s.vmAlive(vm) { + return model.VMRecord{}, fmt.Errorf("vm %s is not running", vm.Name) + } + if strings.TrimSpace(vm.Runtime.GuestIP) == "" { + return model.VMRecord{}, errors.New("vm has no guest IP") + } + if strings.TrimSpace(vm.Runtime.VSockPath) == "" { + return model.VMRecord{}, errors.New("vm has no vsock path") + } + if vm.Runtime.VSockCID == 0 { + return model.VMRecord{}, errors.New("vm has no vsock cid") + } + if err := s.net.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { + return model.VMRecord{}, err + } + portsCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + listeners, err := vsockagent.Ports(portsCtx, s.logger, vm.Runtime.VSockPath) + if err != nil { + return model.VMRecord{}, err + } + result.Ports = buildVMPorts(vm, listeners) + return vm, nil + }) + return result, err +} + +func buildVMPorts(vm model.VMRecord, listeners []vsockagent.PortListener) []api.VMPort { + endpointHost := strings.TrimSpace(vm.Runtime.DNSName) + if endpointHost == "" { + endpointHost = strings.TrimSpace(vm.Runtime.GuestIP) + } + probeHost := strings.TrimSpace(vm.Runtime.GuestIP) + ports := make([]api.VMPort, 0, len(listeners)) + for _, listener := range listeners { + if listener.Port <= 0 { + continue + } + port := api.VMPort{ + Proto: strings.ToLower(strings.TrimSpace(listener.Proto)), + BindAddress: strings.TrimSpace(listener.BindAddress), + Port: listener.Port, + PID: listener.PID, + Process: strings.TrimSpace(listener.Process), + Command: strings.TrimSpace(listener.Command), + Endpoint: net.JoinHostPort(endpointHost, strconv.Itoa(listener.Port)), + } + if port.Command == "" { + port.Command = port.Process + } + if port.Proto == "tcp" && probeHost != "" && endpointHost != "" { + if scheme, ok := probeWebListener(probeHost, listener.Port); ok { + port.Proto = scheme + port.Endpoint = scheme + "://" + net.JoinHostPort(endpointHost, strconv.Itoa(listener.Port)) + "/" + } + } + ports = append(ports, port) + } + sort.Slice(ports, func(i, j int) bool { + if ports[i].Proto != ports[j].Proto { + return ports[i].Proto < ports[j].Proto + } + if ports[i].Port != ports[j].Port { + return ports[i].Port < ports[j].Port + } + if ports[i].PID != ports[j].PID { + return ports[i].PID < ports[j].PID + } + if ports[i].Process != ports[j].Process { + return ports[i].Process < ports[j].Process + } + return ports[i].BindAddress < ports[j].BindAddress + }) + return dedupeVMPorts(ports) +} + +func probeWebListener(guestIP string, port int) (string, bool) { + if probeHTTPScheme("https", guestIP, port) { + return "https", true + } + if probeHTTPScheme("http", guestIP, port) { + return "http", true + } + return "", false +} + +func probeHTTPScheme(scheme, guestIP string, port int) bool { + if strings.TrimSpace(guestIP) == "" || port <= 0 { + return false + } + url := scheme + "://" + net.JoinHostPort(strings.TrimSpace(guestIP), strconv.Itoa(port)) + "/" + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return false + } + transport := &http.Transport{Proxy: nil} + if scheme == "https" { + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + client := &http.Client{ + Timeout: httpProbeTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + Transport: transport, + } + resp, err := client.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1)) + return resp.ProtoMajor >= 1 +} + +func dedupeVMPorts(ports []api.VMPort) []api.VMPort { + if len(ports) < 2 { + return ports + } + deduped := make([]api.VMPort, 0, len(ports)) + seen := make(map[string]struct{}, len(ports)) + for _, port := range ports { + key := port.Proto + "\x00" + port.Endpoint + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + deduped = append(deduped, port) + } + return deduped +} diff --git a/internal/daemon/stats_service_test.go b/internal/daemon/stats_service_test.go new file mode 100644 index 0000000..83a69e2 --- /dev/null +++ b/internal/daemon/stats_service_test.go @@ -0,0 +1,51 @@ +package daemon + +import ( + "testing" + + "banger/internal/model" + "banger/internal/paths" +) + +// TestWireServicesInstantiatesStatsService pins that wireServices +// leaves d.stats non-nil after construction. A wiring-order bug that +// left stats unset would silently break background stats polling and +// the vm.stats / vm.health / vm.ping / vm.ports RPC methods — none +// of those would nil-deref at cold boot because the daemon might +// not get a call for minutes, but the pollStats ticker would +// immediately panic on its first fire. +func TestWireServicesInstantiatesStatsService(t *testing.T) { + d := &Daemon{ + runner: &permissiveRunner{}, + config: model.DaemonConfig{BridgeIP: model.DefaultBridgeIP}, + layout: paths.Layout{ + StateDir: t.TempDir(), + ConfigDir: t.TempDir(), + RuntimeDir: t.TempDir(), + VMsDir: t.TempDir(), + }, + } + wireServices(d) + + if d.stats == nil { + t.Fatal("d.stats is nil after wireServices") + } + // Spot-check the three closures that back every stats method — + // a nil closure would be a less-obvious wiring regression than + // a nil service. + if d.stats.vmAlive == nil { + t.Fatal("d.stats.vmAlive closure is nil") + } + if d.stats.vmHandles == nil { + t.Fatal("d.stats.vmHandles closure is nil") + } + if d.stats.cleanupRuntime == nil { + t.Fatal("d.stats.cleanupRuntime closure is nil") + } + if d.stats.withVMLockByRef == nil { + t.Fatal("d.stats.withVMLockByRef closure is nil") + } + if d.stats.withVMLockByIDErr == nil { + t.Fatal("d.stats.withVMLockByIDErr closure is nil") + } +} diff --git a/internal/daemon/vm_stats.go b/internal/daemon/vm_stats.go deleted file mode 100644 index 62e8d85..0000000 --- a/internal/daemon/vm_stats.go +++ /dev/null @@ -1,157 +0,0 @@ -package daemon - -import ( - "context" - "errors" - "strings" - "time" - - "banger/internal/api" - "banger/internal/model" - "banger/internal/system" - "banger/internal/vsockagent" -) - -func (s *VMService) GetVMStats(ctx context.Context, idOrName string) (model.VMRecord, model.VMStats, error) { - vm, err := s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { - return s.getVMStatsLocked(ctx, vm) - }) - if err != nil { - return model.VMRecord{}, model.VMStats{}, err - } - return vm, vm.Stats, nil -} - -func (s *VMService) HealthVM(ctx context.Context, idOrName string) (result api.VMHealthResult, err error) { - _, err = s.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { - result.Name = vm.Name - if !s.vmAlive(vm) { - result.Healthy = false - return vm, nil - } - if strings.TrimSpace(vm.Runtime.VSockPath) == "" { - return model.VMRecord{}, errors.New("vm has no vsock path") - } - if vm.Runtime.VSockCID == 0 { - return model.VMRecord{}, errors.New("vm has no vsock cid") - } - if err := s.net.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { - return model.VMRecord{}, err - } - pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second) - defer cancel() - if err := vsockagent.Health(pingCtx, s.logger, vm.Runtime.VSockPath); err != nil { - return model.VMRecord{}, err - } - result.Healthy = true - return vm, nil - }) - return result, err -} - -func (s *VMService) PingVM(ctx context.Context, idOrName string) (result api.VMPingResult, err error) { - health, err := s.HealthVM(ctx, idOrName) - if err != nil { - return api.VMPingResult{}, err - } - return api.VMPingResult{Name: health.Name, Alive: health.Healthy}, nil -} - -func (s *VMService) getVMStatsLocked(ctx context.Context, vm model.VMRecord) (model.VMRecord, error) { - stats, err := s.collectStats(ctx, vm) - if err == nil { - vm.Stats = stats - vm.UpdatedAt = model.Now() - _ = s.store.UpsertVM(ctx, vm) - if s.logger != nil { - s.logger.Debug("vm stats collected", append(vmLogAttrs(vm), "rss_bytes", stats.RSSBytes, "vsz_bytes", stats.VSZBytes, "cpu_percent", stats.CPUPercent)...) - } - } - return vm, nil -} - -func (s *VMService) pollStats(ctx context.Context) error { - vms, err := s.store.ListVMs(ctx) - if err != nil { - return err - } - for _, vm := range vms { - if err := s.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error { - if !s.vmAlive(vm) { - return nil - } - stats, err := s.collectStats(ctx, vm) - if err != nil { - if s.logger != nil { - s.logger.Debug("vm stats collection failed", append(vmLogAttrs(vm), "error", err.Error())...) - } - return nil - } - vm.Stats = stats - vm.UpdatedAt = model.Now() - return s.store.UpsertVM(ctx, vm) - }); err != nil { - return err - } - } - return nil -} - -func (s *VMService) stopStaleVMs(ctx context.Context) (err error) { - if s.config.AutoStopStaleAfter <= 0 { - return nil - } - op := s.beginOperation("vm.stop_stale") - defer func() { - if err != nil { - op.fail(err) - return - } - op.done() - }() - vms, err := s.store.ListVMs(ctx) - if err != nil { - return err - } - now := model.Now() - for _, vm := range vms { - if err := s.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error { - if !s.vmAlive(vm) { - return nil - } - if now.Sub(vm.LastTouchedAt) < s.config.AutoStopStaleAfter { - return nil - } - op.stage("stopping_vm", vmLogAttrs(vm)...) - _ = s.net.sendCtrlAltDel(ctx, vm.Runtime.APISockPath) - _ = s.net.waitForExit(ctx, s.vmHandles(vm.ID).PID, vm.Runtime.APISockPath, 10*time.Second) - _ = s.cleanupRuntime(ctx, vm, true) - vm.State = model.VMStateStopped - vm.Runtime.State = model.VMStateStopped - vm.Runtime.TapDevice = "" - s.clearVMHandles(vm) - vm.UpdatedAt = model.Now() - return s.store.UpsertVM(ctx, vm) - }); err != nil { - return err - } - } - return nil -} - -func (s *VMService) collectStats(ctx context.Context, vm model.VMRecord) (model.VMStats, error) { - stats := model.VMStats{ - CollectedAt: model.Now(), - SystemOverlayBytes: system.AllocatedBytes(vm.Runtime.SystemOverlay), - WorkDiskBytes: system.AllocatedBytes(vm.Runtime.WorkDiskPath), - MetricsRaw: system.ParseMetricsFile(vm.Runtime.MetricsPath), - } - if s.vmAlive(vm) { - if ps, err := system.ReadProcessStats(ctx, s.vmHandles(vm.ID).PID); err == nil { - stats.CPUPercent = ps.CPUPercent - stats.RSSBytes = ps.RSSBytes - stats.VSZBytes = ps.VSZBytes - } - } - return stats, nil -} diff --git a/internal/daemon/vm_test.go b/internal/daemon/vm_test.go index 05b4713..fc7e92e 100644 --- a/internal/daemon/vm_test.go +++ b/internal/daemon/vm_test.go @@ -374,7 +374,7 @@ func TestHealthVMReturnsHealthyForRunningGuest(t *testing.T) { d := &Daemon{store: db, runner: runner} wireServices(d) d.vm.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: handlePID}) - result, err := d.vm.HealthVM(ctx, vm.Name) + result, err := d.stats.HealthVM(ctx, vm.Name) if err != nil { t.Fatalf("HealthVM: %v", err) } @@ -438,7 +438,7 @@ func TestPingVMAliasReturnsAliveForHealthyVM(t *testing.T) { d := &Daemon{store: db, runner: runner} wireServices(d) d.vm.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: fake.Process.Pid}) - result, err := d.vm.PingVM(ctx, vm.Name) + result, err := d.stats.PingVM(ctx, vm.Name) if err != nil { t.Fatalf("PingVM: %v", err) } @@ -538,7 +538,7 @@ func TestHealthVMReturnsFalseForStoppedVM(t *testing.T) { d := &Daemon{store: db} wireServices(d) - result, err := d.vm.HealthVM(ctx, vm.Name) + result, err := d.stats.HealthVM(ctx, vm.Name) if err != nil { t.Fatalf("HealthVM: %v", err) } @@ -639,7 +639,7 @@ func TestPortsVMReturnsEnrichedPortsAndWebSchemes(t *testing.T) { wireServices(d) d.vm.setVMHandlesInMemory(vm.ID, model.VMHandles{PID: fake.Process.Pid}) - result, err := d.vm.PortsVM(ctx, vm.Name) + result, err := d.stats.PortsVM(ctx, vm.Name) if err != nil { t.Fatalf("PortsVM: %v", err) } @@ -687,7 +687,7 @@ func TestPortsVMReturnsErrorForStoppedVM(t *testing.T) { d := &Daemon{store: db} wireServices(d) - _, err := d.vm.PortsVM(ctx, vm.Name) + _, err := d.stats.PortsVM(ctx, vm.Name) if err == nil || !strings.Contains(err.Error(), "is not running") { t.Fatalf("PortsVM error = %v, want not running", err) } @@ -1407,7 +1407,7 @@ func TestCollectStatsIgnoresMalformedMetricsFile(t *testing.T) { d := &Daemon{} wireServices(d) - stats, err := d.vm.collectStats(context.Background(), model.VMRecord{ + stats, err := d.stats.collectStats(context.Background(), model.VMRecord{ Runtime: model.VMRuntime{ SystemOverlay: overlay, WorkDiskPath: workDisk,