package updater import ( "archive/tar" "bytes" "compress/gzip" "context" "crypto/sha256" "encoding/hex" "fmt" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "testing" ) // makeReleaseTarball writes a tarball whose root contains the three // expected entries with the given bodies. Used by stage + download // tests so they don't need a real banger build to exercise the // extraction path. func makeReleaseTarball(t *testing.T, bodies map[string][]byte) []byte { t.Helper() var buf bytes.Buffer gz := gzip.NewWriter(&buf) tw := tar.NewWriter(gz) for name, body := range bodies { hdr := &tar.Header{ Name: name, Mode: 0o755, Size: int64(len(body)), Typeflag: tar.TypeReg, } if err := tw.WriteHeader(hdr); err != nil { t.Fatalf("write header: %v", err) } if _, err := tw.Write(body); err != nil { t.Fatalf("write body: %v", err) } } if err := tw.Close(); err != nil { t.Fatalf("close tar: %v", err) } if err := gz.Close(); err != nil { t.Fatalf("close gzip: %v", err) } return buf.Bytes() } func sha256Hex(b []byte) string { sum := sha256.Sum256(b) return hex.EncodeToString(sum[:]) } func TestStageTarballHappyPath(t *testing.T) { body := makeReleaseTarball(t, map[string][]byte{ "banger": []byte("BANGER"), "bangerd": []byte("BANGERD"), "banger-vsock-agent": []byte("AGENT"), }) tarball := filepath.Join(t.TempDir(), "release.tar.gz") if err := os.WriteFile(tarball, body, 0o644); err != nil { t.Fatalf("write tarball: %v", err) } staging := filepath.Join(t.TempDir(), "staged") got, err := StageTarball(tarball, staging) if err != nil { t.Fatalf("StageTarball: %v", err) } for _, p := range []string{got.BangerPath, got.BangerdPath, got.VsockAgentPath} { info, err := os.Stat(p) if err != nil { t.Fatalf("stat %s: %v", p, err) } if info.Mode().Perm() != 0o755 { t.Errorf("%s mode = %o, want 0755", p, info.Mode().Perm()) } } bs, _ := os.ReadFile(got.BangerPath) if string(bs) != "BANGER" { t.Fatalf("banger content = %q", bs) } } func TestStageTarballRejectsExtraEntry(t *testing.T) { body := makeReleaseTarball(t, map[string][]byte{ "banger": []byte("a"), "bangerd": []byte("b"), "banger-vsock-agent": []byte("c"), "bonus.txt": []byte("not allowed"), }) tarball := filepath.Join(t.TempDir(), "rel.tar.gz") _ = os.WriteFile(tarball, body, 0o644) _, err := StageTarball(tarball, t.TempDir()) if err == nil || !strings.Contains(err.Error(), "unexpected entry") { t.Fatalf("err = %v, want unexpected-entry rejection", err) } } func TestStageTarballRejectsMissingEntry(t *testing.T) { body := makeReleaseTarball(t, map[string][]byte{ "banger": []byte("a"), "bangerd": []byte("b"), // banger-vsock-agent intentionally missing }) tarball := filepath.Join(t.TempDir(), "rel.tar.gz") _ = os.WriteFile(tarball, body, 0o644) _, err := StageTarball(tarball, t.TempDir()) if err == nil || !strings.Contains(err.Error(), "missing required entry") { t.Fatalf("err = %v, want missing-required rejection", err) } } func TestStageTarballRejectsPathTraversal(t *testing.T) { // Build the tarball manually so we can inject a `../` entry — // makeReleaseTarball's expected-entry filter would otherwise // catch it earlier. var buf bytes.Buffer gz := gzip.NewWriter(&buf) tw := tar.NewWriter(gz) for _, e := range []struct{ name, body string }{ {"banger", "a"}, {"bangerd", "b"}, {"../escape", "x"}, } { _ = tw.WriteHeader(&tar.Header{Name: e.name, Size: int64(len(e.body)), Mode: 0o755, Typeflag: tar.TypeReg}) _, _ = tw.Write([]byte(e.body)) } _ = tw.Close() _ = gz.Close() tarball := filepath.Join(t.TempDir(), "rel.tar.gz") _ = os.WriteFile(tarball, buf.Bytes(), 0o644) _, err := StageTarball(tarball, t.TempDir()) if err == nil || !strings.Contains(err.Error(), "unsafe path") { t.Fatalf("err = %v, want unsafe-path rejection", err) } } func TestSwapAndRollback(t *testing.T) { root := t.TempDir() binDir := filepath.Join(root, "bin") libDir := filepath.Join(root, "lib", "banger") if err := os.MkdirAll(binDir, 0o755); err != nil { t.Fatal(err) } if err := os.MkdirAll(libDir, 0o755); err != nil { t.Fatal(err) } for _, p := range []string{ filepath.Join(binDir, "banger"), filepath.Join(binDir, "bangerd"), filepath.Join(libDir, "banger-vsock-agent"), } { if err := os.WriteFile(p, []byte("OLD-"+filepath.Base(p)), 0o755); err != nil { t.Fatal(err) } } staging := filepath.Join(root, "staging") _ = os.MkdirAll(staging, 0o700) staged := StagedRelease{ BangerPath: filepath.Join(staging, "banger"), BangerdPath: filepath.Join(staging, "bangerd"), VsockAgentPath: filepath.Join(staging, "banger-vsock-agent"), } for _, pair := range []struct{ p, body string }{ {staged.BangerPath, "NEW-banger"}, {staged.BangerdPath, "NEW-bangerd"}, {staged.VsockAgentPath, "NEW-banger-vsock-agent"}, } { if err := os.WriteFile(pair.p, []byte(pair.body), 0o755); err != nil { t.Fatal(err) } } targets := InstallTargets{ Banger: filepath.Join(binDir, "banger"), Bangerd: filepath.Join(binDir, "bangerd"), VsockAgent: filepath.Join(libDir, "banger-vsock-agent"), } res, err := Swap(staged, targets) if err != nil { t.Fatalf("Swap: %v", err) } if len(res.SwappedTargets) != 3 { t.Fatalf("SwappedTargets len = %d, want 3", len(res.SwappedTargets)) } for _, p := range []string{targets.Banger, targets.Bangerd, targets.VsockAgent} { got, _ := os.ReadFile(p) want := "NEW-" + filepath.Base(p) if string(got) != want { t.Fatalf("%s content = %q, want %q", p, got, want) } prev, err := os.ReadFile(p + previousSuffix) if err != nil { t.Fatalf("missing backup at %s.previous: %v", p, err) } if string(prev) != "OLD-"+filepath.Base(p) { t.Fatalf(".previous content = %q", prev) } } if err := Rollback(res); err != nil { t.Fatalf("Rollback: %v", err) } for _, p := range []string{targets.Banger, targets.Bangerd, targets.VsockAgent} { got, _ := os.ReadFile(p) want := "OLD-" + filepath.Base(p) if string(got) != want { t.Fatalf("post-rollback %s = %q, want %q", p, got, want) } if _, err := os.Stat(p + previousSuffix); !os.IsNotExist(err) { t.Fatalf(".previous should be cleaned after rollback; stat err = %v", err) } } } func TestSwapPartialFailureRollsBackCleanly(t *testing.T) { root := t.TempDir() binDir := filepath.Join(root, "bin") libDir := filepath.Join(root, "lib", "banger") if err := os.MkdirAll(binDir, 0o755); err != nil { t.Fatal(err) } if err := os.MkdirAll(libDir, 0o755); err != nil { t.Fatal(err) } // Pre-create the two binaries that will swap successfully. for _, p := range []string{ filepath.Join(binDir, "bangerd"), filepath.Join(libDir, "banger-vsock-agent"), } { _ = os.WriteFile(p, []byte("OLD-"+filepath.Base(p)), 0o755) } staging := filepath.Join(root, "staging") _ = os.MkdirAll(staging, 0o700) staged := StagedRelease{ BangerPath: filepath.Join(staging, "banger"), BangerdPath: filepath.Join(staging, "bangerd"), VsockAgentPath: filepath.Join(staging, "banger-vsock-agent"), } for _, pair := range []struct{ p, body string }{ {staged.BangerPath, "NEW-banger"}, {staged.BangerdPath, "NEW-bangerd"}, {staged.VsockAgentPath, "NEW-banger-vsock-agent"}, } { _ = os.WriteFile(pair.p, []byte(pair.body), 0o755) } // Block the banger swap (which is LAST in the order) by putting // a regular file where its parent dir should be — MkdirAll fails // with "not a directory". Vsock + bangerd succeed first. blockedParent := filepath.Join(root, "blocked-bin") if err := os.WriteFile(blockedParent, []byte("blocking"), 0o644); err != nil { t.Fatal(err) } targets := InstallTargets{ Banger: filepath.Join(blockedParent, "banger"), Bangerd: filepath.Join(binDir, "bangerd"), VsockAgent: filepath.Join(libDir, "banger-vsock-agent"), } res, err := Swap(staged, targets) if err == nil { t.Fatal("Swap unexpectedly succeeded; banger parent should be blocked by a regular file") } if len(res.SwappedTargets) != 2 { t.Fatalf("SwappedTargets = %v, want 2 (vsock + bangerd before banger failed)", res.SwappedTargets) } // Rolling back the partial swap should restore the filesystem. if err := Rollback(res); err != nil { t.Fatalf("Rollback after partial swap: %v", err) } for _, p := range res.SwappedTargets { got, _ := os.ReadFile(p) want := "OLD-" + filepath.Base(p) if string(got) != want { t.Fatalf("post-rollback %s = %q", p, got) } } } func TestDownloadReleaseHappyPath(t *testing.T) { tarballBody := []byte("fake tarball bytes") tarballSHA := sha256Hex(tarballBody) mux := http.NewServeMux() mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write(tarballBody) }) mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "%s banger.tar.gz\n", tarballSHA) }) srv := httptest.NewServer(mux) defer srv.Close() dst := filepath.Join(t.TempDir(), "out.tar.gz") sums, err := DownloadRelease(context.Background(), srv.Client(), Release{ Version: "v0.1.0", TarballURL: srv.URL + "/banger.tar.gz", SHA256SumsURL: srv.URL + "/SHA256SUMS", }, dst) if err != nil { t.Fatalf("DownloadRelease: %v", err) } if !strings.Contains(string(sums), "banger.tar.gz") { t.Fatalf("returned sums body missing tarball name: %q", sums) } got, _ := os.ReadFile(dst) if !bytes.Equal(got, tarballBody) { t.Fatalf("downloaded body differs from served body") } } func TestDownloadReleaseRejectsTarballMissingFromSums(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("body")) }) mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) { // Sums for a different file; tarball name not listed. fmt.Fprintf(w, "%s unrelated\n", sha256Hex([]byte("body"))) }) srv := httptest.NewServer(mux) defer srv.Close() dst := filepath.Join(t.TempDir(), "out.tar.gz") _, err := DownloadRelease(context.Background(), srv.Client(), Release{ TarballURL: srv.URL + "/banger.tar.gz", SHA256SumsURL: srv.URL + "/SHA256SUMS", }, dst) if err == nil || !strings.Contains(err.Error(), "does not list") { t.Fatalf("err = %v, want SHA256SUMS-missing rejection", err) } } func TestDownloadReleasePropagatesShaMismatch(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/banger.tar.gz", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("served body")) }) mux.HandleFunc("/SHA256SUMS", func(w http.ResponseWriter, r *http.Request) { // Wrong digest for the tarball. fmt.Fprintf(w, "%s banger.tar.gz\n", sha256Hex([]byte("expected body"))) }) srv := httptest.NewServer(mux) defer srv.Close() dst := filepath.Join(t.TempDir(), "out.tar.gz") _, err := DownloadRelease(context.Background(), srv.Client(), Release{ TarballURL: srv.URL + "/banger.tar.gz", SHA256SumsURL: srv.URL + "/SHA256SUMS", }, dst) if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") { t.Fatalf("err = %v, want sha256 mismatch", err) } if _, statErr := os.Stat(dst); !os.IsNotExist(statErr) { t.Fatalf("partial tarball should be removed; stat err = %v", statErr) } }