banger/internal/kernelcat/fetch.go
Thales Maciel 4004ce2e7e
imagecat,kernelcat: bound staged download, hash before extract
Both Fetch flows previously streamed resp.Body straight into
zstd → tar → on-disk extractor with the SHA256 check tacked on at
the END. A bad mirror or an attacker that's compromised the catalog
host could ship a multi-gigabyte tarball, watch banger expand it to
disk, and only THEN see the helpful "sha256 mismatch" message —
having already filled the host filesystem.

Reorder the operations: stage the compressed tarball to a temp file
under the destination directory through an io.LimitReader (cap +1
bytes), hash on the way in, refuse to decompress if either the cap
trips or the SHA mismatches. Worst-case disk use is bounded by the
cap, not by the source.

Cap is exposed as a package var (MaxFetchedBundleBytes,
MaxFetchedKernelBytes) so callers can tune per-deployment and tests
can squeeze it down to provoke the rejection. Default 8 GiB —
generous enough for a 4 GiB rootfs (which compresses to ~1-2 GiB),
tight enough to make a "fill the host disk" attack expensive.

The temp file lives in the destination dir so extraction stays on
the same filesystem and we don't pay for cross-FS rename. defer
os.Remove cleans up; the existing per-package cleanup() handler
still removes any partial extraction on hash mismatch / extraction
failure.

Tests: each package gets a TestFetchRejectsOversizedTarballBefore
Extraction that sets the cap to 64 bytes, points Fetch at a multi-KB
tarball, and asserts (a) error mentions "cap", (b) destination dir
is left clean (no leaked rootfs / manifest / kernel tree). All
existing tests still pass — happy path, hash mismatch, missing
files, path traversal, HTTP error, etc.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 16:09:55 -03:00

227 lines
6.9 KiB
Go

