diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 0da8756..ea3f0fc 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -100,17 +100,24 @@ func Open(ctx context.Context) (d *Daemon, err error) { handles: newHandleCache(), sessions: newSessionRegistry(), } + // From here on, every failure path must run Close() so the host + // state we touched (DNS listener goroutine, resolvectl routing, + // SQLite handle, future side effects) gets unwound. Close is + // idempotent + nil-guarded so it's safe to call on a partially + // initialised daemon — `d.vmDNS == nil` and friends short-circuit + // the teardown of components we never set up. + defer func() { + if err != nil { + _ = d.Close() + } + }() + d.ensureVMSSHClientConfig() d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "log_level", cfg.LogLevel) if err = d.startVMDNS(vmdns.DefaultListenAddr); err != nil { d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error()) return nil, err } - defer func() { - if err != nil { - _ = d.stopVMDNS() - } - }() if err = d.reconcile(ctx); err != nil { d.logger.Error("daemon open failed", "stage", "reconcile", "error", err.Error()) return nil, err diff --git a/internal/daemon/open_close_test.go b/internal/daemon/open_close_test.go new file mode 100644 index 0000000..57d70e4 --- /dev/null +++ b/internal/daemon/open_close_test.go @@ -0,0 +1,147 @@ +package daemon + +import ( + "errors" + "io" + "log/slog" + "sync/atomic" + "testing" + + "banger/internal/model" + "banger/internal/vmdns" +) + +// TestCloseOnPartiallyInitialisedDaemon pins the contract that Open's +// error-path defer relies on: Close must be safe to call when a +// startup step failed before every subsystem was set up. If this +// breaks, `defer d.Close() on err != nil` in Open() starts panicking +// on zero-valued fields. +func TestCloseOnPartiallyInitialisedDaemon(t *testing.T) { + cases := []struct { + name string + build func(t *testing.T) *Daemon + verify func(t *testing.T, d *Daemon) + }{ + { + name: "only store + closing channel (early failure)", + build: func(t *testing.T) *Daemon { + return &Daemon{ + store: openDaemonStore(t), + closing: make(chan struct{}), + sessions: newSessionRegistry(), + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + } + }, + verify: func(t *testing.T, d *Daemon) { + // closing channel should have been closed. + select { + case <-d.closing: + default: + t.Error("closing channel not closed by Close") + } + }, + }, + { + name: "with vmDNS listener (fail after startVMDNS)", + build: func(t *testing.T) *Daemon { + server, err := vmdns.New("127.0.0.1:0", nil) + if err != nil { + t.Fatalf("vmdns.New: %v", err) + } + return &Daemon{ + store: openDaemonStore(t), + closing: make(chan struct{}), + sessions: newSessionRegistry(), + vmDNS: server, + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + } + }, + verify: func(t *testing.T, d *Daemon) { + if d.vmDNS != nil { + t.Error("vmDNS not cleared by Close") + } + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + d := tc.build(t) + if err := d.Close(); err != nil { + t.Fatalf("Close returned error: %v", err) + } + tc.verify(t, d) + + // Second Close must be a no-op (sync.Once) — must not + // panic on channel or re-close. + if err := d.Close(); err != nil { + t.Fatalf("second Close error: %v", err) + } + }) + } +} + +// TestCloseIdempotentUnderConcurrency catches regressions of the +// sync.Once guard that makes repeated Close calls safe. The open- +// failure defer relies on this: if the user cancels before Open +// returns and also calls Close afterwards, both paths must survive. +func TestCloseIdempotentUnderConcurrency(t *testing.T) { + d := &Daemon{ + store: openDaemonStore(t), + closing: make(chan struct{}), + sessions: newSessionRegistry(), + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + config: model.DaemonConfig{BridgeName: ""}, + } + + var count atomic.Int32 + done := make(chan struct{}) + for i := 0; i < 5; i++ { + go func() { + if err := d.Close(); err != nil { + t.Errorf("Close error: %v", err) + } + count.Add(1) + if count.Load() == 5 { + close(done) + } + }() + } + <-done + + // Channel must be closed exactly once (sync.Once covers the + // inner close(d.closing)). Reading from a closed channel is + // non-blocking; panicking here would mean the channel wasn't + // closed or was double-closed (close panics are uncatchable). + select { + case <-d.closing: + default: + t.Fatal("closing channel not closed after concurrent Close calls") + } +} + +// TestOpenFailureRunsCloseCleanup is a structural check: confirms +// the deferred rollback in Open actually fires. Can't easily run +// Open() end-to-end (hits paths.Resolve + sudo), but we can simulate +// the pattern by threading a named-return err through the same +// defer and asserting Close runs. +func TestOpenFailureRunsCloseCleanup(t *testing.T) { + closed := false + fakeClose := func() { closed = true } + + runOpen := func() (err error) { + defer func() { + if err != nil { + fakeClose() + } + }() + err = errors.New("simulated late-stage startup failure") + return err + } + + if err := runOpen(); err == nil { + t.Fatal("expected simulated error") + } + if !closed { + t.Fatal("deferred cleanup did not fire on err != nil") + } +}