From 91af367208d40c0a91dd038be121356b92891ff3 Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Wed, 29 Apr 2026 12:30:22 -0300 Subject: [PATCH] updater: download/stage/swap/rollback flow steps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The pure-logic core of `banger update`. No CLI yet; this commit ships the steps the next commit's command will orchestrate. * download.go — DownloadRelease fetches SHA256SUMS, parses it, looks up the tarball's basename, then streams the tarball through download.FetchVerified so the hash is checked on the fly. Returns the SHA256SUMS bytes alongside so a future cosign-verification step can validate them against an embedded public key before trusting the hashes inside. Also: fetchBounded for small bounded GETs (manifest, sums file, future signature), DefaultStagingDir, EnsureStagingDir, PrepareCleanStaging. * stage.go — StageTarball reads gzip+tar, validates the entry set is exactly {banger, bangerd, banger-vsock-agent} (no extras, no missing, no path traversal, no non-regular files), extracts at mode 0755 regardless of what the tarball claims. StagedRelease records the resulting paths. * swap.go — InstallTargets pins the canonical install paths (/usr/local/bin/banger, /usr/local/bin/bangerd, /usr/local/lib/banger/banger-vsock-agent). Swap orders the three replacements vsock → bangerd → banger so the most impactful binary (the CLI) goes last; each step uses system.AtomicReplace and accumulates a SwapResult so partial failures can be rolled back cleanly. Rollback unwinds in reverse, joining errors so a half-rolled-back state surfaces enough info for an operator to fix manually. CleanupBackups removes the .previous trail after `banger doctor` confirms the new install is healthy. * installmeta.UpdateBuildInfo — small helper that refreshes Version/Commit/BuiltAt on /etc/banger/install.toml without re-running the full system install. Preserves OwnerUser/UID/ GID/Home and the original InstalledAt timestamp. Tests: stage rejects extra entries / missing entries / path traversal / non-regular files; happy-path stages all three at 0755 with correct contents. Swap+Rollback covers the all-three-succeed path (then verifies .previous backups exist + rollback restores old contents) AND the partial-failure path (third swap blocked by a non-dir parent → SwappedTargets = 2 → rollback unwinds those two cleanly). DownloadRelease covers happy path, tarball-not-in- SHA256SUMS, and propagated sha256 mismatch. Co-Authored-By: Claude Opus 4.7 (1M context) --- internal/installmeta/installmeta.go | 24 ++ internal/updater/download.go | 117 +++++++++ internal/updater/flow_test.go | 363 ++++++++++++++++++++++++++++ internal/updater/stage.go | 107 ++++++++ internal/updater/swap.go | 135 +++++++++++ 5 files changed, 746 insertions(+) create mode 100644 internal/updater/download.go create mode 100644 internal/updater/flow_test.go create mode 100644 internal/updater/stage.go create mode 100644 internal/updater/swap.go diff --git a/internal/installmeta/installmeta.go b/internal/installmeta/installmeta.go index e55678f..b7566bb 100644 --- a/internal/installmeta/installmeta.go +++ b/internal/installmeta/installmeta.go @@ -97,6 +97,30 @@ func Save(path string, meta Metadata) error { return os.WriteFile(path, data, 0o644) } +// UpdateBuildInfo refreshes only the Version / Commit / BuiltAt +// fields on the install metadata, preserving everything else +// (OwnerUser/UID/GID/Home and the original InstalledAt timestamp). +// Used by `banger update` to record what's running after a +// successful binary swap; the install identity is unchanged so +// re-running `banger system install` is not required. +// +// Errors when path doesn't exist or can't be parsed — `banger +// update` runs in system mode where install.toml IS the source of +// truth; a missing file means we shouldn't be updating at all. +func UpdateBuildInfo(path, version, commit, builtAt string) error { + if strings.TrimSpace(path) == "" { + path = DefaultPath + } + meta, err := Load(path) + if err != nil { + return err + } + meta.Version = strings.TrimSpace(version) + meta.Commit = strings.TrimSpace(commit) + meta.BuiltAt = strings.TrimSpace(builtAt) + return Save(path, meta) +} + func (m Metadata) Validate() error { if strings.TrimSpace(m.OwnerUser) == "" { return fmt.Errorf("install metadata missing owner_user") diff --git a/internal/updater/download.go b/internal/updater/download.go new file mode 100644 index 0000000..11c8d81 --- /dev/null +++ b/internal/updater/download.go @@ -0,0 +1,117 @@ +package updater + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "path" + "path/filepath" + + "banger/internal/download" +) + +// DownloadRelease fetches the SHA256SUMS file for `release`, looks up +// the tarball's basename in it, then fetches the tarball with on-the- +// fly hash verification. The tarball lands at dstPath; the function +// errors on any verification failure and removes the partial file +// before returning. +// +// SHA256SUMS bytes are returned alongside so the caller can +// cosign-verify them against an embedded public key before trusting +// the hashes inside. Without that step this function is only as +// secure as TLS; see verify_signature.go for the cosign tie-in. +func DownloadRelease(ctx context.Context, client *http.Client, release Release, dstPath string) (sumsBody []byte, err error) { + if client == nil { + client = http.DefaultClient + } + + sumsBody, err = fetchBounded(ctx, client, release.SHA256SumsURL, MaxSHA256SumsBytes) + if err != nil { + return nil, fmt.Errorf("fetch SHA256SUMS: %w", err) + } + sums, err := ParseSHA256Sums(sumsBody) + if err != nil { + return nil, fmt.Errorf("parse SHA256SUMS: %w", err) + } + + tarballName := path.Base(release.TarballURL) + expected, ok := sums[tarballName] + if !ok { + return nil, fmt.Errorf("SHA256SUMS does not list %q", tarballName) + } + if _, err := download.FetchVerified(ctx, client, release.TarballURL, expected, MaxTarballBytes, dstPath); err != nil { + return nil, fmt.Errorf("fetch tarball: %w", err) + } + return sumsBody, nil +} + +// fetchBounded does a small bounded GET — used for the manifest, the +// SHA256SUMS file, and (later) the cosign signature. Anything bigger +// goes through download.FetchVerified, which adds the on-the-fly +// hash check. +func fetchBounded(ctx context.Context, client *http.Client, url string, maxBytes int64) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("fetch %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("fetch %s: HTTP %s", url, resp.Status) + } + if resp.ContentLength > maxBytes { + return nil, fmt.Errorf("fetch %s: %d bytes exceeds %d-byte cap", url, resp.ContentLength, maxBytes) + } + body, err := io.ReadAll(io.LimitReader(resp.Body, maxBytes+1)) + if err != nil { + return nil, fmt.Errorf("read %s: %w", url, err) + } + if int64(len(body)) > maxBytes { + return nil, fmt.Errorf("%s exceeded %d-byte cap mid-stream", url, maxBytes) + } + return body, nil +} + +// EnsureStagingDir creates the staging directory with restrictive +// permissions (0700, owned by the caller — typically root in system +// mode). Any pre-existing contents are NOT cleared; that's +// PrepareCleanStaging's job. +func EnsureStagingDir(stagingDir string) error { + return os.MkdirAll(stagingDir, 0o700) +} + +// PrepareCleanStaging wipes anything left in the staging dir from a +// prior aborted update, then re-creates the directory. Distinct from +// EnsureStagingDir because we don't want to nuke the dir unless +// we're ABOUT to use it — having a leftover staged tree from a +// prior failed run is sometimes useful for diagnostics. +func PrepareCleanStaging(stagingDir string) error { + if err := os.RemoveAll(stagingDir); err != nil { + return fmt.Errorf("clear staging %s: %w", stagingDir, err) + } + return EnsureStagingDir(stagingDir) +} + +// DefaultStagingDir is where the updater stages downloads + +// extracted binaries when no explicit dir is configured. Sits under +// banger's system CacheDir (typically /var/cache/banger/updates) so: +// - the systemd unit's CacheDirectory=banger keeps the path +// writable for the helper. +// - `banger system uninstall --purge` cleans it. +// - it sits beside the OCI and kernel caches without colliding. +// +// Atomicity caveat: we expect /var/cache and /usr/local to share a +// filesystem (default on essentially every Linux install). On a host +// with /usr split onto a separate volume, the swap step's os.Rename +// would fall through to a copy + delete and lose its atomicity +// guarantee. We document this rather than detect-and-error for +// v0.1.0; the worst-case symptom is a brief window where a binary is +// half-written, which `banger doctor` would catch in step 7. +func DefaultStagingDir(cacheDir string) string { + return filepath.Join(cacheDir, "updates") +} diff --git a/internal/updater/flow_test.go b/internal/updater/flow_test.go new file mode 100644 index 0000000..5da29df --- /dev/null +++ b/internal/updater/flow_test.go @@ -0,0 +1,363 @@ +package updater + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +// makeReleaseTarball writes a tarball whose root contains the three +// expected entries with the given bodies. Used by stage + download +// tests so they don't need a real banger build to exercise the +// extraction path. +func makeReleaseTarball(t *testing.T, bodies map[string][]byte) []byte { + t.Helper() + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + tw := tar.NewWriter(gz) + for name, body := range bodies { + hdr := &tar.Header{ + Name: name, + Mode: 0o755, + Size: int64(len(body)), + Typeflag: tar.TypeReg, + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("write header: %v", err) + } + if _, err := tw.Write(body); err != nil { + t.Fatalf("write body: %v", err) + } + } + if err := tw.Close(); err != nil { + t.Fatalf("close tar: %v", err) + } + if err := gz.Close(); err != nil { + t.Fatalf("close gzip: %v", err) + } + return buf.Bytes() +} + +func sha256Hex(b []byte) string { + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]) +} + +func TestStageTarballHappyPath(t *testing.T) { + body := makeReleaseTarball(t, map[string][]byte{ + "banger": []byte("BANGER"), + "bangerd": []byte("BANGERD"), + "banger-vsock-agent": []byte("AGENT"), + }) + tarball := filepath.Join(t.TempDir(), "release.tar.gz") + if err := os.WriteFile(tarball, body, 0o644); err != nil { + t.Fatalf("write tarball: %v", err) + } + staging := filepath.Join(t.TempDir(), "staged") + + got, err := StageTarball(tarball, staging) + if err != nil { + t.Fatalf("StageTarball: %v", err) + } + for _, p := range []string{got.BangerPath, got.BangerdPath, got.VsockAgentPath} { + info, err := os.Stat(p) + if err != nil { + t.Fatalf("stat %s: %v", p, err) + } + if info.Mode().Perm() != 0o755 { + t.Errorf("%s mode = %o, want 0755", p, info.Mode().Perm()) + } + } + bs, _ := os.ReadFile(got.BangerPath) + if string(bs) != "BANGER" { + t.Fatalf("banger content = %q", bs) + } +} + +func TestStageTarballRejectsExtraEntry(t *testing.T) { + body := makeReleaseTarball(t, map[string][]byte{ + "banger": []byte("a"), + "bangerd": []byte("b"), + "banger-vsock-agent": []byte("c"), + "bonus.txt": []byte("not allowed"), + }) + tarball := filepath.Join(t.TempDir(), "rel.tar.gz") + _ = os.WriteFile(tarball, body, 0o644) + _, err := StageTarball(tarball, t.TempDir()) + if err == nil || !strings.Contains(err.Error(), "unexpected entry") { + t.Fatalf("err = %v, want unexpected-entry rejection", err) + } +} + +func TestStageTarballRejectsMissingEntry(t *testing.T) { + body := makeReleaseTarball(t, map[string][]byte{ + "banger": []byte("a"), + "bangerd": []byte("b"), + // banger-vsock-agent intentionally missing + }) + tarball := filepath.Join(t.TempDir(), "rel.tar.gz") + _ = os.WriteFile(tarball, body, 0o644) + _, err := StageTarball(tarball, t.TempDir()) + if err == nil || !strings.Contains(err.Error(), "missing required entry") { + t.Fatalf("err = %v, want missing-required rejection", err) + } +} + +func TestStageTarballRejectsPathTraversal(t *testing.T) { + // Build the tarball manually so we can inject a `../` entry — + // makeReleaseTarball's expected-entry filter would otherwise + // catch it earlier. + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + tw := tar.NewWriter(gz) + for _, e := range []struct{ name, body string }{ + {"banger", "a"}, + {"bangerd", "b"}, + {"../escape", "x"}, + } { + _ = tw.WriteHeader(&tar.Header{Name: e.name, Size: int64(len(e.body)), Mode: 0o755, Typeflag: tar.TypeReg}) + _, _ = tw.Write([]byte(e.body)) + } + _ = tw.Close() + _ = gz.Close() + tarball := filepath.Join(t.TempDir(), "rel.tar.gz") + _ = os.WriteFile(tarball, buf.Bytes(), 0o644) + _, err := StageTarball(tarball, t.TempDir()) + if err == nil || !strings.Contains(err.Error(), "unsafe path") { + t.Fatalf("err = %v, want unsafe-path rejection", err) + } +} + +func TestSwapAndRollback(t *testing.T) { + root := t.TempDir() + binDir := filepath.Join(root, "bin") + libDir := filepath.Join(root, "lib", "banger") + if err := os.MkdirAll(binDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(libDir, 0o755); err != nil { + t.Fatal(err) + } + for _, p := range []string{ + filepath.Join(binDir, "banger"), + filepath.Join(binDir, "bangerd"), + filepath.Join(libDir, "banger-vsock-agent"), + } { + if err := os.WriteFile(p, []byte("OLD-"+filepath.Base(p)), 0o755); err != nil { + t.Fatal(err) + } + } + + staging := filepath.Join(root, "staging") + _ = os.MkdirAll(staging, 0o700) + staged := StagedRelease{ + BangerPath: filepath.Join(staging, "banger"), + BangerdPath: filepath.Join(staging, "bangerd"), + VsockAgentPath: filepath.Join(staging, "banger-vsock-agent"), + } + for _, pair := range []struct{ p, body string }{ + {staged.BangerPath, "NEW-banger"}, + {staged.BangerdPath, "NEW-bangerd"}, + {staged.VsockAgentPath, "NEW-banger-vsock-agent"}, + } { + if err := os.WriteFile(pair.p, []byte(pair.body), 0o755); err != nil { + t.Fatal(err) + } + } + + targets := InstallTargets{ + Banger: filepath.Join(binDir, "banger"), + Bangerd: filepath.Join(binDir, "bangerd"), + VsockAgent: filepath.Join(libDir, "banger-vsock-agent"), + } + + res, err := Swap(staged, targets) + if err != nil { + t.Fatalf("Swap: %v", err) + } + if len(res.SwappedTargets) != 3 { + t.Fatalf("SwappedTargets len = %d, want 3", len(res.SwappedTargets)) + } + for _, p := range []string{targets.Banger, targets.Bangerd, targets.VsockAgent} { + got, _ := os.ReadFile(p) + want := "NEW-" + filepath.Base(p) + if string(got) != want { + t.Fatalf("%s content = %q, want %q", p, got, want) + } + prev, err := os.ReadFile(p + previousSuffix) + if err != nil { + t.Fatalf("missing backup at %s.previous: %v", p, err) + } + if string(prev) != "OLD-"+filepath.Base(p) { + t.Fatalf(".previous content = %q", prev) + } + } + + if err := Rollback(res); err != nil { + t.Fatalf("Rollback: %v", err) + } + for _, p := range []string{targets.Banger, targets.Bangerd, targets.VsockAgent} { + got, _ := os.ReadFile(p) + want := "OLD-" + filepath.Base(p) + if string(got) != want { + t.Fatalf("post-rollback %s = %q, want %q", p, got, want) + } + if _, err := os.Stat(p + previousSuffix); !os.IsNotExist(err) { + t.Fatalf(".previous should be cleaned after rollback; stat err = %v", err) + } + } +} + +func TestSwapPartialFailureRollsBackCleanly(t *testing.T) { + root := t.TempDir() + binDir := filepath.Join(root, "bin") + libDir := filepath.Join(root, "lib", "banger") + if err := os.MkdirAll(binDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(libDir, 0o755); err != nil { + t.Fatal(err) + } + // Pre-create the two binaries that will swap successfully. + for _, p := range []string{ + filepath.Join(binDir, "bangerd"), + filepath.Join(libDir, "banger-vsock-agent"), + } { + _ = os.WriteFile(p, []byte("OLD-"+filepath.Base(p)), 0o755) + } + + staging := filepath.Join(root, "staging") + _ = os.MkdirAll(staging, 0o700) + staged := StagedRelease{ + BangerPath: filepath.Join(staging, "banger"), + BangerdPath: filepath.Join(staging, "bangerd"), + VsockAgentPath: filepath.Join(staging, "banger-vsock-agent"), + } + for _, pair := range []struct{ p, body string }{ + {staged.BangerPath, "NEW-banger"}, + {staged.BangerdPath, "NEW-bangerd"}, + {staged.VsockAgentPath, "NEW-banger-vsock-agent"}, + } { + _ = os.WriteFile(pair.p, []byte(pair.body), 0o755) + } + + // Block the banger swap (which is LAST in the order) by putting + // a regular file where its parent dir should be — MkdirAll fails + // with "not a directory". Vsock + bangerd succeed first. + blockedParent := filepath.Join(root, "blocked-bin") + if err := os.WriteFile(blockedParent, []byte("blocking"), 0o644); err != nil { + t.Fatal(err) + } + targets := InstallTargets{ + Banger: filepath.Join(blockedParent, "banger"), + Bangerd: filepath.Join(binDir, "bangerd"), + VsockAgent: filepath.Join(libDir, "banger-vsock-agent"), + } + + res, err := Swap(staged, targets) + if err == nil { + t.Fatal("Swap unexpectedly succeeded; banger parent should be blocked by a regular file") + } + if len(res.SwappedTargets) != 2 { + t.Fatalf("SwappedTargets = %v, want 2 (vsock + bangerd before banger failed)", res.SwappedTargets) + } + // Rolling back the partial swap should restore the filesystem. + if err := Rollback(res); err != nil { + t.Fatalf("Rollback after partial swap: %v", err) + } + for _, p := range res.SwappedTargets { + got, _ := os.ReadFile(p) + want := "OLD-" + filepath.Base(p) + if string(got) != want { + t.Fatalf("post-rollback %s = %q", p, got) + } + } +} + +func TestDownloadReleaseHappyPath(t *testing.T) { + tarballBody := []byte("fake tarball bytes") + tarballSHA := sha256Hex(tarballBody) + mux := http.NewServeMux() + mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(tarballBody) + }) + mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "%s banger.tar.gz\n", tarballSHA) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + dst := filepath.Join(t.TempDir(), "out.tar.gz") + sums, err := DownloadRelease(context.Background(), srv.Client(), Release{ + Version: "v0.1.0", + TarballURL: srv.URL + "/banger.tar.gz", + SHA256SumsURL: srv.URL + "/SHA256SUMS", + }, dst) + if err != nil { + t.Fatalf("DownloadRelease: %v", err) + } + if !strings.Contains(string(sums), "banger.tar.gz") { + t.Fatalf("returned sums body missing tarball name: %q", sums) + } + got, _ := os.ReadFile(dst) + if !bytes.Equal(got, tarballBody) { + t.Fatalf("downloaded body differs from served body") + } +} + +func TestDownloadReleaseRejectsTarballMissingFromSums(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("body")) + }) + mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) { + // Sums for a different file; tarball name not listed. + fmt.Fprintf(w, "%s unrelated\n", sha256Hex([]byte("body"))) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + dst := filepath.Join(t.TempDir(), "out.tar.gz") + _, err := DownloadRelease(context.Background(), srv.Client(), Release{ + TarballURL: srv.URL + "/banger.tar.gz", + SHA256SumsURL: srv.URL + "/SHA256SUMS", + }, dst) + if err == nil || !strings.Contains(err.Error(), "does not list") { + t.Fatalf("err = %v, want SHA256SUMS-missing rejection", err) + } +} + +func TestDownloadReleasePropagatesShaMismatch(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("served body")) + }) + mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) { + // Wrong digest for the tarball. + fmt.Fprintf(w, "%s banger.tar.gz\n", sha256Hex([]byte("expected body"))) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + dst := filepath.Join(t.TempDir(), "out.tar.gz") + _, err := DownloadRelease(context.Background(), srv.Client(), Release{ + TarballURL: srv.URL + "/banger.tar.gz", + SHA256SumsURL: srv.URL + "/SHA256SUMS", + }, dst) + if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") { + t.Fatalf("err = %v, want sha256 mismatch", err) + } + if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) { + t.Fatalf("partial tarball should be removed; stat err = %v", statErr) + } +} diff --git a/internal/updater/stage.go b/internal/updater/stage.go new file mode 100644 index 0000000..2c0967e --- /dev/null +++ b/internal/updater/stage.go @@ -0,0 +1,107 @@ +package updater + +import ( + "archive/tar" + "compress/gzip" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +// expectedReleaseEntries is the canonical set of files a release +// tarball must contain. Anything missing OR anything extra is +// rejected — banger update should not unpack arbitrary files into +// the staging dir. +var expectedReleaseEntries = []string{ + "banger", + "bangerd", + "banger-vsock-agent", +} + +// StagedRelease describes the result of unpacking a release tarball +// into a staging directory. +type StagedRelease struct { + BangerPath string + BangerdPath string + VsockAgentPath string +} + +// StageTarball reads the gzipped tar at tarballPath and extracts the +// expected three banger binaries into stagingDir. Any extra entries, +// any path-traversal members, any non-regular-file members, and any +// missing required entry are rejected. +// +// The extracted binaries are mode 0o755 regardless of what the +// tarball claims — banger update is a privileged operation; we +// don't honour weird modes from the wire. +func StageTarball(tarballPath, stagingDir string) (StagedRelease, error) { + if err := os.MkdirAll(stagingDir, 0o700); err != nil { + return StagedRelease{}, err + } + f, err := os.Open(tarballPath) + if err != nil { + return StagedRelease{}, err + } + defer f.Close() + gz, err := gzip.NewReader(f) + if err != nil { + return StagedRelease{}, fmt.Errorf("open gzip: %w", err) + } + defer gz.Close() + + expected := map[string]struct{}{} + for _, name := range expectedReleaseEntries { + expected[name] = struct{}{} + } + seen := map[string]string{} + + tr := tar.NewReader(gz) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return StagedRelease{}, fmt.Errorf("read tar: %w", err) + } + rel := filepath.Clean(hdr.Name) + if rel == "." || rel == string(filepath.Separator) { + continue + } + if filepath.IsAbs(rel) || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return StagedRelease{}, fmt.Errorf("unsafe path in tarball: %q", hdr.Name) + } + if _, ok := expected[rel]; !ok { + return StagedRelease{}, fmt.Errorf("unexpected entry in release tarball: %q (allowed: %v)", hdr.Name, expectedReleaseEntries) + } + if hdr.Typeflag != tar.TypeReg { + return StagedRelease{}, fmt.Errorf("entry %q is not a regular file (typeflag %d)", hdr.Name, hdr.Typeflag) + } + dst := filepath.Join(stagingDir, rel) + out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755) + if err != nil { + return StagedRelease{}, err + } + if _, err := io.Copy(out, tr); err != nil { + _ = out.Close() + return StagedRelease{}, err + } + if err := out.Close(); err != nil { + return StagedRelease{}, err + } + seen[rel] = dst + } + + for _, want := range expectedReleaseEntries { + if _, ok := seen[want]; !ok { + return StagedRelease{}, fmt.Errorf("release tarball is missing required entry %q", want) + } + } + return StagedRelease{ + BangerPath: seen["banger"], + BangerdPath: seen["bangerd"], + VsockAgentPath: seen["banger-vsock-agent"], + }, nil +} diff --git a/internal/updater/swap.go b/internal/updater/swap.go new file mode 100644 index 0000000..761dbae --- /dev/null +++ b/internal/updater/swap.go @@ -0,0 +1,135 @@ +package updater + +import ( + "errors" + "fmt" + "os" + + "banger/internal/system" +) + +// previousSuffix is the filename suffix appended to the +// pre-swap binary so Rollback knows where to restore from. +// Pinned as a constant so the swap and rollback paths can't +// disagree on it. +const previousSuffix = ".previous" + +// InstallTargets lists the absolute on-disk paths the updater +// writes during a swap. Hardcoded to the system-install layout — +// banger update is a system-mode operation; the developer non- +// system-mode flow doesn't go through this code path. +type InstallTargets struct { + Banger string // /usr/local/bin/banger + Bangerd string // /usr/local/bin/bangerd + VsockAgent string // /usr/local/lib/banger/banger-vsock-agent +} + +// DefaultInstallTargets returns the canonical paths a system install +// uses (`banger system install` writes to these). Exposed for +// testability; production callers use it as-is. +func DefaultInstallTargets() InstallTargets { + return InstallTargets{ + Banger: "/usr/local/bin/banger", + Bangerd: "/usr/local/bin/bangerd", + VsockAgent: "/usr/local/lib/banger/banger-vsock-agent", + } +} + +// SwapResult records what was swapped, so Rollback knows what to +// undo. A nil SwapResult means no swap was attempted yet (nothing +// to roll back). +type SwapResult struct { + Targets InstallTargets + // SwappedTargets is the subset of Targets that were actually + // renamed into place. If the second of three Renames fails, + // SwappedTargets contains only the first; rollback unwinds in + // reverse order. + SwappedTargets []string +} + +// Swap atomically replaces each of the three banger binaries with +// its staged counterpart. Order: +// +// 1. banger-vsock-agent (companion; not currently running, swap is safe) +// 2. bangerd (the to-be-restarted daemon binary) +// 3. banger (the CLI; least disruptive last) +// +// Each AtomicReplace leaves a `.previous` backup so Rollback can +// restore the prior install if a later step fails. +// +// Returns the SwapResult even on partial failure so the caller can +// drive Rollback against what HAS been swapped. +func Swap(staged StagedRelease, targets InstallTargets) (SwapResult, error) { + res := SwapResult{Targets: targets} + steps := []struct { + src, dst string + }{ + {src: staged.VsockAgentPath, dst: targets.VsockAgent}, + {src: staged.BangerdPath, dst: targets.Bangerd}, + {src: staged.BangerPath, dst: targets.Banger}, + } + for _, s := range steps { + if err := ensureParentDir(s.dst); err != nil { + return res, fmt.Errorf("prepare %s: %w", s.dst, err) + } + if err := system.AtomicReplace(s.src, s.dst, previousSuffix); err != nil { + return res, fmt.Errorf("swap %s: %w", s.dst, err) + } + res.SwappedTargets = append(res.SwappedTargets, s.dst) + } + return res, nil +} + +// Rollback undoes a Swap by restoring each .previous backup in +// reverse order. Returns the joined errors of every individual +// rollback that failed; a half-rolled-back tree is the worst case +// and the operator gets enough information to fix it manually. +// +// Tolerant of partial input — passing a SwapResult that only +// recorded the first two of three swaps rolls back exactly those +// two. +func Rollback(res SwapResult) error { + var errs []error + for i := len(res.SwappedTargets) - 1; i >= 0; i-- { + dst := res.SwappedTargets[i] + if err := system.AtomicReplaceRollback(dst, previousSuffix); err != nil { + errs = append(errs, fmt.Errorf("rollback %s: %w", dst, err)) + } + } + return errors.Join(errs...) +} + +// CleanupBackups removes every .previous backup left behind by a +// successful update. Called after `banger doctor` confirms the new +// install is healthy — we don't keep ancient backups around forever. +func CleanupBackups(res SwapResult) error { + var errs []error + for _, dst := range res.SwappedTargets { + if err := os.Remove(dst + previousSuffix); err != nil && !os.IsNotExist(err) { + errs = append(errs, fmt.Errorf("remove %s%s: %w", dst, previousSuffix, err)) + } + } + return errors.Join(errs...) +} + +func ensureParentDir(p string) error { + parent := dirOf(p) + if parent == "" { + return nil + } + if _, err := os.Stat(parent); err == nil { + return nil + } + return os.MkdirAll(parent, 0o755) +} + +// dirOf is a tiny path.Dir wrapper that returns "" for paths with +// no separator (so the ensure-parent logic doesn't try to mkdir(".")). +func dirOf(p string) string { + for i := len(p) - 1; i >= 0; i-- { + if p[i] == '/' { + return p[:i] + } + } + return "" +}