package kernelcat
import (
"archive/tar"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/klauspost/compress/zstd"
)
// MaxFetchedKernelBytes caps the compressed kernel-tarball download.
// Without this the previous flow streamed straight into the tar+zstd
// extractor and only verified SHA256 afterwards, so a malicious or
// compromised mirror could fill the host disk before the hash check
// fired. Now we stage to a temp file under targetDir, hash on the
// way in, and refuse to decompress on hash mismatch — worst-case
// disk use is bounded by this cap. Override per-call by setting this
// var before invoking Fetch.
var MaxFetchedKernelBytes int64 = 8 << 30 // 8 GiB
// Fetch downloads the tarball for entry, verifies its SHA256, extracts it
// into <kernelsDir>/<entry.Name>/, and writes a manifest. On failure it
// removes the partially-populated target directory.
//
// The tarball is expected to be a tar+zstd archive whose root contains
// vmlinux and optionally initrd.img and/or a modules/ directory. Path
// traversal entries (..) and absolute-path members are rejected.
func Fetch(ctx context.Context, client *http.Client, kernelsDir string, entry CatEntry) (Entry, error) {
if err := ValidateName(entry.Name); err != nil {
return Entry{}, err
}
if strings.TrimSpace(entry.TarballURL) == "" {
return Entry{}, fmt.Errorf("catalog entry %q has no tarball URL", entry.Name)
}
if strings.TrimSpace(entry.TarballSHA256) == "" {
return Entry{}, fmt.Errorf("catalog entry %q has no tarball sha256", entry.Name)
}
if client == nil {
client = http.DefaultClient
}
if err := DeleteLocal(kernelsDir, entry.Name); err != nil {
return Entry{}, fmt.Errorf("clear prior catalog entry: %w", err)
}
targetDir := EntryDir(kernelsDir, entry.Name)
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return Entry{}, err
}
cleanup := func() { _ = os.RemoveAll(targetDir) }
req, err := http.NewRequestWithContext(ctx, http.MethodGet, entry.TarballURL, nil)
if err != nil {
cleanup()
return Entry{}, err
}
resp, err := client.Do(req)
if err != nil {
cleanup()
return Entry{}, fmt.Errorf("fetch %s: %w", entry.TarballURL, err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
cleanup()
return Entry{}, fmt.Errorf("fetch %s: HTTP %s", entry.TarballURL, resp.Status)
}
if resp.ContentLength > MaxFetchedKernelBytes {
cleanup()
return Entry{}, fmt.Errorf("tarball advertised %d bytes, exceeds %d-byte cap", resp.ContentLength, MaxFetchedKernelBytes)
}
// Stage compressed download to a temp file first so we can verify
// SHA256 BEFORE decompressing or extracting. Cap reads to
// MaxFetchedKernelBytes+1 — anything larger is refused.
tmp, err := os.CreateTemp(targetDir, "banger-kernel-*.tar.zst")
if err != nil {
cleanup()
return Entry{}, fmt.Errorf("create staging file: %w", err)
}
tmpPath := tmp.Name()
defer os.Remove(tmpPath)
hasher := sha256.New()
limited := io.LimitReader(resp.Body, MaxFetchedKernelBytes+1)
n, copyErr := io.Copy(io.MultiWriter(tmp, hasher), limited)
if closeErr := tmp.Close(); copyErr == nil && closeErr != nil {
copyErr = closeErr
}
if copyErr != nil {
cleanup()
return Entry{}, fmt.Errorf("download tarball: %w", copyErr)
}
if n > MaxFetchedKernelBytes {
cleanup()
return Entry{}, fmt.Errorf("tarball exceeded %d-byte cap before sha256 check", MaxFetchedKernelBytes)
}
got := hex.EncodeToString(hasher.Sum(nil))
if !strings.EqualFold(got, entry.TarballSHA256) {
cleanup()
return Entry{}, fmt.Errorf("tarball sha256 mismatch: got %s, want %s", got, entry.TarballSHA256)
}
src, err := os.Open(tmpPath)
if err != nil {
cleanup()
return Entry{}, fmt.Errorf("reopen staged tarball: %w", err)
}
defer src.Close()
zr, err := zstd.NewReader(src)
if err != nil {
cleanup()
return Entry{}, fmt.Errorf("init zstd: %w", err)
}
defer zr.Close()
if err := extractTar(zr, targetDir); err != nil {
cleanup()
return Entry{}, err
}
kernelPath := filepath.Join(targetDir, kernelFilename)
if _, err := os.Stat(kernelPath); err != nil {
cleanup()
return Entry{}, fmt.Errorf("tarball missing %s: %w", kernelFilename, err)
}
kernelSum, err := SumFile(kernelPath)
if err != nil {
cleanup()
return Entry{}, err
}
stored := Entry{
Name: entry.Name,
Distro: entry.Distro,
Arch: entry.Arch,
KernelVersion: entry.KernelVersion,
SHA256: kernelSum,
Source: "pull:" + entry.TarballURL,
ImportedAt: time.Now().UTC(),
}
if err := WriteLocal(kernelsDir, stored); err != nil {
cleanup()
return Entry{}, err
}
return ReadLocal(kernelsDir, entry.Name)
}
// extractTar writes each regular file / dir / safe symlink from r into
// target, refusing any member whose normalised path would escape target.
func extractTar(r io.Reader, target string) error {
absTarget, err := filepath.Abs(target)
if err != nil {
return err
}
tr := tar.NewReader(r)
for {
hdr, err := tr.Next()
if err == io.EOF {
return nil
}
if err != nil {
return fmt.Errorf("read tarball: %w", err)
}
rel := filepath.Clean(hdr.Name)
if rel == "." || rel == string(filepath.Separator) {
continue
}
if filepath.IsAbs(rel) || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || rel == ".." {
return fmt.Errorf("unsafe path in tarball: %q", hdr.Name)
}
dst := filepath.Join(absTarget, rel)
if dst != absTarget && !strings.HasPrefix(dst, absTarget+string(filepath.Separator)) {
return fmt.Errorf("unsafe path in tarball: %q", hdr.Name)
}
switch hdr.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(dst, os.FileMode(hdr.Mode)|0o755); err != nil {
return err
}
case tar.TypeReg:
if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil {
return err
}
f, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)|0o600)
if err != nil {
return err
}
if _, err := io.Copy(f, tr); err != nil {
_ = f.Close()
return err
}
if err := f.Close(); err != nil {
return err
}
case tar.TypeSymlink:
if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil {
return err
}
// Absolute targets are interpreted at runtime against the
// eventual rootfs (`/` inside the VM), so they're rooted
// inside absTarget by construction. Only relative targets
// need an escape check at write time.
if !filepath.IsAbs(hdr.Linkname) {
resolved := filepath.Clean(filepath.Join(filepath.Dir(dst), hdr.Linkname))
if resolved != absTarget && !strings.HasPrefix(resolved, absTarget+string(filepath.Separator)) {
return fmt.Errorf("unsafe symlink in tarball: %q -> %q", hdr.Name, hdr.Linkname)
}
}
if err := os.Symlink(hdr.Linkname, dst); err != nil {
return err
}
default:
// Hardlinks / device nodes / fifos: skip silently. Kernel
// module trees shouldn't need them.
}
}
}