package imagecat import ( "archive/tar" "bytes" "context" "crypto/sha256" "encoding/hex" "encoding/json" "io" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "testing" "github.com/klauspost/compress/zstd" ) // makeBundle builds a valid .tar.zst bundle with the given manifest // and rootfs bytes. Returns the bundle bytes and their sha256 hex. func makeBundle(t *testing.T, manifest Manifest, rootfs []byte) ([]byte, string) { t.Helper() var rawTar bytes.Buffer tw := tar.NewWriter(&rawTar) manifestJSON, err := json.Marshal(manifest) if err != nil { t.Fatal(err) } entries := []struct { name string data []byte }{ {RootfsFilename, rootfs}, {ManifestFilename, manifestJSON}, } for _, e := range entries { if err := tw.WriteHeader(&tar.Header{ Name: e.name, Size: int64(len(e.data)), Mode: 0o644, Typeflag: tar.TypeReg, }); err != nil { t.Fatal(err) } if _, err := tw.Write(e.data); err != nil { t.Fatal(err) } } if err := tw.Close(); err != nil { t.Fatal(err) } var zstBuf bytes.Buffer zw, err := zstd.NewWriter(&zstBuf) if err != nil { t.Fatal(err) } if _, err := io.Copy(zw, &rawTar); err != nil { t.Fatal(err) } if err := zw.Close(); err != nil { t.Fatal(err) } sum := sha256.Sum256(zstBuf.Bytes()) return zstBuf.Bytes(), hex.EncodeToString(sum[:]) } func serveBundle(t *testing.T, payload []byte) *httptest.Server { t.Helper() return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/octet-stream") _, _ = w.Write(payload) })) } func TestFetchHappyPath(t *testing.T) { manifest := Manifest{ Name: "debian-bookworm", Distro: "debian", Arch: "x86_64", KernelRef: "generic-6.12", } rootfs := []byte("not-actually-an-ext4-but-that's-fine-for-the-test") bundle, sum := makeBundle(t, manifest, rootfs) srv := serveBundle(t, bundle) t.Cleanup(srv.Close) dest := t.TempDir() got, err := Fetch(context.Background(), srv.Client(), dest, CatEntry{ Name: "debian-bookworm", TarballURL: srv.URL + "/bundle.tar.zst", TarballSHA256: sum, }) if err != nil { t.Fatalf("Fetch: %v", err) } if got.Name != "debian-bookworm" || got.KernelRef != "generic-6.12" || got.Distro != "debian" { t.Fatalf("manifest = %+v", got) } if b, err := os.ReadFile(filepath.Join(dest, RootfsFilename)); err != nil || !bytes.Equal(b, rootfs) { t.Fatalf("rootfs content mismatch: err=%v, %q", err, b) } if _, err := os.Stat(filepath.Join(dest, ManifestFilename)); err != nil { t.Fatalf("manifest missing: %v", err) } } func TestFetchRejectsSHA256Mismatch(t *testing.T) { manifest := Manifest{Name: "debian-bookworm"} bundle, _ := makeBundle(t, manifest, []byte("abc")) srv := serveBundle(t, bundle) t.Cleanup(srv.Close) dest := t.TempDir() _, err := Fetch(context.Background(), srv.Client(), dest, CatEntry{ Name: "debian-bookworm", TarballURL: srv.URL + "/bundle.tar.zst", TarballSHA256: "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef", }) if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") { t.Fatalf("want sha256 mismatch error, got %v", err) } // Cleanup: dest should not contain partial files. if _, err := os.Stat(filepath.Join(dest, RootfsFilename)); !os.IsNotExist(err) { t.Fatalf("rootfs should be cleaned up on sha256 failure, got %v", err) } if _, err := os.Stat(filepath.Join(dest, ManifestFilename)); !os.IsNotExist(err) { t.Fatalf("manifest should be cleaned up on sha256 failure, got %v", err) } } func TestFetchRejectsUnexpectedTarEntry(t *testing.T) { // Hand-roll a bundle with a third, disallowed entry. var rawTar bytes.Buffer tw := tar.NewWriter(&rawTar) for _, e := range []struct{ name, data string }{ {RootfsFilename, "rootfs"}, {ManifestFilename, `{"name":"x"}`}, {"extra", "should be rejected"}, } { if err := tw.WriteHeader(&tar.Header{ Name: e.name, Size: int64(len(e.data)), Mode: 0o644, Typeflag: tar.TypeReg, }); err != nil { t.Fatal(err) } if _, err := tw.Write([]byte(e.data)); err != nil { t.Fatal(err) } } if err := tw.Close(); err != nil { t.Fatal(err) } var zstBuf bytes.Buffer zw, _ := zstd.NewWriter(&zstBuf) _, _ = io.Copy(zw, &rawTar) _ = zw.Close() sum := sha256.Sum256(zstBuf.Bytes()) srv := serveBundle(t, zstBuf.Bytes()) t.Cleanup(srv.Close) _, err := Fetch(context.Background(), srv.Client(), t.TempDir(), CatEntry{ Name: "x", TarballURL: srv.URL + "/bundle.tar.zst", TarballSHA256: hex.EncodeToString(sum[:]), }) if err == nil || !strings.Contains(err.Error(), "unexpected bundle entry") { t.Fatalf("want unexpected entry error, got %v", err) } } func TestFetchRejectsMissingManifest(t *testing.T) { // Bundle with only rootfs. var rawTar bytes.Buffer tw := tar.NewWriter(&rawTar) _ = tw.WriteHeader(&tar.Header{Name: RootfsFilename, Size: 3, Mode: 0o644, Typeflag: tar.TypeReg}) _, _ = tw.Write([]byte("abc")) _ = tw.Close() var zstBuf bytes.Buffer zw, _ := zstd.NewWriter(&zstBuf) _, _ = io.Copy(zw, &rawTar) _ = zw.Close() sum := sha256.Sum256(zstBuf.Bytes()) srv := serveBundle(t, zstBuf.Bytes()) t.Cleanup(srv.Close) _, err := Fetch(context.Background(), srv.Client(), t.TempDir(), CatEntry{ Name: "x", TarballURL: srv.URL + "/bundle.tar.zst", TarballSHA256: hex.EncodeToString(sum[:]), }) if err == nil || !strings.Contains(err.Error(), "missing required files") { t.Fatalf("want missing-required-files error, got %v", err) } } func TestFetchRejectsHTTPFailure(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "not found", http.StatusNotFound) })) t.Cleanup(srv.Close) _, err := Fetch(context.Background(), srv.Client(), t.TempDir(), CatEntry{ Name: "x", TarballURL: srv.URL + "/missing.tar.zst", TarballSHA256: "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef", }) if err == nil || !strings.Contains(err.Error(), "HTTP") { t.Fatalf("want HTTP error, got %v", err) } } func TestFetchRejectsEmptyURL(t *testing.T) { _, err := Fetch(context.Background(), http.DefaultClient, t.TempDir(), CatEntry{ Name: "x", TarballURL: "", TarballSHA256: "abc", }) if err == nil || !strings.Contains(err.Error(), "no tarball URL") { t.Fatalf("want no-URL error, got %v", err) } } func TestFetchRejectsEmptySHA256(t *testing.T) { _, err := Fetch(context.Background(), http.DefaultClient, t.TempDir(), CatEntry{ Name: "x", TarballURL: "https://example.com/x.tar.zst", }) if err == nil || !strings.Contains(err.Error(), "no tarball sha256") { t.Fatalf("want no-sha error, got %v", err) } } func TestFetchRejectsInvalidName(t *testing.T) { _, err := Fetch(context.Background(), http.DefaultClient, t.TempDir(), CatEntry{ Name: "", TarballURL: "https://example.com/x.tar.zst", TarballSHA256: "abc", }) if err == nil || !strings.Contains(err.Error(), "image name is required") { t.Fatalf("want name-required error, got %v", err) } }