package kernelcat import ( "archive/tar" "context" "crypto/sha256" "encoding/hex" "fmt" "io" "net/http" "os" "path/filepath" "strings" "time" "github.com/klauspost/compress/zstd" ) // MaxFetchedKernelBytes caps the compressed kernel-tarball download. // Without this the previous flow streamed straight into the tar+zstd // extractor and only verified SHA256 afterwards, so a malicious or // compromised mirror could fill the host disk before the hash check // fired. Now we stage to a temp file under targetDir, hash on the // way in, and refuse to decompress on hash mismatch — worst-case // disk use is bounded by this cap. Override per-call by setting this // var before invoking Fetch. var MaxFetchedKernelBytes int64 = 8 << 30 // 8 GiB // Fetch downloads the tarball for entry, verifies its SHA256, extracts it // into //, and writes a manifest. On failure it // removes the partially-populated target directory. // // The tarball is expected to be a tar+zstd archive whose root contains // vmlinux and optionally initrd.img and/or a modules/ directory. Path // traversal entries (..) and absolute-path members are rejected. func Fetch(ctx context.Context, client *http.Client, kernelsDir string, entry CatEntry) (Entry, error) { if err := ValidateName(entry.Name); err != nil { return Entry{}, err } if strings.TrimSpace(entry.TarballURL) == "" { return Entry{}, fmt.Errorf("catalog entry %q has no tarball URL", entry.Name) } if strings.TrimSpace(entry.TarballSHA256) == "" { return Entry{}, fmt.Errorf("catalog entry %q has no tarball sha256", entry.Name) } if client == nil { client = http.DefaultClient } if err := DeleteLocal(kernelsDir, entry.Name); err != nil { return Entry{}, fmt.Errorf("clear prior catalog entry: %w", err) } targetDir := EntryDir(kernelsDir, entry.Name) if err := os.MkdirAll(targetDir, 0o755); err != nil { return Entry{}, err } cleanup := func() { _ = os.RemoveAll(targetDir) } req, err := http.NewRequestWithContext(ctx, http.MethodGet, entry.TarballURL, nil) if err != nil { cleanup() return Entry{}, err } resp, err := client.Do(req) if err != nil { cleanup() return Entry{}, fmt.Errorf("fetch %s: %w", entry.TarballURL, err) } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { cleanup() return Entry{}, fmt.Errorf("fetch %s: HTTP %s", entry.TarballURL, resp.Status) } if resp.ContentLength > MaxFetchedKernelBytes { cleanup() return Entry{}, fmt.Errorf("tarball advertised %d bytes, exceeds %d-byte cap", resp.ContentLength, MaxFetchedKernelBytes) } // Stage compressed download to a temp file first so we can verify // SHA256 BEFORE decompressing or extracting. Cap reads to // MaxFetchedKernelBytes+1 — anything larger is refused. tmp, err := os.CreateTemp(targetDir, "banger-kernel-*.tar.zst") if err != nil { cleanup() return Entry{}, fmt.Errorf("create staging file: %w", err) } tmpPath := tmp.Name() defer os.Remove(tmpPath) hasher := sha256.New() limited := io.LimitReader(resp.Body, MaxFetchedKernelBytes+1) n, copyErr := io.Copy(io.MultiWriter(tmp, hasher), limited) if closeErr := tmp.Close(); copyErr == nil && closeErr != nil { copyErr = closeErr } if copyErr != nil { cleanup() return Entry{}, fmt.Errorf("download tarball: %w", copyErr) } if n > MaxFetchedKernelBytes { cleanup() return Entry{}, fmt.Errorf("tarball exceeded %d-byte cap before sha256 check", MaxFetchedKernelBytes) } got := hex.EncodeToString(hasher.Sum(nil)) if !strings.EqualFold(got, entry.TarballSHA256) { cleanup() return Entry{}, fmt.Errorf("tarball sha256 mismatch: got %s, want %s", got, entry.TarballSHA256) } src, err := os.Open(tmpPath) if err != nil { cleanup() return Entry{}, fmt.Errorf("reopen staged tarball: %w", err) } defer src.Close() zr, err := zstd.NewReader(src) if err != nil { cleanup() return Entry{}, fmt.Errorf("init zstd: %w", err) } defer zr.Close() if err := extractTar(zr, targetDir); err != nil { cleanup() return Entry{}, err } kernelPath := filepath.Join(targetDir, kernelFilename) if _, err := os.Stat(kernelPath); err != nil { cleanup() return Entry{}, fmt.Errorf("tarball missing %s: %w", kernelFilename, err) } kernelSum, err := SumFile(kernelPath) if err != nil { cleanup() return Entry{}, err } stored := Entry{ Name: entry.Name, Distro: entry.Distro, Arch: entry.Arch, KernelVersion: entry.KernelVersion, SHA256: kernelSum, Source: "pull:" + entry.TarballURL, ImportedAt: time.Now().UTC(), } if err := WriteLocal(kernelsDir, stored); err != nil { cleanup() return Entry{}, err } return ReadLocal(kernelsDir, entry.Name) } // extractTar writes each regular file / dir / safe symlink from r into // target, refusing any member whose normalised path would escape target. func extractTar(r io.Reader, target string) error { absTarget, err := filepath.Abs(target) if err != nil { return err } tr := tar.NewReader(r) for { hdr, err := tr.Next() if err == io.EOF { return nil } if err != nil { return fmt.Errorf("read tarball: %w", err) } rel := filepath.Clean(hdr.Name) if rel == "." || rel == string(filepath.Separator) { continue } if filepath.IsAbs(rel) || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || rel == ".." { return fmt.Errorf("unsafe path in tarball: %q", hdr.Name) } dst := filepath.Join(absTarget, rel) if dst != absTarget && !strings.HasPrefix(dst, absTarget+string(filepath.Separator)) { return fmt.Errorf("unsafe path in tarball: %q", hdr.Name) } switch hdr.Typeflag { case tar.TypeDir: if err := os.MkdirAll(dst, os.FileMode(hdr.Mode)|0o755); err != nil { return err } case tar.TypeReg: if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { return err } f, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)|0o600) if err != nil { return err } if _, err := io.Copy(f, tr); err != nil { _ = f.Close() return err } if err := f.Close(); err != nil { return err } case tar.TypeSymlink: if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { return err } // Absolute targets are interpreted at runtime against the // eventual rootfs (`/` inside the VM), so they're rooted // inside absTarget by construction. Only relative targets // need an escape check at write time. if !filepath.IsAbs(hdr.Linkname) { resolved := filepath.Clean(filepath.Join(filepath.Dir(dst), hdr.Linkname)) if resolved != absTarget && !strings.HasPrefix(resolved, absTarget+string(filepath.Separator)) { return fmt.Errorf("unsafe symlink in tarball: %q -> %q", hdr.Name, hdr.Linkname) } } if err := os.Symlink(hdr.Linkname, dst); err != nil { return err } default: // Hardlinks / device nodes / fifos: skip silently. Kernel // module trees shouldn't need them. } } }