diff --git a/internal/download/verified.go b/internal/download/verified.go new file mode 100644 index 0000000..7f51743 --- /dev/null +++ b/internal/download/verified.go @@ -0,0 +1,86 @@ +// Package download contains transport-level primitives shared by +// banger's catalog and update flows. Today it exposes one helper +// (FetchVerified). When imagecat and kernelcat are next touched, their +// duplicate copies of the same logic could fold into this package +// without a behaviour change. +package download + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "os" + "strings" +) + +// FetchVerified streams `url` into `dstPath`, capped at maxBytes +// bytes, hashing the body on the fly and refusing payloads whose +// SHA256 doesn't match expectedSHA256. +// +// On any failure (HTTP error, ContentLength > cap, body exceeds +// cap mid-stream, write error, sha256 mismatch) dstPath is removed +// before returning so the caller doesn't have to disambiguate +// "did we leave a partial file?". +// +// Returns the number of bytes written. The caller owns successful +// cleanup of dstPath when it's done with the file. +// +// expectedSHA256 is matched case-insensitively. Pass an empty +// client to use http.DefaultClient. +func FetchVerified(ctx context.Context, client *http.Client, url, expectedSHA256 string, maxBytes int64, dstPath string) (int64, error) { + if client == nil { + client = http.DefaultClient + } + if maxBytes <= 0 { + return 0, fmt.Errorf("FetchVerified: maxBytes must be > 0, got %d", maxBytes) + } + if strings.TrimSpace(expectedSHA256) == "" { + return 0, fmt.Errorf("FetchVerified: expectedSHA256 is required") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return 0, err + } + resp, err := client.Do(req) + if err != nil { + return 0, fmt.Errorf("fetch %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return 0, fmt.Errorf("fetch %s: HTTP %s", url, resp.Status) + } + if resp.ContentLength > maxBytes { + return 0, fmt.Errorf("fetch %s: advertised %d bytes exceeds %d-byte cap", url, resp.ContentLength, maxBytes) + } + + f, err := os.Create(dstPath) + if err != nil { + return 0, err + } + + hasher := sha256.New() + limited := io.LimitReader(resp.Body, maxBytes+1) + n, copyErr := io.Copy(io.MultiWriter(f, hasher), limited) + if closeErr := f.Close(); copyErr == nil && closeErr != nil { + copyErr = closeErr + } + if copyErr != nil { + _ = os.Remove(dstPath) + return 0, fmt.Errorf("download %s: %w", url, copyErr) + } + if n > maxBytes { + _ = os.Remove(dstPath) + return 0, fmt.Errorf("download %s: body exceeded %d-byte cap before sha256 check", url, maxBytes) + } + + got := hex.EncodeToString(hasher.Sum(nil)) + if !strings.EqualFold(got, expectedSHA256) { + _ = os.Remove(dstPath) + return 0, fmt.Errorf("sha256 mismatch for %s: got %s, want %s", url, got, expectedSHA256) + } + return n, nil +} diff --git a/internal/download/verified_test.go b/internal/download/verified_test.go new file mode 100644 index 0000000..5c9ab0b --- /dev/null +++ b/internal/download/verified_test.go @@ -0,0 +1,126 @@ +package download + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +func sha256Hex(b []byte) string { + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]) +} + +func serveBody(t *testing.T, body []byte) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/octet-stream") + _, _ = w.Write(body) + })) + t.Cleanup(srv.Close) + return srv +} + +func TestFetchVerifiedHappyPath(t *testing.T) { + body := bytes.Repeat([]byte("ok"), 1024) + srv := serveBody(t, body) + dst := filepath.Join(t.TempDir(), "out") + + n, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex(body), 1<<20, dst) + if err != nil { + t.Fatalf("FetchVerified: %v", err) + } + if n != int64(len(body)) { + t.Fatalf("n = %d, want %d", n, len(body)) + } + got, _ := os.ReadFile(dst) + if !bytes.Equal(got, body) { + t.Fatalf("file content differs from served body") + } +} + +func TestFetchVerifiedRejectsHashMismatch(t *testing.T) { + body := []byte("payload") + srv := serveBody(t, body) + dst := filepath.Join(t.TempDir(), "out") + wrongHash := sha256Hex([]byte("other")) + + _, err := FetchVerified(context.Background(), srv.Client(), srv.URL, wrongHash, 1<<10, dst) + if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") { + t.Fatalf("err = %v, want sha256 mismatch", err) + } + if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) { + t.Fatalf("partial file should be removed; stat err = %v", statErr) + } +} + +func TestFetchVerifiedRejectsContentLengthOverCap(t *testing.T) { + body := bytes.Repeat([]byte("x"), 2048) + srv := serveBody(t, body) + dst := filepath.Join(t.TempDir(), "out") + + _, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex(body), 64, dst) + if err == nil || !strings.Contains(err.Error(), "cap") { + t.Fatalf("err = %v, want cap rejection", err) + } + if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) { + t.Fatalf("dst created despite oversize Content-Length: %v", statErr) + } +} + +func TestFetchVerifiedRejectsLyingContentLength(t *testing.T) { + // Server returns no Content-Length but a body bigger than cap. + body := bytes.Repeat([]byte("y"), 2048) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Force chunked: don't set Content-Length. + _, _ = w.Write(body) + })) + t.Cleanup(srv.Close) + dst := filepath.Join(t.TempDir(), "out") + + _, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex(body), 64, dst) + if err == nil || !strings.Contains(err.Error(), "cap") { + t.Fatalf("err = %v, want cap rejection on lying server", err) + } + if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) { + t.Fatalf("partial file from lying server should be removed; stat err = %v", statErr) + } +} + +func TestFetchVerifiedRejectsHTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "missing", http.StatusNotFound) + })) + t.Cleanup(srv.Close) + dst := filepath.Join(t.TempDir(), "out") + + _, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex([]byte{}), 1<<10, dst) + if err == nil || !strings.Contains(err.Error(), "404") { + t.Fatalf("err = %v, want 404 mention", err) + } +} + +func TestFetchVerifiedRejectsEmptyExpectedSHA(t *testing.T) { + srv := serveBody(t, []byte("body")) + dst := filepath.Join(t.TempDir(), "out") + _, err := FetchVerified(context.Background(), srv.Client(), srv.URL, "", 1<<10, dst) + if err == nil || !strings.Contains(err.Error(), "expectedSHA256") { + t.Fatalf("err = %v, want empty-sha rejection", err) + } +} + +func TestFetchVerifiedRejectsZeroMaxBytes(t *testing.T) { + srv := serveBody(t, []byte("body")) + dst := filepath.Join(t.TempDir(), "out") + _, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex([]byte("body")), 0, dst) + if err == nil || !strings.Contains(err.Error(), "maxBytes") { + t.Fatalf("err = %v, want maxBytes rejection", err) + } +}