banger/internal/daemon/daemon.go
Thales Maciel 43982a4ae3
Phase B-1: ownership fixup via debugfs pass
imagepull.Flatten now captures per-file uid/gid/mode/type from the
tar headers as it walks layers, returning a Metadata map alongside
the extracted tree. Whiteouts correctly drop the victim's metadata.
The returned Metadata feeds the new imagepull.ApplyOwnership, which
pipes a batched `set_inode_field` script to `debugfs -w -f -`.

Why: mkfs.ext4 -d copies the runner's on-disk uids verbatim, so
without this pass setuid binaries become setuid-nonroot and sshd
refuses to start on the resulting image. With the pass, a pulled
debian:bookworm has /usr/bin/sudo with uid=0 + setuid bit surviving
intact.

imagepull.BuildExt4 signature unchanged; ownership is applied as a
separate step by the daemon orchestrator between BuildExt4 and
StageBootArtifacts, keeping each helper focused. The seam
(d.pullAndFlatten) now returns (Metadata, error) for test stubs to
feed synthetic metadata.

StdinRunner is a new duck-typed extension next to CommandRunner;
the real system.Runner implements RunStdin, test mocks don't need
to unless they exercise stdin. Prevents every existing mock from
growing a new method.

Tests:
 - TestFlattenCapturesHeaderMetadata: setuid bit + mode survive the
   tar-header walk
 - TestApplyOwnershipRewritesUidGidMode: real debugfs round-trip —
   create ext4 with runner's uid, apply synthetic metadata setting
   uid=0 + setuid mode, verify via `debugfs -R stat` that the
   inode now has uid=0 and mode 04755
 - TestBuildOwnershipScriptDeterministic: sorted, well-formed
   sif script output

Debugfs and mkfs.ext4 tests skip if the binaries aren't on PATH.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-16 18:04:22 -03:00

779 lines
24 KiB
Go

package daemon
import (
"bufio"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
"banger/internal/api"
"banger/internal/buildinfo"
"banger/internal/config"
"banger/internal/daemon/opstate"
"banger/internal/imagepull"
"banger/internal/model"
"banger/internal/paths"
"banger/internal/rpc"
"banger/internal/store"
"banger/internal/system"
"banger/internal/vmdns"
)
type Daemon struct {
layout paths.Layout
config model.DaemonConfig
store *store.Store
runner system.CommandRunner
logger *slog.Logger
imageOpsMu sync.Mutex
createVMMu sync.Mutex
createOps opstate.Registry[*vmCreateOperationState]
imageBuildOps opstate.Registry[*imageBuildOperationState]
vmLocks vmLockSet
sessions sessionRegistry
tapPool tapPool
closing chan struct{}
once sync.Once
pid int
listener net.Listener
webListener net.Listener
webServer *http.Server
webURL string
vmDNS *vmdns.Server
vmCaps []vmCapability
imageBuild func(context.Context, imageBuildSpec) error
pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error)
requestHandler func(context.Context, rpc.Request) rpc.Response
guestWaitForSSH func(context.Context, string, string, time.Duration) error
guestDial func(context.Context, string, string) (guestSSHClient, error)
waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error)
}
func Open(ctx context.Context) (d *Daemon, err error) {
layout, err := paths.Resolve()
if err != nil {
return nil, err
}
if err := paths.Ensure(layout); err != nil {
return nil, err
}
cfg, err := config.Load(layout)
if err != nil {
return nil, err
}
logger, normalizedLevel, err := newDaemonLogger(os.Stderr, cfg.LogLevel)
if err != nil {
return nil, err
}
cfg.LogLevel = normalizedLevel
db, err := store.Open(layout.DBPath)
if err != nil {
return nil, err
}
d = &Daemon{
layout: layout,
config: cfg,
store: db,
runner: system.NewRunner(),
logger: logger,
closing: make(chan struct{}),
pid: os.Getpid(),
sessions: newSessionRegistry(),
}
d.ensureVMSSHClientConfig()
d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "log_level", cfg.LogLevel)
if err = d.startVMDNS(vmdns.DefaultListenAddr); err != nil {
d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error())
return nil, err
}
defer func() {
if err != nil {
_ = d.stopVMDNS()
}
}()
if err = d.reconcile(ctx); err != nil {
d.logger.Error("daemon open failed", "stage", "reconcile", "error", err.Error())
return nil, err
}
d.ensureVMDNSResolverRouting(ctx)
if err = d.initializeTapPool(ctx); err != nil {
d.logger.Error("daemon open failed", "stage", "initialize_tap_pool", "error", err.Error())
return nil, err
}
go d.ensureTapPool(context.Background())
return d, nil
}
func (d *Daemon) Close() error {
var err error
d.once.Do(func() {
if d.logger != nil {
d.logger.Info("daemon closing")
}
close(d.closing)
if d.listener != nil {
_ = d.listener.Close()
}
if d.webServer != nil {
_ = d.webServer.Close()
}
if d.webListener != nil {
_ = d.webListener.Close()
}
err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.closeGuestSessionControllers(), d.store.Close())
})
return err
}
func (d *Daemon) Serve(ctx context.Context) error {
_ = os.Remove(d.layout.SocketPath)
listener, err := net.Listen("unix", d.layout.SocketPath)
if err != nil {
if d.logger != nil {
d.logger.Error("daemon listen failed", "socket", d.layout.SocketPath, "error", err.Error())
}
return err
}
d.listener = listener
defer listener.Close()
defer os.Remove(d.layout.SocketPath)
if err := os.Chmod(d.layout.SocketPath, 0o600); err != nil {
return err
}
if d.logger != nil {
d.logger.Info("daemon serving", "socket", d.layout.SocketPath, "pid", d.pid)
}
if err := d.startWebServer(); err != nil {
return err
}
go d.backgroundLoop()
for {
conn, err := listener.Accept()
if err != nil {
select {
case <-ctx.Done():
return nil
case <-d.closing:
return nil
default:
}
if _, ok := err.(net.Error); ok {
if d.logger != nil {
d.logger.Warn("daemon accept temporary failure", "error", err.Error())
}
time.Sleep(100 * time.Millisecond)
continue
}
if d.logger != nil {
d.logger.Error("daemon accept failed", "error", err.Error())
}
return err
}
go d.handleConn(conn)
}
}
func (d *Daemon) handleConn(conn net.Conn) {
defer conn.Close()
reader := bufio.NewReader(conn)
var req rpc.Request
if err := json.NewDecoder(reader).Decode(&req); err != nil {
if d.logger != nil {
d.logger.Warn("daemon request decode failed", "remote", conn.RemoteAddr().String(), "error", err.Error())
}
_ = json.NewEncoder(conn).Encode(rpc.NewError("bad_request", err.Error()))
return
}
reqCtx, cancel := context.WithCancel(context.Background())
defer cancel()
stopWatch := d.watchRequestDisconnect(conn, reader, req.Method, cancel)
defer stopWatch()
resp := d.dispatch(reqCtx, req)
if reqCtx.Err() != nil {
return
}
if err := json.NewEncoder(conn).Encode(resp); err != nil && d.logger != nil {
d.logger.Warn("daemon response encode failed", "method", req.Method, "remote", conn.RemoteAddr().String(), "error", err.Error())
}
}
func (d *Daemon) watchRequestDisconnect(conn net.Conn, reader *bufio.Reader, method string, cancel context.CancelFunc) func() {
if conn == nil || reader == nil {
return func() {}
}
done := make(chan struct{})
var once sync.Once
go func() {
go func() {
<-done
if deadlineSetter, ok := conn.(interface{ SetReadDeadline(time.Time) error }); ok {
_ = deadlineSetter.SetReadDeadline(time.Now())
}
}()
var buf [1]byte
for {
_, err := reader.Read(buf[:])
if err == nil {
continue
}
select {
case <-done:
return
default:
}
if d.logger != nil {
d.logger.Info("daemon request canceled", "method", method, "remote", conn.RemoteAddr().String(), "error", err.Error())
}
cancel()
return
}
}()
return func() {
once.Do(func() {
close(done)
})
}
}
func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
if req.Version != rpc.Version {
return rpc.NewError("bad_version", fmt.Sprintf("unsupported version %d", req.Version))
}
if d.requestHandler != nil {
return d.requestHandler(ctx, req)
}
switch req.Method {
case "ping":
info := buildinfo.Current()
result, _ := rpc.NewResult(api.PingResult{
Status: "ok",
PID: d.pid,
WebURL: d.webURL,
Version: info.Version,
Commit: info.Commit,
BuiltAt: info.BuiltAt,
})
return result
case "shutdown":
go d.Close()
result, _ := rpc.NewResult(api.ShutdownResult{Status: "stopping"})
return result
case "vm.create":
params, err := rpc.DecodeParams[api.VMCreateParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.CreateVM(ctx, params)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.create.begin":
params, err := rpc.DecodeParams[api.VMCreateParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
op, err := d.BeginVMCreate(ctx, params)
return marshalResultOrError(api.VMCreateBeginResult{Operation: op}, err)
case "vm.create.status":
params, err := rpc.DecodeParams[api.VMCreateStatusParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
op, err := d.VMCreateStatus(ctx, params.ID)
return marshalResultOrError(api.VMCreateStatusResult{Operation: op}, err)
case "vm.create.cancel":
params, err := rpc.DecodeParams[api.VMCreateStatusParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
err = d.CancelVMCreate(ctx, params.ID)
return marshalResultOrError(api.Empty{}, err)
case "vm.list":
vms, err := d.store.ListVMs(ctx)
return marshalResultOrError(api.VMListResult{VMs: vms}, err)
case "vm.show":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.FindVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.start":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.StartVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.stop":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.StopVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.kill":
params, err := rpc.DecodeParams[api.VMKillParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.KillVM(ctx, params)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.restart":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.RestartVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.delete":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.DeleteVM(ctx, params.IDOrName)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.set":
params, err := rpc.DecodeParams[api.VMSetParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.SetVM(ctx, params)
return marshalResultOrError(api.VMShowResult{VM: vm}, err)
case "vm.stats":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, stats, err := d.GetVMStats(ctx, params.IDOrName)
return marshalResultOrError(api.VMStatsResult{VM: vm, Stats: stats}, err)
case "vm.logs":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.FindVM(ctx, params.IDOrName)
if err != nil {
return rpc.NewError("not_found", err.Error())
}
return marshalResultOrError(api.VMLogsResult{LogPath: vm.Runtime.LogPath}, nil)
case "vm.ssh":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
vm, err := d.TouchVM(ctx, params.IDOrName)
if err != nil {
return rpc.NewError("not_found", err.Error())
}
if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
return rpc.NewError("not_running", fmt.Sprintf("vm %s is not running", vm.Name))
}
return marshalResultOrError(api.VMSSHResult{Name: vm.Name, GuestIP: vm.Runtime.GuestIP}, nil)
case "vm.health":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.HealthVM(ctx, params.IDOrName)
return marshalResultOrError(result, err)
case "vm.ping":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.PingVM(ctx, params.IDOrName)
return marshalResultOrError(result, err)
case "vm.ports":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.PortsVM(ctx, params.IDOrName)
return marshalResultOrError(result, err)
case "vm.workspace.prepare":
params, err := rpc.DecodeParams[api.VMWorkspacePrepareParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
workspace, err := d.PrepareVMWorkspace(ctx, params)
return marshalResultOrError(api.VMWorkspacePrepareResult{Workspace: workspace}, err)
case "vm.workspace.export":
params, err := rpc.DecodeParams[api.WorkspaceExportParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.ExportVMWorkspace(ctx, params)
return marshalResultOrError(result, err)
case "guest.session.start":
params, err := rpc.DecodeParams[api.GuestSessionStartParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.StartGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.get":
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.GetGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.list":
params, err := rpc.DecodeParams[api.VMRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
sessions, err := d.ListGuestSessions(ctx, params)
return marshalResultOrError(api.GuestSessionListResult{Sessions: sessions}, err)
case "guest.session.stop":
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.StopGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.kill":
params, err := rpc.DecodeParams[api.GuestSessionRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
session, err := d.KillGuestSession(ctx, params)
return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err)
case "guest.session.logs":
params, err := rpc.DecodeParams[api.GuestSessionLogsParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.GuestSessionLogs(ctx, params)
return marshalResultOrError(result, err)
case "guest.session.attach.begin":
params, err := rpc.DecodeParams[api.GuestSessionAttachBeginParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.BeginGuestSessionAttach(ctx, params)
return marshalResultOrError(result, err)
case "guest.session.send":
params, err := rpc.DecodeParams[api.GuestSessionSendParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
result, err := d.SendToGuestSession(ctx, params)
return marshalResultOrError(result, err)
case "image.list":
images, err := d.store.ListImages(ctx)
return marshalResultOrError(api.ImageListResult{Images: images}, err)
case "image.show":
params, err := rpc.DecodeParams[api.ImageRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
image, err := d.FindImage(ctx, params.IDOrName)
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
case "image.build":
params, err := rpc.DecodeParams[api.ImageBuildParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
image, err := d.BuildImage(ctx, params)
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
case "image.build.begin":
params, err := rpc.DecodeParams[api.ImageBuildParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
op, err := d.BeginImageBuild(ctx, params)
return marshalResultOrError(api.ImageBuildBeginResult{Operation: op}, err)
case "image.build.status":
params, err := rpc.DecodeParams[api.ImageBuildStatusParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
op, err := d.ImageBuildStatus(ctx, params.ID)
return marshalResultOrError(api.ImageBuildStatusResult{Operation: op}, err)
case "image.build.cancel":
params, err := rpc.DecodeParams[api.ImageBuildStatusParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
err = d.CancelImageBuild(ctx, params.ID)
return marshalResultOrError(api.Empty{}, err)
case "image.register":
params, err := rpc.DecodeParams[api.ImageRegisterParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
image, err := d.RegisterImage(ctx, params)
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
case "image.promote":
params, err := rpc.DecodeParams[api.ImageRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
image, err := d.PromoteImage(ctx, params.IDOrName)
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
case "image.delete":
params, err := rpc.DecodeParams[api.ImageRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
image, err := d.DeleteImage(ctx, params.IDOrName)
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
case "image.pull":
params, err := rpc.DecodeParams[api.ImagePullParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
image, err := d.PullImage(ctx, params)
return marshalResultOrError(api.ImageShowResult{Image: image}, err)
case "kernel.list":
return marshalResultOrError(d.KernelList(ctx))
case "kernel.show":
params, err := rpc.DecodeParams[api.KernelRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
entry, err := d.KernelShow(ctx, params.Name)
return marshalResultOrError(api.KernelShowResult{Entry: entry}, err)
case "kernel.delete":
params, err := rpc.DecodeParams[api.KernelRefParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
err = d.KernelDelete(ctx, params.Name)
return marshalResultOrError(api.Empty{}, err)
case "kernel.import":
params, err := rpc.DecodeParams[api.KernelImportParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
entry, err := d.KernelImport(ctx, params)
return marshalResultOrError(api.KernelShowResult{Entry: entry}, err)
case "kernel.pull":
params, err := rpc.DecodeParams[api.KernelPullParams](req)
if err != nil {
return rpc.NewError("bad_request", err.Error())
}
entry, err := d.KernelPull(ctx, params)
return marshalResultOrError(api.KernelShowResult{Entry: entry}, err)
case "kernel.catalog":
return marshalResultOrError(d.KernelCatalog(ctx))
default:
return rpc.NewError("unknown_method", req.Method)
}
}
func (d *Daemon) backgroundLoop() {
statsTicker := time.NewTicker(d.config.StatsPollInterval)
staleTicker := time.NewTicker(model.DefaultStaleSweepInterval)
defer statsTicker.Stop()
defer staleTicker.Stop()
for {
select {
case <-d.closing:
return
case <-statsTicker.C:
if err := d.pollStats(context.Background()); err != nil && d.logger != nil {
d.logger.Error("background stats poll failed", "error", err.Error())
}
case <-staleTicker.C:
if err := d.stopStaleVMs(context.Background()); err != nil && d.logger != nil {
d.logger.Error("background stale sweep failed", "error", err.Error())
}
d.pruneVMCreateOperations(time.Now().Add(-10 * time.Minute))
d.pruneImageBuildOperations(time.Now().Add(-10 * time.Minute))
}
}
}
func (d *Daemon) startVMDNS(addr string) error {
server, err := vmdns.New(addr, d.logger)
if err != nil {
return err
}
d.vmDNS = server
if d.logger != nil {
d.logger.Info("vm dns serving", "dns_addr", server.Addr())
}
return nil
}
func (d *Daemon) stopVMDNS() error {
if d.vmDNS == nil {
return nil
}
err := d.vmDNS.Close()
d.vmDNS = nil
return err
}
func (d *Daemon) ensureDefaultImage(ctx context.Context) error {
_ = ctx
return nil
}
func (d *Daemon) reconcile(ctx context.Context) error {
op := d.beginOperation("daemon.reconcile")
vms, err := d.store.ListVMs(ctx)
if err != nil {
return op.fail(err)
}
for _, vm := range vms {
if err := d.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error {
if vm.State != model.VMStateRunning {
return nil
}
if system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
return nil
}
op.stage("stale_vm", vmLogAttrs(vm)...)
_ = d.cleanupRuntime(ctx, vm, true)
vm.State = model.VMStateStopped
vm.Runtime.State = model.VMStateStopped
clearRuntimeHandles(&vm)
vm.UpdatedAt = model.Now()
return d.store.UpsertVM(ctx, vm)
}); err != nil {
return op.fail(err, "vm_id", vm.ID)
}
}
if err := d.rebuildDNS(ctx); err != nil {
return op.fail(err)
}
op.done()
return nil
}
func (d *Daemon) FindVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
if idOrName == "" {
return model.VMRecord{}, errors.New("vm id or name is required")
}
if vm, err := d.store.GetVM(ctx, idOrName); err == nil {
return vm, nil
}
vms, err := d.store.ListVMs(ctx)
if err != nil {
return model.VMRecord{}, err
}
matchCount := 0
var match model.VMRecord
for _, vm := range vms {
if strings.HasPrefix(vm.ID, idOrName) || strings.HasPrefix(vm.Name, idOrName) {
match = vm
matchCount++
}
}
if matchCount == 1 {
return match, nil
}
if matchCount > 1 {
return model.VMRecord{}, fmt.Errorf("multiple VMs match %q", idOrName)
}
return model.VMRecord{}, fmt.Errorf("vm %q not found", idOrName)
}
func (d *Daemon) FindImage(ctx context.Context, idOrName string) (model.Image, error) {
if idOrName == "" {
return model.Image{}, errors.New("image id or name is required")
}
if image, err := d.store.GetImageByName(ctx, idOrName); err == nil {
return image, nil
}
if image, err := d.store.GetImageByID(ctx, idOrName); err == nil {
return image, nil
}
images, err := d.store.ListImages(ctx)
if err != nil {
return model.Image{}, err
}
matchCount := 0
var match model.Image
for _, image := range images {
if strings.HasPrefix(image.ID, idOrName) || strings.HasPrefix(image.Name, idOrName) {
match = image
matchCount++
}
}
if matchCount == 1 {
return match, nil
}
if matchCount > 1 {
return model.Image{}, fmt.Errorf("multiple images match %q", idOrName)
}
return model.Image{}, fmt.Errorf("image %q not found", idOrName)
}
func (d *Daemon) TouchVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
system.TouchNow(&vm)
if err := d.store.UpsertVM(ctx, vm); err != nil {
return model.VMRecord{}, err
}
return vm, nil
})
}
func (d *Daemon) withVMLockByRef(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) {
vm, err := d.FindVM(ctx, idOrName)
if err != nil {
return model.VMRecord{}, err
}
return d.withVMLockByID(ctx, vm.ID, fn)
}
func (d *Daemon) withVMLockByID(ctx context.Context, id string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) {
if strings.TrimSpace(id) == "" {
return model.VMRecord{}, errors.New("vm id is required")
}
unlock := d.lockVMID(id)
defer unlock()
vm, err := d.store.GetVMByID(ctx, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return model.VMRecord{}, fmt.Errorf("vm %q not found", id)
}
return model.VMRecord{}, err
}
return fn(vm)
}
func (d *Daemon) withVMLockByIDErr(ctx context.Context, id string, fn func(model.VMRecord) error) error {
_, err := d.withVMLockByID(ctx, id, func(vm model.VMRecord) (model.VMRecord, error) {
if err := fn(vm); err != nil {
return model.VMRecord{}, err
}
return vm, nil
})
return err
}
func (d *Daemon) lockVMID(id string) func() {
return d.vmLocks.lock(id)
}
func marshalResultOrError(v any, err error) rpc.Response {
if err != nil {
return rpc.NewError("operation_failed", err.Error())
}
resp, marshalErr := rpc.NewResult(v)
if marshalErr != nil {
return rpc.NewError("marshal_failed", marshalErr.Error())
}
return resp
}
func exists(path string) bool {
_, err := os.Stat(path)
return err == nil
}