package kernelcat import ( "archive/tar" "context" "crypto/sha256" "encoding/hex" "fmt" "io" "net/http" "os" "path/filepath" "strings" "time" "github.com/klauspost/compress/zstd" ) // Fetch downloads the tarball for entry, verifies its SHA256, extracts it // into //, and writes a manifest. On failure it // removes the partially-populated target directory. // // The tarball is expected to be a tar+zstd archive whose root contains // vmlinux and optionally initrd.img and/or a modules/ directory. Path // traversal entries (..) and absolute-path members are rejected. func Fetch(ctx context.Context, client *http.Client, kernelsDir string, entry CatEntry) (Entry, error) { if err := ValidateName(entry.Name); err != nil { return Entry{}, err } if strings.TrimSpace(entry.TarballURL) == "" { return Entry{}, fmt.Errorf("catalog entry %q has no tarball URL", entry.Name) } if strings.TrimSpace(entry.TarballSHA256) == "" { return Entry{}, fmt.Errorf("catalog entry %q has no tarball sha256", entry.Name) } if client == nil { client = http.DefaultClient } if err := DeleteLocal(kernelsDir, entry.Name); err != nil { return Entry{}, fmt.Errorf("clear prior catalog entry: %w", err) } targetDir := EntryDir(kernelsDir, entry.Name) if err := os.MkdirAll(targetDir, 0o755); err != nil { return Entry{}, err } cleanup := func() { _ = os.RemoveAll(targetDir) } req, err := http.NewRequestWithContext(ctx, http.MethodGet, entry.TarballURL, nil) if err != nil { cleanup() return Entry{}, err } resp, err := client.Do(req) if err != nil { cleanup() return Entry{}, fmt.Errorf("fetch %s: %w", entry.TarballURL, err) } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { cleanup() return Entry{}, fmt.Errorf("fetch %s: HTTP %s", entry.TarballURL, resp.Status) } hasher := sha256.New() tee := io.TeeReader(resp.Body, hasher) zr, err := zstd.NewReader(tee) if err != nil { cleanup() return Entry{}, fmt.Errorf("init zstd: %w", err) } defer zr.Close() if err := extractTar(zr, targetDir); err != nil { cleanup() return Entry{}, err } // Drain any remaining tarball-padding bytes so the hash covers the // whole transport stream even if the tar reader stopped early. if _, err := io.Copy(io.Discard, tee); err != nil { cleanup() return Entry{}, fmt.Errorf("drain tarball: %w", err) } got := hex.EncodeToString(hasher.Sum(nil)) if !strings.EqualFold(got, entry.TarballSHA256) { cleanup() return Entry{}, fmt.Errorf("tarball sha256 mismatch: got %s, want %s", got, entry.TarballSHA256) } kernelPath := filepath.Join(targetDir, kernelFilename) if _, err := os.Stat(kernelPath); err != nil { cleanup() return Entry{}, fmt.Errorf("tarball missing %s: %w", kernelFilename, err) } kernelSum, err := SumFile(kernelPath) if err != nil { cleanup() return Entry{}, err } stored := Entry{ Name: entry.Name, Distro: entry.Distro, Arch: entry.Arch, KernelVersion: entry.KernelVersion, SHA256: kernelSum, Source: "pull:" + entry.TarballURL, ImportedAt: time.Now().UTC(), } if err := WriteLocal(kernelsDir, stored); err != nil { cleanup() return Entry{}, err } return ReadLocal(kernelsDir, entry.Name) } // extractTar writes each regular file / dir / safe symlink from r into // target, refusing any member whose normalised path would escape target. func extractTar(r io.Reader, target string) error { absTarget, err := filepath.Abs(target) if err != nil { return err } tr := tar.NewReader(r) for { hdr, err := tr.Next() if err == io.EOF { return nil } if err != nil { return fmt.Errorf("read tarball: %w", err) } rel := filepath.Clean(hdr.Name) if rel == "." || rel == string(filepath.Separator) { continue } if filepath.IsAbs(rel) || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || rel == ".." { return fmt.Errorf("unsafe path in tarball: %q", hdr.Name) } dst := filepath.Join(absTarget, rel) if dst != absTarget && !strings.HasPrefix(dst, absTarget+string(filepath.Separator)) { return fmt.Errorf("unsafe path in tarball: %q", hdr.Name) } switch hdr.Typeflag { case tar.TypeDir: if err := os.MkdirAll(dst, os.FileMode(hdr.Mode)|0o755); err != nil { return err } case tar.TypeReg: if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { return err } f, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)|0o600) if err != nil { return err } if _, err := io.Copy(f, tr); err != nil { _ = f.Close() return err } if err := f.Close(); err != nil { return err } case tar.TypeSymlink: if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { return err } // Absolute targets are interpreted at runtime against the // eventual rootfs (`/` inside the VM), so they're rooted // inside absTarget by construction. Only relative targets // need an escape check at write time. if !filepath.IsAbs(hdr.Linkname) { resolved := filepath.Clean(filepath.Join(filepath.Dir(dst), hdr.Linkname)) if resolved != absTarget && !strings.HasPrefix(resolved, absTarget+string(filepath.Separator)) { return fmt.Errorf("unsafe symlink in tarball: %q -> %q", hdr.Name, hdr.Linkname) } } if err := os.Symlink(hdr.Linkname, dst); err != nil { return err } default: // Hardlinks / device nodes / fifos: skip silently. Kernel // module trees shouldn't need them. } } }