Phase 4: remote catalog + banger kernel pull
Introduces the headline feature of the kernel catalog: pulling a kernel
bundle over HTTP without any local build step.
Catalog format (internal/kernelcat/catalog.go):
- Catalog { Version, Entries } + CatEntry { Name, Distro, Arch,
KernelVersion, TarballURL, TarballSHA256, SizeBytes, Description }.
- catalog.json is embedded via go:embed and ships with each banger
binary. It starts empty (Phase 5's CI pipeline will populate it).
- Lookup(name) returns the matching entry or os.ErrNotExist.
Fetch (internal/kernelcat/fetch.go):
- HTTP GET with streaming SHA256 over the response body.
- zstd-decode (github.com/klauspost/compress/zstd) -> tar extract into
<kernelsDir>/<name>/.
- Hardens against path-traversal tarball entries (members whose
normalised path escapes the target dir, and unsafe symlink
targets) and sha256-mismatch downloads; any failure removes the
partially-populated target dir.
- Regular files, directories, and safe symlinks are supported; other
tar types (hardlinks, devices, fifos) are silently skipped.
- After extraction, recomputes sha256 over the on-disk vmlinux and
writes the manifest with Source="pull:<url>".
Daemon methods (internal/daemon/kernels.go):
- KernelPull(ctx, {Name, Force}) - lookup in embedded catalog, refuse
overwrite unless Force, delegate to kernelcat.Fetch.
- KernelCatalog(ctx) - return the embedded catalog annotated per-entry
with whether it has been pulled locally.
RPC: kernel.pull, kernel.catalog dispatch cases.
CLI:
- `banger kernel pull <name> [--force]`.
- `banger kernel list --available` prints the catalog with a
pulled/available STATE column and a human-readable size.
Tests: fetch round-trip (extract + manifest + sha256), sha256 mismatch
rejection with cleanup, missing-vmlinux rejection, path-traversal
rejection, HTTP error propagation, catalog parsing, lookup,
pulled-status reconciliation. All 20 packages green.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
7192ba24ae
commit
f0668ee598
13 changed files with 711 additions and 4 deletions
|
|
@ -307,6 +307,25 @@ type KernelImportParams struct {
|
|||
Arch string `json:"arch,omitempty"`
|
||||
}
|
||||
|
||||
type KernelPullParams struct {
|
||||
Name string `json:"name"`
|
||||
Force bool `json:"force,omitempty"`
|
||||
}
|
||||
|
||||
type KernelCatalogEntry struct {
|
||||
Name string `json:"name"`
|
||||
Distro string `json:"distro,omitempty"`
|
||||
Arch string `json:"arch,omitempty"`
|
||||
KernelVersion string `json:"kernel_version,omitempty"`
|
||||
SizeBytes int64 `json:"size_bytes,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Pulled bool `json:"pulled"`
|
||||
}
|
||||
|
||||
type KernelCatalogResult struct {
|
||||
Entries []KernelCatalogEntry `json:"entries"`
|
||||
}
|
||||
|
||||
type SudoStatus struct {
|
||||
Available bool `json:"available"`
|
||||
Command string `json:"command,omitempty"`
|
||||
|
|
|
|||
|
|
@ -1600,10 +1600,33 @@ func newKernelCommand() *cobra.Command {
|
|||
newKernelShowCommand(),
|
||||
newKernelRmCommand(),
|
||||
newKernelImportCommand(),
|
||||
newKernelPullCommand(),
|
||||
)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newKernelPullCommand() *cobra.Command {
|
||||
var force bool
|
||||
cmd := &cobra.Command{
|
||||
Use: "pull <name>",
|
||||
Short: "Download a cataloged kernel bundle",
|
||||
Args: exactArgsUsage(1, "usage: banger kernel pull <name> [--force]"),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
layout, _, err := ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := rpc.Call[api.KernelShowResult](cmd.Context(), layout.SocketPath, "kernel.pull", api.KernelPullParams{Name: args[0], Force: force})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return printJSON(cmd.OutOrStdout(), result.Entry)
|
||||
},
|
||||
}
|
||||
cmd.Flags().BoolVar(&force, "force", false, "re-pull even if already present")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newKernelImportCommand() *cobra.Command {
|
||||
var params api.KernelImportParams
|
||||
cmd := &cobra.Command{
|
||||
|
|
@ -1639,15 +1662,23 @@ func newKernelImportCommand() *cobra.Command {
|
|||
}
|
||||
|
||||
func newKernelListCommand() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
var available bool
|
||||
cmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List locally available kernels",
|
||||
Args: noArgsUsage("usage: banger kernel list"),
|
||||
Short: "List kernels (local by default, or --available for the catalog)",
|
||||
Args: noArgsUsage("usage: banger kernel list [--available]"),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
layout, _, err := ensureDaemon(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if available {
|
||||
result, err := rpc.Call[api.KernelCatalogResult](cmd.Context(), layout.SocketPath, "kernel.catalog", api.Empty{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return printKernelCatalogTable(cmd.OutOrStdout(), result.Entries)
|
||||
}
|
||||
result, err := rpc.Call[api.KernelListResult](cmd.Context(), layout.SocketPath, "kernel.list", api.Empty{})
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -1655,6 +1686,8 @@ func newKernelListCommand() *cobra.Command {
|
|||
return printKernelListTable(cmd.OutOrStdout(), result.Entries)
|
||||
},
|
||||
}
|
||||
cmd.Flags().BoolVar(&available, "available", false, "show the built-in catalog (with pulled/available status) instead of local entries")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newKernelShowCommand() *cobra.Command {
|
||||
|
|
@ -1717,6 +1750,53 @@ func printKernelListTable(out anyWriter, entries []api.KernelEntry) error {
|
|||
return w.Flush()
|
||||
}
|
||||
|
||||
func printKernelCatalogTable(out anyWriter, entries []api.KernelCatalogEntry) error {
|
||||
w := tabwriter.NewWriter(out, 0, 8, 2, ' ', 0)
|
||||
if _, err := fmt.Fprintln(w, "NAME\tDISTRO\tARCH\tKERNEL\tSIZE\tSTATE"); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
state := "available"
|
||||
if entry.Pulled {
|
||||
state = "pulled"
|
||||
}
|
||||
if _, err := fmt.Fprintf(
|
||||
w,
|
||||
"%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
entry.Name,
|
||||
dashIfEmpty(entry.Distro),
|
||||
dashIfEmpty(entry.Arch),
|
||||
dashIfEmpty(entry.KernelVersion),
|
||||
humanSize(entry.SizeBytes),
|
||||
state,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
func humanSize(bytes int64) string {
|
||||
if bytes <= 0 {
|
||||
return "-"
|
||||
}
|
||||
const (
|
||||
kib = 1024
|
||||
mib = 1024 * kib
|
||||
gib = 1024 * mib
|
||||
)
|
||||
switch {
|
||||
case bytes >= gib:
|
||||
return fmt.Sprintf("%.1fGiB", float64(bytes)/float64(gib))
|
||||
case bytes >= mib:
|
||||
return fmt.Sprintf("%.1fMiB", float64(bytes)/float64(mib))
|
||||
case bytes >= kib:
|
||||
return fmt.Sprintf("%.1fKiB", float64(bytes)/float64(kib))
|
||||
default:
|
||||
return fmt.Sprintf("%dB", bytes)
|
||||
}
|
||||
}
|
||||
|
||||
func dashIfEmpty(s string) string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return "-"
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ func TestKernelCommandExposesSubcommands(t *testing.T) {
|
|||
for _, sub := range kernel.Commands() {
|
||||
names = append(names, sub.Name())
|
||||
}
|
||||
want := []string{"import", "list", "rm", "show"}
|
||||
want := []string{"import", "list", "pull", "rm", "show"}
|
||||
if !reflect.DeepEqual(names, want) {
|
||||
t.Fatalf("kernel subcommands = %v, want %v", names, want)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -550,6 +550,15 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
|
|||
}
|
||||
entry, err := d.KernelImport(ctx, params)
|
||||
return marshalResultOrError(api.KernelShowResult{Entry: entry}, err)
|
||||
case "kernel.pull":
|
||||
params, err := rpc.DecodeParams[api.KernelPullParams](req)
|
||||
if err != nil {
|
||||
return rpc.NewError("bad_request", err.Error())
|
||||
}
|
||||
entry, err := d.KernelPull(ctx, params)
|
||||
return marshalResultOrError(api.KernelShowResult{Entry: entry}, err)
|
||||
case "kernel.catalog":
|
||||
return marshalResultOrError(d.KernelCatalog(ctx))
|
||||
default:
|
||||
return rpc.NewError("unknown_method", req.Method)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -114,6 +114,65 @@ func (d *Daemon) KernelImport(ctx context.Context, params api.KernelImportParams
|
|||
return kernelEntryToAPI(stored), nil
|
||||
}
|
||||
|
||||
// KernelPull downloads a catalog entry by name into the local catalog. It
|
||||
// refuses to overwrite an existing entry unless params.Force is set.
|
||||
func (d *Daemon) KernelPull(ctx context.Context, params api.KernelPullParams) (api.KernelEntry, error) {
|
||||
name := strings.TrimSpace(params.Name)
|
||||
if err := kernelcat.ValidateName(name); err != nil {
|
||||
return api.KernelEntry{}, err
|
||||
}
|
||||
|
||||
if !params.Force {
|
||||
if _, err := kernelcat.ReadLocal(d.layout.KernelsDir, name); err == nil {
|
||||
return api.KernelEntry{}, fmt.Errorf("kernel %q already pulled; pass --force to re-pull", name)
|
||||
} else if !os.IsNotExist(err) {
|
||||
return api.KernelEntry{}, err
|
||||
}
|
||||
}
|
||||
|
||||
catalog, err := kernelcat.LoadEmbedded()
|
||||
if err != nil {
|
||||
return api.KernelEntry{}, err
|
||||
}
|
||||
catEntry, err := catalog.Lookup(name)
|
||||
if err != nil {
|
||||
return api.KernelEntry{}, fmt.Errorf("kernel %q not in catalog (run 'banger kernel list --available' to browse)", name)
|
||||
}
|
||||
|
||||
stored, err := kernelcat.Fetch(ctx, nil, d.layout.KernelsDir, catEntry)
|
||||
if err != nil {
|
||||
return api.KernelEntry{}, err
|
||||
}
|
||||
return kernelEntryToAPI(stored), nil
|
||||
}
|
||||
|
||||
// KernelCatalog returns every entry from the embedded catalog annotated
|
||||
// with whether it has already been pulled locally.
|
||||
func (d *Daemon) KernelCatalog(_ context.Context) (api.KernelCatalogResult, error) {
|
||||
catalog, err := kernelcat.LoadEmbedded()
|
||||
if err != nil {
|
||||
return api.KernelCatalogResult{}, err
|
||||
}
|
||||
local, _ := kernelcat.ListLocal(d.layout.KernelsDir)
|
||||
pulled := make(map[string]bool, len(local))
|
||||
for _, entry := range local {
|
||||
pulled[entry.Name] = true
|
||||
}
|
||||
result := api.KernelCatalogResult{Entries: make([]api.KernelCatalogEntry, 0, len(catalog.Entries))}
|
||||
for _, entry := range catalog.Entries {
|
||||
result.Entries = append(result.Entries, api.KernelCatalogEntry{
|
||||
Name: entry.Name,
|
||||
Distro: entry.Distro,
|
||||
Arch: entry.Arch,
|
||||
KernelVersion: entry.KernelVersion,
|
||||
SizeBytes: entry.SizeBytes,
|
||||
Description: entry.Description,
|
||||
Pulled: pulled[entry.Name],
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// inferKernelVersion makes a best-effort guess at the kernel version from
|
||||
// the source filename (e.g. "vmlinux-6.12.79_1") or falls back to the
|
||||
// modules directory basename. Returns "" if nothing looks useful.
|
||||
|
|
|
|||
|
|
@ -205,6 +205,43 @@ func TestKernelImportCopiesArtifactsAndWritesManifest(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestKernelPullRejectsUnknownCatalogEntry(t *testing.T) {
|
||||
d := &Daemon{
|
||||
layout: paths.Layout{KernelsDir: t.TempDir()},
|
||||
runner: system.NewRunner(),
|
||||
}
|
||||
_, err := d.KernelPull(context.Background(), api.KernelPullParams{Name: "unknown"})
|
||||
if err == nil || !strings.Contains(err.Error(), "not in catalog") {
|
||||
t.Fatalf("KernelPull unknown: err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKernelPullRefusesOverwriteWithoutForce(t *testing.T) {
|
||||
kernelsDir := t.TempDir()
|
||||
seedKernelEntry(t, kernelsDir, "void-6.12")
|
||||
|
||||
d := &Daemon{
|
||||
layout: paths.Layout{KernelsDir: kernelsDir},
|
||||
runner: system.NewRunner(),
|
||||
}
|
||||
_, err := d.KernelPull(context.Background(), api.KernelPullParams{Name: "void-6.12"})
|
||||
if err == nil || !strings.Contains(err.Error(), "already pulled") {
|
||||
t.Fatalf("KernelPull without --force: err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKernelCatalogReportsPulledStatus(t *testing.T) {
|
||||
d := &Daemon{layout: paths.Layout{KernelsDir: t.TempDir()}}
|
||||
result, err := d.KernelCatalog(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("KernelCatalog: %v", err)
|
||||
}
|
||||
// Embedded catalog ships empty; CI (phase 5) populates it.
|
||||
if result.Entries == nil {
|
||||
t.Fatalf("Entries should be non-nil even when catalog is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKernelImportRejectsMissingFromDir(t *testing.T) {
|
||||
d := &Daemon{
|
||||
layout: paths.Layout{KernelsDir: t.TempDir()},
|
||||
|
|
|
|||
59
internal/kernelcat/catalog.go
Normal file
59
internal/kernelcat/catalog.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package kernelcat
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
//go:embed catalog.json
|
||||
var embeddedCatalog []byte
|
||||
|
||||
// Catalog is the published list of kernel bundles banger can pull. It ships
|
||||
// embedded in the banger binary and is updated across releases; Phase 5
|
||||
// wires CI to regenerate it.
|
||||
type Catalog struct {
|
||||
Version int `json:"version"`
|
||||
Entries []CatEntry `json:"entries"`
|
||||
}
|
||||
|
||||
// CatEntry describes one downloadable kernel bundle.
|
||||
type CatEntry struct {
|
||||
Name string `json:"name"`
|
||||
Distro string `json:"distro,omitempty"`
|
||||
Arch string `json:"arch,omitempty"`
|
||||
KernelVersion string `json:"kernel_version,omitempty"`
|
||||
TarballURL string `json:"tarball_url"`
|
||||
TarballSHA256 string `json:"tarball_sha256"`
|
||||
SizeBytes int64 `json:"size_bytes,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// LoadEmbedded returns the catalog compiled into this banger binary.
|
||||
func LoadEmbedded() (Catalog, error) {
|
||||
return ParseCatalog(embeddedCatalog)
|
||||
}
|
||||
|
||||
// ParseCatalog decodes a catalog.json payload. An empty payload is valid
|
||||
// and returns a zero Catalog.
|
||||
func ParseCatalog(data []byte) (Catalog, error) {
|
||||
var cat Catalog
|
||||
if len(data) == 0 {
|
||||
return cat, nil
|
||||
}
|
||||
if err := json.Unmarshal(data, &cat); err != nil {
|
||||
return Catalog{}, fmt.Errorf("parse catalog: %w", err)
|
||||
}
|
||||
return cat, nil
|
||||
}
|
||||
|
||||
// Lookup returns the catalog entry matching name, or os.ErrNotExist.
|
||||
func (c Catalog) Lookup(name string) (CatEntry, error) {
|
||||
for _, e := range c.Entries {
|
||||
if e.Name == name {
|
||||
return e, nil
|
||||
}
|
||||
}
|
||||
return CatEntry{}, os.ErrNotExist
|
||||
}
|
||||
4
internal/kernelcat/catalog.json
Normal file
4
internal/kernelcat/catalog.json
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
{
|
||||
"version": 1,
|
||||
"entries": []
|
||||
}
|
||||
52
internal/kernelcat/catalog_test.go
Normal file
52
internal/kernelcat/catalog_test.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package kernelcat
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseCatalogEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
cat, err := ParseCatalog(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCatalog(nil): %v", err)
|
||||
}
|
||||
if len(cat.Entries) != 0 {
|
||||
t.Fatalf("entries = %d, want 0", len(cat.Entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCatalogValid(t *testing.T) {
|
||||
t.Parallel()
|
||||
cat, err := ParseCatalog([]byte(`{"version":1,"entries":[{"name":"void-6.12","distro":"void","tarball_url":"https://example/v.tar.zst","tarball_sha256":"abc"}]}`))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCatalog: %v", err)
|
||||
}
|
||||
if cat.Version != 1 || len(cat.Entries) != 1 || cat.Entries[0].Name != "void-6.12" {
|
||||
t.Fatalf("catalog = %+v", cat)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCatalogLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
cat := Catalog{Entries: []CatEntry{{Name: "a"}, {Name: "b"}}}
|
||||
if entry, err := cat.Lookup("b"); err != nil || entry.Name != "b" {
|
||||
t.Fatalf("Lookup(b) = %+v, %v", entry, err)
|
||||
}
|
||||
if _, err := cat.Lookup("c"); !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("Lookup(missing) err = %v, want ErrNotExist", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadEmbeddedReturnsValidCatalog(t *testing.T) {
|
||||
t.Parallel()
|
||||
cat, err := LoadEmbedded()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadEmbedded: %v", err)
|
||||
}
|
||||
if cat.Version != 1 {
|
||||
t.Fatalf("embedded catalog.Version = %d, want 1", cat.Version)
|
||||
}
|
||||
// Embedded catalog starts empty; Phase 5 CI populates it.
|
||||
}
|
||||
187
internal/kernelcat/fetch.go
Normal file
187
internal/kernelcat/fetch.go
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
package kernelcat
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
// Fetch downloads the tarball for entry, verifies its SHA256, extracts it
|
||||
// into <kernelsDir>/<entry.Name>/, and writes a manifest. On failure it
|
||||
// removes the partially-populated target directory.
|
||||
//
|
||||
// The tarball is expected to be a tar+zstd archive whose root contains
|
||||
// vmlinux and optionally initrd.img and/or a modules/ directory. Path
|
||||
// traversal entries (..) and absolute-path members are rejected.
|
||||
func Fetch(ctx context.Context, client *http.Client, kernelsDir string, entry CatEntry) (Entry, error) {
|
||||
if err := ValidateName(entry.Name); err != nil {
|
||||
return Entry{}, err
|
||||
}
|
||||
if strings.TrimSpace(entry.TarballURL) == "" {
|
||||
return Entry{}, fmt.Errorf("catalog entry %q has no tarball URL", entry.Name)
|
||||
}
|
||||
if strings.TrimSpace(entry.TarballSHA256) == "" {
|
||||
return Entry{}, fmt.Errorf("catalog entry %q has no tarball sha256", entry.Name)
|
||||
}
|
||||
if client == nil {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
|
||||
if err := DeleteLocal(kernelsDir, entry.Name); err != nil {
|
||||
return Entry{}, fmt.Errorf("clear prior catalog entry: %w", err)
|
||||
}
|
||||
targetDir := EntryDir(kernelsDir, entry.Name)
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return Entry{}, err
|
||||
}
|
||||
|
||||
cleanup := func() { _ = os.RemoveAll(targetDir) }
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, entry.TarballURL, nil)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return Entry{}, err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return Entry{}, fmt.Errorf("fetch %s: %w", entry.TarballURL, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
cleanup()
|
||||
return Entry{}, fmt.Errorf("fetch %s: HTTP %s", entry.TarballURL, resp.Status)
|
||||
}
|
||||
|
||||
hasher := sha256.New()
|
||||
tee := io.TeeReader(resp.Body, hasher)
|
||||
zr, err := zstd.NewReader(tee)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return Entry{}, fmt.Errorf("init zstd: %w", err)
|
||||
}
|
||||
defer zr.Close()
|
||||
|
||||
if err := extractTar(zr, targetDir); err != nil {
|
||||
cleanup()
|
||||
return Entry{}, err
|
||||
}
|
||||
// Drain any remaining tarball-padding bytes so the hash covers the
|
||||
// whole transport stream even if the tar reader stopped early.
|
||||
if _, err := io.Copy(io.Discard, tee); err != nil {
|
||||
cleanup()
|
||||
return Entry{}, fmt.Errorf("drain tarball: %w", err)
|
||||
}
|
||||
|
||||
got := hex.EncodeToString(hasher.Sum(nil))
|
||||
if !strings.EqualFold(got, entry.TarballSHA256) {
|
||||
cleanup()
|
||||
return Entry{}, fmt.Errorf("tarball sha256 mismatch: got %s, want %s", got, entry.TarballSHA256)
|
||||
}
|
||||
|
||||
kernelPath := filepath.Join(targetDir, kernelFilename)
|
||||
if _, err := os.Stat(kernelPath); err != nil {
|
||||
cleanup()
|
||||
return Entry{}, fmt.Errorf("tarball missing %s: %w", kernelFilename, err)
|
||||
}
|
||||
kernelSum, err := SumFile(kernelPath)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return Entry{}, err
|
||||
}
|
||||
|
||||
stored := Entry{
|
||||
Name: entry.Name,
|
||||
Distro: entry.Distro,
|
||||
Arch: entry.Arch,
|
||||
KernelVersion: entry.KernelVersion,
|
||||
SHA256: kernelSum,
|
||||
Source: "pull:" + entry.TarballURL,
|
||||
ImportedAt: time.Now().UTC(),
|
||||
}
|
||||
if err := WriteLocal(kernelsDir, stored); err != nil {
|
||||
cleanup()
|
||||
return Entry{}, err
|
||||
}
|
||||
return ReadLocal(kernelsDir, entry.Name)
|
||||
}
|
||||
|
||||
// extractTar writes each regular file / dir / safe symlink from r into
|
||||
// target, refusing any member whose normalised path would escape target.
|
||||
func extractTar(r io.Reader, target string) error {
|
||||
absTarget, err := filepath.Abs(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tr := tar.NewReader(r)
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("read tarball: %w", err)
|
||||
}
|
||||
rel := filepath.Clean(hdr.Name)
|
||||
if rel == "." || rel == string(filepath.Separator) {
|
||||
continue
|
||||
}
|
||||
if filepath.IsAbs(rel) || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || rel == ".." {
|
||||
return fmt.Errorf("unsafe path in tarball: %q", hdr.Name)
|
||||
}
|
||||
dst := filepath.Join(absTarget, rel)
|
||||
if dst != absTarget && !strings.HasPrefix(dst, absTarget+string(filepath.Separator)) {
|
||||
return fmt.Errorf("unsafe path in tarball: %q", hdr.Name)
|
||||
}
|
||||
switch hdr.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(dst, os.FileMode(hdr.Mode)|0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
case tar.TypeReg:
|
||||
if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
f, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)|0o600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.Copy(f, tr); err != nil {
|
||||
_ = f.Close()
|
||||
return err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
case tar.TypeSymlink:
|
||||
if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
link := hdr.Linkname
|
||||
resolved := link
|
||||
if !filepath.IsAbs(link) {
|
||||
resolved = filepath.Join(filepath.Dir(dst), link)
|
||||
}
|
||||
resolved = filepath.Clean(resolved)
|
||||
if resolved != absTarget && !strings.HasPrefix(resolved, absTarget+string(filepath.Separator)) {
|
||||
return fmt.Errorf("unsafe symlink in tarball: %q -> %q", hdr.Name, hdr.Linkname)
|
||||
}
|
||||
if err := os.Symlink(hdr.Linkname, dst); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
// Hardlinks / device nodes / fifos: skip silently. Kernel
|
||||
// module trees shouldn't need them.
|
||||
}
|
||||
}
|
||||
}
|
||||
198
internal/kernelcat/fetch_test.go
Normal file
198
internal/kernelcat/fetch_test.go
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
package kernelcat
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
// tarballFile describes one member of the test tarball.
|
||||
type tarballFile struct {
|
||||
name string
|
||||
mode int64
|
||||
data []byte
|
||||
link string // for symlinks
|
||||
dir bool
|
||||
}
|
||||
|
||||
func buildTestTarball(t *testing.T, files []tarballFile) ([]byte, string) {
|
||||
t.Helper()
|
||||
var tarBuf bytes.Buffer
|
||||
tw := tar.NewWriter(&tarBuf)
|
||||
for _, f := range files {
|
||||
hdr := &tar.Header{Name: f.name, Mode: f.mode}
|
||||
switch {
|
||||
case f.dir:
|
||||
hdr.Typeflag = tar.TypeDir
|
||||
hdr.Mode = 0o755
|
||||
case f.link != "":
|
||||
hdr.Typeflag = tar.TypeSymlink
|
||||
hdr.Linkname = f.link
|
||||
default:
|
||||
hdr.Typeflag = tar.TypeReg
|
||||
hdr.Size = int64(len(f.data))
|
||||
if hdr.Mode == 0 {
|
||||
hdr.Mode = 0o644
|
||||
}
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
t.Fatalf("tar WriteHeader: %v", err)
|
||||
}
|
||||
if hdr.Typeflag == tar.TypeReg {
|
||||
if _, err := tw.Write(f.data); err != nil {
|
||||
t.Fatalf("tar Write: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := tw.Close(); err != nil {
|
||||
t.Fatalf("tar Close: %v", err)
|
||||
}
|
||||
|
||||
var compressed bytes.Buffer
|
||||
zw, err := zstd.NewWriter(&compressed)
|
||||
if err != nil {
|
||||
t.Fatalf("zstd NewWriter: %v", err)
|
||||
}
|
||||
if _, err := zw.Write(tarBuf.Bytes()); err != nil {
|
||||
t.Fatalf("zstd Write: %v", err)
|
||||
}
|
||||
if err := zw.Close(); err != nil {
|
||||
t.Fatalf("zstd Close: %v", err)
|
||||
}
|
||||
sum := sha256.Sum256(compressed.Bytes())
|
||||
return compressed.Bytes(), hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func serveTarball(t *testing.T, body []byte) *httptest.Server {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(body)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestFetchExtractsTarballAndWritesManifest(t *testing.T) {
|
||||
t.Parallel()
|
||||
body, sum := buildTestTarball(t, []tarballFile{
|
||||
{name: "vmlinux", data: []byte("kernel-bytes")},
|
||||
{name: "initrd.img", data: []byte("initrd-bytes")},
|
||||
{name: "modules", dir: true},
|
||||
{name: "modules/modules.dep", data: []byte("dep")},
|
||||
})
|
||||
srv := serveTarball(t, body)
|
||||
|
||||
kernelsDir := t.TempDir()
|
||||
stored, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{
|
||||
Name: "void-6.12",
|
||||
Distro: "void",
|
||||
Arch: "x86_64",
|
||||
KernelVersion: "6.12.79_1",
|
||||
TarballURL: srv.URL + "/pkg.tar.zst",
|
||||
TarballSHA256: sum,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Fetch: %v", err)
|
||||
}
|
||||
if stored.Name != "void-6.12" || stored.Distro != "void" {
|
||||
t.Fatalf("stored = %+v", stored)
|
||||
}
|
||||
if stored.SHA256 == "" {
|
||||
t.Errorf("SHA256 not populated")
|
||||
}
|
||||
|
||||
for _, rel := range []string{"vmlinux", "initrd.img", "modules/modules.dep", "manifest.json"} {
|
||||
if _, err := os.Stat(filepath.Join(kernelsDir, "void-6.12", rel)); err != nil {
|
||||
t.Errorf("expected %s in catalog: %v", rel, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchRejectsShaMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
body, _ := buildTestTarball(t, []tarballFile{
|
||||
{name: "vmlinux", data: []byte("k")},
|
||||
})
|
||||
srv := serveTarball(t, body)
|
||||
|
||||
kernelsDir := t.TempDir()
|
||||
_, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{
|
||||
Name: "void-6.12",
|
||||
TarballURL: srv.URL + "/pkg.tar.zst",
|
||||
TarballSHA256: "000000000000000000000000000000000000000000000000000000000000beef",
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") {
|
||||
t.Fatalf("expected sha256 mismatch, got %v", err)
|
||||
}
|
||||
if _, statErr := os.Stat(filepath.Join(kernelsDir, "void-6.12")); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("target dir should be cleaned up on mismatch: %v", statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchRejectsMissingKernel(t *testing.T) {
|
||||
t.Parallel()
|
||||
body, sum := buildTestTarball(t, []tarballFile{
|
||||
{name: "initrd.img", data: []byte("i")}, // no vmlinux
|
||||
})
|
||||
srv := serveTarball(t, body)
|
||||
kernelsDir := t.TempDir()
|
||||
_, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{
|
||||
Name: "broken",
|
||||
TarballURL: srv.URL + "/pkg.tar.zst",
|
||||
TarballSHA256: sum,
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "missing vmlinux") {
|
||||
t.Fatalf("expected missing vmlinux, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchRejectsPathTraversal(t *testing.T) {
|
||||
t.Parallel()
|
||||
body, sum := buildTestTarball(t, []tarballFile{
|
||||
{name: "vmlinux", data: []byte("k")},
|
||||
{name: "../escape", data: []byte("bad")},
|
||||
})
|
||||
srv := serveTarball(t, body)
|
||||
kernelsDir := t.TempDir()
|
||||
_, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{
|
||||
Name: "bad-tarball",
|
||||
TarballURL: srv.URL + "/pkg.tar.zst",
|
||||
TarballSHA256: sum,
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "unsafe path") {
|
||||
t.Fatalf("expected unsafe path error, got %v", err)
|
||||
}
|
||||
escapePath := filepath.Join(filepath.Dir(kernelsDir), "escape")
|
||||
if _, statErr := os.Stat(escapePath); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("traversal escape file should not exist: %v", statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchRejectsHTTPError(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
http.Error(w, "nope", http.StatusNotFound)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
kernelsDir := t.TempDir()
|
||||
_, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{
|
||||
Name: "missing",
|
||||
TarballURL: srv.URL + "/pkg.tar.zst",
|
||||
TarballSHA256: "deadbeef",
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "404") {
|
||||
t.Fatalf("expected HTTP 404, got %v", err)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue