banger/internal/imagecat/fetch.go
Thales Maciel 4004ce2e7e
imagecat,kernelcat: bound staged download, hash before extract
Both Fetch flows previously streamed resp.Body straight into
zstd → tar → on-disk extractor with the SHA256 check tacked on at
the END. A bad mirror or an attacker that's compromised the catalog
host could ship a multi-gigabyte tarball, watch banger expand it to
disk, and only THEN see the helpful "sha256 mismatch" message —
having already filled the host filesystem.

Reorder the operations: stage the compressed tarball to a temp file
under the destination directory through an io.LimitReader (cap +1
bytes), hash on the way in, refuse to decompress if either the cap
trips or the SHA mismatches. Worst-case disk use is bounded by the
cap, not by the source.

Cap is exposed as a package var (MaxFetchedBundleBytes,
MaxFetchedKernelBytes) so callers can tune per-deployment and tests
can squeeze it down to provoke the rejection. Default 8 GiB —
generous enough for a 4 GiB rootfs (which compresses to ~1-2 GiB),
tight enough to make a "fill the host disk" attack expensive.

The temp file lives in the destination dir so extraction stays on
the same filesystem and we don't pay for cross-FS rename. defer
os.Remove cleans up; the existing per-package cleanup() handler
still removes any partial extraction on hash mismatch / extraction
failure.

Tests: each package gets a TestFetchRejectsOversizedTarballBefore
Extraction that sets the cap to 64 bytes, points Fetch at a multi-KB
tarball, and asserts (a) error mentions "cap", (b) destination dir
is left clean (no leaked rootfs / manifest / kernel tree). All
existing tests still pass — happy path, hash mismatch, missing
files, path traversal, HTTP error, etc.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 16:09:55 -03:00

211 lines
6.7 KiB
Go

package imagecat
import (
"archive/tar"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/klauspost/compress/zstd"
)
// Bundle filenames expected at the root of the .tar.zst.
const (
RootfsFilename = "rootfs.ext4"
ManifestFilename = "manifest.json"
)
// MaxFetchedBundleBytes caps the compressed bundle download. The
// previous flow streamed straight into a tar+zstd extractor and only
// hashed afterwards, so a malicious or compromised source could
// consume unbounded disk before the SHA mismatch fired. We now stage
// the download to a temp file under destDir, hash it on the way in,
// and refuse to decompress if the hash is wrong — bounding worst-case
// disk use to this cap. Generous enough for any legitimate banger
// rootfs bundle (a 4 GiB ext4 typically zstd-compresses to ~1-2 GiB);
// override per-call by setting this var before invoking Fetch.
var MaxFetchedBundleBytes int64 = 8 << 30 // 8 GiB
// Manifest is the metadata file embedded inside a bundle. It mirrors
// the subset of CatEntry fields that describe the bundle's content
// (the remote URL + sha256 are catalog concerns, not bundle concerns).
type Manifest struct {
Name string `json:"name"`
Distro string `json:"distro,omitempty"`
Arch string `json:"arch,omitempty"`
KernelRef string `json:"kernel_ref,omitempty"`
Description string `json:"description,omitempty"`
}
// Fetch downloads entry's tarball, verifies its SHA256, and writes
// rootfs.ext4 + manifest.json into destDir. Returns the parsed
// manifest. On any error the partially-written files are removed so
// destDir is left in its pre-call state.
//
// destDir must already exist. Fetch does not create it, mirroring
// kernelcat.Fetch so callers manage their own staging.
func Fetch(ctx context.Context, client *http.Client, destDir string, entry CatEntry) (Manifest, error) {
if err := ValidateName(entry.Name); err != nil {
return Manifest{}, err
}
if strings.TrimSpace(entry.TarballURL) == "" {
return Manifest{}, fmt.Errorf("catalog entry %q has no tarball URL", entry.Name)
}
if strings.TrimSpace(entry.TarballSHA256) == "" {
return Manifest{}, fmt.Errorf("catalog entry %q has no tarball sha256", entry.Name)
}
if client == nil {
client = http.DefaultClient
}
absDest, err := filepath.Abs(destDir)
if err != nil {
return Manifest{}, err
}
info, err := os.Stat(absDest)
if err != nil {
return Manifest{}, err
}
if !info.IsDir() {
return Manifest{}, fmt.Errorf("destDir %q is not a directory", destDir)
}
cleanup := func() {
_ = os.Remove(filepath.Join(absDest, RootfsFilename))
_ = os.Remove(filepath.Join(absDest, ManifestFilename))
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, entry.TarballURL, nil)
if err != nil {
return Manifest{}, err
}
resp, err := client.Do(req)
if err != nil {
return Manifest{}, fmt.Errorf("fetch %s: %w", entry.TarballURL, err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return Manifest{}, fmt.Errorf("fetch %s: HTTP %s", entry.TarballURL, resp.Status)
}
if resp.ContentLength > MaxFetchedBundleBytes {
return Manifest{}, fmt.Errorf("tarball advertised %d bytes, exceeds %d-byte cap", resp.ContentLength, MaxFetchedBundleBytes)
}
// Stage the compressed tarball on disk first so we can verify the
// SHA256 BEFORE decompressing or extracting. Cap the read at
// MaxFetchedBundleBytes+1 — anything larger is refused.
tmp, err := os.CreateTemp(absDest, "banger-bundle-*.tar.zst")
if err != nil {
return Manifest{}, fmt.Errorf("create staging file: %w", err)
}
tmpPath := tmp.Name()
defer os.Remove(tmpPath)
hasher := sha256.New()
limited := io.LimitReader(resp.Body, MaxFetchedBundleBytes+1)
n, copyErr := io.Copy(io.MultiWriter(tmp, hasher), limited)
if closeErr := tmp.Close(); copyErr == nil && closeErr != nil {
copyErr = closeErr
}
if copyErr != nil {
return Manifest{}, fmt.Errorf("download tarball: %w", copyErr)
}
if n > MaxFetchedBundleBytes {
return Manifest{}, fmt.Errorf("tarball exceeded %d-byte cap before sha256 check", MaxFetchedBundleBytes)
}
got := hex.EncodeToString(hasher.Sum(nil))
if !strings.EqualFold(got, entry.TarballSHA256) {
return Manifest{}, fmt.Errorf("tarball sha256 mismatch: got %s, want %s", got, entry.TarballSHA256)
}
src, err := os.Open(tmpPath)
if err != nil {
return Manifest{}, fmt.Errorf("reopen staged tarball: %w", err)
}
defer src.Close()
zr, err := zstd.NewReader(src)
if err != nil {
return Manifest{}, fmt.Errorf("init zstd: %w", err)
}
defer zr.Close()
if err := extractBundle(zr, absDest); err != nil {
cleanup()
return Manifest{}, err
}
if _, err := os.Stat(filepath.Join(absDest, RootfsFilename)); err != nil {
cleanup()
return Manifest{}, fmt.Errorf("bundle missing %s: %w", RootfsFilename, err)
}
manifestData, err := os.ReadFile(filepath.Join(absDest, ManifestFilename))
if err != nil {
cleanup()
return Manifest{}, fmt.Errorf("read manifest: %w", err)
}
var manifest Manifest
if err := json.Unmarshal(manifestData, &manifest); err != nil {
cleanup()
return Manifest{}, fmt.Errorf("parse manifest: %w", err)
}
if strings.TrimSpace(manifest.Name) == "" {
manifest.Name = entry.Name
}
return manifest, nil
}
// extractBundle writes the bundle's two regular-file entries into
// absDest, refusing any other member type, any extra entry, and any
// path that escapes absDest.
func extractBundle(r io.Reader, absDest string) error {
tr := tar.NewReader(r)
seen := map[string]bool{}
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("read bundle: %w", err)
}
rel := filepath.Clean(hdr.Name)
if rel == "." || rel == string(filepath.Separator) {
continue
}
if filepath.IsAbs(rel) || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
return fmt.Errorf("unsafe path in bundle: %q", hdr.Name)
}
if rel != RootfsFilename && rel != ManifestFilename {
return fmt.Errorf("unexpected bundle entry %q (expected %s or %s at the root)", hdr.Name, RootfsFilename, ManifestFilename)
}
if hdr.Typeflag != tar.TypeReg {
return fmt.Errorf("bundle entry %q is not a regular file", hdr.Name)
}
dst := filepath.Join(absDest, rel)
f, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
if err != nil {
return err
}
if _, err := io.Copy(f, tr); err != nil {
_ = f.Close()
return err
}
if err := f.Close(); err != nil {
return err
}
seen[rel] = true
}
if !seen[RootfsFilename] || !seen[ManifestFilename] {
return fmt.Errorf("bundle is missing required files: want both %s and %s", RootfsFilename, ManifestFilename)
}
return nil
}