package kernelcat import ( "archive/tar" "bytes" "context" "crypto/sha256" "encoding/hex" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "testing" "github.com/klauspost/compress/zstd" ) // tarballFile describes one member of the test tarball. type tarballFile struct { name string mode int64 data []byte link string // for symlinks dir bool } func buildTestTarball(t *testing.T, files []tarballFile) ([]byte, string) { t.Helper() var tarBuf bytes.Buffer tw := tar.NewWriter(&tarBuf) for _, f := range files { hdr := &tar.Header{Name: f.name, Mode: f.mode} switch { case f.dir: hdr.Typeflag = tar.TypeDir hdr.Mode = 0o755 case f.link != "": hdr.Typeflag = tar.TypeSymlink hdr.Linkname = f.link default: hdr.Typeflag = tar.TypeReg hdr.Size = int64(len(f.data)) if hdr.Mode == 0 { hdr.Mode = 0o644 } } if err := tw.WriteHeader(hdr); err != nil { t.Fatalf("tar WriteHeader: %v", err) } if hdr.Typeflag == tar.TypeReg { if _, err := tw.Write(f.data); err != nil { t.Fatalf("tar Write: %v", err) } } } if err := tw.Close(); err != nil { t.Fatalf("tar Close: %v", err) } var compressed bytes.Buffer zw, err := zstd.NewWriter(&compressed) if err != nil { t.Fatalf("zstd NewWriter: %v", err) } if _, err := zw.Write(tarBuf.Bytes()); err != nil { t.Fatalf("zstd Write: %v", err) } if err := zw.Close(); err != nil { t.Fatalf("zstd Close: %v", err) } sum := sha256.Sum256(compressed.Bytes()) return compressed.Bytes(), hex.EncodeToString(sum[:]) } func serveTarball(t *testing.T, body []byte) *httptest.Server { t.Helper() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/octet-stream") _, _ = w.Write(body) })) t.Cleanup(srv.Close) return srv } func TestFetchExtractsTarballAndWritesManifest(t *testing.T) { t.Parallel() body, sum := buildTestTarball(t, []tarballFile{ {name: "vmlinux", data: []byte("kernel-bytes")}, {name: "initrd.img", data: []byte("initrd-bytes")}, {name: "modules", dir: true}, {name: "modules/modules.dep", data: []byte("dep")}, }) srv := serveTarball(t, body) kernelsDir := t.TempDir() stored, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{ Name: "void-6.12", Distro: "void", Arch: "x86_64", KernelVersion: "6.12.79_1", TarballURL: srv.URL + "/pkg.tar.zst", TarballSHA256: sum, }) if err != nil { t.Fatalf("Fetch: %v", err) } if stored.Name != "void-6.12" || stored.Distro != "void" { t.Fatalf("stored = %+v", stored) } if stored.SHA256 == "" { t.Errorf("SHA256 not populated") } for _, rel := range []string{"vmlinux", "initrd.img", "modules/modules.dep", "manifest.json"} { if _, err := os.Stat(filepath.Join(kernelsDir, "void-6.12", rel)); err != nil { t.Errorf("expected %s in catalog: %v", rel, err) } } } func TestFetchRejectsShaMismatch(t *testing.T) { t.Parallel() body, _ := buildTestTarball(t, []tarballFile{ {name: "vmlinux", data: []byte("k")}, }) srv := serveTarball(t, body) kernelsDir := t.TempDir() _, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{ Name: "void-6.12", TarballURL: srv.URL + "/pkg.tar.zst", TarballSHA256: "000000000000000000000000000000000000000000000000000000000000beef", }) if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") { t.Fatalf("expected sha256 mismatch, got %v", err) } if _, statErr := os.Stat(filepath.Join(kernelsDir, "void-6.12")); !os.IsNotExist(statErr) { t.Fatalf("target dir should be cleaned up on mismatch: %v", statErr) } } // TestFetchRejectsOversizedTarballBeforeExtraction pins the new // disk-bound cap: with MaxFetchedKernelBytes set artificially low the // staged download trips the limit and refuses to decompress, so a // compromised mirror can't fill the host disk before the SHA check // fires. func TestFetchRejectsOversizedTarballBeforeExtraction(t *testing.T) { body, sum := buildTestTarball(t, []tarballFile{ {name: "vmlinux", data: bytes.Repeat([]byte("k"), 4096)}, }) srv := serveTarball(t, body) prev := MaxFetchedKernelBytes MaxFetchedKernelBytes = 64 t.Cleanup(func() { MaxFetchedKernelBytes = prev }) kernelsDir := t.TempDir() _, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{ Name: "void-6.12", TarballURL: srv.URL + "/pkg.tar.zst", TarballSHA256: sum, }) if err == nil { t.Fatal("Fetch succeeded against oversized tarball; want size-cap rejection") } if !strings.Contains(err.Error(), "cap") { t.Fatalf("err = %v, want size-cap message", err) } // targetDir should be cleaned up by the existing cleanup() path. if _, statErr := os.Stat(filepath.Join(kernelsDir, "void-6.12")); !os.IsNotExist(statErr) { t.Fatalf("target dir should be removed on size-cap rejection: %v", statErr) } } func TestFetchRejectsMissingKernel(t *testing.T) { t.Parallel() body, sum := buildTestTarball(t, []tarballFile{ {name: "initrd.img", data: []byte("i")}, // no vmlinux }) srv := serveTarball(t, body) kernelsDir := t.TempDir() _, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{ Name: "broken", TarballURL: srv.URL + "/pkg.tar.zst", TarballSHA256: sum, }) if err == nil || !strings.Contains(err.Error(), "missing vmlinux") { t.Fatalf("expected missing vmlinux, got %v", err) } } func TestFetchRejectsPathTraversal(t *testing.T) { t.Parallel() body, sum := buildTestTarball(t, []tarballFile{ {name: "vmlinux", data: []byte("k")}, {name: "../escape", data: []byte("bad")}, }) srv := serveTarball(t, body) kernelsDir := t.TempDir() _, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{ Name: "bad-tarball", TarballURL: srv.URL + "/pkg.tar.zst", TarballSHA256: sum, }) if err == nil || !strings.Contains(err.Error(), "unsafe path") { t.Fatalf("expected unsafe path error, got %v", err) } escapePath := filepath.Join(filepath.Dir(kernelsDir), "escape") if _, statErr := os.Stat(escapePath); !os.IsNotExist(statErr) { t.Fatalf("traversal escape file should not exist: %v", statErr) } } func TestFetchRejectsHTTPError(t *testing.T) { t.Parallel() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "nope", http.StatusNotFound) })) t.Cleanup(srv.Close) kernelsDir := t.TempDir() _, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{ Name: "missing", TarballURL: srv.URL + "/pkg.tar.zst", TarballSHA256: "deadbeef", }) if err == nil || !strings.Contains(err.Error(), "404") { t.Fatalf("expected HTTP 404, got %v", err) } }