Both Fetch flows previously streamed resp.Body straight into zstd → tar → on-disk extractor with the SHA256 check tacked on at the END. A bad mirror or an attacker that's compromised the catalog host could ship a multi-gigabyte tarball, watch banger expand it to disk, and only THEN see the helpful "sha256 mismatch" message — having already filled the host filesystem. Reorder the operations: stage the compressed tarball to a temp file under the destination directory through an io.LimitReader (cap +1 bytes), hash on the way in, refuse to decompress if either the cap trips or the SHA mismatches. Worst-case disk use is bounded by the cap, not by the source. Cap is exposed as a package var (MaxFetchedBundleBytes, MaxFetchedKernelBytes) so callers can tune per-deployment and tests can squeeze it down to provoke the rejection. Default 8 GiB — generous enough for a 4 GiB rootfs (which compresses to ~1-2 GiB), tight enough to make a "fill the host disk" attack expensive. The temp file lives in the destination dir so extraction stays on the same filesystem and we don't pay for cross-FS rename. defer os.Remove cleans up; the existing per-package cleanup() handler still removes any partial extraction on hash mismatch / extraction failure. Tests: each package gets a TestFetchRejectsOversizedTarballBefore Extraction that sets the cap to 64 bytes, points Fetch at a multi-KB tarball, and asserts (a) error mentions "cap", (b) destination dir is left clean (no leaked rootfs / manifest / kernel tree). All existing tests still pass — happy path, hash mismatch, missing files, path traversal, HTTP error, etc. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
227 lines
6.9 KiB
Go
227 lines
6.9 KiB
Go
package kernelcat
|
|
|
|
import (
|
|
"archive/tar"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/klauspost/compress/zstd"
|
|
)
|
|
|
|
// MaxFetchedKernelBytes caps the compressed kernel-tarball download.
|
|
// Without this the previous flow streamed straight into the tar+zstd
|
|
// extractor and only verified SHA256 afterwards, so a malicious or
|
|
// compromised mirror could fill the host disk before the hash check
|
|
// fired. Now we stage to a temp file under targetDir, hash on the
|
|
// way in, and refuse to decompress on hash mismatch — worst-case
|
|
// disk use is bounded by this cap. Override per-call by setting this
|
|
// var before invoking Fetch.
|
|
var MaxFetchedKernelBytes int64 = 8 << 30 // 8 GiB
|
|
|
|
// Fetch downloads the tarball for entry, verifies its SHA256, extracts it
|
|
// into <kernelsDir>/<entry.Name>/, and writes a manifest. On failure it
|
|
// removes the partially-populated target directory.
|
|
//
|
|
// The tarball is expected to be a tar+zstd archive whose root contains
|
|
// vmlinux and optionally initrd.img and/or a modules/ directory. Path
|
|
// traversal entries (..) and absolute-path members are rejected.
|
|
func Fetch(ctx context.Context, client *http.Client, kernelsDir string, entry CatEntry) (Entry, error) {
|
|
if err := ValidateName(entry.Name); err != nil {
|
|
return Entry{}, err
|
|
}
|
|
if strings.TrimSpace(entry.TarballURL) == "" {
|
|
return Entry{}, fmt.Errorf("catalog entry %q has no tarball URL", entry.Name)
|
|
}
|
|
if strings.TrimSpace(entry.TarballSHA256) == "" {
|
|
return Entry{}, fmt.Errorf("catalog entry %q has no tarball sha256", entry.Name)
|
|
}
|
|
if client == nil {
|
|
client = http.DefaultClient
|
|
}
|
|
|
|
if err := DeleteLocal(kernelsDir, entry.Name); err != nil {
|
|
return Entry{}, fmt.Errorf("clear prior catalog entry: %w", err)
|
|
}
|
|
targetDir := EntryDir(kernelsDir, entry.Name)
|
|
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
|
return Entry{}, err
|
|
}
|
|
|
|
cleanup := func() { _ = os.RemoveAll(targetDir) }
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, entry.TarballURL, nil)
|
|
if err != nil {
|
|
cleanup()
|
|
return Entry{}, err
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
cleanup()
|
|
return Entry{}, fmt.Errorf("fetch %s: %w", entry.TarballURL, err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
cleanup()
|
|
return Entry{}, fmt.Errorf("fetch %s: HTTP %s", entry.TarballURL, resp.Status)
|
|
}
|
|
|
|
if resp.ContentLength > MaxFetchedKernelBytes {
|
|
cleanup()
|
|
return Entry{}, fmt.Errorf("tarball advertised %d bytes, exceeds %d-byte cap", resp.ContentLength, MaxFetchedKernelBytes)
|
|
}
|
|
|
|
// Stage compressed download to a temp file first so we can verify
|
|
// SHA256 BEFORE decompressing or extracting. Cap reads to
|
|
// MaxFetchedKernelBytes+1 — anything larger is refused.
|
|
tmp, err := os.CreateTemp(targetDir, "banger-kernel-*.tar.zst")
|
|
if err != nil {
|
|
cleanup()
|
|
return Entry{}, fmt.Errorf("create staging file: %w", err)
|
|
}
|
|
tmpPath := tmp.Name()
|
|
defer os.Remove(tmpPath)
|
|
|
|
hasher := sha256.New()
|
|
limited := io.LimitReader(resp.Body, MaxFetchedKernelBytes+1)
|
|
n, copyErr := io.Copy(io.MultiWriter(tmp, hasher), limited)
|
|
if closeErr := tmp.Close(); copyErr == nil && closeErr != nil {
|
|
copyErr = closeErr
|
|
}
|
|
if copyErr != nil {
|
|
cleanup()
|
|
return Entry{}, fmt.Errorf("download tarball: %w", copyErr)
|
|
}
|
|
if n > MaxFetchedKernelBytes {
|
|
cleanup()
|
|
return Entry{}, fmt.Errorf("tarball exceeded %d-byte cap before sha256 check", MaxFetchedKernelBytes)
|
|
}
|
|
|
|
got := hex.EncodeToString(hasher.Sum(nil))
|
|
if !strings.EqualFold(got, entry.TarballSHA256) {
|
|
cleanup()
|
|
return Entry{}, fmt.Errorf("tarball sha256 mismatch: got %s, want %s", got, entry.TarballSHA256)
|
|
}
|
|
|
|
src, err := os.Open(tmpPath)
|
|
if err != nil {
|
|
cleanup()
|
|
return Entry{}, fmt.Errorf("reopen staged tarball: %w", err)
|
|
}
|
|
defer src.Close()
|
|
zr, err := zstd.NewReader(src)
|
|
if err != nil {
|
|
cleanup()
|
|
return Entry{}, fmt.Errorf("init zstd: %w", err)
|
|
}
|
|
defer zr.Close()
|
|
|
|
if err := extractTar(zr, targetDir); err != nil {
|
|
cleanup()
|
|
return Entry{}, err
|
|
}
|
|
|
|
kernelPath := filepath.Join(targetDir, kernelFilename)
|
|
if _, err := os.Stat(kernelPath); err != nil {
|
|
cleanup()
|
|
return Entry{}, fmt.Errorf("tarball missing %s: %w", kernelFilename, err)
|
|
}
|
|
kernelSum, err := SumFile(kernelPath)
|
|
if err != nil {
|
|
cleanup()
|
|
return Entry{}, err
|
|
}
|
|
|
|
stored := Entry{
|
|
Name: entry.Name,
|
|
Distro: entry.Distro,
|
|
Arch: entry.Arch,
|
|
KernelVersion: entry.KernelVersion,
|
|
SHA256: kernelSum,
|
|
Source: "pull:" + entry.TarballURL,
|
|
ImportedAt: time.Now().UTC(),
|
|
}
|
|
if err := WriteLocal(kernelsDir, stored); err != nil {
|
|
cleanup()
|
|
return Entry{}, err
|
|
}
|
|
return ReadLocal(kernelsDir, entry.Name)
|
|
}
|
|
|
|
// extractTar writes each regular file / dir / safe symlink from r into
|
|
// target, refusing any member whose normalised path would escape target.
|
|
func extractTar(r io.Reader, target string) error {
|
|
absTarget, err := filepath.Abs(target)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tr := tar.NewReader(r)
|
|
for {
|
|
hdr, err := tr.Next()
|
|
if err == io.EOF {
|
|
return nil
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("read tarball: %w", err)
|
|
}
|
|
rel := filepath.Clean(hdr.Name)
|
|
if rel == "." || rel == string(filepath.Separator) {
|
|
continue
|
|
}
|
|
if filepath.IsAbs(rel) || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || rel == ".." {
|
|
return fmt.Errorf("unsafe path in tarball: %q", hdr.Name)
|
|
}
|
|
dst := filepath.Join(absTarget, rel)
|
|
if dst != absTarget && !strings.HasPrefix(dst, absTarget+string(filepath.Separator)) {
|
|
return fmt.Errorf("unsafe path in tarball: %q", hdr.Name)
|
|
}
|
|
switch hdr.Typeflag {
|
|
case tar.TypeDir:
|
|
if err := os.MkdirAll(dst, os.FileMode(hdr.Mode)|0o755); err != nil {
|
|
return err
|
|
}
|
|
case tar.TypeReg:
|
|
if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil {
|
|
return err
|
|
}
|
|
f, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)|0o600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := io.Copy(f, tr); err != nil {
|
|
_ = f.Close()
|
|
return err
|
|
}
|
|
if err := f.Close(); err != nil {
|
|
return err
|
|
}
|
|
case tar.TypeSymlink:
|
|
if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil {
|
|
return err
|
|
}
|
|
// Absolute targets are interpreted at runtime against the
|
|
// eventual rootfs (`/` inside the VM), so they're rooted
|
|
// inside absTarget by construction. Only relative targets
|
|
// need an escape check at write time.
|
|
if !filepath.IsAbs(hdr.Linkname) {
|
|
resolved := filepath.Clean(filepath.Join(filepath.Dir(dst), hdr.Linkname))
|
|
if resolved != absTarget && !strings.HasPrefix(resolved, absTarget+string(filepath.Separator)) {
|
|
return fmt.Errorf("unsafe symlink in tarball: %q -> %q", hdr.Name, hdr.Linkname)
|
|
}
|
|
}
|
|
if err := os.Symlink(hdr.Linkname, dst); err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
// Hardlinks / device nodes / fifos: skip silently. Kernel
|
|
// module trees shouldn't need them.
|
|
}
|
|
}
|
|
}
|