diff --git a/internal/cli/banger.go b/internal/cli/banger.go index 7184c61..f412af8 100644 --- a/internal/cli/banger.go +++ b/internal/cli/banger.go @@ -9,6 +9,7 @@ import ( "os/exec" "path/filepath" "strings" + "sync" "syscall" "text/tabwriter" "time" @@ -187,9 +188,9 @@ func newVMCommand() *cobra.Command { func newVMKillCommand() *cobra.Command { var signal string cmd := &cobra.Command{ - Use: "kill ", + Use: "kill ...", Short: "Send a signal to a VM process", - Args: exactArgsUsage(1, "usage: banger vm kill [--signal SIGTERM|SIGKILL|...] "), + Args: minArgsUsage(1, "usage: banger vm kill [--signal SIGTERM|SIGKILL|...] ..."), RunE: func(cmd *cobra.Command, args []string) error { if err := system.EnsureSudo(cmd.Context()); err != nil { return err @@ -198,6 +199,20 @@ func newVMKillCommand() *cobra.Command { if err != nil { return err } + if len(args) > 1 { + return runVMBatchAction(cmd, layout.SocketPath, args, func(ctx context.Context, id string) (model.VMRecord, error) { + result, err := rpc.Call[api.VMShowResult]( + ctx, + layout.SocketPath, + "vm.kill", + api.VMKillParams{IDOrName: id, Signal: signal}, + ) + if err != nil { + return model.VMRecord{}, err + } + return result.VM, nil + }) + } result, err := rpc.Call[api.VMShowResult]( cmd.Context(), layout.SocketPath, @@ -316,9 +331,9 @@ func newVMShowCommand() *cobra.Command { func newVMActionCommand(use, short, method string) *cobra.Command { return &cobra.Command{ - Use: use + " ", + Use: use + " ...", Short: short, - Args: exactArgsUsage(1, fmt.Sprintf("usage: banger vm %s ", use)), + Args: minArgsUsage(1, fmt.Sprintf("usage: banger vm %s ...", use)), RunE: func(cmd *cobra.Command, args []string) error { if err := system.EnsureSudo(cmd.Context()); err != nil { return err @@ -327,6 +342,15 @@ func newVMActionCommand(use, short, method string) *cobra.Command { if err != nil { return err } + if len(args) > 1 { + return runVMBatchAction(cmd, layout.SocketPath, args, func(ctx context.Context, id string) (model.VMRecord, error) { + result, err := rpc.Call[api.VMShowResult](ctx, layout.SocketPath, method, api.VMRefParams{IDOrName: id}) + if err != nil { + return model.VMRecord{}, err + } + return result.VM, nil + }) + } result, err := rpc.Call[api.VMShowResult](cmd.Context(), layout.SocketPath, method, api.VMRefParams{IDOrName: args[0]}) if err != nil { return err @@ -345,9 +369,9 @@ func newVMSetCommand() *cobra.Command { noNat bool ) cmd := &cobra.Command{ - Use: "set ", + Use: "set ...", Short: "Update stopped VM settings", - Args: exactArgsUsage(1, "usage: banger vm set [--vcpu N] [--memory MiB] [--disk-size SIZE] [--nat|--no-nat] "), + Args: minArgsUsage(1, "usage: banger vm set [--vcpu N] [--memory MiB] [--disk-size SIZE] [--nat|--no-nat] ..."), RunE: func(cmd *cobra.Command, args []string) error { params, err := vmSetParamsFromFlags(args[0], vcpu, memory, diskSize, nat, noNat) if err != nil { @@ -360,6 +384,17 @@ func newVMSetCommand() *cobra.Command { if err != nil { return err } + if len(args) > 1 { + return runVMBatchAction(cmd, layout.SocketPath, args, func(ctx context.Context, id string) (model.VMRecord, error) { + batchParams := params + batchParams.IDOrName = id + result, err := rpc.Call[api.VMShowResult](ctx, layout.SocketPath, "vm.set", batchParams) + if err != nil { + return model.VMRecord{}, err + } + return result.VM, nil + }) + } result, err := rpc.Call[api.VMShowResult](cmd.Context(), layout.SocketPath, "vm.set", params) if err != nil { return err @@ -597,6 +632,132 @@ func minArgsUsage(n int, usage string) cobra.PositionalArgs { } } +type resolvedVMTarget struct { + Index int + Ref string + VM model.VMRecord +} + +type vmRefResolutionError struct { + Index int + Ref string + Err error +} + +type vmBatchActionResult struct { + Target resolvedVMTarget + VM model.VMRecord + Err error +} + +func runVMBatchAction(cmd *cobra.Command, socketPath string, refs []string, action func(context.Context, string) (model.VMRecord, error)) error { + listResult, err := rpc.Call[api.VMListResult](cmd.Context(), socketPath, "vm.list", api.Empty{}) + if err != nil { + return err + } + targets, resolutionErrs := resolveVMTargets(listResult.VMs, refs) + results := executeVMActionBatch(cmd.Context(), targets, action) + + failed := false + for _, resolutionErr := range resolutionErrs { + if _, err := fmt.Fprintf(cmd.ErrOrStderr(), "%s: %v\n", resolutionErr.Ref, resolutionErr.Err); err != nil { + return err + } + failed = true + } + for _, result := range results { + if result.Err != nil { + if _, err := fmt.Fprintf(cmd.ErrOrStderr(), "%s: %v\n", result.Target.Ref, result.Err); err != nil { + return err + } + failed = true + continue + } + if err := printVMSummary(cmd.OutOrStdout(), result.VM); err != nil { + return err + } + } + if failed { + return errors.New("one or more VM operations failed") + } + return nil +} + +func resolveVMTargets(vms []model.VMRecord, refs []string) ([]resolvedVMTarget, []vmRefResolutionError) { + targets := make([]resolvedVMTarget, 0, len(refs)) + resolutionErrs := make([]vmRefResolutionError, 0) + seen := make(map[string]struct{}, len(refs)) + for index, ref := range refs { + vm, err := resolveVMRef(vms, ref) + if err != nil { + resolutionErrs = append(resolutionErrs, vmRefResolutionError{Index: index, Ref: ref, Err: err}) + continue + } + if _, ok := seen[vm.ID]; ok { + continue + } + seen[vm.ID] = struct{}{} + targets = append(targets, resolvedVMTarget{Index: index, Ref: ref, VM: vm}) + } + return targets, resolutionErrs +} + +func resolveVMRef(vms []model.VMRecord, ref string) (model.VMRecord, error) { + ref = strings.TrimSpace(ref) + if ref == "" { + return model.VMRecord{}, errors.New("vm id or name is required") + } + exactMatches := make([]model.VMRecord, 0, 1) + for _, vm := range vms { + if vm.ID == ref || vm.Name == ref { + exactMatches = append(exactMatches, vm) + } + } + switch len(exactMatches) { + case 1: + return exactMatches[0], nil + case 0: + default: + return model.VMRecord{}, fmt.Errorf("multiple VMs match %q", ref) + } + + prefixMatches := make([]model.VMRecord, 0, 1) + for _, vm := range vms { + if strings.HasPrefix(vm.ID, ref) || strings.HasPrefix(vm.Name, ref) { + prefixMatches = append(prefixMatches, vm) + } + } + switch len(prefixMatches) { + case 1: + return prefixMatches[0], nil + case 0: + return model.VMRecord{}, fmt.Errorf("vm %q not found", ref) + default: + return model.VMRecord{}, fmt.Errorf("multiple VMs match %q", ref) + } +} + +func executeVMActionBatch(ctx context.Context, targets []resolvedVMTarget, action func(context.Context, string) (model.VMRecord, error)) []vmBatchActionResult { + results := make([]vmBatchActionResult, len(targets)) + var wg sync.WaitGroup + wg.Add(len(targets)) + for index, target := range targets { + index := index + target := target + go func() { + defer wg.Done() + vm, err := action(ctx, target.VM.ID) + results[index] = vmBatchActionResult{ + Target: target, + VM: vm, + Err: err, + } + }() + } + wg.Wait() + return results +} + func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error) { layout, err := paths.Resolve() if err != nil { diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 6a60f8b..67c7165 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -2,11 +2,13 @@ package cli import ( "bytes" + "context" "os" "path/filepath" "reflect" "strings" "testing" + "time" "banger/internal/api" "banger/internal/model" @@ -155,6 +157,97 @@ func TestVMSetParamsFromFlagsRejectsNonPositiveCPUAndMemory(t *testing.T) { } } +func TestResolveVMTargetsDeduplicatesAndReportsErrors(t *testing.T) { + vms := []model.VMRecord{ + testCLIResolvedVM("alpha-id", "alpha"), + testCLIResolvedVM("alpine-id", "alpine"), + testCLIResolvedVM("bravo-id", "bravo"), + } + + targets, errs := resolveVMTargets(vms, []string{"alpha", "alpha-id", "al", "missing", "br"}) + + if len(targets) != 2 { + t.Fatalf("len(targets) = %d, want 2", len(targets)) + } + if targets[0].VM.ID != "alpha-id" || targets[0].Ref != "alpha" { + t.Fatalf("targets[0] = %+v, want alpha target", targets[0]) + } + if targets[1].VM.ID != "bravo-id" || targets[1].Ref != "br" { + t.Fatalf("targets[1] = %+v, want bravo target", targets[1]) + } + if len(errs) != 2 { + t.Fatalf("len(errs) = %d, want 2", len(errs)) + } + if errs[0].Ref != "al" || !strings.Contains(errs[0].Err.Error(), "multiple VMs match") { + t.Fatalf("errs[0] = %+v, want ambiguous prefix", errs[0]) + } + if errs[1].Ref != "missing" || !strings.Contains(errs[1].Err.Error(), `vm "missing" not found`) { + t.Fatalf("errs[1] = %+v, want missing vm", errs[1]) + } +} + +func TestResolveVMRefPrefersExactMatchBeforePrefix(t *testing.T) { + vms := []model.VMRecord{ + testCLIResolvedVM("1111111111111111111111111111111111111111111111111111111111111111", "alpha"), + testCLIResolvedVM("alpha222222222222222222222222222222222222222222222222222222222222", "bravo"), + } + + vm, err := resolveVMRef(vms, "alpha") + if err != nil { + t.Fatalf("resolveVMRef(alpha): %v", err) + } + if vm.Name != "alpha" { + t.Fatalf("resolveVMRef(alpha) = %+v, want exact-name vm", vm) + } +} + +func TestExecuteVMActionBatchRunsConcurrentlyAndPreservesOrder(t *testing.T) { + targets := []resolvedVMTarget{ + {Ref: "alpha", VM: testCLIResolvedVM("alpha-id", "alpha")}, + {Ref: "bravo", VM: testCLIResolvedVM("bravo-id", "bravo")}, + } + + started := make(chan string, len(targets)) + release := make(chan struct{}) + done := make(chan []vmBatchActionResult, 1) + go func() { + done <- executeVMActionBatch(context.Background(), targets, func(ctx context.Context, id string) (model.VMRecord, error) { + started <- id + <-release + return model.VMRecord{ID: id, Name: id}, nil + }) + }() + + for range targets { + select { + case <-started: + case <-time.After(500 * time.Millisecond): + t.Fatal("batch actions did not overlap") + } + } + + close(release) + + var results []vmBatchActionResult + select { + case results = <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("executeVMActionBatch did not finish") + } + + if len(results) != len(targets) { + t.Fatalf("len(results) = %d, want %d", len(results), len(targets)) + } + for index, result := range results { + if result.Target.Ref != targets[index].Ref { + t.Fatalf("results[%d].Target.Ref = %q, want %q", index, result.Target.Ref, targets[index].Ref) + } + if result.VM.ID != targets[index].VM.ID { + t.Fatalf("results[%d].VM.ID = %q, want %q", index, result.VM.ID, targets[index].VM.ID) + } + } +} + func TestSSHCommandArgs(t *testing.T) { args, err := sshCommandArgs(model.DaemonConfig{SSHKeyPath: "/bundle/id_ed25519"}, "172.16.0.2", []string{"--", "uname", "-a"}) if err != nil { @@ -312,3 +405,7 @@ func TestAbsolutizeImageBuildPaths(t *testing.T) { t.Fatalf("params = %+v, want %+v", params, want) } } + +func testCLIResolvedVM(id, name string) model.VMRecord { + return model.VMRecord{ID: id, Name: name} +} diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 48ac7f3..3420936 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -32,6 +32,8 @@ type Daemon struct { runner system.CommandRunner logger *slog.Logger mu sync.Mutex + vmLocksMu sync.Mutex + vmLocks map[string]*sync.Mutex closing chan struct{} once sync.Once pid int @@ -488,26 +490,22 @@ func (d *Daemon) reconcile(ctx context.Context) error { return op.fail(err) } for _, vm := range vms { - if vm.State != model.VMStateRunning { - continue - } - if system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { - continue - } - op.stage("stale_vm", vmLogAttrs(vm)...) - _ = d.cleanupRuntime(ctx, vm, true) - vm.State = model.VMStateStopped - vm.Runtime.State = model.VMStateStopped - vm.Runtime.PID = 0 - vm.Runtime.TapDevice = "" - vm.Runtime.APISockPath = "" - vm.Runtime.BaseLoop = "" - vm.Runtime.COWLoop = "" - vm.Runtime.DMName = "" - vm.Runtime.DMDev = "" - vm.UpdatedAt = model.Now() - if err := d.store.UpsertVM(ctx, vm); err != nil { - return op.fail(err, vmLogAttrs(vm)...) + if err := d.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error { + if vm.State != model.VMStateRunning { + return nil + } + if system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { + return nil + } + op.stage("stale_vm", vmLogAttrs(vm)...) + _ = d.cleanupRuntime(ctx, vm, true) + vm.State = model.VMStateStopped + vm.Runtime.State = model.VMStateStopped + clearRuntimeHandles(&vm) + vm.UpdatedAt = model.Now() + return d.store.UpsertVM(ctx, vm) + }); err != nil { + return op.fail(err, "vm_id", vm.ID) } } if err := d.rebuildDNS(ctx); err != nil { @@ -577,17 +575,64 @@ func (d *Daemon) FindImage(ctx context.Context, idOrName string) (model.Image, e } func (d *Daemon) TouchVM(ctx context.Context, idOrName string) (model.VMRecord, error) { - d.mu.Lock() - defer d.mu.Unlock() + return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { + system.TouchNow(&vm) + if err := d.store.UpsertVM(ctx, vm); err != nil { + return model.VMRecord{}, err + } + return vm, nil + }) +} + +func (d *Daemon) withVMLockByRef(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) { vm, err := d.FindVM(ctx, idOrName) if err != nil { return model.VMRecord{}, err } - system.TouchNow(&vm) - if err := d.store.UpsertVM(ctx, vm); err != nil { + return d.withVMLockByID(ctx, vm.ID, fn) +} + +func (d *Daemon) withVMLockByID(ctx context.Context, id string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) { + if strings.TrimSpace(id) == "" { + return model.VMRecord{}, errors.New("vm id is required") + } + unlock := d.lockVMID(id) + defer unlock() + + vm, err := d.store.GetVMByID(ctx, id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return model.VMRecord{}, fmt.Errorf("vm %q not found", id) + } return model.VMRecord{}, err } - return vm, nil + return fn(vm) +} + +func (d *Daemon) withVMLockByIDErr(ctx context.Context, id string, fn func(model.VMRecord) error) error { + _, err := d.withVMLockByID(ctx, id, func(vm model.VMRecord) (model.VMRecord, error) { + if err := fn(vm); err != nil { + return model.VMRecord{}, err + } + return vm, nil + }) + return err +} + +func (d *Daemon) lockVMID(id string) func() { + d.vmLocksMu.Lock() + if d.vmLocks == nil { + d.vmLocks = make(map[string]*sync.Mutex) + } + lock, ok := d.vmLocks[id] + if !ok { + lock = &sync.Mutex{} + d.vmLocks[id] = lock + } + d.vmLocksMu.Unlock() + + lock.Lock() + return lock.Unlock } func marshalResultOrError(v any, err error) rpc.Response { diff --git a/internal/daemon/vm.go b/internal/daemon/vm.go index dfa4faf..3cc951b 100644 --- a/internal/daemon/vm.go +++ b/internal/daemon/vm.go @@ -64,6 +64,8 @@ func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm mo if err != nil { return model.VMRecord{}, err } + unlockVM := d.lockVMID(id) + defer unlockVM() guestIP, err := d.store.NextGuestIP(ctx, bridgePrefix(d.config.BridgeIP)) if err != nil { return model.VMRecord{}, err @@ -130,23 +132,19 @@ func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm mo } func (d *Daemon) StartVM(ctx context.Context, idOrName string) (model.VMRecord, error) { - d.mu.Lock() - defer d.mu.Unlock() - vm, err := d.FindVM(ctx, idOrName) - if err != nil { - return model.VMRecord{}, err - } - image, err := d.store.GetImageByID(ctx, vm.ImageID) - if err != nil { - return model.VMRecord{}, err - } - if vm.State == model.VMStateRunning && system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { - if d.logger != nil { - d.logger.Info("vm already running", vmLogAttrs(vm)...) + return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { + image, err := d.store.GetImageByID(ctx, vm.ImageID) + if err != nil { + return model.VMRecord{}, err } - return vm, nil - } - return d.startVMLocked(ctx, vm, image) + if vm.State == model.VMStateRunning && system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { + if d.logger != nil { + d.logger.Info("vm already running", vmLogAttrs(vm)...) + } + return vm, nil + } + return d.startVMLocked(ctx, vm, image) + }) } func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image model.Image) (_ model.VMRecord, err error) { @@ -292,10 +290,15 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod return vm, nil } -func (d *Daemon) StopVM(ctx context.Context, idOrName string) (vm model.VMRecord, err error) { - d.mu.Lock() - defer d.mu.Unlock() - op := d.beginOperation("vm.stop", "vm_ref", idOrName) +func (d *Daemon) StopVM(ctx context.Context, idOrName string) (model.VMRecord, error) { + return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { + return d.stopVMLocked(ctx, vm) + }) +} + +func (d *Daemon) stopVMLocked(ctx context.Context, current model.VMRecord) (vm model.VMRecord, err error) { + vm = current + op := d.beginOperation("vm.stop", "vm_ref", vm.ID) defer func() { if err != nil { op.fail(err, vmLogAttrs(vm)...) @@ -303,10 +306,6 @@ func (d *Daemon) StopVM(ctx context.Context, idOrName string) (vm model.VMRecord } op.done(vmLogAttrs(vm)...) }() - vm, err = d.FindVM(ctx, idOrName) - if err != nil { - return model.VMRecord{}, err - } if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { op.stage("cleanup_stale_runtime") if err := d.cleanupRuntime(ctx, vm, true); err != nil { @@ -345,10 +344,15 @@ func (d *Daemon) StopVM(ctx context.Context, idOrName string) (vm model.VMRecord return vm, nil } -func (d *Daemon) KillVM(ctx context.Context, params api.VMKillParams) (vm model.VMRecord, err error) { - d.mu.Lock() - defer d.mu.Unlock() - op := d.beginOperation("vm.kill", "vm_ref", params.IDOrName, "signal", params.Signal) +func (d *Daemon) KillVM(ctx context.Context, params api.VMKillParams) (model.VMRecord, error) { + return d.withVMLockByRef(ctx, params.IDOrName, func(vm model.VMRecord) (model.VMRecord, error) { + return d.killVMLocked(ctx, vm, params.Signal) + }) +} + +func (d *Daemon) killVMLocked(ctx context.Context, current model.VMRecord, signalValue string) (vm model.VMRecord, err error) { + vm = current + op := d.beginOperation("vm.kill", "vm_ref", vm.ID, "signal", signalValue) defer func() { if err != nil { op.fail(err, vmLogAttrs(vm)...) @@ -356,11 +360,6 @@ func (d *Daemon) KillVM(ctx context.Context, params api.VMKillParams) (vm model. } op.done(vmLogAttrs(vm)...) }() - - vm, err = d.FindVM(ctx, params.IDOrName) - if err != nil { - return model.VMRecord{}, err - } if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { op.stage("cleanup_stale_runtime") if err := d.cleanupRuntime(ctx, vm, true); err != nil { @@ -375,7 +374,7 @@ func (d *Daemon) KillVM(ctx context.Context, params api.VMKillParams) (vm model. return vm, nil } - signal := strings.TrimSpace(params.Signal) + signal := strings.TrimSpace(signalValue) if signal == "" { signal = "TERM" } @@ -413,19 +412,34 @@ func (d *Daemon) RestartVM(ctx context.Context, idOrName string) (vm model.VMRec } op.done(vmLogAttrs(vm)...) }() - op.stage("stop") - vm, err = d.StopVM(ctx, idOrName) + resolved, err := d.FindVM(ctx, idOrName) if err != nil { return model.VMRecord{}, err } - op.stage("start", vmLogAttrs(vm)...) - return d.StartVM(ctx, vm.ID) + return d.withVMLockByID(ctx, resolved.ID, func(vm model.VMRecord) (model.VMRecord, error) { + op.stage("stop") + vm, err = d.stopVMLocked(ctx, vm) + if err != nil { + return model.VMRecord{}, err + } + image, err := d.store.GetImageByID(ctx, vm.ImageID) + if err != nil { + return model.VMRecord{}, err + } + op.stage("start", vmLogAttrs(vm)...) + return d.startVMLocked(ctx, vm, image) + }) } -func (d *Daemon) DeleteVM(ctx context.Context, idOrName string) (vm model.VMRecord, err error) { - d.mu.Lock() - defer d.mu.Unlock() - op := d.beginOperation("vm.delete", "vm_ref", idOrName) +func (d *Daemon) DeleteVM(ctx context.Context, idOrName string) (model.VMRecord, error) { + return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { + return d.deleteVMLocked(ctx, vm) + }) +} + +func (d *Daemon) deleteVMLocked(ctx context.Context, current model.VMRecord) (vm model.VMRecord, err error) { + vm = current + op := d.beginOperation("vm.delete", "vm_ref", vm.ID) defer func() { if err != nil { op.fail(err, vmLogAttrs(vm)...) @@ -433,10 +447,6 @@ func (d *Daemon) DeleteVM(ctx context.Context, idOrName string) (vm model.VMReco } op.done(vmLogAttrs(vm)...) }() - vm, err = d.FindVM(ctx, idOrName) - if err != nil { - return model.VMRecord{}, err - } if vm.State == model.VMStateRunning && system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { op.stage("kill_running_vm", "pid", vm.Runtime.PID) _ = d.killVMProcess(ctx, vm.Runtime.PID) @@ -464,10 +474,15 @@ func (d *Daemon) DeleteVM(ctx context.Context, idOrName string) (vm model.VMReco return vm, nil } -func (d *Daemon) SetVM(ctx context.Context, params api.VMSetParams) (vm model.VMRecord, err error) { - d.mu.Lock() - defer d.mu.Unlock() - op := d.beginOperation("vm.set", "vm_ref", params.IDOrName) +func (d *Daemon) SetVM(ctx context.Context, params api.VMSetParams) (model.VMRecord, error) { + return d.withVMLockByRef(ctx, params.IDOrName, func(vm model.VMRecord) (model.VMRecord, error) { + return d.setVMLocked(ctx, vm, params) + }) +} + +func (d *Daemon) setVMLocked(ctx context.Context, current model.VMRecord, params api.VMSetParams) (vm model.VMRecord, err error) { + vm = current + op := d.beginOperation("vm.set", "vm_ref", vm.ID) defer func() { if err != nil { op.fail(err, vmLogAttrs(vm)...) @@ -475,10 +490,6 @@ func (d *Daemon) SetVM(ctx context.Context, params api.VMSetParams) (vm model.VM } op.done(vmLogAttrs(vm)...) }() - vm, err = d.FindVM(ctx, params.IDOrName) - if err != nil { - return model.VMRecord{}, err - } running := vm.State == model.VMStateRunning && system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) if params.VCPUCount != nil { if err := validateOptionalPositiveSetting("vcpu", params.VCPUCount); err != nil { @@ -541,12 +552,16 @@ func (d *Daemon) SetVM(ctx context.Context, params api.VMSetParams) (vm model.VM } func (d *Daemon) GetVMStats(ctx context.Context, idOrName string) (model.VMRecord, model.VMStats, error) { - d.mu.Lock() - defer d.mu.Unlock() - vm, err := d.FindVM(ctx, idOrName) + vm, err := d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { + return d.getVMStatsLocked(ctx, vm) + }) if err != nil { return model.VMRecord{}, model.VMStats{}, err } + return vm, vm.Stats, nil +} + +func (d *Daemon) getVMStatsLocked(ctx context.Context, vm model.VMRecord) (model.VMRecord, error) { stats, err := d.collectStats(ctx, vm) if err == nil { vm.Stats = stats @@ -556,30 +571,32 @@ func (d *Daemon) GetVMStats(ctx context.Context, idOrName string) (model.VMRecor d.logger.Debug("vm stats collected", append(vmLogAttrs(vm), "rss_bytes", stats.RSSBytes, "vsz_bytes", stats.VSZBytes, "cpu_percent", stats.CPUPercent)...) } } - return vm, vm.Stats, nil + return vm, nil } func (d *Daemon) pollStats(ctx context.Context) error { - d.mu.Lock() - defer d.mu.Unlock() vms, err := d.store.ListVMs(ctx) if err != nil { return err } for _, vm := range vms { - if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { - continue - } - stats, err := d.collectStats(ctx, vm) - if err != nil { - if d.logger != nil { - d.logger.Debug("vm stats collection failed", append(vmLogAttrs(vm), "error", err.Error())...) + if err := d.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error { + if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { + return nil } - continue + stats, err := d.collectStats(ctx, vm) + if err != nil { + if d.logger != nil { + d.logger.Debug("vm stats collection failed", append(vmLogAttrs(vm), "error", err.Error())...) + } + return nil + } + vm.Stats = stats + vm.UpdatedAt = model.Now() + return d.store.UpsertVM(ctx, vm) + }); err != nil { + return err } - vm.Stats = stats - vm.UpdatedAt = model.Now() - _ = d.store.UpsertVM(ctx, vm) } return nil } @@ -596,29 +613,31 @@ func (d *Daemon) stopStaleVMs(ctx context.Context) (err error) { } op.done() }() - d.mu.Lock() - defer d.mu.Unlock() vms, err := d.store.ListVMs(ctx) if err != nil { return err } now := model.Now() for _, vm := range vms { - if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { - continue + if err := d.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error { + if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { + return nil + } + if now.Sub(vm.LastTouchedAt) < d.config.AutoStopStaleAfter { + return nil + } + op.stage("stopping_vm", vmLogAttrs(vm)...) + _ = d.sendCtrlAltDel(ctx, vm) + _ = d.waitForExit(ctx, vm.Runtime.PID, vm.Runtime.APISockPath, 10*time.Second) + _ = d.cleanupRuntime(ctx, vm, true) + vm.State = model.VMStateStopped + vm.Runtime.State = model.VMStateStopped + clearRuntimeHandles(&vm) + vm.UpdatedAt = model.Now() + return d.store.UpsertVM(ctx, vm) + }); err != nil { + return err } - if now.Sub(vm.LastTouchedAt) < d.config.AutoStopStaleAfter { - continue - } - op.stage("stopping_vm", vmLogAttrs(vm)...) - _ = d.sendCtrlAltDel(ctx, vm) - _ = d.waitForExit(ctx, vm.Runtime.PID, vm.Runtime.APISockPath, 10*time.Second) - _ = d.cleanupRuntime(ctx, vm, true) - vm.State = model.VMStateStopped - vm.Runtime.State = model.VMStateStopped - clearRuntimeHandles(&vm) - vm.UpdatedAt = model.Now() - _ = d.store.UpsertVM(ctx, vm) } return nil } diff --git a/internal/daemon/vm_test.go b/internal/daemon/vm_test.go index 15ac8ba..79dbe22 100644 --- a/internal/daemon/vm_test.go +++ b/internal/daemon/vm_test.go @@ -538,6 +538,108 @@ func TestStopVMFallsBackToForcedCleanupAfterGracefulTimeout(t *testing.T) { } } +func TestWithVMLockByIDSerializesSameVM(t *testing.T) { + ctx := context.Background() + db := openDaemonStore(t) + vm := testVM("serial", "image-serial", "172.16.0.30") + if err := db.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM: %v", err) + } + d := &Daemon{store: db} + + firstEntered := make(chan struct{}) + releaseFirst := make(chan struct{}) + secondEntered := make(chan struct{}) + errCh := make(chan error, 2) + + go func() { + _, err := d.withVMLockByID(ctx, vm.ID, func(vm model.VMRecord) (model.VMRecord, error) { + close(firstEntered) + <-releaseFirst + return vm, nil + }) + errCh <- err + }() + + select { + case <-firstEntered: + case <-time.After(500 * time.Millisecond): + t.Fatal("first lock holder did not enter") + } + + go func() { + _, err := d.withVMLockByID(ctx, vm.ID, func(vm model.VMRecord) (model.VMRecord, error) { + close(secondEntered) + return vm, nil + }) + errCh <- err + }() + + select { + case <-secondEntered: + t.Fatal("second same-vm lock holder entered before release") + case <-time.After(150 * time.Millisecond): + } + + close(releaseFirst) + + select { + case <-secondEntered: + case <-time.After(500 * time.Millisecond): + t.Fatal("second same-vm lock holder never entered") + } + + for i := 0; i < 2; i++ { + if err := <-errCh; err != nil { + t.Fatalf("withVMLockByID returned error: %v", err) + } + } +} + +func TestWithVMLockByIDAllowsDifferentVMsConcurrently(t *testing.T) { + ctx := context.Background() + db := openDaemonStore(t) + vmA := testVM("alpha-lock", "image-alpha", "172.16.0.31") + vmB := testVM("bravo-lock", "image-bravo", "172.16.0.32") + for _, vm := range []model.VMRecord{vmA, vmB} { + if err := db.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM(%s): %v", vm.Name, err) + } + } + d := &Daemon{store: db} + + started := make(chan string, 2) + release := make(chan struct{}) + errCh := make(chan error, 2) + run := func(id string) { + _, err := d.withVMLockByID(ctx, id, func(vm model.VMRecord) (model.VMRecord, error) { + started <- vm.ID + <-release + return vm, nil + }) + errCh <- err + } + + go run(vmA.ID) + go run(vmB.ID) + + for i := 0; i < 2; i++ { + select { + case <-started: + case <-time.After(500 * time.Millisecond): + t.Fatal("different VM locks did not overlap") + } + } + + close(release) + + for i := 0; i < 2; i++ { + if err := <-errCh; err != nil { + t.Fatalf("withVMLockByID returned error: %v", err) + } + } +} + func openDaemonStore(t *testing.T) *store.Store { t.Helper() db, err := store.Open(filepath.Join(t.TempDir(), "state.db"))