diff --git a/README.md b/README.md index 573d4b0..6884ea5 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ Hotkey notes: AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model downloaded to `~/.cache/aman/models/` during daemon initialization. +Model downloads use a network timeout and SHA256 verification before activation. Use `-v/--verbose` to enable DEBUG logs, including recognized/processed transcript text and llama.cpp logs (`llama::` prefix). Without `-v`, logs are diff --git a/src/aiprocess.py b/src/aiprocess.py index 11a92ba..2c9a129 100644 --- a/src/aiprocess.py +++ b/src/aiprocess.py @@ -1,6 +1,7 @@ from __future__ import annotations import ctypes +import hashlib import inspect import json import logging @@ -9,7 +10,14 @@ import sys import urllib.request from typing import Any, Callable, cast -from constants import MODEL_DIR, MODEL_NAME, MODEL_PATH, MODEL_URL +from constants import ( + MODEL_DIR, + MODEL_DOWNLOAD_TIMEOUT_SEC, + MODEL_NAME, + MODEL_PATH, + MODEL_SHA256, + MODEL_URL, +) SYSTEM_PROMPT = ( @@ -90,17 +98,19 @@ def ensure_model(): tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp") logging.info("downloading model: %s", MODEL_NAME) try: - with urllib.request.urlopen(MODEL_URL) as resp: + with urllib.request.urlopen(MODEL_URL, timeout=MODEL_DOWNLOAD_TIMEOUT_SEC) as resp: total = resp.getheader("Content-Length") total_size = int(total) if total else None downloaded = 0 next_log = 0 + digest = hashlib.sha256() with open(tmp_path, "wb") as handle: while True: chunk = resp.read(1024 * 1024) if not chunk: break handle.write(chunk) + digest.update(chunk) downloaded += len(chunk) if total_size: progress = downloaded / total_size @@ -109,7 +119,9 @@ def ensure_model(): next_log += 0.1 elif downloaded // (50 * 1024 * 1024) > (downloaded - len(chunk)) // (50 * 1024 * 1024): logging.info("model download %d MB", downloaded // (1024 * 1024)) + _assert_expected_model_checksum(digest.hexdigest()) os.replace(tmp_path, MODEL_PATH) + logging.info("model checksum verified") except Exception: try: if tmp_path.exists(): @@ -117,6 +129,16 @@ def ensure_model(): except Exception: pass raise + return MODEL_PATH + + +def _assert_expected_model_checksum(checksum: str) -> None: + if checksum.casefold() == MODEL_SHA256.casefold(): + return + raise RuntimeError( + "downloaded model checksum mismatch " + f"(expected {MODEL_SHA256}, got {checksum})" + ) def _load_llama_bindings(): diff --git a/src/constants.py b/src/constants.py index df06122..4566733 100644 --- a/src/constants.py +++ b/src/constants.py @@ -12,5 +12,7 @@ MODEL_URL = ( "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/" "Llama-3.2-3B-Instruct-Q4_K_M.gguf" ) +MODEL_SHA256 = "6c1a2b41161032677be168d354123594c0e6e67d2b9227c84f296ad037c728ff" +MODEL_DOWNLOAD_TIMEOUT_SEC = 60 MODEL_DIR = Path.home() / ".cache" / "aman" / "models" MODEL_PATH = MODEL_DIR / MODEL_NAME diff --git a/tests/test_aiprocess.py b/tests/test_aiprocess.py index 1724d00..51d28c8 100644 --- a/tests/test_aiprocess.py +++ b/tests/test_aiprocess.py @@ -7,7 +7,12 @@ SRC = ROOT / "src" if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) -from aiprocess import _extract_cleaned_text, _supports_response_format +from aiprocess import ( + _assert_expected_model_checksum, + _extract_cleaned_text, + _supports_response_format, +) +from constants import MODEL_SHA256 class ExtractCleanedTextTests(unittest.TestCase): @@ -84,5 +89,14 @@ class SupportsResponseFormatTests(unittest.TestCase): self.assertFalse(_supports_response_format(chat_completion)) +class ModelChecksumTests(unittest.TestCase): + def test_accepts_expected_checksum_case_insensitive(self): + _assert_expected_model_checksum(MODEL_SHA256.upper()) + + def test_rejects_unexpected_checksum(self): + with self.assertRaisesRegex(RuntimeError, "checksum mismatch"): + _assert_expected_model_checksum("0" * 64) + + if __name__ == "__main__": unittest.main()