Both Fetch flows previously streamed resp.Body straight into zstd → tar → on-disk extractor with the SHA256 check tacked on at the END. A bad mirror or an attacker that's compromised the catalog host could ship a multi-gigabyte tarball, watch banger expand it to disk, and only THEN see the helpful "sha256 mismatch" message — having already filled the host filesystem. Reorder the operations: stage the compressed tarball to a temp file under the destination directory through an io.LimitReader (cap +1 bytes), hash on the way in, refuse to decompress if either the cap trips or the SHA mismatches. Worst-case disk use is bounded by the cap, not by the source. Cap is exposed as a package var (MaxFetchedBundleBytes, MaxFetchedKernelBytes) so callers can tune per-deployment and tests can squeeze it down to provoke the rejection. Default 8 GiB — generous enough for a 4 GiB rootfs (which compresses to ~1-2 GiB), tight enough to make a "fill the host disk" attack expensive. The temp file lives in the destination dir so extraction stays on the same filesystem and we don't pay for cross-FS rename. defer os.Remove cleans up; the existing per-package cleanup() handler still removes any partial extraction on hash mismatch / extraction failure. Tests: each package gets a TestFetchRejectsOversizedTarballBefore Extraction that sets the cap to 64 bytes, points Fetch at a multi-KB tarball, and asserts (a) error mentions "cap", (b) destination dir is left clean (no leaked rootfs / manifest / kernel tree). All existing tests still pass — happy path, hash mismatch, missing files, path traversal, HTTP error, etc. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
231 lines
6.6 KiB
Go
231 lines
6.6 KiB
Go
package kernelcat
|
|
|
|
import (
|
|
"archive/tar"
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/klauspost/compress/zstd"
|
|
)
|
|
|
|
// tarballFile describes one member of the test tarball.
|
|
type tarballFile struct {
|
|
name string
|
|
mode int64
|
|
data []byte
|
|
link string // for symlinks
|
|
dir bool
|
|
}
|
|
|
|
func buildTestTarball(t *testing.T, files []tarballFile) ([]byte, string) {
|
|
t.Helper()
|
|
var tarBuf bytes.Buffer
|
|
tw := tar.NewWriter(&tarBuf)
|
|
for _, f := range files {
|
|
hdr := &tar.Header{Name: f.name, Mode: f.mode}
|
|
switch {
|
|
case f.dir:
|
|
hdr.Typeflag = tar.TypeDir
|
|
hdr.Mode = 0o755
|
|
case f.link != "":
|
|
hdr.Typeflag = tar.TypeSymlink
|
|
hdr.Linkname = f.link
|
|
default:
|
|
hdr.Typeflag = tar.TypeReg
|
|
hdr.Size = int64(len(f.data))
|
|
if hdr.Mode == 0 {
|
|
hdr.Mode = 0o644
|
|
}
|
|
}
|
|
if err := tw.WriteHeader(hdr); err != nil {
|
|
t.Fatalf("tar WriteHeader: %v", err)
|
|
}
|
|
if hdr.Typeflag == tar.TypeReg {
|
|
if _, err := tw.Write(f.data); err != nil {
|
|
t.Fatalf("tar Write: %v", err)
|
|
}
|
|
}
|
|
}
|
|
if err := tw.Close(); err != nil {
|
|
t.Fatalf("tar Close: %v", err)
|
|
}
|
|
|
|
var compressed bytes.Buffer
|
|
zw, err := zstd.NewWriter(&compressed)
|
|
if err != nil {
|
|
t.Fatalf("zstd NewWriter: %v", err)
|
|
}
|
|
if _, err := zw.Write(tarBuf.Bytes()); err != nil {
|
|
t.Fatalf("zstd Write: %v", err)
|
|
}
|
|
if err := zw.Close(); err != nil {
|
|
t.Fatalf("zstd Close: %v", err)
|
|
}
|
|
sum := sha256.Sum256(compressed.Bytes())
|
|
return compressed.Bytes(), hex.EncodeToString(sum[:])
|
|
}
|
|
|
|
func serveTarball(t *testing.T, body []byte) *httptest.Server {
|
|
t.Helper()
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
_, _ = w.Write(body)
|
|
}))
|
|
t.Cleanup(srv.Close)
|
|
return srv
|
|
}
|
|
|
|
func TestFetchExtractsTarballAndWritesManifest(t *testing.T) {
|
|
t.Parallel()
|
|
body, sum := buildTestTarball(t, []tarballFile{
|
|
{name: "vmlinux", data: []byte("kernel-bytes")},
|
|
{name: "initrd.img", data: []byte("initrd-bytes")},
|
|
{name: "modules", dir: true},
|
|
{name: "modules/modules.dep", data: []byte("dep")},
|
|
})
|
|
srv := serveTarball(t, body)
|
|
|
|
kernelsDir := t.TempDir()
|
|
stored, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{
|
|
Name: "void-6.12",
|
|
Distro: "void",
|
|
Arch: "x86_64",
|
|
KernelVersion: "6.12.79_1",
|
|
TarballURL: srv.URL + "/pkg.tar.zst",
|
|
TarballSHA256: sum,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Fetch: %v", err)
|
|
}
|
|
if stored.Name != "void-6.12" || stored.Distro != "void" {
|
|
t.Fatalf("stored = %+v", stored)
|
|
}
|
|
if stored.SHA256 == "" {
|
|
t.Errorf("SHA256 not populated")
|
|
}
|
|
|
|
for _, rel := range []string{"vmlinux", "initrd.img", "modules/modules.dep", "manifest.json"} {
|
|
if _, err := os.Stat(filepath.Join(kernelsDir, "void-6.12", rel)); err != nil {
|
|
t.Errorf("expected %s in catalog: %v", rel, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestFetchRejectsShaMismatch(t *testing.T) {
|
|
t.Parallel()
|
|
body, _ := buildTestTarball(t, []tarballFile{
|
|
{name: "vmlinux", data: []byte("k")},
|
|
})
|
|
srv := serveTarball(t, body)
|
|
|
|
kernelsDir := t.TempDir()
|
|
_, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{
|
|
Name: "void-6.12",
|
|
TarballURL: srv.URL + "/pkg.tar.zst",
|
|
TarballSHA256: "000000000000000000000000000000000000000000000000000000000000beef",
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") {
|
|
t.Fatalf("expected sha256 mismatch, got %v", err)
|
|
}
|
|
if _, statErr := os.Stat(filepath.Join(kernelsDir, "void-6.12")); !os.IsNotExist(statErr) {
|
|
t.Fatalf("target dir should be cleaned up on mismatch: %v", statErr)
|
|
}
|
|
}
|
|
|
|
// TestFetchRejectsOversizedTarballBeforeExtraction pins the new
|
|
// disk-bound cap: with MaxFetchedKernelBytes set artificially low the
|
|
// staged download trips the limit and refuses to decompress, so a
|
|
// compromised mirror can't fill the host disk before the SHA check
|
|
// fires.
|
|
func TestFetchRejectsOversizedTarballBeforeExtraction(t *testing.T) {
|
|
body, sum := buildTestTarball(t, []tarballFile{
|
|
{name: "vmlinux", data: bytes.Repeat([]byte("k"), 4096)},
|
|
})
|
|
srv := serveTarball(t, body)
|
|
|
|
prev := MaxFetchedKernelBytes
|
|
MaxFetchedKernelBytes = 64
|
|
t.Cleanup(func() { MaxFetchedKernelBytes = prev })
|
|
|
|
kernelsDir := t.TempDir()
|
|
_, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{
|
|
Name: "void-6.12",
|
|
TarballURL: srv.URL + "/pkg.tar.zst",
|
|
TarballSHA256: sum,
|
|
})
|
|
if err == nil {
|
|
t.Fatal("Fetch succeeded against oversized tarball; want size-cap rejection")
|
|
}
|
|
if !strings.Contains(err.Error(), "cap") {
|
|
t.Fatalf("err = %v, want size-cap message", err)
|
|
}
|
|
// targetDir should be cleaned up by the existing cleanup() path.
|
|
if _, statErr := os.Stat(filepath.Join(kernelsDir, "void-6.12")); !os.IsNotExist(statErr) {
|
|
t.Fatalf("target dir should be removed on size-cap rejection: %v", statErr)
|
|
}
|
|
}
|
|
|
|
func TestFetchRejectsMissingKernel(t *testing.T) {
|
|
t.Parallel()
|
|
body, sum := buildTestTarball(t, []tarballFile{
|
|
{name: "initrd.img", data: []byte("i")}, // no vmlinux
|
|
})
|
|
srv := serveTarball(t, body)
|
|
kernelsDir := t.TempDir()
|
|
_, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{
|
|
Name: "broken",
|
|
TarballURL: srv.URL + "/pkg.tar.zst",
|
|
TarballSHA256: sum,
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "missing vmlinux") {
|
|
t.Fatalf("expected missing vmlinux, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestFetchRejectsPathTraversal(t *testing.T) {
|
|
t.Parallel()
|
|
body, sum := buildTestTarball(t, []tarballFile{
|
|
{name: "vmlinux", data: []byte("k")},
|
|
{name: "../escape", data: []byte("bad")},
|
|
})
|
|
srv := serveTarball(t, body)
|
|
kernelsDir := t.TempDir()
|
|
_, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{
|
|
Name: "bad-tarball",
|
|
TarballURL: srv.URL + "/pkg.tar.zst",
|
|
TarballSHA256: sum,
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "unsafe path") {
|
|
t.Fatalf("expected unsafe path error, got %v", err)
|
|
}
|
|
escapePath := filepath.Join(filepath.Dir(kernelsDir), "escape")
|
|
if _, statErr := os.Stat(escapePath); !os.IsNotExist(statErr) {
|
|
t.Fatalf("traversal escape file should not exist: %v", statErr)
|
|
}
|
|
}
|
|
|
|
func TestFetchRejectsHTTPError(t *testing.T) {
|
|
t.Parallel()
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
http.Error(w, "nope", http.StatusNotFound)
|
|
}))
|
|
t.Cleanup(srv.Close)
|
|
|
|
kernelsDir := t.TempDir()
|
|
_, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{
|
|
Name: "missing",
|
|
TarballURL: srv.URL + "/pkg.tar.zst",
|
|
TarballSHA256: "deadbeef",
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "404") {
|
|
t.Fatalf("expected HTTP 404, got %v", err)
|
|
}
|
|
}
|