download: shared FetchVerified helper for capped + hashed downloads
imagecat.Fetch and kernelcat.Fetch each implement the same pattern:
HTTP GET with a Content-Length pre-check, an io.LimitReader cap on
the body, on-the-fly sha256 hashing, and refusal on either the cap
trip or a hash mismatch. The about-to-arrive `banger update` flow
makes a third caller, which is the right number to factor.
* internal/download.FetchVerified(ctx, client, url, expectedSHA256,
maxBytes, dstPath): streams the body to dstPath through a
sha256 hasher, capped at maxBytes+1 bytes so an oversize body
is detected before the hash check fires. On any failure
(HTTP error, ContentLength > cap, body exceeds cap, write
error, hash mismatch) the partial file is removed before
returning so callers don't have to disambiguate "did we leave
bytes on disk?".
Imagecat and kernelcat are NOT migrated to this helper in this
commit — they each have their own destination-dir layout and
post-verify decompress/extract steps that don't fit a one-size
helper. Lift them later if it stays clean; for now the helper
is sized for the updater's "fetch tarball + sha256SUMS" need.
Tests cover happy path, hash mismatch, advertised Content-Length
over cap, lying server (chunked, no Content-Length, but oversize
body), HTTP non-2xx, and the two arg-validation rejections (empty
expected hash, non-positive maxBytes).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
fa3a7a3e31
commit
abd5d6f5ab
2 changed files with 212 additions and 0 deletions
86
internal/download/verified.go
Normal file
86
internal/download/verified.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
// Package download contains transport-level primitives shared by
|
||||
// banger's catalog and update flows. Today it exposes one helper
|
||||
// (FetchVerified). When imagecat and kernelcat are next touched, their
|
||||
// duplicate copies of the same logic could fold into this package
|
||||
// without a behaviour change.
|
||||
package download
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FetchVerified streams `url` into `dstPath`, capped at maxBytes
|
||||
// bytes, hashing the body on the fly and refusing payloads whose
|
||||
// SHA256 doesn't match expectedSHA256.
|
||||
//
|
||||
// On any failure (HTTP error, ContentLength > cap, body exceeds
|
||||
// cap mid-stream, write error, sha256 mismatch) dstPath is removed
|
||||
// before returning so the caller doesn't have to disambiguate
|
||||
// "did we leave a partial file?".
|
||||
//
|
||||
// Returns the number of bytes written. The caller owns successful
|
||||
// cleanup of dstPath when it's done with the file.
|
||||
//
|
||||
// expectedSHA256 is matched case-insensitively. Pass an empty
|
||||
// client to use http.DefaultClient.
|
||||
func FetchVerified(ctx context.Context, client *http.Client, url, expectedSHA256 string, maxBytes int64, dstPath string) (int64, error) {
|
||||
if client == nil {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
if maxBytes <= 0 {
|
||||
return 0, fmt.Errorf("FetchVerified: maxBytes must be > 0, got %d", maxBytes)
|
||||
}
|
||||
if strings.TrimSpace(expectedSHA256) == "" {
|
||||
return 0, fmt.Errorf("FetchVerified: expectedSHA256 is required")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("fetch %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return 0, fmt.Errorf("fetch %s: HTTP %s", url, resp.Status)
|
||||
}
|
||||
if resp.ContentLength > maxBytes {
|
||||
return 0, fmt.Errorf("fetch %s: advertised %d bytes exceeds %d-byte cap", url, resp.ContentLength, maxBytes)
|
||||
}
|
||||
|
||||
f, err := os.Create(dstPath)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
hasher := sha256.New()
|
||||
limited := io.LimitReader(resp.Body, maxBytes+1)
|
||||
n, copyErr := io.Copy(io.MultiWriter(f, hasher), limited)
|
||||
if closeErr := f.Close(); copyErr == nil && closeErr != nil {
|
||||
copyErr = closeErr
|
||||
}
|
||||
if copyErr != nil {
|
||||
_ = os.Remove(dstPath)
|
||||
return 0, fmt.Errorf("download %s: %w", url, copyErr)
|
||||
}
|
||||
if n > maxBytes {
|
||||
_ = os.Remove(dstPath)
|
||||
return 0, fmt.Errorf("download %s: body exceeded %d-byte cap before sha256 check", url, maxBytes)
|
||||
}
|
||||
|
||||
got := hex.EncodeToString(hasher.Sum(nil))
|
||||
if !strings.EqualFold(got, expectedSHA256) {
|
||||
_ = os.Remove(dstPath)
|
||||
return 0, fmt.Errorf("sha256 mismatch for %s: got %s, want %s", url, got, expectedSHA256)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
126
internal/download/verified_test.go
Normal file
126
internal/download/verified_test.go
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
package download
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func sha256Hex(b []byte) string {
|
||||
sum := sha256.Sum256(b)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func serveBody(t *testing.T, body []byte) *httptest.Server {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(body)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestFetchVerifiedHappyPath(t *testing.T) {
|
||||
body := bytes.Repeat([]byte("ok"), 1024)
|
||||
srv := serveBody(t, body)
|
||||
dst := filepath.Join(t.TempDir(), "out")
|
||||
|
||||
n, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex(body), 1<<20, dst)
|
||||
if err != nil {
|
||||
t.Fatalf("FetchVerified: %v", err)
|
||||
}
|
||||
if n != int64(len(body)) {
|
||||
t.Fatalf("n = %d, want %d", n, len(body))
|
||||
}
|
||||
got, _ := os.ReadFile(dst)
|
||||
if !bytes.Equal(got, body) {
|
||||
t.Fatalf("file content differs from served body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchVerifiedRejectsHashMismatch(t *testing.T) {
|
||||
body := []byte("payload")
|
||||
srv := serveBody(t, body)
|
||||
dst := filepath.Join(t.TempDir(), "out")
|
||||
wrongHash := sha256Hex([]byte("other"))
|
||||
|
||||
_, err := FetchVerified(context.Background(), srv.Client(), srv.URL, wrongHash, 1<<10, dst)
|
||||
if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") {
|
||||
t.Fatalf("err = %v, want sha256 mismatch", err)
|
||||
}
|
||||
if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("partial file should be removed; stat err = %v", statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchVerifiedRejectsContentLengthOverCap(t *testing.T) {
|
||||
body := bytes.Repeat([]byte("x"), 2048)
|
||||
srv := serveBody(t, body)
|
||||
dst := filepath.Join(t.TempDir(), "out")
|
||||
|
||||
_, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex(body), 64, dst)
|
||||
if err == nil || !strings.Contains(err.Error(), "cap") {
|
||||
t.Fatalf("err = %v, want cap rejection", err)
|
||||
}
|
||||
if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("dst created despite oversize Content-Length: %v", statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchVerifiedRejectsLyingContentLength(t *testing.T) {
|
||||
// Server returns no Content-Length but a body bigger than cap.
|
||||
body := bytes.Repeat([]byte("y"), 2048)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Force chunked: don't set Content-Length.
|
||||
_, _ = w.Write(body)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
dst := filepath.Join(t.TempDir(), "out")
|
||||
|
||||
_, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex(body), 64, dst)
|
||||
if err == nil || !strings.Contains(err.Error(), "cap") {
|
||||
t.Fatalf("err = %v, want cap rejection on lying server", err)
|
||||
}
|
||||
if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("partial file from lying server should be removed; stat err = %v", statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchVerifiedRejectsHTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "missing", http.StatusNotFound)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
dst := filepath.Join(t.TempDir(), "out")
|
||||
|
||||
_, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex([]byte{}), 1<<10, dst)
|
||||
if err == nil || !strings.Contains(err.Error(), "404") {
|
||||
t.Fatalf("err = %v, want 404 mention", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchVerifiedRejectsEmptyExpectedSHA(t *testing.T) {
|
||||
srv := serveBody(t, []byte("body"))
|
||||
dst := filepath.Join(t.TempDir(), "out")
|
||||
_, err := FetchVerified(context.Background(), srv.Client(), srv.URL, "", 1<<10, dst)
|
||||
if err == nil || !strings.Contains(err.Error(), "expectedSHA256") {
|
||||
t.Fatalf("err = %v, want empty-sha rejection", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchVerifiedRejectsZeroMaxBytes(t *testing.T) {
|
||||
srv := serveBody(t, []byte("body"))
|
||||
dst := filepath.Join(t.TempDir(), "out")
|
||||
_, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex([]byte("body")), 0, dst)
|
||||
if err == nil || !strings.Contains(err.Error(), "maxBytes") {
|
||||
t.Fatalf("err = %v, want maxBytes rejection", err)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue