Stop long-running daemon operations from running under context.Background by\nthreading a request-scoped context from handleConn into dispatch. The daemon\nnow cancels in-flight handlers when the client socket goes away, and the RPC\nclient closes its Unix connection when the caller context is canceled so that\ninterrupts actually reach the daemon boundary.\n\nAdd regression coverage for both sides of the path: canceled dispatch calls,\nclient disconnects during handleConn, watcher EOF cancellation, and context\ncancellation without an RPC deadline.\n\nValidated with GOCACHE=/tmp/banger-gocache go test ./... and\nGOCACHE=/tmp/banger-gocache make build.
571 lines
16 KiB
Go
571 lines
16 KiB
Go
package daemon
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"banger/internal/api"
|
|
"banger/internal/config"
|
|
"banger/internal/model"
|
|
"banger/internal/paths"
|
|
"banger/internal/rpc"
|
|
"banger/internal/store"
|
|
"banger/internal/system"
|
|
)
|
|
|
|
type Daemon struct {
|
|
layout paths.Layout
|
|
config model.DaemonConfig
|
|
store *store.Store
|
|
runner system.CommandRunner
|
|
logger *slog.Logger
|
|
mu sync.Mutex
|
|
closing chan struct{}
|
|
once sync.Once
|
|
pid int
|
|
listener net.Listener
|
|
requestHandler func(context.Context, rpc.Request) rpc.Response
|
|
}
|
|
|
|
func Open(ctx context.Context) (*Daemon, error) {
|
|
layout, err := paths.Resolve()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := paths.Ensure(layout); err != nil {
|
|
return nil, err
|
|
}
|
|
cfg, err := config.Load(layout)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
logger, normalizedLevel, err := newDaemonLogger(os.Stderr, cfg.LogLevel)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cfg.LogLevel = normalizedLevel
|
|
db, err := store.Open(layout.DBPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
d := &Daemon{
|
|
layout: layout,
|
|
config: cfg,
|
|
store: db,
|
|
runner: system.NewRunner(),
|
|
logger: logger,
|
|
closing: make(chan struct{}),
|
|
pid: os.Getpid(),
|
|
}
|
|
d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "runtime_dir", cfg.RuntimeDir, "log_level", cfg.LogLevel)
|
|
if err := d.ensureDefaultImage(ctx); err != nil {
|
|
d.logger.Error("daemon open failed", "stage", "ensure_default_image", "error", err.Error())
|
|
return nil, err
|
|
}
|
|
if err := d.reconcile(ctx); err != nil {
|
|
d.logger.Error("daemon open failed", "stage", "reconcile", "error", err.Error())
|
|
return nil, err
|
|
}
|
|
return d, nil
|
|
}
|
|
|
|
func (d *Daemon) Close() error {
|
|
var err error
|
|
d.once.Do(func() {
|
|
if d.logger != nil {
|
|
d.logger.Info("daemon closing")
|
|
}
|
|
close(d.closing)
|
|
if d.listener != nil {
|
|
_ = d.listener.Close()
|
|
}
|
|
err = d.store.Close()
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (d *Daemon) Serve(ctx context.Context) error {
|
|
_ = os.Remove(d.layout.SocketPath)
|
|
listener, err := net.Listen("unix", d.layout.SocketPath)
|
|
if err != nil {
|
|
if d.logger != nil {
|
|
d.logger.Error("daemon listen failed", "socket", d.layout.SocketPath, "error", err.Error())
|
|
}
|
|
return err
|
|
}
|
|
d.listener = listener
|
|
defer listener.Close()
|
|
defer os.Remove(d.layout.SocketPath)
|
|
if err := os.Chmod(d.layout.SocketPath, 0o600); err != nil {
|
|
return err
|
|
}
|
|
if d.logger != nil {
|
|
d.logger.Info("daemon serving", "socket", d.layout.SocketPath, "pid", d.pid)
|
|
}
|
|
|
|
go d.backgroundLoop()
|
|
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case <-d.closing:
|
|
return nil
|
|
default:
|
|
}
|
|
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
|
if d.logger != nil {
|
|
d.logger.Warn("daemon accept temporary failure", "error", err.Error())
|
|
}
|
|
time.Sleep(100 * time.Millisecond)
|
|
continue
|
|
}
|
|
if d.logger != nil {
|
|
d.logger.Error("daemon accept failed", "error", err.Error())
|
|
}
|
|
return err
|
|
}
|
|
go d.handleConn(conn)
|
|
}
|
|
}
|
|
|
|
func (d *Daemon) handleConn(conn net.Conn) {
|
|
defer conn.Close()
|
|
reader := bufio.NewReader(conn)
|
|
var req rpc.Request
|
|
if err := json.NewDecoder(reader).Decode(&req); err != nil {
|
|
if d.logger != nil {
|
|
d.logger.Warn("daemon request decode failed", "remote", conn.RemoteAddr().String(), "error", err.Error())
|
|
}
|
|
_ = json.NewEncoder(conn).Encode(rpc.NewError("bad_request", err.Error()))
|
|
return
|
|
}
|
|
reqCtx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
stopWatch := d.watchRequestDisconnect(conn, reader, req.Method, cancel)
|
|
defer stopWatch()
|
|
resp := d.dispatch(reqCtx, req)
|
|
if reqCtx.Err() != nil {
|
|
return
|
|
}
|
|
if err := json.NewEncoder(conn).Encode(resp); err != nil && d.logger != nil {
|
|
d.logger.Warn("daemon response encode failed", "method", req.Method, "remote", conn.RemoteAddr().String(), "error", err.Error())
|
|
}
|
|
}
|
|
|
|
func (d *Daemon) watchRequestDisconnect(conn net.Conn, reader *bufio.Reader, method string, cancel context.CancelFunc) func() {
|
|
if conn == nil || reader == nil {
|
|
return func() {}
|
|
}
|
|
done := make(chan struct{})
|
|
var once sync.Once
|
|
go func() {
|
|
go func() {
|
|
<-done
|
|
if deadlineSetter, ok := conn.(interface{ SetReadDeadline(time.Time) error }); ok {
|
|
_ = deadlineSetter.SetReadDeadline(time.Now())
|
|
}
|
|
}()
|
|
var buf [1]byte
|
|
for {
|
|
_, err := reader.Read(buf[:])
|
|
if err == nil {
|
|
continue
|
|
}
|
|
select {
|
|
case <-done:
|
|
return
|
|
default:
|
|
}
|
|
if d.logger != nil {
|
|
d.logger.Info("daemon request canceled", "method", method, "remote", conn.RemoteAddr().String(), "error", err.Error())
|
|
}
|
|
cancel()
|
|
return
|
|
}
|
|
}()
|
|
return func() {
|
|
once.Do(func() {
|
|
close(done)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
|
|
if req.Version != rpc.Version {
|
|
return rpc.NewError("bad_version", fmt.Sprintf("unsupported version %d", req.Version))
|
|
}
|
|
if d.requestHandler != nil {
|
|
return d.requestHandler(ctx, req)
|
|
}
|
|
switch req.Method {
|
|
case "ping":
|
|
result, _ := rpc.NewResult(api.PingResult{Status: "ok", PID: d.pid})
|
|
return result
|
|
case "shutdown":
|
|
go d.Close()
|
|
result, _ := rpc.NewResult(api.ShutdownResult{Status: "stopping"})
|
|
return result
|
|
case "vm.create":
|
|
params, err := rpc.DecodeParams[api.VMCreateParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
vm, err := d.CreateVM(ctx, params)
|
|
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
|
|
case "vm.list":
|
|
vms, err := d.store.ListVMs(ctx)
|
|
return marshalResultOrError(api.VMListResult{VMs: vms}, err)
|
|
case "vm.show":
|
|
params, err := rpc.DecodeParams[api.VMRefParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
vm, err := d.FindVM(ctx, params.IDOrName)
|
|
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
|
|
case "vm.start":
|
|
params, err := rpc.DecodeParams[api.VMRefParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
vm, err := d.StartVM(ctx, params.IDOrName)
|
|
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
|
|
case "vm.stop":
|
|
params, err := rpc.DecodeParams[api.VMRefParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
vm, err := d.StopVM(ctx, params.IDOrName)
|
|
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
|
|
case "vm.kill":
|
|
params, err := rpc.DecodeParams[api.VMKillParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
vm, err := d.KillVM(ctx, params)
|
|
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
|
|
case "vm.restart":
|
|
params, err := rpc.DecodeParams[api.VMRefParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
vm, err := d.RestartVM(ctx, params.IDOrName)
|
|
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
|
|
case "vm.delete":
|
|
params, err := rpc.DecodeParams[api.VMRefParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
vm, err := d.DeleteVM(ctx, params.IDOrName)
|
|
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
|
|
case "vm.set":
|
|
params, err := rpc.DecodeParams[api.VMSetParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
vm, err := d.SetVM(ctx, params)
|
|
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
|
|
case "vm.stats":
|
|
params, err := rpc.DecodeParams[api.VMRefParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
vm, stats, err := d.GetVMStats(ctx, params.IDOrName)
|
|
return marshalResultOrError(api.VMStatsResult{VM: vm, Stats: stats}, err)
|
|
case "vm.logs":
|
|
params, err := rpc.DecodeParams[api.VMRefParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
vm, err := d.FindVM(ctx, params.IDOrName)
|
|
if err != nil {
|
|
return rpc.NewError("not_found", err.Error())
|
|
}
|
|
return marshalResultOrError(api.VMLogsResult{LogPath: vm.Runtime.LogPath}, nil)
|
|
case "vm.ssh":
|
|
params, err := rpc.DecodeParams[api.VMRefParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
vm, err := d.TouchVM(ctx, params.IDOrName)
|
|
if err != nil {
|
|
return rpc.NewError("not_found", err.Error())
|
|
}
|
|
if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
|
|
return rpc.NewError("not_running", fmt.Sprintf("vm %s is not running", vm.Name))
|
|
}
|
|
return marshalResultOrError(api.VMSSHResult{Name: vm.Name, GuestIP: vm.Runtime.GuestIP}, nil)
|
|
case "image.list":
|
|
images, err := d.store.ListImages(ctx)
|
|
return marshalResultOrError(api.ImageListResult{Images: images}, err)
|
|
case "image.show":
|
|
params, err := rpc.DecodeParams[api.ImageRefParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
image, err := d.FindImage(ctx, params.IDOrName)
|
|
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
|
|
case "image.build":
|
|
params, err := rpc.DecodeParams[api.ImageBuildParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
image, err := d.BuildImage(ctx, params)
|
|
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
|
|
case "image.delete":
|
|
params, err := rpc.DecodeParams[api.ImageRefParams](req)
|
|
if err != nil {
|
|
return rpc.NewError("bad_request", err.Error())
|
|
}
|
|
image, err := d.DeleteImage(ctx, params.IDOrName)
|
|
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
|
|
default:
|
|
return rpc.NewError("unknown_method", req.Method)
|
|
}
|
|
}
|
|
|
|
func (d *Daemon) backgroundLoop() {
|
|
statsTicker := time.NewTicker(d.config.StatsPollInterval)
|
|
staleTicker := time.NewTicker(model.DefaultStaleSweepInterval)
|
|
defer statsTicker.Stop()
|
|
defer staleTicker.Stop()
|
|
for {
|
|
select {
|
|
case <-d.closing:
|
|
return
|
|
case <-statsTicker.C:
|
|
if err := d.pollStats(context.Background()); err != nil && d.logger != nil {
|
|
d.logger.Error("background stats poll failed", "error", err.Error())
|
|
}
|
|
case <-staleTicker.C:
|
|
if err := d.stopStaleVMs(context.Background()); err != nil && d.logger != nil {
|
|
d.logger.Error("background stale sweep failed", "error", err.Error())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *Daemon) ensureDefaultImage(ctx context.Context) error {
|
|
if d.config.DefaultImageName == "" {
|
|
return nil
|
|
}
|
|
desired, ok := d.desiredDefaultImage()
|
|
if !ok {
|
|
if d.logger != nil {
|
|
d.logger.Debug("default image skipped", "image_name", d.config.DefaultImageName, "rootfs_path", d.config.DefaultRootfs, "kernel_path", d.config.DefaultKernel)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
image, err := d.store.GetImageByName(ctx, d.config.DefaultImageName)
|
|
switch {
|
|
case err == nil:
|
|
if image.Managed {
|
|
if d.logger != nil {
|
|
d.logger.Debug("managed default image left untouched", append(imageLogAttrs(image), "managed", image.Managed)...)
|
|
}
|
|
return nil
|
|
}
|
|
if defaultImageMatches(image, desired) {
|
|
if d.logger != nil {
|
|
d.logger.Debug("default image already current", imageLogAttrs(image)...)
|
|
}
|
|
return nil
|
|
}
|
|
updated := desired
|
|
updated.ID = image.ID
|
|
updated.CreatedAt = image.CreatedAt
|
|
updated.UpdatedAt = model.Now()
|
|
if err := d.store.UpsertImage(ctx, updated); err != nil {
|
|
return err
|
|
}
|
|
if d.logger != nil {
|
|
d.logger.Info("default image reconciled", append(imageLogAttrs(updated), "previous_rootfs_path", image.RootfsPath, "previous_kernel_path", image.KernelPath)...)
|
|
}
|
|
return nil
|
|
case errors.Is(err, sql.ErrNoRows):
|
|
id, err := model.NewID()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
now := model.Now()
|
|
desired.ID = id
|
|
desired.CreatedAt = now
|
|
desired.UpdatedAt = now
|
|
if err := d.store.UpsertImage(ctx, desired); err != nil {
|
|
return err
|
|
}
|
|
if d.logger != nil {
|
|
d.logger.Info("default image registered", append(imageLogAttrs(desired), "managed", desired.Managed)...)
|
|
}
|
|
return nil
|
|
default:
|
|
return err
|
|
}
|
|
}
|
|
|
|
func (d *Daemon) desiredDefaultImage() (model.Image, bool) {
|
|
rootfs := d.config.DefaultRootfs
|
|
kernel := d.config.DefaultKernel
|
|
if !exists(rootfs) || !exists(kernel) {
|
|
return model.Image{}, false
|
|
}
|
|
return model.Image{
|
|
Name: d.config.DefaultImageName,
|
|
Managed: false,
|
|
ArtifactDir: "",
|
|
RootfsPath: rootfs,
|
|
KernelPath: kernel,
|
|
InitrdPath: d.config.DefaultInitrd,
|
|
ModulesDir: d.config.DefaultModulesDir,
|
|
PackagesPath: d.config.DefaultPackagesFile,
|
|
Docker: strings.Contains(filepath.Base(rootfs), "docker"),
|
|
}, true
|
|
}
|
|
|
|
func defaultImageMatches(current, desired model.Image) bool {
|
|
return current.Name == desired.Name &&
|
|
current.Managed == desired.Managed &&
|
|
current.ArtifactDir == desired.ArtifactDir &&
|
|
current.RootfsPath == desired.RootfsPath &&
|
|
current.KernelPath == desired.KernelPath &&
|
|
current.InitrdPath == desired.InitrdPath &&
|
|
current.ModulesDir == desired.ModulesDir &&
|
|
current.PackagesPath == desired.PackagesPath &&
|
|
current.Docker == desired.Docker
|
|
}
|
|
|
|
func (d *Daemon) reconcile(ctx context.Context) error {
|
|
op := d.beginOperation("daemon.reconcile")
|
|
vms, err := d.store.ListVMs(ctx)
|
|
if err != nil {
|
|
return op.fail(err)
|
|
}
|
|
for _, vm := range vms {
|
|
if vm.State != model.VMStateRunning {
|
|
continue
|
|
}
|
|
if system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
|
|
continue
|
|
}
|
|
op.stage("stale_vm", vmLogAttrs(vm)...)
|
|
_ = d.cleanupRuntime(ctx, vm, true)
|
|
vm.State = model.VMStateStopped
|
|
vm.Runtime.State = model.VMStateStopped
|
|
vm.Runtime.PID = 0
|
|
vm.Runtime.TapDevice = ""
|
|
vm.Runtime.APISockPath = ""
|
|
vm.Runtime.BaseLoop = ""
|
|
vm.Runtime.COWLoop = ""
|
|
vm.Runtime.DMName = ""
|
|
vm.Runtime.DMDev = ""
|
|
vm.UpdatedAt = model.Now()
|
|
if err := d.store.UpsertVM(ctx, vm); err != nil {
|
|
return op.fail(err, vmLogAttrs(vm)...)
|
|
}
|
|
}
|
|
op.done()
|
|
return nil
|
|
}
|
|
|
|
func (d *Daemon) FindVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
|
|
if idOrName == "" {
|
|
return model.VMRecord{}, errors.New("vm id or name is required")
|
|
}
|
|
if vm, err := d.store.GetVM(ctx, idOrName); err == nil {
|
|
return vm, nil
|
|
}
|
|
vms, err := d.store.ListVMs(ctx)
|
|
if err != nil {
|
|
return model.VMRecord{}, err
|
|
}
|
|
matchCount := 0
|
|
var match model.VMRecord
|
|
for _, vm := range vms {
|
|
if strings.HasPrefix(vm.ID, idOrName) || strings.HasPrefix(vm.Name, idOrName) {
|
|
match = vm
|
|
matchCount++
|
|
}
|
|
}
|
|
if matchCount == 1 {
|
|
return match, nil
|
|
}
|
|
if matchCount > 1 {
|
|
return model.VMRecord{}, fmt.Errorf("multiple VMs match %q", idOrName)
|
|
}
|
|
return model.VMRecord{}, fmt.Errorf("vm %q not found", idOrName)
|
|
}
|
|
|
|
func (d *Daemon) FindImage(ctx context.Context, idOrName string) (model.Image, error) {
|
|
if idOrName == "" {
|
|
return model.Image{}, errors.New("image id or name is required")
|
|
}
|
|
if image, err := d.store.GetImageByName(ctx, idOrName); err == nil {
|
|
return image, nil
|
|
}
|
|
if image, err := d.store.GetImageByID(ctx, idOrName); err == nil {
|
|
return image, nil
|
|
}
|
|
images, err := d.store.ListImages(ctx)
|
|
if err != nil {
|
|
return model.Image{}, err
|
|
}
|
|
matchCount := 0
|
|
var match model.Image
|
|
for _, image := range images {
|
|
if strings.HasPrefix(image.ID, idOrName) || strings.HasPrefix(image.Name, idOrName) {
|
|
match = image
|
|
matchCount++
|
|
}
|
|
}
|
|
if matchCount == 1 {
|
|
return match, nil
|
|
}
|
|
if matchCount > 1 {
|
|
return model.Image{}, fmt.Errorf("multiple images match %q", idOrName)
|
|
}
|
|
return model.Image{}, fmt.Errorf("image %q not found", idOrName)
|
|
}
|
|
|
|
func (d *Daemon) TouchVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
|
|
d.mu.Lock()
|
|
defer d.mu.Unlock()
|
|
vm, err := d.FindVM(ctx, idOrName)
|
|
if err != nil {
|
|
return model.VMRecord{}, err
|
|
}
|
|
system.TouchNow(&vm)
|
|
if err := d.store.UpsertVM(ctx, vm); err != nil {
|
|
return model.VMRecord{}, err
|
|
}
|
|
return vm, nil
|
|
}
|
|
|
|
func marshalResultOrError(v any, err error) rpc.Response {
|
|
if err != nil {
|
|
return rpc.NewError("operation_failed", err.Error())
|
|
}
|
|
resp, marshalErr := rpc.NewResult(v)
|
|
if marshalErr != nil {
|
|
return rpc.NewError("marshal_failed", marshalErr.Error())
|
|
}
|
|
return resp
|
|
}
|
|
|
|
func exists(path string) bool {
|
|
_, err := os.Stat(path)
|
|
return err == nil
|
|
}
|