download: shared FetchVerified helper for capped + hashed downloads

imagecat.Fetch and kernelcat.Fetch each implement the same pattern:
HTTP GET with a Content-Length pre-check, an io.LimitReader cap on
the body, on-the-fly sha256 hashing, and refusal on either the cap
trip or a hash mismatch. The about-to-arrive `banger update` flow
makes a third caller, which is the right number to factor.

  * internal/download.FetchVerified(ctx, client, url, expectedSHA256,
    maxBytes, dstPath): streams the body to dstPath through a
    sha256 hasher, capped at maxBytes+1 bytes so an oversize body
    is detected before the hash check fires. On any failure
    (HTTP error, ContentLength > cap, body exceeds cap, write
    error, hash mismatch) the partial file is removed before
    returning so callers don't have to disambiguate "did we leave
    bytes on disk?".

Imagecat and kernelcat are NOT migrated to this helper in this
commit — they each have their own destination-dir layout and
post-verify decompress/extract steps that don't fit a one-size
helper. Lift them later if it stays clean; for now the helper
is sized for the updater's "fetch tarball + sha256SUMS" need.

Tests cover happy path, hash mismatch, advertised Content-Length
over cap, lying server (chunked, no Content-Length, but oversize
body), HTTP non-2xx, and the two arg-validation rejections (empty
expected hash, non-positive maxBytes).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Thales Maciel 2026-04-28 18:44:27 -03:00
parent fa3a7a3e31
commit abd5d6f5ab
No known key found for this signature in database
GPG key ID: 33112E6833C34679
2 changed files with 212 additions and 0 deletions

View file

@ -0,0 +1,86 @@
// Package download contains transport-level primitives shared by
// banger's catalog and update flows. Today it exposes one helper
// (FetchVerified). When imagecat and kernelcat are next touched, their
// duplicate copies of the same logic could fold into this package
// without a behaviour change.
package download
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"os"
"strings"
)
// FetchVerified streams `url` into `dstPath`, capped at maxBytes
// bytes, hashing the body on the fly and refusing payloads whose
// SHA256 doesn't match expectedSHA256.
//
// On any failure (HTTP error, ContentLength > cap, body exceeds
// cap mid-stream, write error, sha256 mismatch) dstPath is removed
// before returning so the caller doesn't have to disambiguate
// "did we leave a partial file?".
//
// Returns the number of bytes written. The caller owns successful
// cleanup of dstPath when it's done with the file.
//
// expectedSHA256 is matched case-insensitively. Pass an empty
// client to use http.DefaultClient.
func FetchVerified(ctx context.Context, client *http.Client, url, expectedSHA256 string, maxBytes int64, dstPath string) (int64, error) {
if client == nil {
client = http.DefaultClient
}
if maxBytes <= 0 {
return 0, fmt.Errorf("FetchVerified: maxBytes must be > 0, got %d", maxBytes)
}
if strings.TrimSpace(expectedSHA256) == "" {
return 0, fmt.Errorf("FetchVerified: expectedSHA256 is required")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return 0, err
}
resp, err := client.Do(req)
if err != nil {
return 0, fmt.Errorf("fetch %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return 0, fmt.Errorf("fetch %s: HTTP %s", url, resp.Status)
}
if resp.ContentLength > maxBytes {
return 0, fmt.Errorf("fetch %s: advertised %d bytes exceeds %d-byte cap", url, resp.ContentLength, maxBytes)
}
f, err := os.Create(dstPath)
if err != nil {
return 0, err
}
hasher := sha256.New()
limited := io.LimitReader(resp.Body, maxBytes+1)
n, copyErr := io.Copy(io.MultiWriter(f, hasher), limited)
if closeErr := f.Close(); copyErr == nil && closeErr != nil {
copyErr = closeErr
}
if copyErr != nil {
_ = os.Remove(dstPath)
return 0, fmt.Errorf("download %s: %w", url, copyErr)
}
if n > maxBytes {
_ = os.Remove(dstPath)
return 0, fmt.Errorf("download %s: body exceeded %d-byte cap before sha256 check", url, maxBytes)
}
got := hex.EncodeToString(hasher.Sum(nil))
if !strings.EqualFold(got, expectedSHA256) {
_ = os.Remove(dstPath)
return 0, fmt.Errorf("sha256 mismatch for %s: got %s, want %s", url, got, expectedSHA256)
}
return n, nil
}

View file

@ -0,0 +1,126 @@
package download
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
func sha256Hex(b []byte) string {
sum := sha256.Sum256(b)
return hex.EncodeToString(sum[:])
}
func serveBody(t *testing.T, body []byte) *httptest.Server {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(body)
}))
t.Cleanup(srv.Close)
return srv
}
func TestFetchVerifiedHappyPath(t *testing.T) {
body := bytes.Repeat([]byte("ok"), 1024)
srv := serveBody(t, body)
dst := filepath.Join(t.TempDir(), "out")
n, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex(body), 1<<20, dst)
if err != nil {
t.Fatalf("FetchVerified: %v", err)
}
if n != int64(len(body)) {
t.Fatalf("n = %d, want %d", n, len(body))
}
got, _ := os.ReadFile(dst)
if !bytes.Equal(got, body) {
t.Fatalf("file content differs from served body")
}
}
func TestFetchVerifiedRejectsHashMismatch(t *testing.T) {
body := []byte("payload")
srv := serveBody(t, body)
dst := filepath.Join(t.TempDir(), "out")
wrongHash := sha256Hex([]byte("other"))
_, err := FetchVerified(context.Background(), srv.Client(), srv.URL, wrongHash, 1<<10, dst)
if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") {
t.Fatalf("err = %v, want sha256 mismatch", err)
}
if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) {
t.Fatalf("partial file should be removed; stat err = %v", statErr)
}
}
func TestFetchVerifiedRejectsContentLengthOverCap(t *testing.T) {
body := bytes.Repeat([]byte("x"), 2048)
srv := serveBody(t, body)
dst := filepath.Join(t.TempDir(), "out")
_, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex(body), 64, dst)
if err == nil || !strings.Contains(err.Error(), "cap") {
t.Fatalf("err = %v, want cap rejection", err)
}
if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) {
t.Fatalf("dst created despite oversize Content-Length: %v", statErr)
}
}
func TestFetchVerifiedRejectsLyingContentLength(t *testing.T) {
// Server returns no Content-Length but a body bigger than cap.
body := bytes.Repeat([]byte("y"), 2048)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Force chunked: don't set Content-Length.
_, _ = w.Write(body)
}))
t.Cleanup(srv.Close)
dst := filepath.Join(t.TempDir(), "out")
_, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex(body), 64, dst)
if err == nil || !strings.Contains(err.Error(), "cap") {
t.Fatalf("err = %v, want cap rejection on lying server", err)
}
if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) {
t.Fatalf("partial file from lying server should be removed; stat err = %v", statErr)
}
}
func TestFetchVerifiedRejectsHTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "missing", http.StatusNotFound)
}))
t.Cleanup(srv.Close)
dst := filepath.Join(t.TempDir(), "out")
_, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex([]byte{}), 1<<10, dst)
if err == nil || !strings.Contains(err.Error(), "404") {
t.Fatalf("err = %v, want 404 mention", err)
}
}
func TestFetchVerifiedRejectsEmptyExpectedSHA(t *testing.T) {
srv := serveBody(t, []byte("body"))
dst := filepath.Join(t.TempDir(), "out")
_, err := FetchVerified(context.Background(), srv.Client(), srv.URL, "", 1<<10, dst)
if err == nil || !strings.Contains(err.Error(), "expectedSHA256") {
t.Fatalf("err = %v, want empty-sha rejection", err)
}
}
func TestFetchVerifiedRejectsZeroMaxBytes(t *testing.T) {
srv := serveBody(t, []byte("body"))
dst := filepath.Join(t.TempDir(), "out")
_, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex([]byte("body")), 0, dst)
if err == nil || !strings.Contains(err.Error(), "maxBytes") {
t.Fatalf("err = %v, want maxBytes rejection", err)
}
}