Both Fetch flows previously streamed resp.Body straight into zstd → tar → on-disk extractor with the SHA256 check tacked on at the END. A bad mirror or an attacker that's compromised the catalog host could ship a multi-gigabyte tarball, watch banger expand it to disk, and only THEN see the helpful "sha256 mismatch" message — having already filled the host filesystem. Reorder the operations: stage the compressed tarball to a temp file under the destination directory through an io.LimitReader (cap +1 bytes), hash on the way in, refuse to decompress if either the cap trips or the SHA mismatches. Worst-case disk use is bounded by the cap, not by the source. Cap is exposed as a package var (MaxFetchedBundleBytes, MaxFetchedKernelBytes) so callers can tune per-deployment and tests can squeeze it down to provoke the rejection. Default 8 GiB — generous enough for a 4 GiB rootfs (which compresses to ~1-2 GiB), tight enough to make a "fill the host disk" attack expensive. The temp file lives in the destination dir so extraction stays on the same filesystem and we don't pay for cross-FS rename. defer os.Remove cleans up; the existing per-package cleanup() handler still removes any partial extraction on hash mismatch / extraction failure. Tests: each package gets a TestFetchRejectsOversizedTarballBefore Extraction that sets the cap to 64 bytes, points Fetch at a multi-KB tarball, and asserts (a) error mentions "cap", (b) destination dir is left clean (no leaked rootfs / manifest / kernel tree). All existing tests still pass — happy path, hash mismatch, missing files, path traversal, HTTP error, etc. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
211 lines
6.7 KiB
Go
211 lines
6.7 KiB
Go
package imagecat
|
|
|
|
import (
|
|
"archive/tar"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/klauspost/compress/zstd"
|
|
)
|
|
|
|
// Bundle filenames expected at the root of the .tar.zst.
|
|
const (
|
|
RootfsFilename = "rootfs.ext4"
|
|
ManifestFilename = "manifest.json"
|
|
)
|
|
|
|
// MaxFetchedBundleBytes caps the compressed bundle download. The
|
|
// previous flow streamed straight into a tar+zstd extractor and only
|
|
// hashed afterwards, so a malicious or compromised source could
|
|
// consume unbounded disk before the SHA mismatch fired. We now stage
|
|
// the download to a temp file under destDir, hash it on the way in,
|
|
// and refuse to decompress if the hash is wrong — bounding worst-case
|
|
// disk use to this cap. Generous enough for any legitimate banger
|
|
// rootfs bundle (a 4 GiB ext4 typically zstd-compresses to ~1-2 GiB);
|
|
// override per-call by setting this var before invoking Fetch.
|
|
var MaxFetchedBundleBytes int64 = 8 << 30 // 8 GiB
|
|
|
|
// Manifest is the metadata file embedded inside a bundle. It mirrors
|
|
// the subset of CatEntry fields that describe the bundle's content
|
|
// (the remote URL + sha256 are catalog concerns, not bundle concerns).
|
|
type Manifest struct {
|
|
Name string `json:"name"`
|
|
Distro string `json:"distro,omitempty"`
|
|
Arch string `json:"arch,omitempty"`
|
|
KernelRef string `json:"kernel_ref,omitempty"`
|
|
Description string `json:"description,omitempty"`
|
|
}
|
|
|
|
// Fetch downloads entry's tarball, verifies its SHA256, and writes
|
|
// rootfs.ext4 + manifest.json into destDir. Returns the parsed
|
|
// manifest. On any error the partially-written files are removed so
|
|
// destDir is left in its pre-call state.
|
|
//
|
|
// destDir must already exist. Fetch does not create it, mirroring
|
|
// kernelcat.Fetch so callers manage their own staging.
|
|
func Fetch(ctx context.Context, client *http.Client, destDir string, entry CatEntry) (Manifest, error) {
|
|
if err := ValidateName(entry.Name); err != nil {
|
|
return Manifest{}, err
|
|
}
|
|
if strings.TrimSpace(entry.TarballURL) == "" {
|
|
return Manifest{}, fmt.Errorf("catalog entry %q has no tarball URL", entry.Name)
|
|
}
|
|
if strings.TrimSpace(entry.TarballSHA256) == "" {
|
|
return Manifest{}, fmt.Errorf("catalog entry %q has no tarball sha256", entry.Name)
|
|
}
|
|
if client == nil {
|
|
client = http.DefaultClient
|
|
}
|
|
|
|
absDest, err := filepath.Abs(destDir)
|
|
if err != nil {
|
|
return Manifest{}, err
|
|
}
|
|
info, err := os.Stat(absDest)
|
|
if err != nil {
|
|
return Manifest{}, err
|
|
}
|
|
if !info.IsDir() {
|
|
return Manifest{}, fmt.Errorf("destDir %q is not a directory", destDir)
|
|
}
|
|
|
|
cleanup := func() {
|
|
_ = os.Remove(filepath.Join(absDest, RootfsFilename))
|
|
_ = os.Remove(filepath.Join(absDest, ManifestFilename))
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, entry.TarballURL, nil)
|
|
if err != nil {
|
|
return Manifest{}, err
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return Manifest{}, fmt.Errorf("fetch %s: %w", entry.TarballURL, err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return Manifest{}, fmt.Errorf("fetch %s: HTTP %s", entry.TarballURL, resp.Status)
|
|
}
|
|
|
|
if resp.ContentLength > MaxFetchedBundleBytes {
|
|
return Manifest{}, fmt.Errorf("tarball advertised %d bytes, exceeds %d-byte cap", resp.ContentLength, MaxFetchedBundleBytes)
|
|
}
|
|
|
|
// Stage the compressed tarball on disk first so we can verify the
|
|
// SHA256 BEFORE decompressing or extracting. Cap the read at
|
|
// MaxFetchedBundleBytes+1 — anything larger is refused.
|
|
tmp, err := os.CreateTemp(absDest, "banger-bundle-*.tar.zst")
|
|
if err != nil {
|
|
return Manifest{}, fmt.Errorf("create staging file: %w", err)
|
|
}
|
|
tmpPath := tmp.Name()
|
|
defer os.Remove(tmpPath)
|
|
|
|
hasher := sha256.New()
|
|
limited := io.LimitReader(resp.Body, MaxFetchedBundleBytes+1)
|
|
n, copyErr := io.Copy(io.MultiWriter(tmp, hasher), limited)
|
|
if closeErr := tmp.Close(); copyErr == nil && closeErr != nil {
|
|
copyErr = closeErr
|
|
}
|
|
if copyErr != nil {
|
|
return Manifest{}, fmt.Errorf("download tarball: %w", copyErr)
|
|
}
|
|
if n > MaxFetchedBundleBytes {
|
|
return Manifest{}, fmt.Errorf("tarball exceeded %d-byte cap before sha256 check", MaxFetchedBundleBytes)
|
|
}
|
|
|
|
got := hex.EncodeToString(hasher.Sum(nil))
|
|
if !strings.EqualFold(got, entry.TarballSHA256) {
|
|
return Manifest{}, fmt.Errorf("tarball sha256 mismatch: got %s, want %s", got, entry.TarballSHA256)
|
|
}
|
|
|
|
src, err := os.Open(tmpPath)
|
|
if err != nil {
|
|
return Manifest{}, fmt.Errorf("reopen staged tarball: %w", err)
|
|
}
|
|
defer src.Close()
|
|
zr, err := zstd.NewReader(src)
|
|
if err != nil {
|
|
return Manifest{}, fmt.Errorf("init zstd: %w", err)
|
|
}
|
|
defer zr.Close()
|
|
|
|
if err := extractBundle(zr, absDest); err != nil {
|
|
cleanup()
|
|
return Manifest{}, err
|
|
}
|
|
|
|
if _, err := os.Stat(filepath.Join(absDest, RootfsFilename)); err != nil {
|
|
cleanup()
|
|
return Manifest{}, fmt.Errorf("bundle missing %s: %w", RootfsFilename, err)
|
|
}
|
|
manifestData, err := os.ReadFile(filepath.Join(absDest, ManifestFilename))
|
|
if err != nil {
|
|
cleanup()
|
|
return Manifest{}, fmt.Errorf("read manifest: %w", err)
|
|
}
|
|
var manifest Manifest
|
|
if err := json.Unmarshal(manifestData, &manifest); err != nil {
|
|
cleanup()
|
|
return Manifest{}, fmt.Errorf("parse manifest: %w", err)
|
|
}
|
|
if strings.TrimSpace(manifest.Name) == "" {
|
|
manifest.Name = entry.Name
|
|
}
|
|
return manifest, nil
|
|
}
|
|
|
|
// extractBundle writes the bundle's two regular-file entries into
|
|
// absDest, refusing any other member type, any extra entry, and any
|
|
// path that escapes absDest.
|
|
func extractBundle(r io.Reader, absDest string) error {
|
|
tr := tar.NewReader(r)
|
|
seen := map[string]bool{}
|
|
for {
|
|
hdr, err := tr.Next()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("read bundle: %w", err)
|
|
}
|
|
rel := filepath.Clean(hdr.Name)
|
|
if rel == "." || rel == string(filepath.Separator) {
|
|
continue
|
|
}
|
|
if filepath.IsAbs(rel) || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
|
|
return fmt.Errorf("unsafe path in bundle: %q", hdr.Name)
|
|
}
|
|
if rel != RootfsFilename && rel != ManifestFilename {
|
|
return fmt.Errorf("unexpected bundle entry %q (expected %s or %s at the root)", hdr.Name, RootfsFilename, ManifestFilename)
|
|
}
|
|
if hdr.Typeflag != tar.TypeReg {
|
|
return fmt.Errorf("bundle entry %q is not a regular file", hdr.Name)
|
|
}
|
|
dst := filepath.Join(absDest, rel)
|
|
f, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := io.Copy(f, tr); err != nil {
|
|
_ = f.Close()
|
|
return err
|
|
}
|
|
if err := f.Close(); err != nil {
|
|
return err
|
|
}
|
|
seen[rel] = true
|
|
}
|
|
if !seen[RootfsFilename] || !seen[ManifestFilename] {
|
|
return fmt.Errorf("bundle is missing required files: want both %s and %s", RootfsFilename, ManifestFilename)
|
|
}
|
|
return nil
|
|
}
|