package download import ( "bytes" "context" "crypto/sha256" "encoding/hex" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "testing" ) func sha256Hex(b []byte) string { sum := sha256.Sum256(b) return hex.EncodeToString(sum[:]) } func serveBody(t *testing.T, body []byte) *httptest.Server { t.Helper() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/octet-stream") _, _ = w.Write(body) })) t.Cleanup(srv.Close) return srv } func TestFetchVerifiedHappyPath(t *testing.T) { body := bytes.Repeat([]byte("ok"), 1024) srv := serveBody(t, body) dst := filepath.Join(t.TempDir(), "out") n, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex(body), 1<<20, dst) if err != nil { t.Fatalf("FetchVerified: %v", err) } if n != int64(len(body)) { t.Fatalf("n = %d, want %d", n, len(body)) } got, _ := os.ReadFile(dst) if !bytes.Equal(got, body) { t.Fatalf("file content differs from served body") } } func TestFetchVerifiedRejectsHashMismatch(t *testing.T) { body := []byte("payload") srv := serveBody(t, body) dst := filepath.Join(t.TempDir(), "out") wrongHash := sha256Hex([]byte("other")) _, err := FetchVerified(context.Background(), srv.Client(), srv.URL, wrongHash, 1<<10, dst) if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") { t.Fatalf("err = %v, want sha256 mismatch", err) } if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) { t.Fatalf("partial file should be removed; stat err = %v", statErr) } } func TestFetchVerifiedRejectsContentLengthOverCap(t *testing.T) { body := bytes.Repeat([]byte("x"), 2048) srv := serveBody(t, body) dst := filepath.Join(t.TempDir(), "out") _, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex(body), 64, dst) if err == nil || !strings.Contains(err.Error(), "cap") { t.Fatalf("err = %v, want cap rejection", err) } if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) { t.Fatalf("dst created despite oversize Content-Length: %v", statErr) } } func TestFetchVerifiedRejectsLyingContentLength(t *testing.T) { // Server returns no Content-Length but a body bigger than cap. body := bytes.Repeat([]byte("y"), 2048) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Force chunked: don't set Content-Length. _, _ = w.Write(body) })) t.Cleanup(srv.Close) dst := filepath.Join(t.TempDir(), "out") _, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex(body), 64, dst) if err == nil || !strings.Contains(err.Error(), "cap") { t.Fatalf("err = %v, want cap rejection on lying server", err) } if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) { t.Fatalf("partial file from lying server should be removed; stat err = %v", statErr) } } func TestFetchVerifiedRejectsHTTPError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "missing", http.StatusNotFound) })) t.Cleanup(srv.Close) dst := filepath.Join(t.TempDir(), "out") _, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex([]byte{}), 1<<10, dst) if err == nil || !strings.Contains(err.Error(), "404") { t.Fatalf("err = %v, want 404 mention", err) } } func TestFetchVerifiedRejectsEmptyExpectedSHA(t *testing.T) { srv := serveBody(t, []byte("body")) dst := filepath.Join(t.TempDir(), "out") _, err := FetchVerified(context.Background(), srv.Client(), srv.URL, "", 1<<10, dst) if err == nil || !strings.Contains(err.Error(), "expectedSHA256") { t.Fatalf("err = %v, want empty-sha rejection", err) } } func TestFetchVerifiedRejectsZeroMaxBytes(t *testing.T) { srv := serveBody(t, []byte("body")) dst := filepath.Join(t.TempDir(), "out") _, err := FetchVerified(context.Background(), srv.Client(), srv.URL, sha256Hex([]byte("body")), 0, dst) if err == nil || !strings.Contains(err.Error(), "maxBytes") { t.Fatalf("err = %v, want maxBytes rejection", err) } }