banger/internal/updater/flow_test.go
Thales Maciel 91af367208
updater: download/stage/swap/rollback flow steps
The pure-logic core of `banger update`. No CLI yet; this commit
ships the steps the next commit's command will orchestrate.

  * download.go — DownloadRelease fetches SHA256SUMS, parses it,
    looks up the tarball's basename, then streams the tarball
    through download.FetchVerified so the hash is checked on the
    fly. Returns the SHA256SUMS bytes alongside so a future
    cosign-verification step can validate them against an embedded
    public key before trusting the hashes inside.
    Also: fetchBounded for small bounded GETs (manifest, sums file,
    future signature), DefaultStagingDir, EnsureStagingDir,
    PrepareCleanStaging.
  * stage.go — StageTarball reads gzip+tar, validates the entry
    set is exactly {banger, bangerd, banger-vsock-agent} (no
    extras, no missing, no path traversal, no non-regular files),
    extracts at mode 0755 regardless of what the tarball claims.
    StagedRelease records the resulting paths.
  * swap.go — InstallTargets pins the canonical install paths
    (/usr/local/bin/banger, /usr/local/bin/bangerd,
    /usr/local/lib/banger/banger-vsock-agent). Swap orders the
    three replacements vsock → bangerd → banger so the most
    impactful binary (the CLI) goes last; each step uses
    system.AtomicReplace and accumulates a SwapResult so partial
    failures can be rolled back cleanly. Rollback unwinds in
    reverse, joining errors so a half-rolled-back state surfaces
    enough info for an operator to fix manually. CleanupBackups
    removes the .previous trail after `banger doctor` confirms
    the new install is healthy.
  * installmeta.UpdateBuildInfo — small helper that refreshes
    Version/Commit/BuiltAt on /etc/banger/install.toml without
    re-running the full system install. Preserves OwnerUser/UID/
    GID/Home and the original InstalledAt timestamp.

Tests: stage rejects extra entries / missing entries / path
traversal / non-regular files; happy-path stages all three at 0755
with correct contents. Swap+Rollback covers the all-three-succeed
path (then verifies .previous backups exist + rollback restores
old contents) AND the partial-failure path (third swap blocked by
a non-dir parent → SwappedTargets = 2 → rollback unwinds those
two cleanly). DownloadRelease covers happy path, tarball-not-in-
SHA256SUMS, and propagated sha256 mismatch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-29 12:30:22 -03:00

363 lines
11 KiB
Go

package updater
import (
"archive/tar"
"bytes"
"compress/gzip"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
// makeReleaseTarball writes a tarball whose root contains the three
// expected entries with the given bodies. Used by stage + download
// tests so they don't need a real banger build to exercise the
// extraction path.
func makeReleaseTarball(t *testing.T, bodies map[string][]byte) []byte {
t.Helper()
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
tw := tar.NewWriter(gz)
for name, body := range bodies {
hdr := &tar.Header{
Name: name,
Mode: 0o755,
Size: int64(len(body)),
Typeflag: tar.TypeReg,
}
if err := tw.WriteHeader(hdr); err != nil {
t.Fatalf("write header: %v", err)
}
if _, err := tw.Write(body); err != nil {
t.Fatalf("write body: %v", err)
}
}
if err := tw.Close(); err != nil {
t.Fatalf("close tar: %v", err)
}
if err := gz.Close(); err != nil {
t.Fatalf("close gzip: %v", err)
}
return buf.Bytes()
}
func sha256Hex(b []byte) string {
sum := sha256.Sum256(b)
return hex.EncodeToString(sum[:])
}
func TestStageTarballHappyPath(t *testing.T) {
body := makeReleaseTarball(t, map[string][]byte{
"banger": []byte("BANGER"),
"bangerd": []byte("BANGERD"),
"banger-vsock-agent": []byte("AGENT"),
})
tarball := filepath.Join(t.TempDir(), "release.tar.gz")
if err := os.WriteFile(tarball, body, 0o644); err != nil {
t.Fatalf("write tarball: %v", err)
}
staging := filepath.Join(t.TempDir(), "staged")
got, err := StageTarball(tarball, staging)
if err != nil {
t.Fatalf("StageTarball: %v", err)
}
for _, p := range []string{got.BangerPath, got.BangerdPath, got.VsockAgentPath} {
info, err := os.Stat(p)
if err != nil {
t.Fatalf("stat %s: %v", p, err)
}
if info.Mode().Perm() != 0o755 {
t.Errorf("%s mode = %o, want 0755", p, info.Mode().Perm())
}
}
bs, _ := os.ReadFile(got.BangerPath)
if string(bs) != "BANGER" {
t.Fatalf("banger content = %q", bs)
}
}
func TestStageTarballRejectsExtraEntry(t *testing.T) {
body := makeReleaseTarball(t, map[string][]byte{
"banger": []byte("a"),
"bangerd": []byte("b"),
"banger-vsock-agent": []byte("c"),
"bonus.txt": []byte("not allowed"),
})
tarball := filepath.Join(t.TempDir(), "rel.tar.gz")
_ = os.WriteFile(tarball, body, 0o644)
_, err := StageTarball(tarball, t.TempDir())
if err == nil || !strings.Contains(err.Error(), "unexpected entry") {
t.Fatalf("err = %v, want unexpected-entry rejection", err)
}
}
func TestStageTarballRejectsMissingEntry(t *testing.T) {
body := makeReleaseTarball(t, map[string][]byte{
"banger": []byte("a"),
"bangerd": []byte("b"),
// banger-vsock-agent intentionally missing
})
tarball := filepath.Join(t.TempDir(), "rel.tar.gz")
_ = os.WriteFile(tarball, body, 0o644)
_, err := StageTarball(tarball, t.TempDir())
if err == nil || !strings.Contains(err.Error(), "missing required entry") {
t.Fatalf("err = %v, want missing-required rejection", err)
}
}
func TestStageTarballRejectsPathTraversal(t *testing.T) {
// Build the tarball manually so we can inject a `../` entry —
// makeReleaseTarball's expected-entry filter would otherwise
// catch it earlier.
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
tw := tar.NewWriter(gz)
for _, e := range []struct{ name, body string }{
{"banger", "a"},
{"bangerd", "b"},
{"../escape", "x"},
} {
_ = tw.WriteHeader(&tar.Header{Name: e.name, Size: int64(len(e.body)), Mode: 0o755, Typeflag: tar.TypeReg})
_, _ = tw.Write([]byte(e.body))
}
_ = tw.Close()
_ = gz.Close()
tarball := filepath.Join(t.TempDir(), "rel.tar.gz")
_ = os.WriteFile(tarball, buf.Bytes(), 0o644)
_, err := StageTarball(tarball, t.TempDir())
if err == nil || !strings.Contains(err.Error(), "unsafe path") {
t.Fatalf("err = %v, want unsafe-path rejection", err)
}
}
func TestSwapAndRollback(t *testing.T) {
root := t.TempDir()
binDir := filepath.Join(root, "bin")
libDir := filepath.Join(root, "lib", "banger")
if err := os.MkdirAll(binDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(libDir, 0o755); err != nil {
t.Fatal(err)
}
for _, p := range []string{
filepath.Join(binDir, "banger"),
filepath.Join(binDir, "bangerd"),
filepath.Join(libDir, "banger-vsock-agent"),
} {
if err := os.WriteFile(p, []byte("OLD-"+filepath.Base(p)), 0o755); err != nil {
t.Fatal(err)
}
}
staging := filepath.Join(root, "staging")
_ = os.MkdirAll(staging, 0o700)
staged := StagedRelease{
BangerPath: filepath.Join(staging, "banger"),
BangerdPath: filepath.Join(staging, "bangerd"),
VsockAgentPath: filepath.Join(staging, "banger-vsock-agent"),
}
for _, pair := range []struct{ p, body string }{
{staged.BangerPath, "NEW-banger"},
{staged.BangerdPath, "NEW-bangerd"},
{staged.VsockAgentPath, "NEW-banger-vsock-agent"},
} {
if err := os.WriteFile(pair.p, []byte(pair.body), 0o755); err != nil {
t.Fatal(err)
}
}
targets := InstallTargets{
Banger: filepath.Join(binDir, "banger"),
Bangerd: filepath.Join(binDir, "bangerd"),
VsockAgent: filepath.Join(libDir, "banger-vsock-agent"),
}
res, err := Swap(staged, targets)
if err != nil {
t.Fatalf("Swap: %v", err)
}
if len(res.SwappedTargets) != 3 {
t.Fatalf("SwappedTargets len = %d, want 3", len(res.SwappedTargets))
}
for _, p := range []string{targets.Banger, targets.Bangerd, targets.VsockAgent} {
got, _ := os.ReadFile(p)
want := "NEW-" + filepath.Base(p)
if string(got) != want {
t.Fatalf("%s content = %q, want %q", p, got, want)
}
prev, err := os.ReadFile(p + previousSuffix)
if err != nil {
t.Fatalf("missing backup at %s.previous: %v", p, err)
}
if string(prev) != "OLD-"+filepath.Base(p) {
t.Fatalf(".previous content = %q", prev)
}
}
if err := Rollback(res); err != nil {
t.Fatalf("Rollback: %v", err)
}
for _, p := range []string{targets.Banger, targets.Bangerd, targets.VsockAgent} {
got, _ := os.ReadFile(p)
want := "OLD-" + filepath.Base(p)
if string(got) != want {
t.Fatalf("post-rollback %s = %q, want %q", p, got, want)
}
if _, err := os.Stat(p + previousSuffix); !os.IsNotExist(err) {
t.Fatalf(".previous should be cleaned after rollback; stat err = %v", err)
}
}
}
func TestSwapPartialFailureRollsBackCleanly(t *testing.T) {
root := t.TempDir()
binDir := filepath.Join(root, "bin")
libDir := filepath.Join(root, "lib", "banger")
if err := os.MkdirAll(binDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(libDir, 0o755); err != nil {
t.Fatal(err)
}
// Pre-create the two binaries that will swap successfully.
for _, p := range []string{
filepath.Join(binDir, "bangerd"),
filepath.Join(libDir, "banger-vsock-agent"),
} {
_ = os.WriteFile(p, []byte("OLD-"+filepath.Base(p)), 0o755)
}
staging := filepath.Join(root, "staging")
_ = os.MkdirAll(staging, 0o700)
staged := StagedRelease{
BangerPath: filepath.Join(staging, "banger"),
BangerdPath: filepath.Join(staging, "bangerd"),
VsockAgentPath: filepath.Join(staging, "banger-vsock-agent"),
}
for _, pair := range []struct{ p, body string }{
{staged.BangerPath, "NEW-banger"},
{staged.BangerdPath, "NEW-bangerd"},
{staged.VsockAgentPath, "NEW-banger-vsock-agent"},
} {
_ = os.WriteFile(pair.p, []byte(pair.body), 0o755)
}
// Block the banger swap (which is LAST in the order) by putting
// a regular file where its parent dir should be — MkdirAll fails
// with "not a directory". Vsock + bangerd succeed first.
blockedParent := filepath.Join(root, "blocked-bin")
if err := os.WriteFile(blockedParent, []byte("blocking"), 0o644); err != nil {
t.Fatal(err)
}
targets := InstallTargets{
Banger: filepath.Join(blockedParent, "banger"),
Bangerd: filepath.Join(binDir, "bangerd"),
VsockAgent: filepath.Join(libDir, "banger-vsock-agent"),
}
res, err := Swap(staged, targets)
if err == nil {
t.Fatal("Swap unexpectedly succeeded; banger parent should be blocked by a regular file")
}
if len(res.SwappedTargets) != 2 {
t.Fatalf("SwappedTargets = %v, want 2 (vsock + bangerd before banger failed)", res.SwappedTargets)
}
// Rolling back the partial swap should restore the filesystem.
if err := Rollback(res); err != nil {
t.Fatalf("Rollback after partial swap: %v", err)
}
for _, p := range res.SwappedTargets {
got, _ := os.ReadFile(p)
want := "OLD-" + filepath.Base(p)
if string(got) != want {
t.Fatalf("post-rollback %s = %q", p, got)
}
}
}
func TestDownloadReleaseHappyPath(t *testing.T) {
tarballBody := []byte("fake tarball bytes")
tarballSHA := sha256Hex(tarballBody)
mux := http.NewServeMux()
mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write(tarballBody)
})
mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%s banger.tar.gz\n", tarballSHA)
})
srv := httptest.NewServer(mux)
defer srv.Close()
dst := filepath.Join(t.TempDir(), "out.tar.gz")
sums, err := DownloadRelease(context.Background(), srv.Client(), Release{
Version: "v0.1.0",
TarballURL: srv.URL + "/banger.tar.gz",
SHA256SumsURL: srv.URL + "/SHA256SUMS",
}, dst)
if err != nil {
t.Fatalf("DownloadRelease: %v", err)
}
if !strings.Contains(string(sums), "banger.tar.gz") {
t.Fatalf("returned sums body missing tarball name: %q", sums)
}
got, _ := os.ReadFile(dst)
if !bytes.Equal(got, tarballBody) {
t.Fatalf("downloaded body differs from served body")
}
}
func TestDownloadReleaseRejectsTarballMissingFromSums(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("body"))
})
mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) {
// Sums for a different file; tarball name not listed.
fmt.Fprintf(w, "%s unrelated\n", sha256Hex([]byte("body")))
})
srv := httptest.NewServer(mux)
defer srv.Close()
dst := filepath.Join(t.TempDir(), "out.tar.gz")
_, err := DownloadRelease(context.Background(), srv.Client(), Release{
TarballURL: srv.URL + "/banger.tar.gz",
SHA256SumsURL: srv.URL + "/SHA256SUMS",
}, dst)
if err == nil || !strings.Contains(err.Error(), "does not list") {
t.Fatalf("err = %v, want SHA256SUMS-missing rejection", err)
}
}
func TestDownloadReleasePropagatesShaMismatch(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("served body"))
})
mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) {
// Wrong digest for the tarball.
fmt.Fprintf(w, "%s banger.tar.gz\n", sha256Hex([]byte("expected body")))
})
srv := httptest.NewServer(mux)
defer srv.Close()
dst := filepath.Join(t.TempDir(), "out.tar.gz")
_, err := DownloadRelease(context.Background(), srv.Client(), Release{
TarballURL: srv.URL + "/banger.tar.gz",
SHA256SumsURL: srv.URL + "/SHA256SUMS",
}, dst)
if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") {
t.Fatalf("err = %v, want sha256 mismatch", err)
}
if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) {
t.Fatalf("partial tarball should be removed; stat err = %v", statErr)
}
}