updater: download/stage/swap/rollback flow steps

The pure-logic core of `banger update`. No CLI yet; this commit
ships the steps the next commit's command will orchestrate.

  * download.go — DownloadRelease fetches SHA256SUMS, parses it,
    looks up the tarball's basename, then streams the tarball
    through download.FetchVerified so the hash is checked on the
    fly. Returns the SHA256SUMS bytes alongside so a future
    cosign-verification step can validate them against an embedded
    public key before trusting the hashes inside.
    Also: fetchBounded for small bounded GETs (manifest, sums file,
    future signature), DefaultStagingDir, EnsureStagingDir,
    PrepareCleanStaging.
  * stage.go — StageTarball reads gzip+tar, validates the entry
    set is exactly {banger, bangerd, banger-vsock-agent} (no
    extras, no missing, no path traversal, no non-regular files),
    extracts at mode 0755 regardless of what the tarball claims.
    StagedRelease records the resulting paths.
  * swap.go — InstallTargets pins the canonical install paths
    (/usr/local/bin/banger, /usr/local/bin/bangerd,
    /usr/local/lib/banger/banger-vsock-agent). Swap orders the
    three replacements vsock → bangerd → banger so the most
    impactful binary (the CLI) goes last; each step uses
    system.AtomicReplace and accumulates a SwapResult so partial
    failures can be rolled back cleanly. Rollback unwinds in
    reverse, joining errors so a half-rolled-back state surfaces
    enough info for an operator to fix manually. CleanupBackups
    removes the .previous trail after `banger doctor` confirms
    the new install is healthy.
  * installmeta.UpdateBuildInfo — small helper that refreshes
    Version/Commit/BuiltAt on /etc/banger/install.toml without
    re-running the full system install. Preserves OwnerUser/UID/
    GID/Home and the original InstalledAt timestamp.

Tests: stage rejects extra entries / missing entries / path
traversal / non-regular files; happy-path stages all three at 0755
with correct contents. Swap+Rollback covers the all-three-succeed
path (then verifies .previous backups exist + rollback restores
old contents) AND the partial-failure path (third swap blocked by
a non-dir parent → SwappedTargets = 2 → rollback unwinds those
two cleanly). DownloadRelease covers happy path, tarball-not-in-
SHA256SUMS, and propagated sha256 mismatch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Thales Maciel 2026-04-29 12:30:22 -03:00
parent fb6d2b1dae
commit 91af367208
No known key found for this signature in database
GPG key ID: 33112E6833C34679
5 changed files with 746 additions and 0 deletions

View file

@ -97,6 +97,30 @@ func Save(path string, meta Metadata) error {
return os.WriteFile(path, data, 0o644)
}
// UpdateBuildInfo refreshes only the Version / Commit / BuiltAt
// fields on the install metadata, preserving everything else
// (OwnerUser/UID/GID/Home and the original InstalledAt timestamp).
// Used by `banger update` to record what's running after a
// successful binary swap; the install identity is unchanged so
// re-running `banger system install` is not required.
//
// Errors when path doesn't exist or can't be parsed — `banger
// update` runs in system mode where install.toml IS the source of
// truth; a missing file means we shouldn't be updating at all.
func UpdateBuildInfo(path, version, commit, builtAt string) error {
if strings.TrimSpace(path) == "" {
path = DefaultPath
}
meta, err := Load(path)
if err != nil {
return err
}
meta.Version = strings.TrimSpace(version)
meta.Commit = strings.TrimSpace(commit)
meta.BuiltAt = strings.TrimSpace(builtAt)
return Save(path, meta)
}
func (m Metadata) Validate() error {
if strings.TrimSpace(m.OwnerUser) == "" {
return fmt.Errorf("install metadata missing owner_user")

View file

@ -0,0 +1,117 @@
package updater
import (
"context"
"fmt"
"io"
"net/http"
"os"
"path"
"path/filepath"
"banger/internal/download"
)
// DownloadRelease fetches the SHA256SUMS file for `release`, looks up
// the tarball's basename in it, then fetches the tarball with on-the-
// fly hash verification. The tarball lands at dstPath; the function
// errors on any verification failure and removes the partial file
// before returning.
//
// SHA256SUMS bytes are returned alongside so the caller can
// cosign-verify them against an embedded public key before trusting
// the hashes inside. Without that step this function is only as
// secure as TLS; see verify_signature.go for the cosign tie-in.
func DownloadRelease(ctx context.Context, client *http.Client, release Release, dstPath string) (sumsBody []byte, err error) {
if client == nil {
client = http.DefaultClient
}
sumsBody, err = fetchBounded(ctx, client, release.SHA256SumsURL, MaxSHA256SumsBytes)
if err != nil {
return nil, fmt.Errorf("fetch SHA256SUMS: %w", err)
}
sums, err := ParseSHA256Sums(sumsBody)
if err != nil {
return nil, fmt.Errorf("parse SHA256SUMS: %w", err)
}
tarballName := path.Base(release.TarballURL)
expected, ok := sums[tarballName]
if !ok {
return nil, fmt.Errorf("SHA256SUMS does not list %q", tarballName)
}
if _, err := download.FetchVerified(ctx, client, release.TarballURL, expected, MaxTarballBytes, dstPath); err != nil {
return nil, fmt.Errorf("fetch tarball: %w", err)
}
return sumsBody, nil
}
// fetchBounded does a small bounded GET — used for the manifest, the
// SHA256SUMS file, and (later) the cosign signature. Anything bigger
// goes through download.FetchVerified, which adds the on-the-fly
// hash check.
func fetchBounded(ctx context.Context, client *http.Client, url string, maxBytes int64) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("fetch %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("fetch %s: HTTP %s", url, resp.Status)
}
if resp.ContentLength > maxBytes {
return nil, fmt.Errorf("fetch %s: %d bytes exceeds %d-byte cap", url, resp.ContentLength, maxBytes)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, maxBytes+1))
if err != nil {
return nil, fmt.Errorf("read %s: %w", url, err)
}
if int64(len(body)) > maxBytes {
return nil, fmt.Errorf("%s exceeded %d-byte cap mid-stream", url, maxBytes)
}
return body, nil
}
// EnsureStagingDir creates the staging directory with restrictive
// permissions (0700, owned by the caller — typically root in system
// mode). Any pre-existing contents are NOT cleared; that's
// PrepareCleanStaging's job.
func EnsureStagingDir(stagingDir string) error {
return os.MkdirAll(stagingDir, 0o700)
}
// PrepareCleanStaging wipes anything left in the staging dir from a
// prior aborted update, then re-creates the directory. Distinct from
// EnsureStagingDir because we don't want to nuke the dir unless
// we're ABOUT to use it — having a leftover staged tree from a
// prior failed run is sometimes useful for diagnostics.
func PrepareCleanStaging(stagingDir string) error {
if err := os.RemoveAll(stagingDir); err != nil {
return fmt.Errorf("clear staging %s: %w", stagingDir, err)
}
return EnsureStagingDir(stagingDir)
}
// DefaultStagingDir is where the updater stages downloads +
// extracted binaries when no explicit dir is configured. Sits under
// banger's system CacheDir (typically /var/cache/banger/updates) so:
// - the systemd unit's CacheDirectory=banger keeps the path
// writable for the helper.
// - `banger system uninstall --purge` cleans it.
// - it sits beside the OCI and kernel caches without colliding.
//
// Atomicity caveat: we expect /var/cache and /usr/local to share a
// filesystem (default on essentially every Linux install). On a host
// with /usr split onto a separate volume, the swap step's os.Rename
// would fall through to a copy + delete and lose its atomicity
// guarantee. We document this rather than detect-and-error for
// v0.1.0; the worst-case symptom is a brief window where a binary is
// half-written, which `banger doctor` would catch in step 7.
func DefaultStagingDir(cacheDir string) string {
return filepath.Join(cacheDir, "updates")
}

View file

@ -0,0 +1,363 @@
package updater
import (
"archive/tar"
"bytes"
"compress/gzip"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
// makeReleaseTarball writes a tarball whose root contains the three
// expected entries with the given bodies. Used by stage + download
// tests so they don't need a real banger build to exercise the
// extraction path.
func makeReleaseTarball(t *testing.T, bodies map[string][]byte) []byte {
t.Helper()
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
tw := tar.NewWriter(gz)
for name, body := range bodies {
hdr := &tar.Header{
Name: name,
Mode: 0o755,
Size: int64(len(body)),
Typeflag: tar.TypeReg,
}
if err := tw.WriteHeader(hdr); err != nil {
t.Fatalf("write header: %v", err)
}
if _, err := tw.Write(body); err != nil {
t.Fatalf("write body: %v", err)
}
}
if err := tw.Close(); err != nil {
t.Fatalf("close tar: %v", err)
}
if err := gz.Close(); err != nil {
t.Fatalf("close gzip: %v", err)
}
return buf.Bytes()
}
func sha256Hex(b []byte) string {
sum := sha256.Sum256(b)
return hex.EncodeToString(sum[:])
}
func TestStageTarballHappyPath(t *testing.T) {
body := makeReleaseTarball(t, map[string][]byte{
"banger": []byte("BANGER"),
"bangerd": []byte("BANGERD"),
"banger-vsock-agent": []byte("AGENT"),
})
tarball := filepath.Join(t.TempDir(), "release.tar.gz")
if err := os.WriteFile(tarball, body, 0o644); err != nil {
t.Fatalf("write tarball: %v", err)
}
staging := filepath.Join(t.TempDir(), "staged")
got, err := StageTarball(tarball, staging)
if err != nil {
t.Fatalf("StageTarball: %v", err)
}
for _, p := range []string{got.BangerPath, got.BangerdPath, got.VsockAgentPath} {
info, err := os.Stat(p)
if err != nil {
t.Fatalf("stat %s: %v", p, err)
}
if info.Mode().Perm() != 0o755 {
t.Errorf("%s mode = %o, want 0755", p, info.Mode().Perm())
}
}
bs, _ := os.ReadFile(got.BangerPath)
if string(bs) != "BANGER" {
t.Fatalf("banger content = %q", bs)
}
}
func TestStageTarballRejectsExtraEntry(t *testing.T) {
body := makeReleaseTarball(t, map[string][]byte{
"banger": []byte("a"),
"bangerd": []byte("b"),
"banger-vsock-agent": []byte("c"),
"bonus.txt": []byte("not allowed"),
})
tarball := filepath.Join(t.TempDir(), "rel.tar.gz")
_ = os.WriteFile(tarball, body, 0o644)
_, err := StageTarball(tarball, t.TempDir())
if err == nil || !strings.Contains(err.Error(), "unexpected entry") {
t.Fatalf("err = %v, want unexpected-entry rejection", err)
}
}
func TestStageTarballRejectsMissingEntry(t *testing.T) {
body := makeReleaseTarball(t, map[string][]byte{
"banger": []byte("a"),
"bangerd": []byte("b"),
// banger-vsock-agent intentionally missing
})
tarball := filepath.Join(t.TempDir(), "rel.tar.gz")
_ = os.WriteFile(tarball, body, 0o644)
_, err := StageTarball(tarball, t.TempDir())
if err == nil || !strings.Contains(err.Error(), "missing required entry") {
t.Fatalf("err = %v, want missing-required rejection", err)
}
}
func TestStageTarballRejectsPathTraversal(t *testing.T) {
// Build the tarball manually so we can inject a `../` entry —
// makeReleaseTarball's expected-entry filter would otherwise
// catch it earlier.
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
tw := tar.NewWriter(gz)
for _, e := range []struct{ name, body string }{
{"banger", "a"},
{"bangerd", "b"},
{"../escape", "x"},
} {
_ = tw.WriteHeader(&tar.Header{Name: e.name, Size: int64(len(e.body)), Mode: 0o755, Typeflag: tar.TypeReg})
_, _ = tw.Write([]byte(e.body))
}
_ = tw.Close()
_ = gz.Close()
tarball := filepath.Join(t.TempDir(), "rel.tar.gz")
_ = os.WriteFile(tarball, buf.Bytes(), 0o644)
_, err := StageTarball(tarball, t.TempDir())
if err == nil || !strings.Contains(err.Error(), "unsafe path") {
t.Fatalf("err = %v, want unsafe-path rejection", err)
}
}
func TestSwapAndRollback(t *testing.T) {
root := t.TempDir()
binDir := filepath.Join(root, "bin")
libDir := filepath.Join(root, "lib", "banger")
if err := os.MkdirAll(binDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(libDir, 0o755); err != nil {
t.Fatal(err)
}
for _, p := range []string{
filepath.Join(binDir, "banger"),
filepath.Join(binDir, "bangerd"),
filepath.Join(libDir, "banger-vsock-agent"),
} {
if err := os.WriteFile(p, []byte("OLD-"+filepath.Base(p)), 0o755); err != nil {
t.Fatal(err)
}
}
staging := filepath.Join(root, "staging")
_ = os.MkdirAll(staging, 0o700)
staged := StagedRelease{
BangerPath: filepath.Join(staging, "banger"),
BangerdPath: filepath.Join(staging, "bangerd"),
VsockAgentPath: filepath.Join(staging, "banger-vsock-agent"),
}
for _, pair := range []struct{ p, body string }{
{staged.BangerPath, "NEW-banger"},
{staged.BangerdPath, "NEW-bangerd"},
{staged.VsockAgentPath, "NEW-banger-vsock-agent"},
} {
if err := os.WriteFile(pair.p, []byte(pair.body), 0o755); err != nil {
t.Fatal(err)
}
}
targets := InstallTargets{
Banger: filepath.Join(binDir, "banger"),
Bangerd: filepath.Join(binDir, "bangerd"),
VsockAgent: filepath.Join(libDir, "banger-vsock-agent"),
}
res, err := Swap(staged, targets)
if err != nil {
t.Fatalf("Swap: %v", err)
}
if len(res.SwappedTargets) != 3 {
t.Fatalf("SwappedTargets len = %d, want 3", len(res.SwappedTargets))
}
for _, p := range []string{targets.Banger, targets.Bangerd, targets.VsockAgent} {
got, _ := os.ReadFile(p)
want := "NEW-" + filepath.Base(p)
if string(got) != want {
t.Fatalf("%s content = %q, want %q", p, got, want)
}
prev, err := os.ReadFile(p + previousSuffix)
if err != nil {
t.Fatalf("missing backup at %s.previous: %v", p, err)
}
if string(prev) != "OLD-"+filepath.Base(p) {
t.Fatalf(".previous content = %q", prev)
}
}
if err := Rollback(res); err != nil {
t.Fatalf("Rollback: %v", err)
}
for _, p := range []string{targets.Banger, targets.Bangerd, targets.VsockAgent} {
got, _ := os.ReadFile(p)
want := "OLD-" + filepath.Base(p)
if string(got) != want {
t.Fatalf("post-rollback %s = %q, want %q", p, got, want)
}
if _, err := os.Stat(p + previousSuffix); !os.IsNotExist(err) {
t.Fatalf(".previous should be cleaned after rollback; stat err = %v", err)
}
}
}
func TestSwapPartialFailureRollsBackCleanly(t *testing.T) {
root := t.TempDir()
binDir := filepath.Join(root, "bin")
libDir := filepath.Join(root, "lib", "banger")
if err := os.MkdirAll(binDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(libDir, 0o755); err != nil {
t.Fatal(err)
}
// Pre-create the two binaries that will swap successfully.
for _, p := range []string{
filepath.Join(binDir, "bangerd"),
filepath.Join(libDir, "banger-vsock-agent"),
} {
_ = os.WriteFile(p, []byte("OLD-"+filepath.Base(p)), 0o755)
}
staging := filepath.Join(root, "staging")
_ = os.MkdirAll(staging, 0o700)
staged := StagedRelease{
BangerPath: filepath.Join(staging, "banger"),
BangerdPath: filepath.Join(staging, "bangerd"),
VsockAgentPath: filepath.Join(staging, "banger-vsock-agent"),
}
for _, pair := range []struct{ p, body string }{
{staged.BangerPath, "NEW-banger"},
{staged.BangerdPath, "NEW-bangerd"},
{staged.VsockAgentPath, "NEW-banger-vsock-agent"},
} {
_ = os.WriteFile(pair.p, []byte(pair.body), 0o755)
}
// Block the banger swap (which is LAST in the order) by putting
// a regular file where its parent dir should be — MkdirAll fails
// with "not a directory". Vsock + bangerd succeed first.
blockedParent := filepath.Join(root, "blocked-bin")
if err := os.WriteFile(blockedParent, []byte("blocking"), 0o644); err != nil {
t.Fatal(err)
}
targets := InstallTargets{
Banger: filepath.Join(blockedParent, "banger"),
Bangerd: filepath.Join(binDir, "bangerd"),
VsockAgent: filepath.Join(libDir, "banger-vsock-agent"),
}
res, err := Swap(staged, targets)
if err == nil {
t.Fatal("Swap unexpectedly succeeded; banger parent should be blocked by a regular file")
}
if len(res.SwappedTargets) != 2 {
t.Fatalf("SwappedTargets = %v, want 2 (vsock + bangerd before banger failed)", res.SwappedTargets)
}
// Rolling back the partial swap should restore the filesystem.
if err := Rollback(res); err != nil {
t.Fatalf("Rollback after partial swap: %v", err)
}
for _, p := range res.SwappedTargets {
got, _ := os.ReadFile(p)
want := "OLD-" + filepath.Base(p)
if string(got) != want {
t.Fatalf("post-rollback %s = %q", p, got)
}
}
}
func TestDownloadReleaseHappyPath(t *testing.T) {
tarballBody := []byte("fake tarball bytes")
tarballSHA := sha256Hex(tarballBody)
mux := http.NewServeMux()
mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write(tarballBody)
})
mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%s banger.tar.gz\n", tarballSHA)
})
srv := httptest.NewServer(mux)
defer srv.Close()
dst := filepath.Join(t.TempDir(), "out.tar.gz")
sums, err := DownloadRelease(context.Background(), srv.Client(), Release{
Version: "v0.1.0",
TarballURL: srv.URL + "/banger.tar.gz",
SHA256SumsURL: srv.URL + "/SHA256SUMS",
}, dst)
if err != nil {
t.Fatalf("DownloadRelease: %v", err)
}
if !strings.Contains(string(sums), "banger.tar.gz") {
t.Fatalf("returned sums body missing tarball name: %q", sums)
}
got, _ := os.ReadFile(dst)
if !bytes.Equal(got, tarballBody) {
t.Fatalf("downloaded body differs from served body")
}
}
func TestDownloadReleaseRejectsTarballMissingFromSums(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("body"))
})
mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) {
// Sums for a different file; tarball name not listed.
fmt.Fprintf(w, "%s unrelated\n", sha256Hex([]byte("body")))
})
srv := httptest.NewServer(mux)
defer srv.Close()
dst := filepath.Join(t.TempDir(), "out.tar.gz")
_, err := DownloadRelease(context.Background(), srv.Client(), Release{
TarballURL: srv.URL + "/banger.tar.gz",
SHA256SumsURL: srv.URL + "/SHA256SUMS",
}, dst)
if err == nil || !strings.Contains(err.Error(), "does not list") {
t.Fatalf("err = %v, want SHA256SUMS-missing rejection", err)
}
}
func TestDownloadReleasePropagatesShaMismatch(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("served body"))
})
mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) {
// Wrong digest for the tarball.
fmt.Fprintf(w, "%s banger.tar.gz\n", sha256Hex([]byte("expected body")))
})
srv := httptest.NewServer(mux)
defer srv.Close()
dst := filepath.Join(t.TempDir(), "out.tar.gz")
_, err := DownloadRelease(context.Background(), srv.Client(), Release{
TarballURL: srv.URL + "/banger.tar.gz",
SHA256SumsURL: srv.URL + "/SHA256SUMS",
}, dst)
if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") {
t.Fatalf("err = %v, want sha256 mismatch", err)
}
if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) {
t.Fatalf("partial tarball should be removed; stat err = %v", statErr)
}
}

107
internal/updater/stage.go Normal file
View file

@ -0,0 +1,107 @@
package updater
import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
// expectedReleaseEntries is the canonical set of files a release
// tarball must contain. Anything missing OR anything extra is
// rejected — banger update should not unpack arbitrary files into
// the staging dir.
var expectedReleaseEntries = []string{
"banger",
"bangerd",
"banger-vsock-agent",
}
// StagedRelease describes the result of unpacking a release tarball
// into a staging directory.
type StagedRelease struct {
BangerPath string
BangerdPath string
VsockAgentPath string
}
// StageTarball reads the gzipped tar at tarballPath and extracts the
// expected three banger binaries into stagingDir. Any extra entries,
// any path-traversal members, any non-regular-file members, and any
// missing required entry are rejected.
//
// The extracted binaries are mode 0o755 regardless of what the
// tarball claims — banger update is a privileged operation; we
// don't honour weird modes from the wire.
func StageTarball(tarballPath, stagingDir string) (StagedRelease, error) {
if err := os.MkdirAll(stagingDir, 0o700); err != nil {
return StagedRelease{}, err
}
f, err := os.Open(tarballPath)
if err != nil {
return StagedRelease{}, err
}
defer f.Close()
gz, err := gzip.NewReader(f)
if err != nil {
return StagedRelease{}, fmt.Errorf("open gzip: %w", err)
}
defer gz.Close()
expected := map[string]struct{}{}
for _, name := range expectedReleaseEntries {
expected[name] = struct{}{}
}
seen := map[string]string{}
tr := tar.NewReader(gz)
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return StagedRelease{}, fmt.Errorf("read tar: %w", err)
}
rel := filepath.Clean(hdr.Name)
if rel == "." || rel == string(filepath.Separator) {
continue
}
if filepath.IsAbs(rel) || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
return StagedRelease{}, fmt.Errorf("unsafe path in tarball: %q", hdr.Name)
}
if _, ok := expected[rel]; !ok {
return StagedRelease{}, fmt.Errorf("unexpected entry in release tarball: %q (allowed: %v)", hdr.Name, expectedReleaseEntries)
}
if hdr.Typeflag != tar.TypeReg {
return StagedRelease{}, fmt.Errorf("entry %q is not a regular file (typeflag %d)", hdr.Name, hdr.Typeflag)
}
dst := filepath.Join(stagingDir, rel)
out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755)
if err != nil {
return StagedRelease{}, err
}
if _, err := io.Copy(out, tr); err != nil {
_ = out.Close()
return StagedRelease{}, err
}
if err := out.Close(); err != nil {
return StagedRelease{}, err
}
seen[rel] = dst
}
for _, want := range expectedReleaseEntries {
if _, ok := seen[want]; !ok {
return StagedRelease{}, fmt.Errorf("release tarball is missing required entry %q", want)
}
}
return StagedRelease{
BangerPath: seen["banger"],
BangerdPath: seen["bangerd"],
VsockAgentPath: seen["banger-vsock-agent"],
}, nil
}

