Harden model download with timeout and checksum
This commit is contained in:
parent
64c8c26bce
commit
0df8c356af
4 changed files with 42 additions and 3 deletions
|
|
@ -94,6 +94,7 @@ Hotkey notes:
|
|||
|
||||
AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model
|
||||
downloaded to `~/.cache/aman/models/` during daemon initialization.
|
||||
Model downloads use a network timeout and SHA256 verification before activation.
|
||||
|
||||
Use `-v/--verbose` to enable DEBUG logs, including recognized/processed
|
||||
transcript text and llama.cpp logs (`llama::` prefix). Without `-v`, logs are
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
|
|
@ -9,7 +10,14 @@ import sys
|
|||
import urllib.request
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
from constants import MODEL_DIR, MODEL_NAME, MODEL_PATH, MODEL_URL
|
||||
from constants import (
|
||||
MODEL_DIR,
|
||||
MODEL_DOWNLOAD_TIMEOUT_SEC,
|
||||
MODEL_NAME,
|
||||
MODEL_PATH,
|
||||
MODEL_SHA256,
|
||||
MODEL_URL,
|
||||
)
|
||||
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
|
|
@ -90,17 +98,19 @@ def ensure_model():
|
|||
tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp")
|
||||
logging.info("downloading model: %s", MODEL_NAME)
|
||||
try:
|
||||
with urllib.request.urlopen(MODEL_URL) as resp:
|
||||
with urllib.request.urlopen(MODEL_URL, timeout=MODEL_DOWNLOAD_TIMEOUT_SEC) as resp:
|
||||
total = resp.getheader("Content-Length")
|
||||
total_size = int(total) if total else None
|
||||
downloaded = 0
|
||||
next_log = 0
|
||||
digest = hashlib.sha256()
|
||||
with open(tmp_path, "wb") as handle:
|
||||
while True:
|
||||
chunk = resp.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
handle.write(chunk)
|
||||
digest.update(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total_size:
|
||||
progress = downloaded / total_size
|
||||
|
|
@ -109,7 +119,9 @@ def ensure_model():
|
|||
next_log += 0.1
|
||||
elif downloaded // (50 * 1024 * 1024) > (downloaded - len(chunk)) // (50 * 1024 * 1024):
|
||||
logging.info("model download %d MB", downloaded // (1024 * 1024))
|
||||
_assert_expected_model_checksum(digest.hexdigest())
|
||||
os.replace(tmp_path, MODEL_PATH)
|
||||
logging.info("model checksum verified")
|
||||
except Exception:
|
||||
try:
|
||||
if tmp_path.exists():
|
||||
|
|
@ -117,6 +129,16 @@ def ensure_model():
|
|||
except Exception:
|
||||
pass
|
||||
raise
|
||||
return MODEL_PATH
|
||||
|
||||
|
||||
def _assert_expected_model_checksum(checksum: str) -> None:
|
||||
if checksum.casefold() == MODEL_SHA256.casefold():
|
||||
return
|
||||
raise RuntimeError(
|
||||
"downloaded model checksum mismatch "
|
||||
f"(expected {MODEL_SHA256}, got {checksum})"
|
||||
)
|
||||
|
||||
|
||||
def _load_llama_bindings():
|
||||
|
|
|
|||
|
|
@ -12,5 +12,7 @@ MODEL_URL = (
|
|||
"https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/"
|
||||
"Llama-3.2-3B-Instruct-Q4_K_M.gguf"
|
||||
)
|
||||
MODEL_SHA256 = "6c1a2b41161032677be168d354123594c0e6e67d2b9227c84f296ad037c728ff"
|
||||
MODEL_DOWNLOAD_TIMEOUT_SEC = 60
|
||||
MODEL_DIR = Path.home() / ".cache" / "aman" / "models"
|
||||
MODEL_PATH = MODEL_DIR / MODEL_NAME
|
||||
|
|
|
|||
|
|
@ -7,7 +7,12 @@ SRC = ROOT / "src"
|
|||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from aiprocess import _extract_cleaned_text, _supports_response_format
|
||||
from aiprocess import (
|
||||
_assert_expected_model_checksum,
|
||||
_extract_cleaned_text,
|
||||
_supports_response_format,
|
||||
)
|
||||
from constants import MODEL_SHA256
|
||||
|
||||
|
||||
class ExtractCleanedTextTests(unittest.TestCase):
|
||||
|
|
@ -84,5 +89,14 @@ class SupportsResponseFormatTests(unittest.TestCase):
|
|||
self.assertFalse(_supports_response_format(chat_completion))
|
||||
|
||||
|
||||
class ModelChecksumTests(unittest.TestCase):
|
||||
def test_accepts_expected_checksum_case_insensitive(self):
|
||||
_assert_expected_model_checksum(MODEL_SHA256.upper())
|
||||
|
||||
def test_rejects_unexpected_checksum(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "checksum mismatch"):
|
||||
_assert_expected_model_checksum("0" * 64)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue