banger/internal/daemon/daemon.go
Thales Maciel 6cd52d12f4
workspace prepare: release VM mutex before guest I/O
Previously withVMLockByRef held the per-VM mutex across InspectRepo,
waitForGuestSSH, dialGuest, ImportRepoToGuest (the tar stream!), and
the readonly chmod. A large repo could block `vm stop` / `vm delete`
/ `vm restart` on the same VM for however long the import took.

Split into two phases:

  1. VM mutex held briefly to validate state (running + PID alive)
     and snapshot the fields needed for SSH (guest IP, api sock).
  2. VM mutex released. Acquire workspaceLocks[id] — a separate
     per-VM mutex scoped to workspace.prepare / workspace.export —
     for the guest I/O phase.

Lifecycle ops (stop/delete/restart/set) only take vmLocks, so they
no longer queue behind a slow import. Two concurrent prepares on the
same VM still serialise via workspaceLocks so tar streams don't
interleave. ExportVMWorkspace also acquires workspaceLocks to avoid
snapshotting a half-streamed import.

Two regression tests (sequential — they swap package-level seams):

  ReleasesVMLockDuringGuestIO: stall the import fake, assert the VM
  mutex is acquirable from another goroutine during the stall.

  SerialisesConcurrentPreparesOnSameVM: 3 concurrent prepares, assert
  Import is only ever invoked 1-at-a-time per VM.

ARCHITECTURE.md documents the split + updated lock ordering.
2026-04-19 13:32:42 -03:00

757 lines
23 KiB
Go

package daemon
import (
"bufio"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
"banger/internal/api"
"banger/internal/buildinfo"
"banger/internal/config"
"banger/internal/daemon/opstate"
"banger/internal/imagecat"
"banger/internal/imagepull"
"banger/internal/model"
"banger/internal/paths"
"banger/internal/rpc"
"banger/internal/store"
"banger/internal/system"
"banger/internal/vmdns"
)
type Daemon struct {
layout paths.Layout
config model.DaemonConfig
store *store.Store
runner system.CommandRunner
logger *slog.Logger
imageOpsMu sync.Mutex
createVMMu sync.Mutex
createOps opstate.Registry[*vmCreateOperationState]
vmLocks vmLockSet
// workspaceLocks serialises workspace.prepare / workspace.export
// calls on the same VM (two concurrent prepares would clobber each
// other's tar streams). It is a SEPARATE scope from vmLocks so
// slow guest I/O — SSH dial, tar upload, chmod — does not block
// vm stop/delete/restart. See ARCHITECTURE.md.
workspaceLocks vmLockSet
sessions sessionRegistry
tapPool tapPool
closing chan struct{}
once sync.Once
pid int
listener net.Listener
webListener net.Listener
webServer *http.Server
webURL string
vmDNS *vmdns.Server
vmCaps []vmCapability
pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error)
finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error
bundleFetch func(ctx context.Context, destDir string, entry imagecat.CatEntry) (imagecat.Manifest, error)
requestHandler func(context.Context, rpc.Request) rpc.Response
guestWaitForSSH func(context.Context, string, string, time.Duration) error
guestDial func(context.Context, string, string) (guestSSHClient, error)
waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error)
}
func Open(ctx context.Context) (d *Daemon, err error) {
layout, err := paths.Resolve()
if err != nil {
return nil, err
}
if err := paths.Ensure(layout); err != nil {
return nil, err
}
cfg, err := config.Load(layout)
if err != nil {
return nil, err
}
logger, normalizedLevel, err := newDaemonLogger(os.Stderr, cfg.LogLevel)
if err != nil {
return nil, err
}
cfg.LogLevel = normalizedLevel
db, err := store.Open(layout.DBPath)
if err != nil {
return nil, err
}
d = &Daemon{
layout: layout,
config: cfg,
store: db,
runner: system.NewRunner(),
logger: logger,
closing: make(chan struct{}),
pid: os.Getpid(),
sessions: newSessionRegistry(),
}
d.ensureVMSSHClientConfig()
d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "log_level", cfg.LogLevel)
if err = d.startVMDNS(vmdns.DefaultListenAddr); err != nil {
d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error())
return nil, err
}
defer func() {
if err != nil {
_ = d.stopVMDNS()
}
}()
if err = d.reconcile(ctx); err != nil {
d.logger.Error("daemon open failed", "stage", "reconcile", "error", err.Error())
return nil, err
}
d.ensureVMDNSResolverRouting(ctx)
if err = d.initializeTapPool(ctx); err != nil {
d.logger.Error("daemon open failed", "stage", "initialize_tap_pool", "error", err.Error())
return nil, err
}
go d.ensureTapPool(context.Background())
return d, nil
}
func (d *Daemon) Close() error {
var err error
d.once.Do(func() {
if d.logger != nil {
d.logger.Info("daemon closing")
}
close(d.closing)
if d.listener != nil {
_ = d.listener.Close()
}
if d.webServer != nil {
_ = d.webServer.Close()
}
if d.webListener != nil {
_ = d.webListener.Close()
}
err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.closeGuestSessionControllers(), d.store.Close())
})
return err
}
func (d *Daemon) Serve(ctx context.Context) error {
_ = os.Remove(d.layout.SocketPath)
listener, err := net.Listen("unix", d.layout.SocketPath)
if err != nil {
if d.logger != nil {
d.logger.Error("daemon listen failed", "socket", d.layout.SocketPath, "error", err.Error())
}
return err
}
d.listener = listener
defer listener.Close()
defer os.Remove(d.layout.SocketPath)
if err := os.Chmod(d.layout.SocketPath, 0o600); err != nil {
return err
}
if d.logger != nil {
d.logger.Info("daemon serving", "socket", d.layout.SocketPath, "pid", d.pid)
}
if err := d.startWebServer(); err != nil {
return err
}
go d.backgroundLoop()
for {
conn, err := listener.Accept()
if err != nil {
select {
case <-ctx.Done():
return nil
case <-d.closing:
return nil
default:
}
if _, ok := err.(net.Error); ok {
if d.logger != nil {
d.logger.Warn("daemon accept temporary failure", "error", err.Error())
}
time.Sleep(100 * time.Millisecond)
continue
}
if d.logger != nil {
d.logger.Error("daemon accept failed", "error", err.Error())
}
return err
}
go d.handleConn(conn)
}
}
func (d *Daemon) handleConn(conn net.Conn) {
defer conn.Close()
reader := bufio.NewReader(conn)
var req rpc.Request
if err := json.NewDecoder(reader).Decode(&req); err != nil {
if d.logger != nil {
d.logger.Warn("daemon request decode failed", "remote", conn.RemoteAddr().String(), "error", err.Error())
}
_ = json.NewEncoder(conn).Encode(rpc.NewError("bad_request", err.Error()))
return
}
reqCtx, cancel := context.WithCancel(context.Background())
defer cancel()
stopWatch := d.watchRequestDisconnect(conn, reader, req.Method, cancel)
defer stopWatch()
resp := d.dispatch(reqCtx, req)
if reqCtx.Err() != nil {
return
}
if err := json.NewEncoder(conn).Encode(resp); err != nil && d.logger != nil {
d.logger.Warn("daemon response encode failed", "method", req.Method, "remote", conn.RemoteAddr().String(), "error", err.Error())
}
}
func (d *Daemon) watchRequestDisconnect(conn net.Conn, reader *bufio.Reader, method string, cancel context.CancelFunc) func() {
if conn == nil || reader == nil {
return func() {}
}
done := make(chan struct{})
var once sync.Once
go func() {
go func() {
<-done
if deadlineSetter, ok := conn.(interface{ SetReadDeadline(time.Time) error }); ok {
_ = deadlineSetter.SetReadDeadline(time.Now())
}
}()
var buf [1]byte
for {
_, err := reader.Read(buf[:])
if err == nil {
continue
}
select {
case <-done:
return
default:
}
if d.logger != nil {
d.logger.Info("daemon request canceled", "method", method, "remote", conn.RemoteAddr().String(), "error", err.Error())
}
cancel()
return
}
}()
return func() {
once.Do(func() {
close(done)
})
}
}
func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
if req.Version != rpc.Version {
return rpc.NewError("bad_version", fmt.Sprintf("unsupported version %d", req.Version))
}
if d.requestHandler != nil {
return d.requestHandler(ctx, req)
}
switch req.Method {
case "ping":
info := buildinfo.Current()
result, _ := rpc.NewResult(api.PingResult{
Status: "ok",
PID: d.pid,
WebURL: d.webURL,
Version: info.Version,
Commit: info.Commit,
BuiltAt: info.BuiltAt,
})
return result
case "shutdown":
go d.Close()
result, _ := rpc.NewResult(api.ShutdownResult{Status: "stopping"})
return result
case "vm.create":
params, err := rpc.DecodeParams[api.VMCreateParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.CreateVM(ctx, params)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.create.begin":
params, err := rpc.DecodeParams[api.VMCreateParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
op, err := d.BeginVMCreate(ctx, params)
return marshalResultOrError(api.VMCreateBeginResult{Operation: op}, err)
case "vm.create.status":
params, err := rpc.DecodeParams[api.VMCreateStatusParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
op, err := d.VMCreateStatus(ctx, params.ID)
return marshalResultOrError(api.VMCreateStatusResult{Operation: op}, err)
case "vm.create.cancel":
params, err := rpc.DecodeParams[api.VMCreateStatusParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
err = d.CancelVMCreate(ctx, params.ID)
return marshalResultOrError(api.Empty{}, err)
case "vm.list":
vms, err := d.store.ListVMs(ctx)
return marshalResultOrError(api.VMListResult{VMs: vms}, err)
case "vm.show":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.FindVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.start":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.StartVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.stop":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.StopVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.kill":
params, err := rpc.DecodeParams[api.VMKillParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.KillVM(ctx, params)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.restart":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.RestartVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.delete":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.DeleteVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.set":
params, err := rpc.DecodeParams[api.VMSetParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.SetVM(ctx, params)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.stats":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, stats, err := d.GetVMStats(ctx, params.IDOrName)
return marshalResultOrError(api.VMStatsResult{VM: vm, Stats: stats}, err)
case "vm.logs":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.FindVM(ctx, params.IDOrName)
if err != nil {
return rpc.NewError("not_found", err.Error())
}
return marshalResultOrError(api.VMLogsResult{LogPath: vm.Runtime.LogPath}, nil)
case "vm.ssh":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.TouchVM(ctx, params.IDOrName)
if err != nil {
return rpc.NewError("not_found", err.Error())
}
if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
return rpc.NewError("not_running", fmt.Sprintf("vm %s is not running", vm.Name))
}
return marshalResultOrError(api.VMSSHResult{Name: vm.Name, GuestIP: vm.Runtime.GuestIP}, nil)
case "vm.health":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.HealthVM(ctx, params.IDOrName)
return marshalResultOrError(result, err)
case "vm.ping":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.PingVM(ctx, params.IDOrName)
return marshalResultOrError(result, err)
case "vm.ports":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.PortsVM(ctx, params.IDOrName)
return marshalResultOrError(result, err)
case "vm.workspace.prepare":
params, err := rpc.DecodeParams[api.VMWorkspacePrepareParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
workspace, err := d.PrepareVMWorkspace(ctx, params)
return marshalResultOrError(api.VMWorkspacePrepareResult{Workspace: workspace}, err)
case "vm.workspace.export":
params, err := rpc.DecodeParams[api.WorkspaceExportParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.ExportVMWorkspace(ctx, params)
return marshalResultOrError(result, err)
case "guest.session.start":
params, err := rpc.DecodeParams[api.GuestSessionStartParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.StartGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.get":
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.GetGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.list":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
sessions, err := d.ListGuestSessions(ctx, params)
return marshalResultOrError(api.GuestSessionListResult{Sessions: sessions}, err)
case "guest.session.stop":
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.StopGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.kill":
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.KillGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.logs":
params, err := rpc.DecodeParams[api.GuestSessionLogsParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.GuestSessionLogs(ctx, params)
return marshalResultOrError(result, err)
case "guest.session.attach.begin":
params, err := rpc.DecodeParams[api.GuestSessionAttachBeginParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.BeginGuestSessionAttach(ctx, params)
return marshalResultOrError(result, err)
case "guest.session.send":
params, err := rpc.DecodeParams[api.GuestSessionSendParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.SendToGuestSession(ctx, params)
return marshalResultOrError(result, err)
case "image.list":
images, err := d.store.ListImages(ctx)
return marshalResultOrError(api.ImageListResult{Images: images}, err)
case "image.show":
params, err := rpc.DecodeParams[api.ImageRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
image, err := d.FindImage(ctx, params.IDOrName)
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
case "image.register":
params, err := rpc.DecodeParams[api.ImageRegisterParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
image, err := d.RegisterImage(ctx, params)
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
case "image.promote":
params, err := rpc.DecodeParams[api.ImageRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
image, err := d.PromoteImage(ctx, params.IDOrName)
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
case "image.delete":
params, err := rpc.DecodeParams[api.ImageRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
image, err := d.DeleteImage(ctx, params.IDOrName)
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
case "image.pull":
params, err := rpc.DecodeParams[api.ImagePullParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
image, err := d.PullImage(ctx, params)
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
case "kernel.list":
return marshalResultOrError(d.KernelList(ctx))
case "kernel.show":
params, err := rpc.DecodeParams[api.KernelRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
entry, err := d.KernelShow(ctx, params.Name)
return marshalResultOrError(api.KernelShowResult{Entry: entry}, err)
case "kernel.delete":
params, err := rpc.DecodeParams[api.KernelRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
err = d.KernelDelete(ctx, params.Name)
return marshalResultOrError(api.Empty{}, err)
case "kernel.import":
params, err := rpc.DecodeParams[api.KernelImportParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
entry, err := d.KernelImport(ctx, params)
return marshalResultOrError(api.KernelShowResult{Entry: entry}, err)
case "kernel.pull":
params, err := rpc.DecodeParams[api.KernelPullParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
entry, err := d.KernelPull(ctx, params)
return marshalResultOrError(api.KernelShowResult{Entry: entry}, err)
case "kernel.catalog":
return marshalResultOrError(d.KernelCatalog(ctx))
default:
return rpc.NewError("unknown_method", req.Method)
}
}
func (d *Daemon) backgroundLoop() {
statsTicker := time.NewTicker(d.config.StatsPollInterval)
staleTicker := time.NewTicker(model.DefaultStaleSweepInterval)
defer statsTicker.Stop()
defer staleTicker.Stop()
for {
select {
case <-d.closing:
return
case <-statsTicker.C:
if err := d.pollStats(context.Background()); err != nil && d.logger != nil {
d.logger.Error("background stats poll failed", "error", err.Error())
}
case <-staleTicker.C:
if err := d.stopStaleVMs(context.Background()); err != nil && d.logger != nil {
d.logger.Error("background stale sweep failed", "error", err.Error())
}
d.pruneVMCreateOperations(time.Now().Add(-10 * time.Minute))
}
}
}
func (d *Daemon) startVMDNS(addr string) error {
server, err := vmdns.New(addr, d.logger)
if err != nil {
return err
}
d.vmDNS = server
if d.logger != nil {
d.logger.Info("vm dns serving", "dns_addr", server.Addr())
}
return nil
}
func (d *Daemon) stopVMDNS() error {
if d.vmDNS == nil {
return nil
}
err := d.vmDNS.Close()
d.vmDNS = nil
return err
}
func (d *Daemon) ensureDefaultImage(ctx context.Context) error {
_ = ctx
return nil
}
func (d *Daemon) reconcile(ctx context.Context) error {
op := d.beginOperation("daemon.reconcile")
vms, err := d.store.ListVMs(ctx)
if err != nil {
return op.fail(err)
}
for _, vm := range vms {
if err := d.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error {
if vm.State != model.VMStateRunning {
return nil
}
if system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
return nil
}
op.stage("stale_vm", vmLogAttrs(vm)...)
_ = d.cleanupRuntime(ctx, vm, true)
vm.State = model.VMStateStopped
vm.Runtime.State = model.VMStateStopped
clearRuntimeHandles(&vm)
vm.UpdatedAt = model.Now()
return d.store.UpsertVM(ctx, vm)
}); err != nil {
return op.fail(err, "vm_id", vm.ID)
}
}
if err := d.rebuildDNS(ctx); err != nil {
return op.fail(err)
}
op.done()
return nil
}
func (d *Daemon) FindVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
if idOrName == "" {
return model.VMRecord{}, errors.New("vm id or name is required")
}
if vm, err := d.store.GetVM(ctx, idOrName); err == nil {
return vm, nil
}
vms, err := d.store.ListVMs(ctx)
if err != nil {
return model.VMRecord{}, err
}
matchCount := 0
var match model.VMRecord
for _, vm := range vms {
if strings.HasPrefix(vm.ID, idOrName) || strings.HasPrefix(vm.Name, idOrName) {
match = vm
matchCount++
}
}
if matchCount == 1 {
return match, nil
}
if matchCount > 1 {
return model.VMRecord{}, fmt.Errorf("multiple VMs match %q", idOrName)
}
return model.VMRecord{}, fmt.Errorf("vm %q not found", idOrName)
}
func (d *Daemon) FindImage(ctx context.Context, idOrName string) (model.Image, error) {
if idOrName == "" {
return model.Image{}, errors.New("image id or name is required")
}
if image, err := d.store.GetImageByName(ctx, idOrName); err == nil {
return image, nil
}
if image, err := d.store.GetImageByID(ctx, idOrName); err == nil {
return image, nil
}
images, err := d.store.ListImages(ctx)
if err != nil {
return model.Image{}, err
}
matchCount := 0
var match model.Image
for _, image := range images {
if strings.HasPrefix(image.ID, idOrName) || strings.HasPrefix(image.Name, idOrName) {
match = image
matchCount++
}
}
if matchCount == 1 {
return match, nil
}
if matchCount > 1 {
return model.Image{}, fmt.Errorf("multiple images match %q", idOrName)
}
return model.Image{}, fmt.Errorf("image %q not found", idOrName)
}
func (d *Daemon) TouchVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
system.TouchNow(&vm)
if err := d.store.UpsertVM(ctx, vm); err != nil {
return model.VMRecord{}, err
}
return vm, nil
})
}
func (d *Daemon) withVMLockByRef(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) {
vm, err := d.FindVM(ctx, idOrName)
if err != nil {
return model.VMRecord{}, err
}
return d.withVMLockByID(ctx, vm.ID, fn)
}
func (d *Daemon) withVMLockByID(ctx context.Context, id string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) {
if strings.TrimSpace(id) == "" {
return model.VMRecord{}, errors.New("vm id is required")
}
unlock := d.lockVMID(id)
defer unlock()
vm, err := d.store.GetVMByID(ctx, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return model.VMRecord{}, fmt.Errorf("vm %q not found", id)
}
return model.VMRecord{}, err
}
return fn(vm)
}
func (d *Daemon) withVMLockByIDErr(ctx context.Context, id string, fn func(model.VMRecord) error) error {
_, err := d.withVMLockByID(ctx, id, func(vm model.VMRecord) (model.VMRecord, error) {
if err := fn(vm); err != nil {
return model.VMRecord{}, err
}
return vm, nil
})
return err
}
func (d *Daemon) lockVMID(id string) func() {
return d.vmLocks.lock(id)
}
func marshalResultOrError(v any, err error) rpc.Response {
if err != nil {
return rpc.NewError("operation_failed", err.Error())
}
resp, marshalErr := rpc.NewResult(v)
if marshalErr != nil {
return rpc.NewError("marshal_failed", marshalErr.Error())
}
return resp
}
func exists(path string) bool {
_, err := os.Stat(path)
return err == nil
}