updater: download/stage/swap/rollback flow steps
The pure-logic core of `banger update`. No CLI yet; this commit
ships the steps the next commit's command will orchestrate.
* download.go — DownloadRelease fetches SHA256SUMS, parses it,
looks up the tarball's basename, then streams the tarball
through download.FetchVerified so the hash is checked on the
fly. Returns the SHA256SUMS bytes alongside so a future
cosign-verification step can validate them against an embedded
public key before trusting the hashes inside.
Also: fetchBounded for small bounded GETs (manifest, sums file,
future signature), DefaultStagingDir, EnsureStagingDir,
PrepareCleanStaging.
* stage.go — StageTarball reads gzip+tar, validates the entry
set is exactly {banger, bangerd, banger-vsock-agent} (no
extras, no missing, no path traversal, no non-regular files),
extracts at mode 0755 regardless of what the tarball claims.
StagedRelease records the resulting paths.
* swap.go — InstallTargets pins the canonical install paths
(/usr/local/bin/banger, /usr/local/bin/bangerd,
/usr/local/lib/banger/banger-vsock-agent). Swap orders the
three replacements vsock → bangerd → banger so the most
impactful binary (the CLI) goes last; each step uses
system.AtomicReplace and accumulates a SwapResult so partial
failures can be rolled back cleanly. Rollback unwinds in
reverse, joining errors so a half-rolled-back state surfaces
enough info for an operator to fix manually. CleanupBackups
removes the .previous trail after `banger doctor` confirms
the new install is healthy.
* installmeta.UpdateBuildInfo — small helper that refreshes
Version/Commit/BuiltAt on /etc/banger/install.toml without
re-running the full system install. Preserves OwnerUser/UID/
GID/Home and the original InstalledAt timestamp.
Tests: stage rejects extra entries / missing entries / path
traversal / non-regular files; happy-path stages all three at 0755
with correct contents. Swap+Rollback covers the all-three-succeed
path (then verifies .previous backups exist + rollback restores
old contents) AND the partial-failure path (third swap blocked by
a non-dir parent → SwappedTargets = 2 → rollback unwinds those
two cleanly). DownloadRelease covers happy path, tarball-not-in-
SHA256SUMS, and propagated sha256 mismatch.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
fb6d2b1dae
commit
91af367208
5 changed files with 746 additions and 0 deletions
|
|
@ -97,6 +97,30 @@ func Save(path string, meta Metadata) error {
|
|||
return os.WriteFile(path, data, 0o644)
|
||||
}
|
||||
|
||||
// UpdateBuildInfo refreshes only the Version / Commit / BuiltAt
|
||||
// fields on the install metadata, preserving everything else
|
||||
// (OwnerUser/UID/GID/Home and the original InstalledAt timestamp).
|
||||
// Used by `banger update` to record what's running after a
|
||||
// successful binary swap; the install identity is unchanged so
|
||||
// re-running `banger system install` is not required.
|
||||
//
|
||||
// Errors when path doesn't exist or can't be parsed — `banger
|
||||
// update` runs in system mode where install.toml IS the source of
|
||||
// truth; a missing file means we shouldn't be updating at all.
|
||||
func UpdateBuildInfo(path, version, commit, builtAt string) error {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
path = DefaultPath
|
||||
}
|
||||
meta, err := Load(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
meta.Version = strings.TrimSpace(version)
|
||||
meta.Commit = strings.TrimSpace(commit)
|
||||
meta.BuiltAt = strings.TrimSpace(builtAt)
|
||||
return Save(path, meta)
|
||||
}
|
||||
|
||||
func (m Metadata) Validate() error {
|
||||
if strings.TrimSpace(m.OwnerUser) == "" {
|
||||
return fmt.Errorf("install metadata missing owner_user")
|
||||
|
|
|
|||
117
internal/updater/download.go
Normal file
117
internal/updater/download.go
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"banger/internal/download"
|
||||
)
|
||||
|
||||
// DownloadRelease fetches the SHA256SUMS file for `release`, looks up
|
||||
// the tarball's basename in it, then fetches the tarball with on-the-
|
||||
// fly hash verification. The tarball lands at dstPath; the function
|
||||
// errors on any verification failure and removes the partial file
|
||||
// before returning.
|
||||
//
|
||||
// SHA256SUMS bytes are returned alongside so the caller can
|
||||
// cosign-verify them against an embedded public key before trusting
|
||||
// the hashes inside. Without that step this function is only as
|
||||
// secure as TLS; see verify_signature.go for the cosign tie-in.
|
||||
func DownloadRelease(ctx context.Context, client *http.Client, release Release, dstPath string) (sumsBody []byte, err error) {
|
||||
if client == nil {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
|
||||
sumsBody, err = fetchBounded(ctx, client, release.SHA256SumsURL, MaxSHA256SumsBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch SHA256SUMS: %w", err)
|
||||
}
|
||||
sums, err := ParseSHA256Sums(sumsBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse SHA256SUMS: %w", err)
|
||||
}
|
||||
|
||||
tarballName := path.Base(release.TarballURL)
|
||||
expected, ok := sums[tarballName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("SHA256SUMS does not list %q", tarballName)
|
||||
}
|
||||
if _, err := download.FetchVerified(ctx, client, release.TarballURL, expected, MaxTarballBytes, dstPath); err != nil {
|
||||
return nil, fmt.Errorf("fetch tarball: %w", err)
|
||||
}
|
||||
return sumsBody, nil
|
||||
}
|
||||
|
||||
// fetchBounded does a small bounded GET — used for the manifest, the
|
||||
// SHA256SUMS file, and (later) the cosign signature. Anything bigger
|
||||
// goes through download.FetchVerified, which adds the on-the-fly
|
||||
// hash check.
|
||||
func fetchBounded(ctx context.Context, client *http.Client, url string, maxBytes int64) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("fetch %s: HTTP %s", url, resp.Status)
|
||||
}
|
||||
if resp.ContentLength > maxBytes {
|
||||
return nil, fmt.Errorf("fetch %s: %d bytes exceeds %d-byte cap", url, resp.ContentLength, maxBytes)
|
||||
}
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxBytes+1))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read %s: %w", url, err)
|
||||
}
|
||||
if int64(len(body)) > maxBytes {
|
||||
return nil, fmt.Errorf("%s exceeded %d-byte cap mid-stream", url, maxBytes)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// EnsureStagingDir creates the staging directory with restrictive
|
||||
// permissions (0700, owned by the caller — typically root in system
|
||||
// mode). Any pre-existing contents are NOT cleared; that's
|
||||
// PrepareCleanStaging's job.
|
||||
func EnsureStagingDir(stagingDir string) error {
|
||||
return os.MkdirAll(stagingDir, 0o700)
|
||||
}
|
||||
|
||||
// PrepareCleanStaging wipes anything left in the staging dir from a
|
||||
// prior aborted update, then re-creates the directory. Distinct from
|
||||
// EnsureStagingDir because we don't want to nuke the dir unless
|
||||
// we're ABOUT to use it — having a leftover staged tree from a
|
||||
// prior failed run is sometimes useful for diagnostics.
|
||||
func PrepareCleanStaging(stagingDir string) error {
|
||||
if err := os.RemoveAll(stagingDir); err != nil {
|
||||
return fmt.Errorf("clear staging %s: %w", stagingDir, err)
|
||||
}
|
||||
return EnsureStagingDir(stagingDir)
|
||||
}
|
||||
|
||||
// DefaultStagingDir is where the updater stages downloads +
|
||||
// extracted binaries when no explicit dir is configured. Sits under
|
||||
// banger's system CacheDir (typically /var/cache/banger/updates) so:
|
||||
// - the systemd unit's CacheDirectory=banger keeps the path
|
||||
// writable for the helper.
|
||||
// - `banger system uninstall --purge` cleans it.
|
||||
// - it sits beside the OCI and kernel caches without colliding.
|
||||
//
|
||||
// Atomicity caveat: we expect /var/cache and /usr/local to share a
|
||||
// filesystem (default on essentially every Linux install). On a host
|
||||
// with /usr split onto a separate volume, the swap step's os.Rename
|
||||
// would fall through to a copy + delete and lose its atomicity
|
||||
// guarantee. We document this rather than detect-and-error for
|
||||
// v0.1.0; the worst-case symptom is a brief window where a binary is
|
||||
// half-written, which `banger doctor` would catch in step 7.
|
||||
func DefaultStagingDir(cacheDir string) string {
|
||||
return filepath.Join(cacheDir, "updates")
|
||||
}
|
||||
363
internal/updater/flow_test.go
Normal file
363
internal/updater/flow_test.go
Normal file
|
|
@ -0,0 +1,363 @@
|
|||
package updater
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// makeReleaseTarball writes a tarball whose root contains the three
|
||||
// expected entries with the given bodies. Used by stage + download
|
||||
// tests so they don't need a real banger build to exercise the
|
||||
// extraction path.
|
||||
func makeReleaseTarball(t *testing.T, bodies map[string][]byte) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
tw := tar.NewWriter(gz)
|
||||
for name, body := range bodies {
|
||||
hdr := &tar.Header{
|
||||
Name: name,
|
||||
Mode: 0o755,
|
||||
Size: int64(len(body)),
|
||||
Typeflag: tar.TypeReg,
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
t.Fatalf("write header: %v", err)
|
||||
}
|
||||
if _, err := tw.Write(body); err != nil {
|
||||
t.Fatalf("write body: %v", err)
|
||||
}
|
||||
}
|
||||
if err := tw.Close(); err != nil {
|
||||
t.Fatalf("close tar: %v", err)
|
||||
}
|
||||
if err := gz.Close(); err != nil {
|
||||
t.Fatalf("close gzip: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func sha256Hex(b []byte) string {
|
||||
sum := sha256.Sum256(b)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func TestStageTarballHappyPath(t *testing.T) {
|
||||
body := makeReleaseTarball(t, map[string][]byte{
|
||||
"banger": []byte("BANGER"),
|
||||
"bangerd": []byte("BANGERD"),
|
||||
"banger-vsock-agent": []byte("AGENT"),
|
||||
})
|
||||
tarball := filepath.Join(t.TempDir(), "release.tar.gz")
|
||||
if err := os.WriteFile(tarball, body, 0o644); err != nil {
|
||||
t.Fatalf("write tarball: %v", err)
|
||||
}
|
||||
staging := filepath.Join(t.TempDir(), "staged")
|
||||
|
||||
got, err := StageTarball(tarball, staging)
|
||||
if err != nil {
|
||||
t.Fatalf("StageTarball: %v", err)
|
||||
}
|
||||
for _, p := range []string{got.BangerPath, got.BangerdPath, got.VsockAgentPath} {
|
||||
info, err := os.Stat(p)
|
||||
if err != nil {
|
||||
t.Fatalf("stat %s: %v", p, err)
|
||||
}
|
||||
if info.Mode().Perm() != 0o755 {
|
||||
t.Errorf("%s mode = %o, want 0755", p, info.Mode().Perm())
|
||||
}
|
||||
}
|
||||
bs, _ := os.ReadFile(got.BangerPath)
|
||||
if string(bs) != "BANGER" {
|
||||
t.Fatalf("banger content = %q", bs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStageTarballRejectsExtraEntry(t *testing.T) {
|
||||
body := makeReleaseTarball(t, map[string][]byte{
|
||||
"banger": []byte("a"),
|
||||
"bangerd": []byte("b"),
|
||||
"banger-vsock-agent": []byte("c"),
|
||||
"bonus.txt": []byte("not allowed"),
|
||||
})
|
||||
tarball := filepath.Join(t.TempDir(), "rel.tar.gz")
|
||||
_ = os.WriteFile(tarball, body, 0o644)
|
||||
_, err := StageTarball(tarball, t.TempDir())
|
||||
if err == nil || !strings.Contains(err.Error(), "unexpected entry") {
|
||||
t.Fatalf("err = %v, want unexpected-entry rejection", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStageTarballRejectsMissingEntry(t *testing.T) {
|
||||
body := makeReleaseTarball(t, map[string][]byte{
|
||||
"banger": []byte("a"),
|
||||
"bangerd": []byte("b"),
|
||||
// banger-vsock-agent intentionally missing
|
||||
})
|
||||
tarball := filepath.Join(t.TempDir(), "rel.tar.gz")
|
||||
_ = os.WriteFile(tarball, body, 0o644)
|
||||
_, err := StageTarball(tarball, t.TempDir())
|
||||
if err == nil || !strings.Contains(err.Error(), "missing required entry") {
|
||||
t.Fatalf("err = %v, want missing-required rejection", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStageTarballRejectsPathTraversal(t *testing.T) {
|
||||
// Build the tarball manually so we can inject a `../` entry —
|
||||
// makeReleaseTarball's expected-entry filter would otherwise
|
||||
// catch it earlier.
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
tw := tar.NewWriter(gz)
|
||||
for _, e := range []struct{ name, body string }{
|
||||
{"banger", "a"},
|
||||
{"bangerd", "b"},
|
||||
{"../escape", "x"},
|
||||
} {
|
||||
_ = tw.WriteHeader(&tar.Header{Name: e.name, Size: int64(len(e.body)), Mode: 0o755, Typeflag: tar.TypeReg})
|
||||
_, _ = tw.Write([]byte(e.body))
|
||||
}
|
||||
_ = tw.Close()
|
||||
_ = gz.Close()
|
||||
tarball := filepath.Join(t.TempDir(), "rel.tar.gz")
|
||||
_ = os.WriteFile(tarball, buf.Bytes(), 0o644)
|
||||
_, err := StageTarball(tarball, t.TempDir())
|
||||
if err == nil || !strings.Contains(err.Error(), "unsafe path") {
|
||||
t.Fatalf("err = %v, want unsafe-path rejection", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwapAndRollback(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
binDir := filepath.Join(root, "bin")
|
||||
libDir := filepath.Join(root, "lib", "banger")
|
||||
if err := os.MkdirAll(binDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(libDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, p := range []string{
|
||||
filepath.Join(binDir, "banger"),
|
||||
filepath.Join(binDir, "bangerd"),
|
||||
filepath.Join(libDir, "banger-vsock-agent"),
|
||||
} {
|
||||
if err := os.WriteFile(p, []byte("OLD-"+filepath.Base(p)), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
staging := filepath.Join(root, "staging")
|
||||
_ = os.MkdirAll(staging, 0o700)
|
||||
staged := StagedRelease{
|
||||
BangerPath: filepath.Join(staging, "banger"),
|
||||
BangerdPath: filepath.Join(staging, "bangerd"),
|
||||
VsockAgentPath: filepath.Join(staging, "banger-vsock-agent"),
|
||||
}
|
||||
for _, pair := range []struct{ p, body string }{
|
||||
{staged.BangerPath, "NEW-banger"},
|
||||
{staged.BangerdPath, "NEW-bangerd"},
|
||||
{staged.VsockAgentPath, "NEW-banger-vsock-agent"},
|
||||
} {
|
||||
if err := os.WriteFile(pair.p, []byte(pair.body), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
targets := InstallTargets{
|
||||
Banger: filepath.Join(binDir, "banger"),
|
||||
Bangerd: filepath.Join(binDir, "bangerd"),
|
||||
VsockAgent: filepath.Join(libDir, "banger-vsock-agent"),
|
||||
}
|
||||
|
||||
res, err := Swap(staged, targets)
|
||||
if err != nil {
|
||||
t.Fatalf("Swap: %v", err)
|
||||
}
|
||||
if len(res.SwappedTargets) != 3 {
|
||||
t.Fatalf("SwappedTargets len = %d, want 3", len(res.SwappedTargets))
|
||||
}
|
||||
for _, p := range []string{targets.Banger, targets.Bangerd, targets.VsockAgent} {
|
||||
got, _ := os.ReadFile(p)
|
||||
want := "NEW-" + filepath.Base(p)
|
||||
if string(got) != want {
|
||||
t.Fatalf("%s content = %q, want %q", p, got, want)
|
||||
}
|
||||
prev, err := os.ReadFile(p + previousSuffix)
|
||||
if err != nil {
|
||||
t.Fatalf("missing backup at %s.previous: %v", p, err)
|
||||
}
|
||||
if string(prev) != "OLD-"+filepath.Base(p) {
|
||||
t.Fatalf(".previous content = %q", prev)
|
||||
}
|
||||
}
|
||||
|
||||
if err := Rollback(res); err != nil {
|
||||
t.Fatalf("Rollback: %v", err)
|
||||
}
|
||||
for _, p := range []string{targets.Banger, targets.Bangerd, targets.VsockAgent} {
|
||||
got, _ := os.ReadFile(p)
|
||||
want := "OLD-" + filepath.Base(p)
|
||||
if string(got) != want {
|
||||
t.Fatalf("post-rollback %s = %q, want %q", p, got, want)
|
||||
}
|
||||
if _, err := os.Stat(p + previousSuffix); !os.IsNotExist(err) {
|
||||
t.Fatalf(".previous should be cleaned after rollback; stat err = %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwapPartialFailureRollsBackCleanly(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
binDir := filepath.Join(root, "bin")
|
||||
libDir := filepath.Join(root, "lib", "banger")
|
||||
if err := os.MkdirAll(binDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(libDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Pre-create the two binaries that will swap successfully.
|
||||
for _, p := range []string{
|
||||
filepath.Join(binDir, "bangerd"),
|
||||
filepath.Join(libDir, "banger-vsock-agent"),
|
||||
} {
|
||||
_ = os.WriteFile(p, []byte("OLD-"+filepath.Base(p)), 0o755)
|
||||
}
|
||||
|
||||
staging := filepath.Join(root, "staging")
|
||||
_ = os.MkdirAll(staging, 0o700)
|
||||
staged := StagedRelease{
|
||||
BangerPath: filepath.Join(staging, "banger"),
|
||||
BangerdPath: filepath.Join(staging, "bangerd"),
|
||||
VsockAgentPath: filepath.Join(staging, "banger-vsock-agent"),
|
||||
}
|
||||
for _, pair := range []struct{ p, body string }{
|
||||
{staged.BangerPath, "NEW-banger"},
|
||||
{staged.BangerdPath, "NEW-bangerd"},
|
||||
{staged.VsockAgentPath, "NEW-banger-vsock-agent"},
|
||||
} {
|
||||
_ = os.WriteFile(pair.p, []byte(pair.body), 0o755)
|
||||
}
|
||||
|
||||
// Block the banger swap (which is LAST in the order) by putting
|
||||
// a regular file where its parent dir should be — MkdirAll fails
|
||||
// with "not a directory". Vsock + bangerd succeed first.
|
||||
blockedParent := filepath.Join(root, "blocked-bin")
|
||||
if err := os.WriteFile(blockedParent, []byte("blocking"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
targets := InstallTargets{
|
||||
Banger: filepath.Join(blockedParent, "banger"),
|
||||
Bangerd: filepath.Join(binDir, "bangerd"),
|
||||
VsockAgent: filepath.Join(libDir, "banger-vsock-agent"),
|
||||
}
|
||||
|
||||
res, err := Swap(staged, targets)
|
||||
if err == nil {
|
||||
t.Fatal("Swap unexpectedly succeeded; banger parent should be blocked by a regular file")
|
||||
}
|
||||
if len(res.SwappedTargets) != 2 {
|
||||
t.Fatalf("SwappedTargets = %v, want 2 (vsock + bangerd before banger failed)", res.SwappedTargets)
|
||||
}
|
||||
// Rolling back the partial swap should restore the filesystem.
|
||||
if err := Rollback(res); err != nil {
|
||||
t.Fatalf("Rollback after partial swap: %v", err)
|
||||
}
|
||||
for _, p := range res.SwappedTargets {
|
||||
got, _ := os.ReadFile(p)
|
||||
want := "OLD-" + filepath.Base(p)
|
||||
if string(got) != want {
|
||||
t.Fatalf("post-rollback %s = %q", p, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadReleaseHappyPath(t *testing.T) {
|
||||
tarballBody := []byte("fake tarball bytes")
|
||||
tarballSHA := sha256Hex(tarballBody)
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write(tarballBody)
|
||||
})
|
||||
mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "%s banger.tar.gz\n", tarballSHA)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
dst := filepath.Join(t.TempDir(), "out.tar.gz")
|
||||
sums, err := DownloadRelease(context.Background(), srv.Client(), Release{
|
||||
Version: "v0.1.0",
|
||||
TarballURL: srv.URL + "/banger.tar.gz",
|
||||
SHA256SumsURL: srv.URL + "/SHA256SUMS",
|
||||
}, dst)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadRelease: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(sums), "banger.tar.gz") {
|
||||
t.Fatalf("returned sums body missing tarball name: %q", sums)
|
||||
}
|
||||
got, _ := os.ReadFile(dst)
|
||||
if !bytes.Equal(got, tarballBody) {
|
||||
t.Fatalf("downloaded body differs from served body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadReleaseRejectsTarballMissingFromSums(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("body"))
|
||||
})
|
||||
mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Sums for a different file; tarball name not listed.
|
||||
fmt.Fprintf(w, "%s unrelated\n", sha256Hex([]byte("body")))
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
dst := filepath.Join(t.TempDir(), "out.tar.gz")
|
||||
_, err := DownloadRelease(context.Background(), srv.Client(), Release{
|
||||
TarballURL: srv.URL + "/banger.tar.gz",
|
||||
SHA256SumsURL: srv.URL + "/SHA256SUMS",
|
||||
}, dst)
|
||||
if err == nil || !strings.Contains(err.Error(), "does not list") {
|
||||
t.Fatalf("err = %v, want SHA256SUMS-missing rejection", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadReleasePropagatesShaMismatch(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("served body"))
|
||||
})
|
||||
mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Wrong digest for the tarball.
|
||||
fmt.Fprintf(w, "%s banger.tar.gz\n", sha256Hex([]byte("expected body")))
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
dst := filepath.Join(t.TempDir(), "out.tar.gz")
|
||||
_, err := DownloadRelease(context.Background(), srv.Client(), Release{
|
||||
TarballURL: srv.URL + "/banger.tar.gz",
|
||||
SHA256SumsURL: srv.URL + "/SHA256SUMS",
|
||||
}, dst)
|
||||
if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") {
|
||||
t.Fatalf("err = %v, want sha256 mismatch", err)
|
||||
}
|
||||
if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("partial tarball should be removed; stat err = %v", statErr)
|
||||
}
|
||||
}
|
||||
107
internal/updater/stage.go
Normal file
107
internal/updater/stage.go
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
package updater
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// expectedReleaseEntries is the canonical set of files a release
|
||||
// tarball must contain. Anything missing OR anything extra is
|
||||
// rejected — banger update should not unpack arbitrary files into
|
||||
// the staging dir.
|
||||
var expectedReleaseEntries = []string{
|
||||
"banger",
|
||||
"bangerd",
|
||||
"banger-vsock-agent",
|
||||
}
|
||||
|
||||
// StagedRelease describes the result of unpacking a release tarball
|
||||
// into a staging directory.
|
||||
type StagedRelease struct {
|
||||
BangerPath string
|
||||
BangerdPath string
|
||||
VsockAgentPath string
|
||||
}
|
||||
|
||||
// StageTarball reads the gzipped tar at tarballPath and extracts the
|
||||
// expected three banger binaries into stagingDir. Any extra entries,
|
||||
// any path-traversal members, any non-regular-file members, and any
|
||||
// missing required entry are rejected.
|
||||
//
|
||||
// The extracted binaries are mode 0o755 regardless of what the
|
||||
// tarball claims — banger update is a privileged operation; we
|
||||
// don't honour weird modes from the wire.
|
||||
func StageTarball(tarballPath, stagingDir string) (StagedRelease, error) {
|
||||
if err := os.MkdirAll(stagingDir, 0o700); err != nil {
|
||||
return StagedRelease{}, err
|
||||
}
|
||||
f, err := os.Open(tarballPath)
|
||||
if err != nil {
|
||||
return StagedRelease{}, err
|
||||
}
|
||||
defer f.Close()
|
||||
gz, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return StagedRelease{}, fmt.Errorf("open gzip: %w", err)
|
||||
}
|
||||
defer gz.Close()
|
||||
|
||||
expected := map[string]struct{}{}
|
||||
for _, name := range expectedReleaseEntries {
|
||||
expected[name] = struct{}{}
|
||||
}
|
||||
seen := map[string]string{}
|
||||
|
||||
tr := tar.NewReader(gz)
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return StagedRelease{}, fmt.Errorf("read tar: %w", err)
|
||||
}
|
||||
rel := filepath.Clean(hdr.Name)
|
||||
if rel == "." || rel == string(filepath.Separator) {
|
||||
continue
|
||||
}
|
||||
if filepath.IsAbs(rel) || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
|
||||
return StagedRelease{}, fmt.Errorf("unsafe path in tarball: %q", hdr.Name)
|
||||
}
|
||||
if _, ok := expected[rel]; !ok {
|
||||
return StagedRelease{}, fmt.Errorf("unexpected entry in release tarball: %q (allowed: %v)", hdr.Name, expectedReleaseEntries)
|
||||
}
|
||||
if hdr.Typeflag != tar.TypeReg {
|
||||
return StagedRelease{}, fmt.Errorf("entry %q is not a regular file (typeflag %d)", hdr.Name, hdr.Typeflag)
|
||||
}
|
||||
dst := filepath.Join(stagingDir, rel)
|
||||
out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return StagedRelease{}, err
|
||||
}
|
||||
if _, err := io.Copy(out, tr); err != nil {
|
||||
_ = out.Close()
|
||||
return StagedRelease{}, err
|
||||
}
|
||||
if err := out.Close(); err != nil {
|
||||
return StagedRelease{}, err
|
||||
}
|
||||
seen[rel] = dst
|
||||
}
|
||||
|
||||
for _, want := range expectedReleaseEntries {
|
||||
if _, ok := seen[want]; !ok {
|
||||
return StagedRelease{}, fmt.Errorf("release tarball is missing required entry %q", want)
|
||||
}
|
||||
}
|
||||
return StagedRelease{
|
||||
BangerPath: seen["banger"],
|
||||
BangerdPath: seen["bangerd"],
|
||||
VsockAgentPath: seen["banger-vsock-agent"],
|
||||
}, nil
|
||||
}
|
||||
135
internal/updater/swap.go
Normal file
135
internal/updater/swap.go
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
package updater
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"banger/internal/system"
|
||||
)
|
||||
|
||||
// previousSuffix is the filename suffix appended to the
|
||||
// pre-swap binary so Rollback knows where to restore from.
|
||||
// Pinned as a constant so the swap and rollback paths can't
|
||||
// disagree on it.
|
||||
const previousSuffix = ".previous"
|
||||
|
||||
// InstallTargets lists the absolute on-disk paths the updater
|
||||
// writes during a swap. Hardcoded to the system-install layout —
|
||||
// banger update is a system-mode operation; the developer non-
|
||||
// system-mode flow doesn't go through this code path.
|
||||
type InstallTargets struct {
|
||||
Banger string // /usr/local/bin/banger
|
||||
Bangerd string // /usr/local/bin/bangerd
|
||||
VsockAgent string // /usr/local/lib/banger/banger-vsock-agent
|
||||
}
|
||||
|
||||
// DefaultInstallTargets returns the canonical paths a system install
|
||||
// uses (`banger system install` writes to these). Exposed for
|
||||
// testability; production callers use it as-is.
|
||||
func DefaultInstallTargets() InstallTargets {
|
||||
return InstallTargets{
|
||||
Banger: "/usr/local/bin/banger",
|
||||
Bangerd: "/usr/local/bin/bangerd",
|
||||
VsockAgent: "/usr/local/lib/banger/banger-vsock-agent",
|
||||
}
|
||||
}
|
||||
|
||||
// SwapResult records what was swapped, so Rollback knows what to
|
||||
// undo. A nil SwapResult means no swap was attempted yet (nothing
|
||||
// to roll back).
|
||||
type SwapResult struct {
|
||||
Targets InstallTargets
|
||||
// SwappedTargets is the subset of Targets that were actually
|
||||
// renamed into place. If the second of three Renames fails,
|
||||
// SwappedTargets contains only the first; rollback unwinds in
|
||||
// reverse order.
|
||||
SwappedTargets []string
|
||||
}
|
||||
|
||||
// Swap atomically replaces each of the three banger binaries with
|
||||
// its staged counterpart. Order:
|
||||
//
|
||||
// 1. banger-vsock-agent (companion; not currently running, swap is safe)
|
||||
// 2. bangerd (the to-be-restarted daemon binary)
|
||||
// 3. banger (the CLI; least disruptive last)
|
||||
//
|
||||
// Each AtomicReplace leaves a `.previous` backup so Rollback can
|
||||
// restore the prior install if a later step fails.
|
||||
//
|
||||
// Returns the SwapResult even on partial failure so the caller can
|
||||
// drive Rollback against what HAS been swapped.
|
||||
func Swap(staged StagedRelease, targets InstallTargets) (SwapResult, error) {
|
||||
res := SwapResult{Targets: targets}
|
||||
steps := []struct {
|
||||
src, dst string
|
||||
}{
|
||||
{src: staged.VsockAgentPath, dst: targets.VsockAgent},
|
||||
{src: staged.BangerdPath, dst: targets.Bangerd},
|
||||
{src: staged.BangerPath, dst: targets.Banger},
|
||||
}
|
||||
for _, s := range steps {
|
||||
if err := ensureParentDir(s.dst); err != nil {
|
||||
return res, fmt.Errorf("prepare %s: %w", s.dst, err)
|
||||
}
|
||||
if err := system.AtomicReplace(s.src, s.dst, previousSuffix); err != nil {
|
||||
return res, fmt.Errorf("swap %s: %w", s.dst, err)
|
||||
}
|
||||
res.SwappedTargets = append(res.SwappedTargets, s.dst)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Rollback undoes a Swap by restoring each .previous backup in
|
||||
// reverse order. Returns the joined errors of every individual
|
||||
// rollback that failed; a half-rolled-back tree is the worst case
|
||||
// and the operator gets enough information to fix it manually.
|
||||
//
|
||||
// Tolerant of partial input — passing a SwapResult that only
|
||||
// recorded the first two of three swaps rolls back exactly those
|
||||
// two.
|
||||
func Rollback(res SwapResult) error {
|
||||
var errs []error
|
||||
for i := len(res.SwappedTargets) - 1; i >= 0; i-- {
|
||||
dst := res.SwappedTargets[i]
|
||||
if err := system.AtomicReplaceRollback(dst, previousSuffix); err != nil {
|
||||
errs = append(errs, fmt.Errorf("rollback %s: %w", dst, err))
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// CleanupBackups removes every .previous backup left behind by a
|
||||
// successful update. Called after `banger doctor` confirms the new
|
||||
// install is healthy — we don't keep ancient backups around forever.
|
||||
func CleanupBackups(res SwapResult) error {
|
||||
var errs []error
|
||||
for _, dst := range res.SwappedTargets {
|
||||
if err := os.Remove(dst + previousSuffix); err != nil && !os.IsNotExist(err) {
|
||||
errs = append(errs, fmt.Errorf("remove %s%s: %w", dst, previousSuffix, err))
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func ensureParentDir(p string) error {
|
||||
parent := dirOf(p)
|
||||
if parent == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := os.Stat(parent); err == nil {
|
||||
return nil
|
||||
}
|
||||
return os.MkdirAll(parent, 0o755)
|
||||
}
|
||||
|
||||
// dirOf is a tiny path.Dir wrapper that returns "" for paths with
|
||||
// no separator (so the ensure-parent logic doesn't try to mkdir(".")).
|
||||
func dirOf(p string) string {
|
||||
for i := len(p) - 1; i >= 0; i-- {
|
||||
if p[i] == '/' {
|
||||
return p[:i]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue