diff --git a/src/aiprocess.py b/src/aiprocess.py index 2c9a129..a098bc4 100644 --- a/src/aiprocess.py +++ b/src/aiprocess.py @@ -8,6 +8,7 @@ import logging import os import sys import urllib.request +from pathlib import Path from typing import Any, Callable, cast from constants import ( @@ -92,8 +93,19 @@ class LlamaProcessor: def ensure_model(): + had_invalid_cache = False if MODEL_PATH.exists(): - return MODEL_PATH + existing_checksum = _sha256_file(MODEL_PATH) + if existing_checksum.casefold() == MODEL_SHA256.casefold(): + return MODEL_PATH + had_invalid_cache = True + logging.warning( + "cached model checksum mismatch (%s); expected %s, redownloading", + existing_checksum, + MODEL_SHA256, + ) + MODEL_PATH.unlink() + MODEL_DIR.mkdir(parents=True, exist_ok=True) tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp") logging.info("downloading model: %s", MODEL_NAME) @@ -128,6 +140,8 @@ def ensure_model(): tmp_path.unlink() except Exception: pass + if had_invalid_cache: + raise RuntimeError("cached model checksum mismatch and redownload failed") raise return MODEL_PATH @@ -141,6 +155,17 @@ def _assert_expected_model_checksum(checksum: str) -> None: ) +def _sha256_file(path: Path) -> str: + digest = hashlib.sha256() + with open(path, "rb") as handle: + while True: + chunk = handle.read(1024 * 1024) + if not chunk: + break + digest.update(chunk) + return digest.hexdigest() + + def _load_llama_bindings(): try: from llama_cpp import Llama, llama_cpp as llama_cpp_lib # type: ignore[import-not-found] diff --git a/tests/test_aiprocess.py b/tests/test_aiprocess.py index 51d28c8..05ca6ab 100644 --- a/tests/test_aiprocess.py +++ b/tests/test_aiprocess.py @@ -1,16 +1,21 @@ import sys +import tempfile import unittest +from hashlib import sha256 from pathlib import Path +from unittest.mock import patch ROOT = Path(__file__).resolve().parents[1] SRC = ROOT / "src" if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) +import aiprocess from aiprocess import ( _assert_expected_model_checksum, _extract_cleaned_text, _supports_response_format, + ensure_model, ) from constants import MODEL_SHA256 @@ -98,5 +103,81 @@ class ModelChecksumTests(unittest.TestCase): _assert_expected_model_checksum("0" * 64) +class _Response: + def __init__(self, payload: bytes): + self.payload = payload + self.offset = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def getheader(self, name: str): + if name.lower() == "content-length": + return str(len(self.payload)) + return None + + def read(self, size: int) -> bytes: + if self.offset >= len(self.payload): + return b"" + chunk = self.payload[self.offset : self.offset + size] + self.offset += len(chunk) + return chunk + + +class EnsureModelTests(unittest.TestCase): + def test_existing_valid_model_skips_download(self): + payload = b"valid-model" + checksum = sha256(payload).hexdigest() + with tempfile.TemporaryDirectory() as td: + model_path = Path(td) / "model.gguf" + model_path.write_bytes(payload) + with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object( + aiprocess, "MODEL_DIR", model_path.parent + ), patch.object(aiprocess, "MODEL_SHA256", checksum), patch( + "aiprocess.urllib.request.urlopen" + ) as urlopen: + out = ensure_model() + + self.assertEqual(out, model_path) + urlopen.assert_not_called() + + def test_existing_invalid_model_triggers_redownload(self): + cached_payload = b"bad-model" + downloaded_payload = b"good-model" + expected_checksum = sha256(downloaded_payload).hexdigest() + with tempfile.TemporaryDirectory() as td: + model_path = Path(td) / "model.gguf" + model_path.write_bytes(cached_payload) + with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object( + aiprocess, "MODEL_DIR", model_path.parent + ), patch.object(aiprocess, "MODEL_SHA256", expected_checksum), patch( + "aiprocess.urllib.request.urlopen", + return_value=_Response(downloaded_payload), + ) as urlopen: + out = ensure_model() + + self.assertEqual(out, model_path) + self.assertEqual(model_path.read_bytes(), downloaded_payload) + urlopen.assert_called_once() + + def test_invalid_cached_model_and_redownload_failure_raises_clear_error(self): + with tempfile.TemporaryDirectory() as td: + model_path = Path(td) / "model.gguf" + model_path.write_bytes(b"bad-model") + with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object( + aiprocess, "MODEL_DIR", model_path.parent + ), patch.object(aiprocess, "MODEL_SHA256", "f" * 64), patch( + "aiprocess.urllib.request.urlopen", + side_effect=RuntimeError("network down"), + ): + with self.assertRaisesRegex( + RuntimeError, "cached model checksum mismatch and redownload failed" + ): + ensure_model() + + if __name__ == "__main__": unittest.main()