Harden model download with timeout and checksum

This commit is contained in:
Thales Maciel 2026-02-26 16:30:04 -03:00
parent 64c8c26bce
commit 0df8c356af
4 changed files with 42 additions and 3 deletions

View file

@ -94,6 +94,7 @@ Hotkey notes:
AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model
downloaded to `~/.cache/aman/models/` during daemon initialization.
Model downloads use a network timeout and SHA256 verification before activation.
Use `-v/--verbose` to enable DEBUG logs, including recognized/processed
transcript text and llama.cpp logs (`llama::` prefix). Without `-v`, logs are

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import ctypes
import hashlib
import inspect
import json
import logging
@ -9,7 +10,14 @@ import sys
import urllib.request
from typing import Any, Callable, cast
from constants import MODEL_DIR, MODEL_NAME, MODEL_PATH, MODEL_URL
from constants import (
MODEL_DIR,
MODEL_DOWNLOAD_TIMEOUT_SEC,
MODEL_NAME,
MODEL_PATH,
MODEL_SHA256,
MODEL_URL,
)
SYSTEM_PROMPT = (
@ -90,17 +98,19 @@ def ensure_model():
tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp")
logging.info("downloading model: %s", MODEL_NAME)
try:
with urllib.request.urlopen(MODEL_URL) as resp:
with urllib.request.urlopen(MODEL_URL, timeout=MODEL_DOWNLOAD_TIMEOUT_SEC) as resp:
total = resp.getheader("Content-Length")
total_size = int(total) if total else None
downloaded = 0
next_log = 0
digest = hashlib.sha256()
with open(tmp_path, "wb") as handle:
while True:
chunk = resp.read(1024 * 1024)
if not chunk:
break
handle.write(chunk)
digest.update(chunk)
downloaded += len(chunk)
if total_size:
progress = downloaded / total_size
@ -109,7 +119,9 @@ def ensure_model():
next_log += 0.1
elif downloaded // (50 * 1024 * 1024) > (downloaded - len(chunk)) // (50 * 1024 * 1024):
logging.info("model download %d MB", downloaded // (1024 * 1024))
_assert_expected_model_checksum(digest.hexdigest())
os.replace(tmp_path, MODEL_PATH)
logging.info("model checksum verified")
except Exception:
try:
if tmp_path.exists():
@ -117,6 +129,16 @@ def ensure_model():
except Exception:
pass
raise
return MODEL_PATH
def _assert_expected_model_checksum(checksum: str) -> None:
if checksum.casefold() == MODEL_SHA256.casefold():
return
raise RuntimeError(
"downloaded model checksum mismatch "
f"(expected {MODEL_SHA256}, got {checksum})"
)
def _load_llama_bindings():

View file

@ -12,5 +12,7 @@ MODEL_URL = (
"https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/"
"Llama-3.2-3B-Instruct-Q4_K_M.gguf"
)
MODEL_SHA256 = "6c1a2b41161032677be168d354123594c0e6e67d2b9227c84f296ad037c728ff"
MODEL_DOWNLOAD_TIMEOUT_SEC = 60
MODEL_DIR = Path.home() / ".cache" / "aman" / "models"
MODEL_PATH = MODEL_DIR / MODEL_NAME

View file

@ -7,7 +7,12 @@ SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from aiprocess import _extract_cleaned_text, _supports_response_format
from aiprocess import (
_assert_expected_model_checksum,
_extract_cleaned_text,
_supports_response_format,
)
from constants import MODEL_SHA256
class ExtractCleanedTextTests(unittest.TestCase):
@ -84,5 +89,14 @@ class SupportsResponseFormatTests(unittest.TestCase):
self.assertFalse(_supports_response_format(chat_completion))
class ModelChecksumTests(unittest.TestCase):
def test_accepts_expected_checksum_case_insensitive(self):
_assert_expected_model_checksum(MODEL_SHA256.upper())
def test_rejects_unexpected_checksum(self):
with self.assertRaisesRegex(RuntimeError, "checksum mismatch"):
_assert_expected_model_checksum("0" * 64)
if __name__ == "__main__":
unittest.main()