From f0668ee59811fd0580459fbaad85bd6a40f62626 Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Thu, 16 Apr 2026 15:05:42 -0300 Subject: [PATCH] Phase 4: remote catalog + banger kernel pull Introduces the headline feature of the kernel catalog: pulling a kernel bundle over HTTP without any local build step. Catalog format (internal/kernelcat/catalog.go): - Catalog { Version, Entries } + CatEntry { Name, Distro, Arch, KernelVersion, TarballURL, TarballSHA256, SizeBytes, Description }. - catalog.json is embedded via go:embed and ships with each banger binary. It starts empty (Phase 5's CI pipeline will populate it). - Lookup(name) returns the matching entry or os.ErrNotExist. Fetch (internal/kernelcat/fetch.go): - HTTP GET with streaming SHA256 over the response body. - zstd-decode (github.com/klauspost/compress/zstd) -> tar extract into //. - Hardens against path-traversal tarball entries (members whose normalised path escapes the target dir, and unsafe symlink targets) and sha256-mismatch downloads; any failure removes the partially-populated target dir. - Regular files, directories, and safe symlinks are supported; other tar types (hardlinks, devices, fifos) are silently skipped. - After extraction, recomputes sha256 over the on-disk vmlinux and writes the manifest with Source="pull:". Daemon methods (internal/daemon/kernels.go): - KernelPull(ctx, {Name, Force}) - lookup in embedded catalog, refuse overwrite unless Force, delegate to kernelcat.Fetch. - KernelCatalog(ctx) - return the embedded catalog annotated per-entry with whether it has been pulled locally. RPC: kernel.pull, kernel.catalog dispatch cases. CLI: - `banger kernel pull [--force]`. - `banger kernel list --available` prints the catalog with a pulled/available STATE column and a human-readable size. Tests: fetch round-trip (extract + manifest + sha256), sha256 mismatch rejection with cleanup, missing-vmlinux rejection, path-traversal rejection, HTTP error propagation, catalog parsing, lookup, pulled-status reconciliation. All 20 packages green. Co-Authored-By: Claude Sonnet 4.6 --- go.mod | 1 + go.sum | 2 + internal/api/types.go | 19 +++ internal/cli/banger.go | 86 ++++++++++++- internal/cli/cli_test.go | 2 +- internal/daemon/daemon.go | 9 ++ internal/daemon/kernels.go | 59 +++++++++ internal/daemon/kernels_test.go | 37 ++++++ internal/kernelcat/catalog.go | 59 +++++++++ internal/kernelcat/catalog.json | 4 + internal/kernelcat/catalog_test.go | 52 ++++++++ internal/kernelcat/fetch.go | 187 +++++++++++++++++++++++++++ internal/kernelcat/fetch_test.go | 198 +++++++++++++++++++++++++++++ 13 files changed, 711 insertions(+), 4 deletions(-) create mode 100644 internal/kernelcat/catalog.go create mode 100644 internal/kernelcat/catalog.json create mode 100644 internal/kernelcat/catalog_test.go create mode 100644 internal/kernelcat/fetch.go create mode 100644 internal/kernelcat/fetch_test.go diff --git a/go.mod b/go.mod index 3a07334..2ddb7c4 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect + github.com/klauspost/compress v1.18.5 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mdlayher/socket v0.2.0 // indirect diff --git a/go.sum b/go.sum index 44fbb17..9547056 100644 --- a/go.sum +++ b/go.sum @@ -462,6 +462,8 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= +github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= diff --git a/internal/api/types.go b/internal/api/types.go index d73fac8..f5c28dc 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -307,6 +307,25 @@ type KernelImportParams struct { Arch string `json:"arch,omitempty"` } +type KernelPullParams struct { + Name string `json:"name"` + Force bool `json:"force,omitempty"` +} + +type KernelCatalogEntry struct { + Name string `json:"name"` + Distro string `json:"distro,omitempty"` + Arch string `json:"arch,omitempty"` + KernelVersion string `json:"kernel_version,omitempty"` + SizeBytes int64 `json:"size_bytes,omitempty"` + Description string `json:"description,omitempty"` + Pulled bool `json:"pulled"` +} + +type KernelCatalogResult struct { + Entries []KernelCatalogEntry `json:"entries"` +} + type SudoStatus struct { Available bool `json:"available"` Command string `json:"command,omitempty"` diff --git a/internal/cli/banger.go b/internal/cli/banger.go index 678772a..bfa3f1e 100644 --- a/internal/cli/banger.go +++ b/internal/cli/banger.go @@ -1600,10 +1600,33 @@ func newKernelCommand() *cobra.Command { newKernelShowCommand(), newKernelRmCommand(), newKernelImportCommand(), + newKernelPullCommand(), ) return cmd } +func newKernelPullCommand() *cobra.Command { + var force bool + cmd := &cobra.Command{ + Use: "pull ", + Short: "Download a cataloged kernel bundle", + Args: exactArgsUsage(1, "usage: banger kernel pull [--force]"), + RunE: func(cmd *cobra.Command, args []string) error { + layout, _, err := ensureDaemon(cmd.Context()) + if err != nil { + return err + } + result, err := rpc.Call[api.KernelShowResult](cmd.Context(), layout.SocketPath, "kernel.pull", api.KernelPullParams{Name: args[0], Force: force}) + if err != nil { + return err + } + return printJSON(cmd.OutOrStdout(), result.Entry) + }, + } + cmd.Flags().BoolVar(&force, "force", false, "re-pull even if already present") + return cmd +} + func newKernelImportCommand() *cobra.Command { var params api.KernelImportParams cmd := &cobra.Command{ @@ -1639,15 +1662,23 @@ func newKernelImportCommand() *cobra.Command { } func newKernelListCommand() *cobra.Command { - return &cobra.Command{ + var available bool + cmd := &cobra.Command{ Use: "list", - Short: "List locally available kernels", - Args: noArgsUsage("usage: banger kernel list"), + Short: "List kernels (local by default, or --available for the catalog)", + Args: noArgsUsage("usage: banger kernel list [--available]"), RunE: func(cmd *cobra.Command, args []string) error { layout, _, err := ensureDaemon(cmd.Context()) if err != nil { return err } + if available { + result, err := rpc.Call[api.KernelCatalogResult](cmd.Context(), layout.SocketPath, "kernel.catalog", api.Empty{}) + if err != nil { + return err + } + return printKernelCatalogTable(cmd.OutOrStdout(), result.Entries) + } result, err := rpc.Call[api.KernelListResult](cmd.Context(), layout.SocketPath, "kernel.list", api.Empty{}) if err != nil { return err @@ -1655,6 +1686,8 @@ func newKernelListCommand() *cobra.Command { return printKernelListTable(cmd.OutOrStdout(), result.Entries) }, } + cmd.Flags().BoolVar(&available, "available", false, "show the built-in catalog (with pulled/available status) instead of local entries") + return cmd } func newKernelShowCommand() *cobra.Command { @@ -1717,6 +1750,53 @@ func printKernelListTable(out anyWriter, entries []api.KernelEntry) error { return w.Flush() } +func printKernelCatalogTable(out anyWriter, entries []api.KernelCatalogEntry) error { + w := tabwriter.NewWriter(out, 0, 8, 2, ' ', 0) + if _, err := fmt.Fprintln(w, "NAME\tDISTRO\tARCH\tKERNEL\tSIZE\tSTATE"); err != nil { + return err + } + for _, entry := range entries { + state := "available" + if entry.Pulled { + state = "pulled" + } + if _, err := fmt.Fprintf( + w, + "%s\t%s\t%s\t%s\t%s\t%s\n", + entry.Name, + dashIfEmpty(entry.Distro), + dashIfEmpty(entry.Arch), + dashIfEmpty(entry.KernelVersion), + humanSize(entry.SizeBytes), + state, + ); err != nil { + return err + } + } + return w.Flush() +} + +func humanSize(bytes int64) string { + if bytes <= 0 { + return "-" + } + const ( + kib = 1024 + mib = 1024 * kib + gib = 1024 * mib + ) + switch { + case bytes >= gib: + return fmt.Sprintf("%.1fGiB", float64(bytes)/float64(gib)) + case bytes >= mib: + return fmt.Sprintf("%.1fMiB", float64(bytes)/float64(mib)) + case bytes >= kib: + return fmt.Sprintf("%.1fKiB", float64(bytes)/float64(kib)) + default: + return fmt.Sprintf("%dB", bytes) + } +} + func dashIfEmpty(s string) string { if strings.TrimSpace(s) == "" { return "-" diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 4ca164c..dd21570 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -75,7 +75,7 @@ func TestKernelCommandExposesSubcommands(t *testing.T) { for _, sub := range kernel.Commands() { names = append(names, sub.Name()) } - want := []string{"import", "list", "rm", "show"} + want := []string{"import", "list", "pull", "rm", "show"} if !reflect.DeepEqual(names, want) { t.Fatalf("kernel subcommands = %v, want %v", names, want) } diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index d4768be..3af71e2 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -550,6 +550,15 @@ func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response { } entry, err := d.KernelImport(ctx, params) return marshalResultOrError(api.KernelShowResult{Entry: entry}, err) + case "kernel.pull": + params, err := rpc.DecodeParams[api.KernelPullParams](req) + if err != nil { + return rpc.NewError("bad_request", err.Error()) + } + entry, err := d.KernelPull(ctx, params) + return marshalResultOrError(api.KernelShowResult{Entry: entry}, err) + case "kernel.catalog": + return marshalResultOrError(d.KernelCatalog(ctx)) default: return rpc.NewError("unknown_method", req.Method) } diff --git a/internal/daemon/kernels.go b/internal/daemon/kernels.go index e3de641..39f0196 100644 --- a/internal/daemon/kernels.go +++ b/internal/daemon/kernels.go @@ -114,6 +114,65 @@ func (d *Daemon) KernelImport(ctx context.Context, params api.KernelImportParams return kernelEntryToAPI(stored), nil } +// KernelPull downloads a catalog entry by name into the local catalog. It +// refuses to overwrite an existing entry unless params.Force is set. +func (d *Daemon) KernelPull(ctx context.Context, params api.KernelPullParams) (api.KernelEntry, error) { + name := strings.TrimSpace(params.Name) + if err := kernelcat.ValidateName(name); err != nil { + return api.KernelEntry{}, err + } + + if !params.Force { + if _, err := kernelcat.ReadLocal(d.layout.KernelsDir, name); err == nil { + return api.KernelEntry{}, fmt.Errorf("kernel %q already pulled; pass --force to re-pull", name) + } else if !os.IsNotExist(err) { + return api.KernelEntry{}, err + } + } + + catalog, err := kernelcat.LoadEmbedded() + if err != nil { + return api.KernelEntry{}, err + } + catEntry, err := catalog.Lookup(name) + if err != nil { + return api.KernelEntry{}, fmt.Errorf("kernel %q not in catalog (run 'banger kernel list --available' to browse)", name) + } + + stored, err := kernelcat.Fetch(ctx, nil, d.layout.KernelsDir, catEntry) + if err != nil { + return api.KernelEntry{}, err + } + return kernelEntryToAPI(stored), nil +} + +// KernelCatalog returns every entry from the embedded catalog annotated +// with whether it has already been pulled locally. +func (d *Daemon) KernelCatalog(_ context.Context) (api.KernelCatalogResult, error) { + catalog, err := kernelcat.LoadEmbedded() + if err != nil { + return api.KernelCatalogResult{}, err + } + local, _ := kernelcat.ListLocal(d.layout.KernelsDir) + pulled := make(map[string]bool, len(local)) + for _, entry := range local { + pulled[entry.Name] = true + } + result := api.KernelCatalogResult{Entries: make([]api.KernelCatalogEntry, 0, len(catalog.Entries))} + for _, entry := range catalog.Entries { + result.Entries = append(result.Entries, api.KernelCatalogEntry{ + Name: entry.Name, + Distro: entry.Distro, + Arch: entry.Arch, + KernelVersion: entry.KernelVersion, + SizeBytes: entry.SizeBytes, + Description: entry.Description, + Pulled: pulled[entry.Name], + }) + } + return result, nil +} + // inferKernelVersion makes a best-effort guess at the kernel version from // the source filename (e.g. "vmlinux-6.12.79_1") or falls back to the // modules directory basename. Returns "" if nothing looks useful. diff --git a/internal/daemon/kernels_test.go b/internal/daemon/kernels_test.go index e747c1f..7179b5f 100644 --- a/internal/daemon/kernels_test.go +++ b/internal/daemon/kernels_test.go @@ -205,6 +205,43 @@ func TestKernelImportCopiesArtifactsAndWritesManifest(t *testing.T) { } } +func TestKernelPullRejectsUnknownCatalogEntry(t *testing.T) { + d := &Daemon{ + layout: paths.Layout{KernelsDir: t.TempDir()}, + runner: system.NewRunner(), + } + _, err := d.KernelPull(context.Background(), api.KernelPullParams{Name: "unknown"}) + if err == nil || !strings.Contains(err.Error(), "not in catalog") { + t.Fatalf("KernelPull unknown: err=%v", err) + } +} + +func TestKernelPullRefusesOverwriteWithoutForce(t *testing.T) { + kernelsDir := t.TempDir() + seedKernelEntry(t, kernelsDir, "void-6.12") + + d := &Daemon{ + layout: paths.Layout{KernelsDir: kernelsDir}, + runner: system.NewRunner(), + } + _, err := d.KernelPull(context.Background(), api.KernelPullParams{Name: "void-6.12"}) + if err == nil || !strings.Contains(err.Error(), "already pulled") { + t.Fatalf("KernelPull without --force: err=%v", err) + } +} + +func TestKernelCatalogReportsPulledStatus(t *testing.T) { + d := &Daemon{layout: paths.Layout{KernelsDir: t.TempDir()}} + result, err := d.KernelCatalog(context.Background()) + if err != nil { + t.Fatalf("KernelCatalog: %v", err) + } + // Embedded catalog ships empty; CI (phase 5) populates it. + if result.Entries == nil { + t.Fatalf("Entries should be non-nil even when catalog is empty") + } +} + func TestKernelImportRejectsMissingFromDir(t *testing.T) { d := &Daemon{ layout: paths.Layout{KernelsDir: t.TempDir()}, diff --git a/internal/kernelcat/catalog.go b/internal/kernelcat/catalog.go new file mode 100644 index 0000000..d703451 --- /dev/null +++ b/internal/kernelcat/catalog.go @@ -0,0 +1,59 @@ +package kernelcat + +import ( + _ "embed" + "encoding/json" + "fmt" + "os" +) + +//go:embed catalog.json +var embeddedCatalog []byte + +// Catalog is the published list of kernel bundles banger can pull. It ships +// embedded in the banger binary and is updated across releases; Phase 5 +// wires CI to regenerate it. +type Catalog struct { + Version int `json:"version"` + Entries []CatEntry `json:"entries"` +} + +// CatEntry describes one downloadable kernel bundle. +type CatEntry struct { + Name string `json:"name"` + Distro string `json:"distro,omitempty"` + Arch string `json:"arch,omitempty"` + KernelVersion string `json:"kernel_version,omitempty"` + TarballURL string `json:"tarball_url"` + TarballSHA256 string `json:"tarball_sha256"` + SizeBytes int64 `json:"size_bytes,omitempty"` + Description string `json:"description,omitempty"` +} + +// LoadEmbedded returns the catalog compiled into this banger binary. +func LoadEmbedded() (Catalog, error) { + return ParseCatalog(embeddedCatalog) +} + +// ParseCatalog decodes a catalog.json payload. An empty payload is valid +// and returns a zero Catalog. +func ParseCatalog(data []byte) (Catalog, error) { + var cat Catalog + if len(data) == 0 { + return cat, nil + } + if err := json.Unmarshal(data, &cat); err != nil { + return Catalog{}, fmt.Errorf("parse catalog: %w", err) + } + return cat, nil +} + +// Lookup returns the catalog entry matching name, or os.ErrNotExist. +func (c Catalog) Lookup(name string) (CatEntry, error) { + for _, e := range c.Entries { + if e.Name == name { + return e, nil + } + } + return CatEntry{}, os.ErrNotExist +} diff --git a/internal/kernelcat/catalog.json b/internal/kernelcat/catalog.json new file mode 100644 index 0000000..7f19696 --- /dev/null +++ b/internal/kernelcat/catalog.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "entries": [] +} diff --git a/internal/kernelcat/catalog_test.go b/internal/kernelcat/catalog_test.go new file mode 100644 index 0000000..2f26463 --- /dev/null +++ b/internal/kernelcat/catalog_test.go @@ -0,0 +1,52 @@ +package kernelcat + +import ( + "errors" + "os" + "testing" +) + +func TestParseCatalogEmpty(t *testing.T) { + t.Parallel() + cat, err := ParseCatalog(nil) + if err != nil { + t.Fatalf("ParseCatalog(nil): %v", err) + } + if len(cat.Entries) != 0 { + t.Fatalf("entries = %d, want 0", len(cat.Entries)) + } +} + +func TestParseCatalogValid(t *testing.T) { + t.Parallel() + cat, err := ParseCatalog([]byte(`{"version":1,"entries":[{"name":"void-6.12","distro":"void","tarball_url":"https://example/v.tar.zst","tarball_sha256":"abc"}]}`)) + if err != nil { + t.Fatalf("ParseCatalog: %v", err) + } + if cat.Version != 1 || len(cat.Entries) != 1 || cat.Entries[0].Name != "void-6.12" { + t.Fatalf("catalog = %+v", cat) + } +} + +func TestCatalogLookup(t *testing.T) { + t.Parallel() + cat := Catalog{Entries: []CatEntry{{Name: "a"}, {Name: "b"}}} + if entry, err := cat.Lookup("b"); err != nil || entry.Name != "b" { + t.Fatalf("Lookup(b) = %+v, %v", entry, err) + } + if _, err := cat.Lookup("c"); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("Lookup(missing) err = %v, want ErrNotExist", err) + } +} + +func TestLoadEmbeddedReturnsValidCatalog(t *testing.T) { + t.Parallel() + cat, err := LoadEmbedded() + if err != nil { + t.Fatalf("LoadEmbedded: %v", err) + } + if cat.Version != 1 { + t.Fatalf("embedded catalog.Version = %d, want 1", cat.Version) + } + // Embedded catalog starts empty; Phase 5 CI populates it. +} diff --git a/internal/kernelcat/fetch.go b/internal/kernelcat/fetch.go new file mode 100644 index 0000000..91eec81 --- /dev/null +++ b/internal/kernelcat/fetch.go @@ -0,0 +1,187 @@ +package kernelcat + +import ( + "archive/tar" + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/klauspost/compress/zstd" +) + +// Fetch downloads the tarball for entry, verifies its SHA256, extracts it +// into //, and writes a manifest. On failure it +// removes the partially-populated target directory. +// +// The tarball is expected to be a tar+zstd archive whose root contains +// vmlinux and optionally initrd.img and/or a modules/ directory. Path +// traversal entries (..) and absolute-path members are rejected. +func Fetch(ctx context.Context, client *http.Client, kernelsDir string, entry CatEntry) (Entry, error) { + if err := ValidateName(entry.Name); err != nil { + return Entry{}, err + } + if strings.TrimSpace(entry.TarballURL) == "" { + return Entry{}, fmt.Errorf("catalog entry %q has no tarball URL", entry.Name) + } + if strings.TrimSpace(entry.TarballSHA256) == "" { + return Entry{}, fmt.Errorf("catalog entry %q has no tarball sha256", entry.Name) + } + if client == nil { + client = http.DefaultClient + } + + if err := DeleteLocal(kernelsDir, entry.Name); err != nil { + return Entry{}, fmt.Errorf("clear prior catalog entry: %w", err) + } + targetDir := EntryDir(kernelsDir, entry.Name) + if err := os.MkdirAll(targetDir, 0o755); err != nil { + return Entry{}, err + } + + cleanup := func() { _ = os.RemoveAll(targetDir) } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, entry.TarballURL, nil) + if err != nil { + cleanup() + return Entry{}, err + } + resp, err := client.Do(req) + if err != nil { + cleanup() + return Entry{}, fmt.Errorf("fetch %s: %w", entry.TarballURL, err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + cleanup() + return Entry{}, fmt.Errorf("fetch %s: HTTP %s", entry.TarballURL, resp.Status) + } + + hasher := sha256.New() + tee := io.TeeReader(resp.Body, hasher) + zr, err := zstd.NewReader(tee) + if err != nil { + cleanup() + return Entry{}, fmt.Errorf("init zstd: %w", err) + } + defer zr.Close() + + if err := extractTar(zr, targetDir); err != nil { + cleanup() + return Entry{}, err + } + // Drain any remaining tarball-padding bytes so the hash covers the + // whole transport stream even if the tar reader stopped early. + if _, err := io.Copy(io.Discard, tee); err != nil { + cleanup() + return Entry{}, fmt.Errorf("drain tarball: %w", err) + } + + got := hex.EncodeToString(hasher.Sum(nil)) + if !strings.EqualFold(got, entry.TarballSHA256) { + cleanup() + return Entry{}, fmt.Errorf("tarball sha256 mismatch: got %s, want %s", got, entry.TarballSHA256) + } + + kernelPath := filepath.Join(targetDir, kernelFilename) + if _, err := os.Stat(kernelPath); err != nil { + cleanup() + return Entry{}, fmt.Errorf("tarball missing %s: %w", kernelFilename, err) + } + kernelSum, err := SumFile(kernelPath) + if err != nil { + cleanup() + return Entry{}, err + } + + stored := Entry{ + Name: entry.Name, + Distro: entry.Distro, + Arch: entry.Arch, + KernelVersion: entry.KernelVersion, + SHA256: kernelSum, + Source: "pull:" + entry.TarballURL, + ImportedAt: time.Now().UTC(), + } + if err := WriteLocal(kernelsDir, stored); err != nil { + cleanup() + return Entry{}, err + } + return ReadLocal(kernelsDir, entry.Name) +} + +// extractTar writes each regular file / dir / safe symlink from r into +// target, refusing any member whose normalised path would escape target. +func extractTar(r io.Reader, target string) error { + absTarget, err := filepath.Abs(target) + if err != nil { + return err + } + tr := tar.NewReader(r) + for { + hdr, err := tr.Next() + if err == io.EOF { + return nil + } + if err != nil { + return fmt.Errorf("read tarball: %w", err) + } + rel := filepath.Clean(hdr.Name) + if rel == "." || rel == string(filepath.Separator) { + continue + } + if filepath.IsAbs(rel) || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || rel == ".." { + return fmt.Errorf("unsafe path in tarball: %q", hdr.Name) + } + dst := filepath.Join(absTarget, rel) + if dst != absTarget && !strings.HasPrefix(dst, absTarget+string(filepath.Separator)) { + return fmt.Errorf("unsafe path in tarball: %q", hdr.Name) + } + switch hdr.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(dst, os.FileMode(hdr.Mode)|0o755); err != nil { + return err + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return err + } + f, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)|0o600) + if err != nil { + return err + } + if _, err := io.Copy(f, tr); err != nil { + _ = f.Close() + return err + } + if err := f.Close(); err != nil { + return err + } + case tar.TypeSymlink: + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return err + } + link := hdr.Linkname + resolved := link + if !filepath.IsAbs(link) { + resolved = filepath.Join(filepath.Dir(dst), link) + } + resolved = filepath.Clean(resolved) + if resolved != absTarget && !strings.HasPrefix(resolved, absTarget+string(filepath.Separator)) { + return fmt.Errorf("unsafe symlink in tarball: %q -> %q", hdr.Name, hdr.Linkname) + } + if err := os.Symlink(hdr.Linkname, dst); err != nil { + return err + } + default: + // Hardlinks / device nodes / fifos: skip silently. Kernel + // module trees shouldn't need them. + } + } +} diff --git a/internal/kernelcat/fetch_test.go b/internal/kernelcat/fetch_test.go new file mode 100644 index 0000000..797ebba --- /dev/null +++ b/internal/kernelcat/fetch_test.go @@ -0,0 +1,198 @@ +package kernelcat + +import ( + "archive/tar" + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/klauspost/compress/zstd" +) + +// tarballFile describes one member of the test tarball. +type tarballFile struct { + name string + mode int64 + data []byte + link string // for symlinks + dir bool +} + +func buildTestTarball(t *testing.T, files []tarballFile) ([]byte, string) { + t.Helper() + var tarBuf bytes.Buffer + tw := tar.NewWriter(&tarBuf) + for _, f := range files { + hdr := &tar.Header{Name: f.name, Mode: f.mode} + switch { + case f.dir: + hdr.Typeflag = tar.TypeDir + hdr.Mode = 0o755 + case f.link != "": + hdr.Typeflag = tar.TypeSymlink + hdr.Linkname = f.link + default: + hdr.Typeflag = tar.TypeReg + hdr.Size = int64(len(f.data)) + if hdr.Mode == 0 { + hdr.Mode = 0o644 + } + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("tar WriteHeader: %v", err) + } + if hdr.Typeflag == tar.TypeReg { + if _, err := tw.Write(f.data); err != nil { + t.Fatalf("tar Write: %v", err) + } + } + } + if err := tw.Close(); err != nil { + t.Fatalf("tar Close: %v", err) + } + + var compressed bytes.Buffer + zw, err := zstd.NewWriter(&compressed) + if err != nil { + t.Fatalf("zstd NewWriter: %v", err) + } + if _, err := zw.Write(tarBuf.Bytes()); err != nil { + t.Fatalf("zstd Write: %v", err) + } + if err := zw.Close(); err != nil { + t.Fatalf("zstd Close: %v", err) + } + sum := sha256.Sum256(compressed.Bytes()) + return compressed.Bytes(), hex.EncodeToString(sum[:]) +} + +func serveTarball(t *testing.T, body []byte) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/octet-stream") + _, _ = w.Write(body) + })) + t.Cleanup(srv.Close) + return srv +} + +func TestFetchExtractsTarballAndWritesManifest(t *testing.T) { + t.Parallel() + body, sum := buildTestTarball(t, []tarballFile{ + {name: "vmlinux", data: []byte("kernel-bytes")}, + {name: "initrd.img", data: []byte("initrd-bytes")}, + {name: "modules", dir: true}, + {name: "modules/modules.dep", data: []byte("dep")}, + }) + srv := serveTarball(t, body) + + kernelsDir := t.TempDir() + stored, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{ + Name: "void-6.12", + Distro: "void", + Arch: "x86_64", + KernelVersion: "6.12.79_1", + TarballURL: srv.URL + "/pkg.tar.zst", + TarballSHA256: sum, + }) + if err != nil { + t.Fatalf("Fetch: %v", err) + } + if stored.Name != "void-6.12" || stored.Distro != "void" { + t.Fatalf("stored = %+v", stored) + } + if stored.SHA256 == "" { + t.Errorf("SHA256 not populated") + } + + for _, rel := range []string{"vmlinux", "initrd.img", "modules/modules.dep", "manifest.json"} { + if _, err := os.Stat(filepath.Join(kernelsDir, "void-6.12", rel)); err != nil { + t.Errorf("expected %s in catalog: %v", rel, err) + } + } +} + +func TestFetchRejectsShaMismatch(t *testing.T) { + t.Parallel() + body, _ := buildTestTarball(t, []tarballFile{ + {name: "vmlinux", data: []byte("k")}, + }) + srv := serveTarball(t, body) + + kernelsDir := t.TempDir() + _, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{ + Name: "void-6.12", + TarballURL: srv.URL + "/pkg.tar.zst", + TarballSHA256: "000000000000000000000000000000000000000000000000000000000000beef", + }) + if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") { + t.Fatalf("expected sha256 mismatch, got %v", err) + } + if _, statErr := os.Stat(filepath.Join(kernelsDir, "void-6.12")); !os.IsNotExist(statErr) { + t.Fatalf("target dir should be cleaned up on mismatch: %v", statErr) + } +} + +func TestFetchRejectsMissingKernel(t *testing.T) { + t.Parallel() + body, sum := buildTestTarball(t, []tarballFile{ + {name: "initrd.img", data: []byte("i")}, // no vmlinux + }) + srv := serveTarball(t, body) + kernelsDir := t.TempDir() + _, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{ + Name: "broken", + TarballURL: srv.URL + "/pkg.tar.zst", + TarballSHA256: sum, + }) + if err == nil || !strings.Contains(err.Error(), "missing vmlinux") { + t.Fatalf("expected missing vmlinux, got %v", err) + } +} + +func TestFetchRejectsPathTraversal(t *testing.T) { + t.Parallel() + body, sum := buildTestTarball(t, []tarballFile{ + {name: "vmlinux", data: []byte("k")}, + {name: "../escape", data: []byte("bad")}, + }) + srv := serveTarball(t, body) + kernelsDir := t.TempDir() + _, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{ + Name: "bad-tarball", + TarballURL: srv.URL + "/pkg.tar.zst", + TarballSHA256: sum, + }) + if err == nil || !strings.Contains(err.Error(), "unsafe path") { + t.Fatalf("expected unsafe path error, got %v", err) + } + escapePath := filepath.Join(filepath.Dir(kernelsDir), "escape") + if _, statErr := os.Stat(escapePath); !os.IsNotExist(statErr) { + t.Fatalf("traversal escape file should not exist: %v", statErr) + } +} + +func TestFetchRejectsHTTPError(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "nope", http.StatusNotFound) + })) + t.Cleanup(srv.Close) + + kernelsDir := t.TempDir() + _, err := Fetch(context.Background(), nil, kernelsDir, CatEntry{ + Name: "missing", + TarballURL: srv.URL + "/pkg.tar.zst", + TarballSHA256: "deadbeef", + }) + if err == nil || !strings.Contains(err.Error(), "404") { + t.Fatalf("expected HTTP 404, got %v", err) + } +}