diff --git a/internal/daemon/vm.go b/internal/daemon/vm.go index 511f693..dfa4faf 100644 --- a/internal/daemon/vm.go +++ b/internal/daemon/vm.go @@ -18,6 +18,11 @@ import ( "banger/internal/vmdns" ) +var ( + errWaitForExitTimeout = errors.New("timed out waiting for VM to exit") + gracefulShutdownWait = 10 * time.Second +) + func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm model.VMRecord, err error) { d.mu.Lock() defer d.mu.Unlock() @@ -320,8 +325,11 @@ func (d *Daemon) StopVM(ctx context.Context, idOrName string) (vm model.VMRecord return model.VMRecord{}, err } op.stage("wait_for_exit", "pid", vm.Runtime.PID) - if err := d.waitForExit(ctx, vm.Runtime.PID, vm.Runtime.APISockPath, 30*time.Second); err != nil { - return model.VMRecord{}, err + if err := d.waitForExit(ctx, vm.Runtime.PID, vm.Runtime.APISockPath, gracefulShutdownWait); err != nil { + if !errors.Is(err, errWaitForExitTimeout) { + return model.VMRecord{}, err + } + op.stage("graceful_shutdown_timeout", "pid", vm.Runtime.PID) } op.stage("cleanup_runtime") if err := d.cleanupRuntime(ctx, vm, true); err != nil { @@ -377,7 +385,10 @@ func (d *Daemon) KillVM(ctx context.Context, params api.VMKillParams) (vm model. } op.stage("wait_for_exit", "pid", vm.Runtime.PID) if err := d.waitForExit(ctx, vm.Runtime.PID, vm.Runtime.APISockPath, 30*time.Second); err != nil { - return model.VMRecord{}, err + if !errors.Is(err, errWaitForExitTimeout) { + return model.VMRecord{}, err + } + op.stage("signal_timeout", "pid", vm.Runtime.PID, "signal", signal) } op.stage("cleanup_runtime") if err := d.cleanupRuntime(ctx, vm, true); err != nil { @@ -810,7 +821,7 @@ func (d *Daemon) waitForExit(ctx context.Context, pid int, apiSock string, timeo return nil } if time.Now().After(deadline) { - return fmt.Errorf("timed out waiting for VM to exit") + return errWaitForExitTimeout } select { case <-ctx.Done(): diff --git a/internal/daemon/vm_test.go b/internal/daemon/vm_test.go index c6f44ba..15ac8ba 100644 --- a/internal/daemon/vm_test.go +++ b/internal/daemon/vm_test.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "net" + "net/http" "os" "os/exec" "path/filepath" @@ -480,6 +482,62 @@ func TestCleanupRuntimeRediscoversLiveFirecrackerPID(t *testing.T) { } } +func TestStopVMFallsBackToForcedCleanupAfterGracefulTimeout(t *testing.T) { + ctx := context.Background() + db := openDaemonStore(t) + apiSock := filepath.Join(t.TempDir(), "fc.sock") + startFakeFirecrackerAPI(t, apiSock) + + fake := startFakeFirecrackerProcess(t, apiSock) + t.Cleanup(func() { + if fake.ProcessState == nil || !fake.ProcessState.Exited() { + _ = fake.Process.Kill() + _ = fake.Wait() + } + }) + + oldGracefulWait := gracefulShutdownWait + gracefulShutdownWait = 50 * time.Millisecond + t.Cleanup(func() { + gracefulShutdownWait = oldGracefulWait + }) + + vm := testVM("stubborn", "image-stubborn", "172.16.0.23") + vm.State = model.VMStateRunning + vm.Runtime.State = model.VMStateRunning + vm.Runtime.PID = fake.Process.Pid + vm.Runtime.APISockPath = apiSock + if err := db.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM: %v", err) + } + + runner := &processKillingRunner{ + scriptedRunner: &scriptedRunner{ + t: t, + steps: []runnerStep{ + sudoStep("", nil, "chown", fmt.Sprintf("%d:%d", os.Getuid(), os.Getgid()), apiSock), + sudoStep("", nil, "chmod", "600", apiSock), + {call: runnerCall{name: "pgrep", args: []string{"-n", "-f", apiSock}}, out: []byte(strconv.Itoa(fake.Process.Pid) + "\n")}, + sudoStep("", nil, "kill", "-KILL", strconv.Itoa(fake.Process.Pid)), + }, + }, + proc: fake, + } + d := &Daemon{store: db, runner: runner} + + got, err := d.StopVM(ctx, vm.ID) + if err != nil { + t.Fatalf("StopVM returned error: %v", err) + } + runner.assertExhausted() + if got.State != model.VMStateStopped || got.Runtime.State != model.VMStateStopped { + t.Fatalf("StopVM state = %s/%s, want stopped", got.State, got.Runtime.State) + } + if got.Runtime.PID != 0 || got.Runtime.APISockPath != "" { + t.Fatalf("runtime handles not cleared: %+v", got.Runtime) + } +} + func openDaemonStore(t *testing.T) *store.Store { t.Helper() db, err := store.Open(filepath.Join(t.TempDir(), "state.db")) @@ -552,6 +610,34 @@ func startFakeFirecrackerProcess(t *testing.T, apiSock string) *exec.Cmd { return nil } +func startFakeFirecrackerAPI(t *testing.T, apiSock string) { + t.Helper() + + if err := os.MkdirAll(filepath.Dir(apiSock), 0o755); err != nil { + t.Fatalf("MkdirAll(%s): %v", filepath.Dir(apiSock), err) + } + listener, err := net.Listen("unix", apiSock) + if err != nil { + t.Fatalf("listen unix %s: %v", apiSock, err) + } + mux := http.NewServeMux() + mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + w.WriteHeader(http.StatusNoContent) + }) + server := &http.Server{Handler: mux} + go func() { + _ = server.Serve(listener) + }() + t.Cleanup(func() { + _ = server.Close() + _ = os.Remove(apiSock) + }) +} + type processKillingRunner struct { *scriptedRunner proc *exec.Cmd