135
internal/updater/swap.go Normal file
View file

@ -0,0 +1,135 @@
package updater
import (
"errors"
"fmt"
"os"
"banger/internal/system"
)
// previousSuffix is the filename suffix appended to the
// pre-swap binary so Rollback knows where to restore from.
// Pinned as a constant so the swap and rollback paths can't
// disagree on it.
const previousSuffix = ".previous"
// InstallTargets lists the absolute on-disk paths the updater
// writes during a swap. Hardcoded to the system-install layout —
// banger update is a system-mode operation; the developer non-
// system-mode flow doesn't go through this code path.
type InstallTargets struct {
Banger string // /usr/local/bin/banger
Bangerd string // /usr/local/bin/bangerd
VsockAgent string // /usr/local/lib/banger/banger-vsock-agent
}
// DefaultInstallTargets returns the canonical paths a system install
// uses (`banger system install` writes to these). Exposed for
// testability; production callers use it as-is.
func DefaultInstallTargets() InstallTargets {
return InstallTargets{
Banger: "/usr/local/bin/banger",
Bangerd: "/usr/local/bin/bangerd",
VsockAgent: "/usr/local/lib/banger/banger-vsock-agent",
}
}
// SwapResult records what was swapped, so Rollback knows what to
// undo. A nil SwapResult means no swap was attempted yet (nothing
// to roll back).
type SwapResult struct {
Targets InstallTargets
// SwappedTargets is the subset of Targets that were actually
// renamed into place. If the second of three Renames fails,
// SwappedTargets contains only the first; rollback unwinds in
// reverse order.
SwappedTargets []string
}
// Swap atomically replaces each of the three banger binaries with
// its staged counterpart. Order:
//
// 1. banger-vsock-agent (companion; not currently running, swap is safe)
// 2. bangerd (the to-be-restarted daemon binary)
// 3. banger (the CLI; least disruptive last)
//
// Each AtomicReplace leaves a `.previous` backup so Rollback can
// restore the prior install if a later step fails.
//
// Returns the SwapResult even on partial failure so the caller can
// drive Rollback against what HAS been swapped.
func Swap(staged StagedRelease, targets InstallTargets) (SwapResult, error) {
res := SwapResult{Targets: targets}
steps := []struct {
src, dst string
}{
{src: staged.VsockAgentPath, dst: targets.VsockAgent},
{src: staged.BangerdPath, dst: targets.Bangerd},
{src: staged.BangerPath, dst: targets.Banger},
}
for _, s := range steps {
if err := ensureParentDir(s.dst); err != nil {
return res, fmt.Errorf("prepare %s: %w", s.dst, err)
}
if err := system.AtomicReplace(s.src, s.dst, previousSuffix); err != nil {
return res, fmt.Errorf("swap %s: %w", s.dst, err)
}
res.SwappedTargets = append(res.SwappedTargets, s.dst)
}
return res, nil
}
// Rollback undoes a Swap by restoring each .previous backup in
// reverse order. Returns the joined errors of every individual
// rollback that failed; a half-rolled-back tree is the worst case
// and the operator gets enough information to fix it manually.
//
// Tolerant of partial input — passing a SwapResult that only
// recorded the first two of three swaps rolls back exactly those
// two.
func Rollback(res SwapResult) error {
var errs []error
for i := len(res.SwappedTargets) - 1; i >= 0; i-- {
dst := res.SwappedTargets[i]
if err := system.AtomicReplaceRollback(dst, previousSuffix); err != nil {
errs = append(errs, fmt.Errorf("rollback %s: %w", dst, err))
}
}
return errors.Join(errs...)
}
// CleanupBackups removes every .previous backup left behind by a
// successful update. Called after `banger doctor` confirms the new
// install is healthy — we don't keep ancient backups around forever.
func CleanupBackups(res SwapResult) error {
var errs []error
for _, dst := range res.SwappedTargets {
if err := os.Remove(dst + previousSuffix); err != nil && !os.IsNotExist(err) {
errs = append(errs, fmt.Errorf("remove %s%s: %w", dst, previousSuffix, err))
}
}
return errors.Join(errs...)
}
func ensureParentDir(p string) error {
parent := dirOf(p)
if parent == "" {
return nil
}
if _, err := os.Stat(parent); err == nil {
return nil
}
return os.MkdirAll(parent, 0o755)
}
// dirOf is a tiny path.Dir wrapper that returns "" for paths with
// no separator (so the ensure-parent logic doesn't try to mkdir(".")).
func dirOf(p string) string {
for i := len(p) - 1; i >= 0; i-- {
if p[i] == '/' {
return p[:i]
}
}
return ""
}