Verify cached model checksum before use

This commit is contained in:
Thales Maciel 2026-02-26 16:38:37 -03:00
parent 8b3532f2ca
commit 386ba4af92
2 changed files with 107 additions and 1 deletions

View file

@ -8,6 +8,7 @@ import logging
import os import os
import sys import sys
import urllib.request import urllib.request
from pathlib import Path
from typing import Any, Callable, cast from typing import Any, Callable, cast
from constants import ( from constants import (
@ -92,8 +93,19 @@ class LlamaProcessor:
def ensure_model(): def ensure_model():
had_invalid_cache = False
if MODEL_PATH.exists(): if MODEL_PATH.exists():
return MODEL_PATH existing_checksum = _sha256_file(MODEL_PATH)
if existing_checksum.casefold() == MODEL_SHA256.casefold():
return MODEL_PATH
had_invalid_cache = True
logging.warning(
"cached model checksum mismatch (%s); expected %s, redownloading",
existing_checksum,
MODEL_SHA256,
)
MODEL_PATH.unlink()
MODEL_DIR.mkdir(parents=True, exist_ok=True) MODEL_DIR.mkdir(parents=True, exist_ok=True)
tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp") tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp")
logging.info("downloading model: %s", MODEL_NAME) logging.info("downloading model: %s", MODEL_NAME)
@ -128,6 +140,8 @@ def ensure_model():
tmp_path.unlink() tmp_path.unlink()
except Exception: except Exception:
pass pass
if had_invalid_cache:
raise RuntimeError("cached model checksum mismatch and redownload failed")
raise raise
return MODEL_PATH return MODEL_PATH
@ -141,6 +155,17 @@ def _assert_expected_model_checksum(checksum: str) -> None:
) )
def _sha256_file(path: Path) -> str:
digest = hashlib.sha256()
with open(path, "rb") as handle:
while True:
chunk = handle.read(1024 * 1024)
if not chunk:
break
digest.update(chunk)
return digest.hexdigest()
def _load_llama_bindings(): def _load_llama_bindings():
try: try:
from llama_cpp import Llama, llama_cpp as llama_cpp_lib # type: ignore[import-not-found] from llama_cpp import Llama, llama_cpp as llama_cpp_lib # type: ignore[import-not-found]

View file

@ -1,16 +1,21 @@
import sys import sys
import tempfile
import unittest import unittest
from hashlib import sha256
from pathlib import Path from pathlib import Path
from unittest.mock import patch
ROOT = Path(__file__).resolve().parents[1] ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src" SRC = ROOT / "src"
if str(SRC) not in sys.path: if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC)) sys.path.insert(0, str(SRC))
import aiprocess
from aiprocess import ( from aiprocess import (
_assert_expected_model_checksum, _assert_expected_model_checksum,
_extract_cleaned_text, _extract_cleaned_text,
_supports_response_format, _supports_response_format,
ensure_model,
) )
from constants import MODEL_SHA256 from constants import MODEL_SHA256
@ -98,5 +103,81 @@ class ModelChecksumTests(unittest.TestCase):
_assert_expected_model_checksum("0" * 64) _assert_expected_model_checksum("0" * 64)
class _Response:
def __init__(self, payload: bytes):
self.payload = payload
self.offset = 0
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def getheader(self, name: str):
if name.lower() == "content-length":
return str(len(self.payload))
return None
def read(self, size: int) -> bytes:
if self.offset >= len(self.payload):
return b""
chunk = self.payload[self.offset : self.offset + size]
self.offset += len(chunk)
return chunk
class EnsureModelTests(unittest.TestCase):
def test_existing_valid_model_skips_download(self):
payload = b"valid-model"
checksum = sha256(payload).hexdigest()
with tempfile.TemporaryDirectory() as td:
model_path = Path(td) / "model.gguf"
model_path.write_bytes(payload)
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
aiprocess, "MODEL_DIR", model_path.parent
), patch.object(aiprocess, "MODEL_SHA256", checksum), patch(
"aiprocess.urllib.request.urlopen"
) as urlopen:
out = ensure_model()
self.assertEqual(out, model_path)
urlopen.assert_not_called()
def test_existing_invalid_model_triggers_redownload(self):
cached_payload = b"bad-model"
downloaded_payload = b"good-model"
expected_checksum = sha256(downloaded_payload).hexdigest()
with tempfile.TemporaryDirectory() as td:
model_path = Path(td) / "model.gguf"
model_path.write_bytes(cached_payload)
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
aiprocess, "MODEL_DIR", model_path.parent
), patch.object(aiprocess, "MODEL_SHA256", expected_checksum), patch(
"aiprocess.urllib.request.urlopen",
return_value=_Response(downloaded_payload),
) as urlopen:
out = ensure_model()
self.assertEqual(out, model_path)
self.assertEqual(model_path.read_bytes(), downloaded_payload)
urlopen.assert_called_once()
def test_invalid_cached_model_and_redownload_failure_raises_clear_error(self):
with tempfile.TemporaryDirectory() as td:
model_path = Path(td) / "model.gguf"
model_path.write_bytes(b"bad-model")
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
aiprocess, "MODEL_DIR", model_path.parent
), patch.object(aiprocess, "MODEL_SHA256", "f" * 64), patch(
"aiprocess.urllib.request.urlopen",
side_effect=RuntimeError("network down"),
):
with self.assertRaisesRegex(
RuntimeError, "cached model checksum mismatch and redownload failed"
):
ensure_model()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()