Verify cached model checksum before use
This commit is contained in:
parent
8b3532f2ca
commit
386ba4af92
2 changed files with 107 additions and 1 deletions
|
|
@ -8,6 +8,7 @@ import logging
|
|||
import os
|
||||
import sys
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
from constants import (
|
||||
|
|
@ -92,8 +93,19 @@ class LlamaProcessor:
|
|||
|
||||
|
||||
def ensure_model():
|
||||
had_invalid_cache = False
|
||||
if MODEL_PATH.exists():
|
||||
return MODEL_PATH
|
||||
existing_checksum = _sha256_file(MODEL_PATH)
|
||||
if existing_checksum.casefold() == MODEL_SHA256.casefold():
|
||||
return MODEL_PATH
|
||||
had_invalid_cache = True
|
||||
logging.warning(
|
||||
"cached model checksum mismatch (%s); expected %s, redownloading",
|
||||
existing_checksum,
|
||||
MODEL_SHA256,
|
||||
)
|
||||
MODEL_PATH.unlink()
|
||||
|
||||
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp")
|
||||
logging.info("downloading model: %s", MODEL_NAME)
|
||||
|
|
@ -128,6 +140,8 @@ def ensure_model():
|
|||
tmp_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
if had_invalid_cache:
|
||||
raise RuntimeError("cached model checksum mismatch and redownload failed")
|
||||
raise
|
||||
return MODEL_PATH
|
||||
|
||||
|
|
@ -141,6 +155,17 @@ def _assert_expected_model_checksum(checksum: str) -> None:
|
|||
)
|
||||
|
||||
|
||||
def _sha256_file(path: Path) -> str:
|
||||
digest = hashlib.sha256()
|
||||
with open(path, "rb") as handle:
|
||||
while True:
|
||||
chunk = handle.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
digest.update(chunk)
|
||||
return digest.hexdigest()
|
||||
|
||||
|
||||
def _load_llama_bindings():
|
||||
try:
|
||||
from llama_cpp import Llama, llama_cpp as llama_cpp_lib # type: ignore[import-not-found]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue