// Package download contains transport-level primitives shared by // banger's catalog and update flows. Today it exposes one helper // (FetchVerified). When imagecat and kernelcat are next touched, their // duplicate copies of the same logic could fold into this package // without a behaviour change. package download import ( "context" "crypto/sha256" "encoding/hex" "fmt" "io" "net/http" "os" "strings" ) // FetchVerified streams `url` into `dstPath`, capped at maxBytes // bytes, hashing the body on the fly and refusing payloads whose // SHA256 doesn't match expectedSHA256. // // On any failure (HTTP error, ContentLength > cap, body exceeds // cap mid-stream, write error, sha256 mismatch) dstPath is removed // before returning so the caller doesn't have to disambiguate // "did we leave a partial file?". // // Returns the number of bytes written. The caller owns successful // cleanup of dstPath when it's done with the file. // // expectedSHA256 is matched case-insensitively. Pass an empty // client to use http.DefaultClient. func FetchVerified(ctx context.Context, client *http.Client, url, expectedSHA256 string, maxBytes int64, dstPath string) (int64, error) { if client == nil { client = http.DefaultClient } if maxBytes <= 0 { return 0, fmt.Errorf("FetchVerified: maxBytes must be > 0, got %d", maxBytes) } if strings.TrimSpace(expectedSHA256) == "" { return 0, fmt.Errorf("FetchVerified: expectedSHA256 is required") } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return 0, err } resp, err := client.Do(req) if err != nil { return 0, fmt.Errorf("fetch %s: %w", url, err) } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return 0, fmt.Errorf("fetch %s: HTTP %s", url, resp.Status) } if resp.ContentLength > maxBytes { return 0, fmt.Errorf("fetch %s: advertised %d bytes exceeds %d-byte cap", url, resp.ContentLength, maxBytes) } f, err := os.Create(dstPath) if err != nil { return 0, err } hasher := sha256.New() limited := io.LimitReader(resp.Body, maxBytes+1) n, copyErr := io.Copy(io.MultiWriter(f, hasher), limited) if closeErr := f.Close(); copyErr == nil && closeErr != nil { copyErr = closeErr } if copyErr != nil { _ = os.Remove(dstPath) return 0, fmt.Errorf("download %s: %w", url, copyErr) } if n > maxBytes { _ = os.Remove(dstPath) return 0, fmt.Errorf("download %s: body exceeded %d-byte cap before sha256 check", url, maxBytes) } got := hex.EncodeToString(hasher.Sum(nil)) if !strings.EqualFold(got, expectedSHA256) { _ = os.Remove(dstPath) return 0, fmt.Errorf("sha256 mismatch for %s: got %s, want %s", url, got, expectedSHA256) } return n, nil }