Add concurrent multi-VM CLI actions
Teach the lifecycle and set commands to accept multiple VM refs, resolve them from one vm list snapshot, dedupe repeated refs, and fan out the existing single-target RPCs concurrently. Valid targets still run when other refs are ambiguous or missing, and batch output stays in first-seen order. Refactor the daemon off the single global VM mutation lock by adding per-VM locks for start/stop/restart/delete/kill/set, touch, reconcile, stale-stop, and stats updates. That keeps same-VM operations serialized while allowing different VMs to progress in parallel, including newly created VMs once their ID exists. Verified with go test ./... and make build.
This commit is contained in:
parent
2d5bcb5516
commit
4812693c1e
5 changed files with 542 additions and 118 deletions
|
|
@ -9,6 +9,7 @@ import (
|
|||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
|
@ -187,9 +188,9 @@ func newVMCommand() *cobra.Command {
|
|||
func newVMKillCommand() *cobra.Command {
|
||||
var signal string
|
||||
cmd := &cobra.Command{
|
||||
Use: "kill <id-or-name>",
|
||||
Use: "kill <id-or-name>...",
|
||||
Short: "Send a signal to a VM process",
|
||||
Args: exactArgsUsage(1, "usage: banger vm kill [--signal SIGTERM|SIGKILL|...] <id-or-name>"),
|
||||
Args: minArgsUsage(1, "usage: banger vm kill [--signal SIGTERM|SIGKILL|...] <id-or-name>..."),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := system.EnsureSudo(cmd.Context()); err != nil {
|
||||
return err
|
||||
|
|
@ -198,6 +199,20 @@ func newVMKillCommand() *cobra.Command {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(args) > 1 {
|
||||
return runVMBatchAction(cmd, layout.SocketPath, args, func(ctx context.Context, id string) (model.VMRecord, error) {
|
||||
result, err := rpc.Call[api.VMShowResult](
|
||||
ctx,
|
||||
layout.SocketPath,
|
||||
"vm.kill",
|
||||
api.VMKillParams{IDOrName: id, Signal: signal},
|
||||
)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
return result.VM, nil
|
||||
})
|
||||
}
|
||||
result, err := rpc.Call[api.VMShowResult](
|
||||
cmd.Context(),
|
||||
layout.SocketPath,
|
||||
|
|
@ -316,9 +331,9 @@ func newVMShowCommand() *cobra.Command {
|
|||
|
||||
func newVMActionCommand(use, short, method string) *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: use + " <id-or-name>",
|
||||
Use: use + " <id-or-name>...",
|
||||
Short: short,
|
||||
Args: exactArgsUsage(1, fmt.Sprintf("usage: banger vm %s <id-or-name>", use)),
|
||||
Args: minArgsUsage(1, fmt.Sprintf("usage: banger vm %s <id-or-name>...", use)),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := system.EnsureSudo(cmd.Context()); err != nil {
|
||||
return err
|
||||
|
|
@ -327,6 +342,15 @@ func newVMActionCommand(use, short, method string) *cobra.Command {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(args) > 1 {
|
||||
return runVMBatchAction(cmd, layout.SocketPath, args, func(ctx context.Context, id string) (model.VMRecord, error) {
|
||||
result, err := rpc.Call[api.VMShowResult](ctx, layout.SocketPath, method, api.VMRefParams{IDOrName: id})
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
return result.VM, nil
|
||||
})
|
||||
}
|
||||
result, err := rpc.Call[api.VMShowResult](cmd.Context(), layout.SocketPath, method, api.VMRefParams{IDOrName: args[0]})
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -345,9 +369,9 @@ func newVMSetCommand() *cobra.Command {
|
|||
noNat bool
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Use: "set <id-or-name>",
|
||||
Use: "set <id-or-name>...",
|
||||
Short: "Update stopped VM settings",
|
||||
Args: exactArgsUsage(1, "usage: banger vm set [--vcpu N] [--memory MiB] [--disk-size SIZE] [--nat|--no-nat] <id-or-name>"),
|
||||
Args: minArgsUsage(1, "usage: banger vm set [--vcpu N] [--memory MiB] [--disk-size SIZE] [--nat|--no-nat] <id-or-name>..."),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
params, err := vmSetParamsFromFlags(args[0], vcpu, memory, diskSize, nat, noNat)
|
||||
if err != nil {
|
||||
|
|
@ -360,6 +384,17 @@ func newVMSetCommand() *cobra.Command {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(args) > 1 {
|
||||
return runVMBatchAction(cmd, layout.SocketPath, args, func(ctx context.Context, id string) (model.VMRecord, error) {
|
||||
batchParams := params
|
||||
batchParams.IDOrName = id
|
||||
result, err := rpc.Call[api.VMShowResult](ctx, layout.SocketPath, "vm.set", batchParams)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
return result.VM, nil
|
||||
})
|
||||
}
|
||||
result, err := rpc.Call[api.VMShowResult](cmd.Context(), layout.SocketPath, "vm.set", params)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -597,6 +632,132 @@ func minArgsUsage(n int, usage string) cobra.PositionalArgs {
|
|||
}
|
||||
}
|
||||
|
||||
type resolvedVMTarget struct {
|
||||
Index int
|
||||
Ref string
|
||||
VM model.VMRecord
|
||||
}
|
||||
|
||||
type vmRefResolutionError struct {
|
||||
Index int
|
||||
Ref string
|
||||
Err error
|
||||
}
|
||||
|
||||
type vmBatchActionResult struct {
|
||||
Target resolvedVMTarget
|
||||
VM model.VMRecord
|
||||
Err error
|
||||
}
|
||||
|
||||
func runVMBatchAction(cmd *cobra.Command, socketPath string, refs []string, action func(context.Context, string) (model.VMRecord, error)) error {
|
||||
listResult, err := rpc.Call[api.VMListResult](cmd.Context(), socketPath, "vm.list", api.Empty{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
targets, resolutionErrs := resolveVMTargets(listResult.VMs, refs)
|
||||
results := executeVMActionBatch(cmd.Context(), targets, action)
|
||||
|
||||
failed := false
|
||||
for _, resolutionErr := range resolutionErrs {
|
||||
if _, err := fmt.Fprintf(cmd.ErrOrStderr(), "%s: %v\n", resolutionErr.Ref, resolutionErr.Err); err != nil {
|
||||
return err
|
||||
}
|
||||
failed = true
|
||||
}
|
||||
for _, result := range results {
|
||||
if result.Err != nil {
|
||||
if _, err := fmt.Fprintf(cmd.ErrOrStderr(), "%s: %v\n", result.Target.Ref, result.Err); err != nil {
|
||||
return err
|
||||
}
|
||||
failed = true
|
||||
continue
|
||||
}
|
||||
if err := printVMSummary(cmd.OutOrStdout(), result.VM); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if failed {
|
||||
return errors.New("one or more VM operations failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveVMTargets(vms []model.VMRecord, refs []string) ([]resolvedVMTarget, []vmRefResolutionError) {
|
||||
targets := make([]resolvedVMTarget, 0, len(refs))
|
||||
resolutionErrs := make([]vmRefResolutionError, 0)
|
||||
seen := make(map[string]struct{}, len(refs))
|
||||
for index, ref := range refs {
|
||||
vm, err := resolveVMRef(vms, ref)
|
||||
if err != nil {
|
||||
resolutionErrs = append(resolutionErrs, vmRefResolutionError{Index: index, Ref: ref, Err: err})
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[vm.ID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[vm.ID] = struct{}{}
|
||||
targets = append(targets, resolvedVMTarget{Index: index, Ref: ref, VM: vm})
|
||||
}
|
||||
return targets, resolutionErrs
|
||||
}
|
||||
|
||||
func resolveVMRef(vms []model.VMRecord, ref string) (model.VMRecord, error) {
|
||||
ref = strings.TrimSpace(ref)
|
||||
if ref == "" {
|
||||
return model.VMRecord{}, errors.New("vm id or name is required")
|
||||
}
|
||||
exactMatches := make([]model.VMRecord, 0, 1)
|
||||
for _, vm := range vms {
|
||||
if vm.ID == ref || vm.Name == ref {
|
||||
exactMatches = append(exactMatches, vm)
|
||||
}
|
||||
}
|
||||
switch len(exactMatches) {
|
||||
case 1:
|
||||
return exactMatches[0], nil
|
||||
case 0:
|
||||
default:
|
||||
return model.VMRecord{}, fmt.Errorf("multiple VMs match %q", ref)
|
||||
}
|
||||
|
||||
prefixMatches := make([]model.VMRecord, 0, 1)
|
||||
for _, vm := range vms {
|
||||
if strings.HasPrefix(vm.ID, ref) || strings.HasPrefix(vm.Name, ref) {
|
||||
prefixMatches = append(prefixMatches, vm)
|
||||
}
|
||||
}
|
||||
switch len(prefixMatches) {
|
||||
case 1:
|
||||
return prefixMatches[0], nil
|
||||
case 0:
|
||||
return model.VMRecord{}, fmt.Errorf("vm %q not found", ref)
|
||||
default:
|
||||
return model.VMRecord{}, fmt.Errorf("multiple VMs match %q", ref)
|
||||
}
|
||||
}
|
||||
|
||||
func executeVMActionBatch(ctx context.Context, targets []resolvedVMTarget, action func(context.Context, string) (model.VMRecord, error)) []vmBatchActionResult {
|
||||
results := make([]vmBatchActionResult, len(targets))
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(targets))
|
||||
for index, target := range targets {
|
||||
index := index
|
||||
target := target
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
vm, err := action(ctx, target.VM.ID)
|
||||
results[index] = vmBatchActionResult{
|
||||
Target: target,
|
||||
VM: vm,
|
||||
Err: err,
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
return results
|
||||
}
|
||||
|
||||
func ensureDaemon(ctx context.Context) (paths.Layout, model.DaemonConfig, error) {
|
||||
layout, err := paths.Resolve()
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -2,11 +2,13 @@ package cli
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"banger/internal/api"
|
||||
"banger/internal/model"
|
||||
|
|
@ -155,6 +157,97 @@ func TestVMSetParamsFromFlagsRejectsNonPositiveCPUAndMemory(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestResolveVMTargetsDeduplicatesAndReportsErrors(t *testing.T) {
|
||||
vms := []model.VMRecord{
|
||||
testCLIResolvedVM("alpha-id", "alpha"),
|
||||
testCLIResolvedVM("alpine-id", "alpine"),
|
||||
testCLIResolvedVM("bravo-id", "bravo"),
|
||||
}
|
||||
|
||||
targets, errs := resolveVMTargets(vms, []string{"alpha", "alpha-id", "al", "missing", "br"})
|
||||
|
||||
if len(targets) != 2 {
|
||||
t.Fatalf("len(targets) = %d, want 2", len(targets))
|
||||
}
|
||||
if targets[0].VM.ID != "alpha-id" || targets[0].Ref != "alpha" {
|
||||
t.Fatalf("targets[0] = %+v, want alpha target", targets[0])
|
||||
}
|
||||
if targets[1].VM.ID != "bravo-id" || targets[1].Ref != "br" {
|
||||
t.Fatalf("targets[1] = %+v, want bravo target", targets[1])
|
||||
}
|
||||
if len(errs) != 2 {
|
||||
t.Fatalf("len(errs) = %d, want 2", len(errs))
|
||||
}
|
||||
if errs[0].Ref != "al" || !strings.Contains(errs[0].Err.Error(), "multiple VMs match") {
|
||||
t.Fatalf("errs[0] = %+v, want ambiguous prefix", errs[0])
|
||||
}
|
||||
if errs[1].Ref != "missing" || !strings.Contains(errs[1].Err.Error(), `vm "missing" not found`) {
|
||||
t.Fatalf("errs[1] = %+v, want missing vm", errs[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveVMRefPrefersExactMatchBeforePrefix(t *testing.T) {
|
||||
vms := []model.VMRecord{
|
||||
testCLIResolvedVM("1111111111111111111111111111111111111111111111111111111111111111", "alpha"),
|
||||
testCLIResolvedVM("alpha222222222222222222222222222222222222222222222222222222222222", "bravo"),
|
||||
}
|
||||
|
||||
vm, err := resolveVMRef(vms, "alpha")
|
||||
if err != nil {
|
||||
t.Fatalf("resolveVMRef(alpha): %v", err)
|
||||
}
|
||||
if vm.Name != "alpha" {
|
||||
t.Fatalf("resolveVMRef(alpha) = %+v, want exact-name vm", vm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteVMActionBatchRunsConcurrentlyAndPreservesOrder(t *testing.T) {
|
||||
targets := []resolvedVMTarget{
|
||||
{Ref: "alpha", VM: testCLIResolvedVM("alpha-id", "alpha")},
|
||||
{Ref: "bravo", VM: testCLIResolvedVM("bravo-id", "bravo")},
|
||||
}
|
||||
|
||||
started := make(chan string, len(targets))
|
||||
release := make(chan struct{})
|
||||
done := make(chan []vmBatchActionResult, 1)
|
||||
go func() {
|
||||
done <- executeVMActionBatch(context.Background(), targets, func(ctx context.Context, id string) (model.VMRecord, error) {
|
||||
started <- id
|
||||
<-release
|
||||
return model.VMRecord{ID: id, Name: id}, nil
|
||||
})
|
||||
}()
|
||||
|
||||
for range targets {
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("batch actions did not overlap")
|
||||
}
|
||||
}
|
||||
|
||||
close(release)
|
||||
|
||||
var results []vmBatchActionResult
|
||||
select {
|
||||
case results = <-done:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("executeVMActionBatch did not finish")
|
||||
}
|
||||
|
||||
if len(results) != len(targets) {
|
||||
t.Fatalf("len(results) = %d, want %d", len(results), len(targets))
|
||||
}
|
||||
for index, result := range results {
|
||||
if result.Target.Ref != targets[index].Ref {
|
||||
t.Fatalf("results[%d].Target.Ref = %q, want %q", index, result.Target.Ref, targets[index].Ref)
|
||||
}
|
||||
if result.VM.ID != targets[index].VM.ID {
|
||||
t.Fatalf("results[%d].VM.ID = %q, want %q", index, result.VM.ID, targets[index].VM.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommandArgs(t *testing.T) {
|
||||
args, err := sshCommandArgs(model.DaemonConfig{SSHKeyPath: "/bundle/id_ed25519"}, "172.16.0.2", []string{"--", "uname", "-a"})
|
||||
if err != nil {
|
||||
|
|
@ -312,3 +405,7 @@ func TestAbsolutizeImageBuildPaths(t *testing.T) {
|
|||
t.Fatalf("params = %+v, want %+v", params, want)
|
||||
}
|
||||
}
|
||||
|
||||
func testCLIResolvedVM(id, name string) model.VMRecord {
|
||||
return model.VMRecord{ID: id, Name: name}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,6 +32,8 @@ type Daemon struct {
|
|||
runner system.CommandRunner
|
||||
logger *slog.Logger
|
||||
mu sync.Mutex
|
||||
vmLocksMu sync.Mutex
|
||||
vmLocks map[string]*sync.Mutex
|
||||
closing chan struct{}
|
||||
once sync.Once
|
||||
pid int
|
||||
|
|
@ -488,26 +490,22 @@ func (d *Daemon) reconcile(ctx context.Context) error {
|
|||
return op.fail(err)
|
||||
}
|
||||
for _, vm := range vms {
|
||||
if err := d.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error {
|
||||
if vm.State != model.VMStateRunning {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
if system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
op.stage("stale_vm", vmLogAttrs(vm)...)
|
||||
_ = d.cleanupRuntime(ctx, vm, true)
|
||||
vm.State = model.VMStateStopped
|
||||
vm.Runtime.State = model.VMStateStopped
|
||||
vm.Runtime.PID = 0
|
||||
vm.Runtime.TapDevice = ""
|
||||
vm.Runtime.APISockPath = ""
|
||||
vm.Runtime.BaseLoop = ""
|
||||
vm.Runtime.COWLoop = ""
|
||||
vm.Runtime.DMName = ""
|
||||
vm.Runtime.DMDev = ""
|
||||
clearRuntimeHandles(&vm)
|
||||
vm.UpdatedAt = model.Now()
|
||||
if err := d.store.UpsertVM(ctx, vm); err != nil {
|
||||
return op.fail(err, vmLogAttrs(vm)...)
|
||||
return d.store.UpsertVM(ctx, vm)
|
||||
}); err != nil {
|
||||
return op.fail(err, "vm_id", vm.ID)
|
||||
}
|
||||
}
|
||||
if err := d.rebuildDNS(ctx); err != nil {
|
||||
|
|
@ -577,17 +575,64 @@ func (d *Daemon) FindImage(ctx context.Context, idOrName string) (model.Image, e
|
|||
}
|
||||
|
||||
func (d *Daemon) TouchVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
vm, err := d.FindVM(ctx, idOrName)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
system.TouchNow(&vm)
|
||||
if err := d.store.UpsertVM(ctx, vm); err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
return vm, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Daemon) withVMLockByRef(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) {
|
||||
vm, err := d.FindVM(ctx, idOrName)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
return d.withVMLockByID(ctx, vm.ID, fn)
|
||||
}
|
||||
|
||||
func (d *Daemon) withVMLockByID(ctx context.Context, id string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return model.VMRecord{}, errors.New("vm id is required")
|
||||
}
|
||||
unlock := d.lockVMID(id)
|
||||
defer unlock()
|
||||
|
||||
vm, err := d.store.GetVMByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return model.VMRecord{}, fmt.Errorf("vm %q not found", id)
|
||||
}
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
return fn(vm)
|
||||
}
|
||||
|
||||
func (d *Daemon) withVMLockByIDErr(ctx context.Context, id string, fn func(model.VMRecord) error) error {
|
||||
_, err := d.withVMLockByID(ctx, id, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
if err := fn(vm); err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
return vm, nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Daemon) lockVMID(id string) func() {
|
||||
d.vmLocksMu.Lock()
|
||||
if d.vmLocks == nil {
|
||||
d.vmLocks = make(map[string]*sync.Mutex)
|
||||
}
|
||||
lock, ok := d.vmLocks[id]
|
||||
if !ok {
|
||||
lock = &sync.Mutex{}
|
||||
d.vmLocks[id] = lock
|
||||
}
|
||||
d.vmLocksMu.Unlock()
|
||||
|
||||
lock.Lock()
|
||||
return lock.Unlock
|
||||
}
|
||||
|
||||
func marshalResultOrError(v any, err error) rpc.Response {
|
||||
|
|
|
|||
|
|
@ -64,6 +64,8 @@ func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm mo
|
|||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
unlockVM := d.lockVMID(id)
|
||||
defer unlockVM()
|
||||
guestIP, err := d.store.NextGuestIP(ctx, bridgePrefix(d.config.BridgeIP))
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
|
|
@ -130,12 +132,7 @@ func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm mo
|
|||
}
|
||||
|
||||
func (d *Daemon) StartVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
vm, err := d.FindVM(ctx, idOrName)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
image, err := d.store.GetImageByID(ctx, vm.ImageID)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
|
|
@ -147,6 +144,7 @@ func (d *Daemon) StartVM(ctx context.Context, idOrName string) (model.VMRecord,
|
|||
return vm, nil
|
||||
}
|
||||
return d.startVMLocked(ctx, vm, image)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image model.Image) (_ model.VMRecord, err error) {
|
||||
|
|
@ -292,10 +290,15 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod
|
|||
return vm, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) StopVM(ctx context.Context, idOrName string) (vm model.VMRecord, err error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
op := d.beginOperation("vm.stop", "vm_ref", idOrName)
|
||||
func (d *Daemon) StopVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
|
||||
return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
return d.stopVMLocked(ctx, vm)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Daemon) stopVMLocked(ctx context.Context, current model.VMRecord) (vm model.VMRecord, err error) {
|
||||
vm = current
|
||||
op := d.beginOperation("vm.stop", "vm_ref", vm.ID)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
op.fail(err, vmLogAttrs(vm)...)
|
||||
|
|
@ -303,10 +306,6 @@ func (d *Daemon) StopVM(ctx context.Context, idOrName string) (vm model.VMRecord
|
|||
}
|
||||
op.done(vmLogAttrs(vm)...)
|
||||
}()
|
||||
vm, err = d.FindVM(ctx, idOrName)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
|
||||
op.stage("cleanup_stale_runtime")
|
||||
if err := d.cleanupRuntime(ctx, vm, true); err != nil {
|
||||
|
|
@ -345,10 +344,15 @@ func (d *Daemon) StopVM(ctx context.Context, idOrName string) (vm model.VMRecord
|
|||
return vm, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) KillVM(ctx context.Context, params api.VMKillParams) (vm model.VMRecord, err error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
op := d.beginOperation("vm.kill", "vm_ref", params.IDOrName, "signal", params.Signal)
|
||||
func (d *Daemon) KillVM(ctx context.Context, params api.VMKillParams) (model.VMRecord, error) {
|
||||
return d.withVMLockByRef(ctx, params.IDOrName, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
return d.killVMLocked(ctx, vm, params.Signal)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Daemon) killVMLocked(ctx context.Context, current model.VMRecord, signalValue string) (vm model.VMRecord, err error) {
|
||||
vm = current
|
||||
op := d.beginOperation("vm.kill", "vm_ref", vm.ID, "signal", signalValue)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
op.fail(err, vmLogAttrs(vm)...)
|
||||
|
|
@ -356,11 +360,6 @@ func (d *Daemon) KillVM(ctx context.Context, params api.VMKillParams) (vm model.
|
|||
}
|
||||
op.done(vmLogAttrs(vm)...)
|
||||
}()
|
||||
|
||||
vm, err = d.FindVM(ctx, params.IDOrName)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
|
||||
op.stage("cleanup_stale_runtime")
|
||||
if err := d.cleanupRuntime(ctx, vm, true); err != nil {
|
||||
|
|
@ -375,7 +374,7 @@ func (d *Daemon) KillVM(ctx context.Context, params api.VMKillParams) (vm model.
|
|||
return vm, nil
|
||||
}
|
||||
|
||||
signal := strings.TrimSpace(params.Signal)
|
||||
signal := strings.TrimSpace(signalValue)
|
||||
if signal == "" {
|
||||
signal = "TERM"
|
||||
}
|
||||
|
|
@ -413,19 +412,34 @@ func (d *Daemon) RestartVM(ctx context.Context, idOrName string) (vm model.VMRec
|
|||
}
|
||||
op.done(vmLogAttrs(vm)...)
|
||||
}()
|
||||
resolved, err := d.FindVM(ctx, idOrName)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
return d.withVMLockByID(ctx, resolved.ID, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
op.stage("stop")
|
||||
vm, err = d.StopVM(ctx, idOrName)
|
||||
vm, err = d.stopVMLocked(ctx, vm)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
image, err := d.store.GetImageByID(ctx, vm.ImageID)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
op.stage("start", vmLogAttrs(vm)...)
|
||||
return d.StartVM(ctx, vm.ID)
|
||||
return d.startVMLocked(ctx, vm, image)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Daemon) DeleteVM(ctx context.Context, idOrName string) (vm model.VMRecord, err error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
op := d.beginOperation("vm.delete", "vm_ref", idOrName)
|
||||
func (d *Daemon) DeleteVM(ctx context.Context, idOrName string) (model.VMRecord, error) {
|
||||
return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
return d.deleteVMLocked(ctx, vm)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Daemon) deleteVMLocked(ctx context.Context, current model.VMRecord) (vm model.VMRecord, err error) {
|
||||
vm = current
|
||||
op := d.beginOperation("vm.delete", "vm_ref", vm.ID)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
op.fail(err, vmLogAttrs(vm)...)
|
||||
|
|
@ -433,10 +447,6 @@ func (d *Daemon) DeleteVM(ctx context.Context, idOrName string) (vm model.VMReco
|
|||
}
|
||||
op.done(vmLogAttrs(vm)...)
|
||||
}()
|
||||
vm, err = d.FindVM(ctx, idOrName)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
if vm.State == model.VMStateRunning && system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
|
||||
op.stage("kill_running_vm", "pid", vm.Runtime.PID)
|
||||
_ = d.killVMProcess(ctx, vm.Runtime.PID)
|
||||
|
|
@ -464,10 +474,15 @@ func (d *Daemon) DeleteVM(ctx context.Context, idOrName string) (vm model.VMReco
|
|||
return vm, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) SetVM(ctx context.Context, params api.VMSetParams) (vm model.VMRecord, err error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
op := d.beginOperation("vm.set", "vm_ref", params.IDOrName)
|
||||
func (d *Daemon) SetVM(ctx context.Context, params api.VMSetParams) (model.VMRecord, error) {
|
||||
return d.withVMLockByRef(ctx, params.IDOrName, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
return d.setVMLocked(ctx, vm, params)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Daemon) setVMLocked(ctx context.Context, current model.VMRecord, params api.VMSetParams) (vm model.VMRecord, err error) {
|
||||
vm = current
|
||||
op := d.beginOperation("vm.set", "vm_ref", vm.ID)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
op.fail(err, vmLogAttrs(vm)...)
|
||||
|
|
@ -475,10 +490,6 @@ func (d *Daemon) SetVM(ctx context.Context, params api.VMSetParams) (vm model.VM
|
|||
}
|
||||
op.done(vmLogAttrs(vm)...)
|
||||
}()
|
||||
vm, err = d.FindVM(ctx, params.IDOrName)
|
||||
if err != nil {
|
||||
return model.VMRecord{}, err
|
||||
}
|
||||
running := vm.State == model.VMStateRunning && system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath)
|
||||
if params.VCPUCount != nil {
|
||||
if err := validateOptionalPositiveSetting("vcpu", params.VCPUCount); err != nil {
|
||||
|
|
@ -541,12 +552,16 @@ func (d *Daemon) SetVM(ctx context.Context, params api.VMSetParams) (vm model.VM
|
|||
}
|
||||
|
||||
func (d *Daemon) GetVMStats(ctx context.Context, idOrName string) (model.VMRecord, model.VMStats, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
vm, err := d.FindVM(ctx, idOrName)
|
||||
vm, err := d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
return d.getVMStatsLocked(ctx, vm)
|
||||
})
|
||||
if err != nil {
|
||||
return model.VMRecord{}, model.VMStats{}, err
|
||||
}
|
||||
return vm, vm.Stats, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) getVMStatsLocked(ctx context.Context, vm model.VMRecord) (model.VMRecord, error) {
|
||||
stats, err := d.collectStats(ctx, vm)
|
||||
if err == nil {
|
||||
vm.Stats = stats
|
||||
|
|
@ -556,30 +571,32 @@ func (d *Daemon) GetVMStats(ctx context.Context, idOrName string) (model.VMRecor
|
|||
d.logger.Debug("vm stats collected", append(vmLogAttrs(vm), "rss_bytes", stats.RSSBytes, "vsz_bytes", stats.VSZBytes, "cpu_percent", stats.CPUPercent)...)
|
||||
}
|
||||
}
|
||||
return vm, vm.Stats, nil
|
||||
return vm, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) pollStats(ctx context.Context) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
vms, err := d.store.ListVMs(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, vm := range vms {
|
||||
if err := d.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error {
|
||||
if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
stats, err := d.collectStats(ctx, vm)
|
||||
if err != nil {
|
||||
if d.logger != nil {
|
||||
d.logger.Debug("vm stats collection failed", append(vmLogAttrs(vm), "error", err.Error())...)
|
||||
}
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
vm.Stats = stats
|
||||
vm.UpdatedAt = model.Now()
|
||||
_ = d.store.UpsertVM(ctx, vm)
|
||||
return d.store.UpsertVM(ctx, vm)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -596,19 +613,18 @@ func (d *Daemon) stopStaleVMs(ctx context.Context) (err error) {
|
|||
}
|
||||
op.done()
|
||||
}()
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
vms, err := d.store.ListVMs(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
now := model.Now()
|
||||
for _, vm := range vms {
|
||||
if err := d.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error {
|
||||
if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
if now.Sub(vm.LastTouchedAt) < d.config.AutoStopStaleAfter {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
op.stage("stopping_vm", vmLogAttrs(vm)...)
|
||||
_ = d.sendCtrlAltDel(ctx, vm)
|
||||
|
|
@ -618,7 +634,10 @@ func (d *Daemon) stopStaleVMs(ctx context.Context) (err error) {
|
|||
vm.Runtime.State = model.VMStateStopped
|
||||
clearRuntimeHandles(&vm)
|
||||
vm.UpdatedAt = model.Now()
|
||||
_ = d.store.UpsertVM(ctx, vm)
|
||||
return d.store.UpsertVM(ctx, vm)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -538,6 +538,108 @@ func TestStopVMFallsBackToForcedCleanupAfterGracefulTimeout(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestWithVMLockByIDSerializesSameVM(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
vm := testVM("serial", "image-serial", "172.16.0.30")
|
||||
if err := db.UpsertVM(ctx, vm); err != nil {
|
||||
t.Fatalf("UpsertVM: %v", err)
|
||||
}
|
||||
d := &Daemon{store: db}
|
||||
|
||||
firstEntered := make(chan struct{})
|
||||
releaseFirst := make(chan struct{})
|
||||
secondEntered := make(chan struct{})
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := d.withVMLockByID(ctx, vm.ID, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
close(firstEntered)
|
||||
<-releaseFirst
|
||||
return vm, nil
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-firstEntered:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("first lock holder did not enter")
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, err := d.withVMLockByID(ctx, vm.ID, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
close(secondEntered)
|
||||
return vm, nil
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-secondEntered:
|
||||
t.Fatal("second same-vm lock holder entered before release")
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(releaseFirst)
|
||||
|
||||
select {
|
||||
case <-secondEntered:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("second same-vm lock holder never entered")
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := <-errCh; err != nil {
|
||||
t.Fatalf("withVMLockByID returned error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithVMLockByIDAllowsDifferentVMsConcurrently(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := openDaemonStore(t)
|
||||
vmA := testVM("alpha-lock", "image-alpha", "172.16.0.31")
|
||||
vmB := testVM("bravo-lock", "image-bravo", "172.16.0.32")
|
||||
for _, vm := range []model.VMRecord{vmA, vmB} {
|
||||
if err := db.UpsertVM(ctx, vm); err != nil {
|
||||
t.Fatalf("UpsertVM(%s): %v", vm.Name, err)
|
||||
}
|
||||
}
|
||||
d := &Daemon{store: db}
|
||||
|
||||
started := make(chan string, 2)
|
||||
release := make(chan struct{})
|
||||
errCh := make(chan error, 2)
|
||||
run := func(id string) {
|
||||
_, err := d.withVMLockByID(ctx, id, func(vm model.VMRecord) (model.VMRecord, error) {
|
||||
started <- vm.ID
|
||||
<-release
|
||||
return vm, nil
|
||||
})
|
||||
errCh <- err
|
||||
}
|
||||
|
||||
go run(vmA.ID)
|
||||
go run(vmB.ID)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("different VM locks did not overlap")
|
||||
}
|
||||
}
|
||||
|
||||
close(release)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := <-errCh; err != nil {
|
||||
t.Fatalf("withVMLockByID returned error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func openDaemonStore(t *testing.T) *store.Store {
|
||||
t.Helper()
|
||||
db, err := store.Open(filepath.Join(t.TempDir(), "state.db"))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue