banger/internal/imagecat/fetch_test.go
Thales Maciel 4004ce2e7e
imagecat,kernelcat: bound staged download, hash before extract
Both Fetch flows previously streamed resp.Body straight into
zstd → tar → on-disk extractor with the SHA256 check tacked on at
the END. A bad mirror or an attacker that's compromised the catalog
host could ship a multi-gigabyte tarball, watch banger expand it to
disk, and only THEN see the helpful "sha256 mismatch" message —
having already filled the host filesystem.

Reorder the operations: stage the compressed tarball to a temp file
under the destination directory through an io.LimitReader (cap +1
bytes), hash on the way in, refuse to decompress if either the cap
trips or the SHA mismatches. Worst-case disk use is bounded by the
cap, not by the source.

Cap is exposed as a package var (MaxFetchedBundleBytes,
MaxFetchedKernelBytes) so callers can tune per-deployment and tests
can squeeze it down to provoke the rejection. Default 8 GiB —
generous enough for a 4 GiB rootfs (which compresses to ~1-2 GiB),
tight enough to make a "fill the host disk" attack expensive.

The temp file lives in the destination dir so extraction stays on
the same filesystem and we don't pay for cross-FS rename. defer
os.Remove cleans up; the existing per-package cleanup() handler
still removes any partial extraction on hash mismatch / extraction
failure.

Tests: each package gets a TestFetchRejectsOversizedTarballBefore
Extraction that sets the cap to 64 bytes, points Fetch at a multi-KB
tarball, and asserts (a) error mentions "cap", (b) destination dir
is left clean (no leaked rootfs / manifest / kernel tree). All
existing tests still pass — happy path, hash mismatch, missing
files, path traversal, HTTP error, etc.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 16:09:55 -03:00

286 lines
8.3 KiB
Go

package imagecat
import (
"archive/tar"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/klauspost/compress/zstd"
)
// makeBundle builds a valid .tar.zst bundle with the given manifest
// and rootfs bytes. Returns the bundle bytes and their sha256 hex.
func makeBundle(t *testing.T, manifest Manifest, rootfs []byte) ([]byte, string) {
t.Helper()
var rawTar bytes.Buffer
tw := tar.NewWriter(&rawTar)
manifestJSON, err := json.Marshal(manifest)
if err != nil {
t.Fatal(err)
}
entries := []struct {
name string
data []byte
}{
{RootfsFilename, rootfs},
{ManifestFilename, manifestJSON},
}
for _, e := range entries {
if err := tw.WriteHeader(&tar.Header{
Name: e.name,
Size: int64(len(e.data)),
Mode: 0o644,
Typeflag: tar.TypeReg,
}); err != nil {
t.Fatal(err)
}
if _, err := tw.Write(e.data); err != nil {
t.Fatal(err)
}
}
if err := tw.Close(); err != nil {
t.Fatal(err)
}
var zstBuf bytes.Buffer
zw, err := zstd.NewWriter(&zstBuf)
if err != nil {
t.Fatal(err)
}
if _, err := io.Copy(zw, &rawTar); err != nil {
t.Fatal(err)
}
if err := zw.Close(); err != nil {
t.Fatal(err)
}
sum := sha256.Sum256(zstBuf.Bytes())
return zstBuf.Bytes(), hex.EncodeToString(sum[:])
}
func serveBundle(t *testing.T, payload []byte) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(payload)
}))
}
func TestFetchHappyPath(t *testing.T) {
manifest := Manifest{
Name: "debian-bookworm",
Distro: "debian",
Arch: "x86_64",
KernelRef: "generic-6.12",
}
rootfs := []byte("not-actually-an-ext4-but-that's-fine-for-the-test")
bundle, sum := makeBundle(t, manifest, rootfs)
srv := serveBundle(t, bundle)
t.Cleanup(srv.Close)
dest := t.TempDir()
got, err := Fetch(context.Background(), srv.Client(), dest, CatEntry{
Name: "debian-bookworm",
TarballURL: srv.URL + "/bundle.tar.zst",
TarballSHA256: sum,
})
if err != nil {
t.Fatalf("Fetch: %v", err)
}
if got.Name != "debian-bookworm" || got.KernelRef != "generic-6.12" || got.Distro != "debian" {
t.Fatalf("manifest = %+v", got)
}
if b, err := os.ReadFile(filepath.Join(dest, RootfsFilename)); err != nil || !bytes.Equal(b, rootfs) {
t.Fatalf("rootfs content mismatch: err=%v, %q", err, b)
}
if _, err := os.Stat(filepath.Join(dest, ManifestFilename)); err != nil {
t.Fatalf("manifest missing: %v", err)
}
}
func TestFetchRejectsSHA256Mismatch(t *testing.T) {
manifest := Manifest{Name: "debian-bookworm"}
bundle, _ := makeBundle(t, manifest, []byte("abc"))
srv := serveBundle(t, bundle)
t.Cleanup(srv.Close)
dest := t.TempDir()
_, err := Fetch(context.Background(), srv.Client(), dest, CatEntry{
Name: "debian-bookworm",
TarballURL: srv.URL + "/bundle.tar.zst",
TarballSHA256: "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef",
})
if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") {
t.Fatalf("want sha256 mismatch error, got %v", err)
}
// Cleanup: dest should not contain partial files.
if _, err := os.Stat(filepath.Join(dest, RootfsFilename)); !os.IsNotExist(err) {
t.Fatalf("rootfs should be cleaned up on sha256 failure, got %v", err)
}
if _, err := os.Stat(filepath.Join(dest, ManifestFilename)); !os.IsNotExist(err) {
t.Fatalf("manifest should be cleaned up on sha256 failure, got %v", err)
}
}
// TestFetchRejectsOversizedTarballBeforeExtraction pins the new
// disk-bound cap: by setting MaxFetchedBundleBytes very low, the
// staged-tarball download must trip the limit and refuse to even
// decompress, leaving the destination dir clean. This is the
// "compromised mirror floods the host" scenario.
func TestFetchRejectsOversizedTarballBeforeExtraction(t *testing.T) {
manifest := Manifest{Name: "debian-bookworm"}
bundle, sum := makeBundle(t, manifest, bytes.Repeat([]byte("x"), 4096))
srv := serveBundle(t, bundle)
t.Cleanup(srv.Close)
prev := MaxFetchedBundleBytes
MaxFetchedBundleBytes = 64
t.Cleanup(func() { MaxFetchedBundleBytes = prev })
dest := t.TempDir()
_, err := Fetch(context.Background(), srv.Client(), dest, CatEntry{
Name: "debian-bookworm",
TarballURL: srv.URL + "/bundle.tar.zst",
TarballSHA256: sum,
})
if err == nil {
t.Fatal("Fetch succeeded against an oversized tarball; want size-cap rejection")
}
if !strings.Contains(err.Error(), "cap") {
t.Fatalf("err = %v, want size-cap message", err)
}
// dest must be untouched: no rootfs, no manifest, no leftover tmp.
entries, _ := os.ReadDir(dest)
if len(entries) != 0 {
var names []string
for _, e := range entries {
names = append(names, e.Name())
}
t.Fatalf("dest left dirty after size-cap rejection: %v", names)
}
}
func TestFetchRejectsUnexpectedTarEntry(t *testing.T) {
// Hand-roll a bundle with a third, disallowed entry.
var rawTar bytes.Buffer
tw := tar.NewWriter(&rawTar)
for _, e := range []struct{ name, data string }{
{RootfsFilename, "rootfs"},
{ManifestFilename, `{"name":"x"}`},
{"extra", "should be rejected"},
} {
if err := tw.WriteHeader(&tar.Header{
Name: e.name,
Size: int64(len(e.data)),
Mode: 0o644,
Typeflag: tar.TypeReg,
}); err != nil {
t.Fatal(err)
}
if _, err := tw.Write([]byte(e.data)); err != nil {
t.Fatal(err)
}
}
if err := tw.Close(); err != nil {
t.Fatal(err)
}
var zstBuf bytes.Buffer
zw, _ := zstd.NewWriter(&zstBuf)
_, _ = io.Copy(zw, &rawTar)
_ = zw.Close()
sum := sha256.Sum256(zstBuf.Bytes())
srv := serveBundle(t, zstBuf.Bytes())
t.Cleanup(srv.Close)
_, err := Fetch(context.Background(), srv.Client(), t.TempDir(), CatEntry{
Name: "x",
TarballURL: srv.URL + "/bundle.tar.zst",
TarballSHA256: hex.EncodeToString(sum[:]),
})
if err == nil || !strings.Contains(err.Error(), "unexpected bundle entry") {
t.Fatalf("want unexpected entry error, got %v", err)
}
}
func TestFetchRejectsMissingManifest(t *testing.T) {
// Bundle with only rootfs.
var rawTar bytes.Buffer
tw := tar.NewWriter(&rawTar)
_ = tw.WriteHeader(&tar.Header{Name: RootfsFilename, Size: 3, Mode: 0o644, Typeflag: tar.TypeReg})
_, _ = tw.Write([]byte("abc"))
_ = tw.Close()
var zstBuf bytes.Buffer
zw, _ := zstd.NewWriter(&zstBuf)
_, _ = io.Copy(zw, &rawTar)
_ = zw.Close()
sum := sha256.Sum256(zstBuf.Bytes())
srv := serveBundle(t, zstBuf.Bytes())
t.Cleanup(srv.Close)
_, err := Fetch(context.Background(), srv.Client(), t.TempDir(), CatEntry{
Name: "x",
TarballURL: srv.URL + "/bundle.tar.zst",
TarballSHA256: hex.EncodeToString(sum[:]),
})
if err == nil || !strings.Contains(err.Error(), "missing required files") {
t.Fatalf("want missing-required-files error, got %v", err)
}
}
func TestFetchRejectsHTTPFailure(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "not found", http.StatusNotFound)
}))
t.Cleanup(srv.Close)
_, err := Fetch(context.Background(), srv.Client(), t.TempDir(), CatEntry{
Name: "x",
TarballURL: srv.URL + "/missing.tar.zst",
TarballSHA256: "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef",
})
if err == nil || !strings.Contains(err.Error(), "HTTP") {
t.Fatalf("want HTTP error, got %v", err)
}
}
func TestFetchRejectsEmptyURL(t *testing.T) {
_, err := Fetch(context.Background(), http.DefaultClient, t.TempDir(), CatEntry{
Name: "x",
TarballURL: "",
TarballSHA256: "abc",
})
if err == nil || !strings.Contains(err.Error(), "no tarball URL") {
t.Fatalf("want no-URL error, got %v", err)
}
}
func TestFetchRejectsEmptySHA256(t *testing.T) {
_, err := Fetch(context.Background(), http.DefaultClient, t.TempDir(), CatEntry{
Name: "x",
TarballURL: "https://example.com/x.tar.zst",
})
if err == nil || !strings.Contains(err.Error(), "no tarball sha256") {
t.Fatalf("want no-sha error, got %v", err)
}
}
func TestFetchRejectsInvalidName(t *testing.T) {
_, err := Fetch(context.Background(), http.DefaultClient, t.TempDir(), CatEntry{
Name: "",
TarballURL: "https://example.com/x.tar.zst",
TarballSHA256: "abc",
})
if err == nil || !strings.Contains(err.Error(), "image name is required") {
t.Fatalf("want name-required error, got %v", err)
}
}