Harden model download with timeout and checksum
This commit is contained in:
parent
64c8c26bce
commit
0df8c356af
4 changed files with 42 additions and 3 deletions
|
|
@ -94,6 +94,7 @@ Hotkey notes:
|
||||||
|
|
||||||
AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model
|
AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model
|
||||||
downloaded to `~/.cache/aman/models/` during daemon initialization.
|
downloaded to `~/.cache/aman/models/` during daemon initialization.
|
||||||
|
Model downloads use a network timeout and SHA256 verification before activation.
|
||||||
|
|
||||||
Use `-v/--verbose` to enable DEBUG logs, including recognized/processed
|
Use `-v/--verbose` to enable DEBUG logs, including recognized/processed
|
||||||
transcript text and llama.cpp logs (`llama::` prefix). Without `-v`, logs are
|
transcript text and llama.cpp logs (`llama::` prefix). Without `-v`, logs are
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
import hashlib
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -9,7 +10,14 @@ import sys
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from typing import Any, Callable, cast
|
from typing import Any, Callable, cast
|
||||||
|
|
||||||
from constants import MODEL_DIR, MODEL_NAME, MODEL_PATH, MODEL_URL
|
from constants import (
|
||||||
|
MODEL_DIR,
|
||||||
|
MODEL_DOWNLOAD_TIMEOUT_SEC,
|
||||||
|
MODEL_NAME,
|
||||||
|
MODEL_PATH,
|
||||||
|
MODEL_SHA256,
|
||||||
|
MODEL_URL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
SYSTEM_PROMPT = (
|
SYSTEM_PROMPT = (
|
||||||
|
|
@ -90,17 +98,19 @@ def ensure_model():
|
||||||
tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp")
|
tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp")
|
||||||
logging.info("downloading model: %s", MODEL_NAME)
|
logging.info("downloading model: %s", MODEL_NAME)
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(MODEL_URL) as resp:
|
with urllib.request.urlopen(MODEL_URL, timeout=MODEL_DOWNLOAD_TIMEOUT_SEC) as resp:
|
||||||
total = resp.getheader("Content-Length")
|
total = resp.getheader("Content-Length")
|
||||||
total_size = int(total) if total else None
|
total_size = int(total) if total else None
|
||||||
downloaded = 0
|
downloaded = 0
|
||||||
next_log = 0
|
next_log = 0
|
||||||
|
digest = hashlib.sha256()
|
||||||
with open(tmp_path, "wb") as handle:
|
with open(tmp_path, "wb") as handle:
|
||||||
while True:
|
while True:
|
||||||
chunk = resp.read(1024 * 1024)
|
chunk = resp.read(1024 * 1024)
|
||||||
if not chunk:
|
if not chunk:
|
||||||
break
|
break
|
||||||
handle.write(chunk)
|
handle.write(chunk)
|
||||||
|
digest.update(chunk)
|
||||||
downloaded += len(chunk)
|
downloaded += len(chunk)
|
||||||
if total_size:
|
if total_size:
|
||||||
progress = downloaded / total_size
|
progress = downloaded / total_size
|
||||||
|
|
@ -109,7 +119,9 @@ def ensure_model():
|
||||||
next_log += 0.1
|
next_log += 0.1
|
||||||
elif downloaded // (50 * 1024 * 1024) > (downloaded - len(chunk)) // (50 * 1024 * 1024):
|
elif downloaded // (50 * 1024 * 1024) > (downloaded - len(chunk)) // (50 * 1024 * 1024):
|
||||||
logging.info("model download %d MB", downloaded // (1024 * 1024))
|
logging.info("model download %d MB", downloaded // (1024 * 1024))
|
||||||
|
_assert_expected_model_checksum(digest.hexdigest())
|
||||||
os.replace(tmp_path, MODEL_PATH)
|
os.replace(tmp_path, MODEL_PATH)
|
||||||
|
logging.info("model checksum verified")
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
if tmp_path.exists():
|
if tmp_path.exists():
|
||||||
|
|
@ -117,6 +129,16 @@ def ensure_model():
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
raise
|
raise
|
||||||
|
return MODEL_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_expected_model_checksum(checksum: str) -> None:
|
||||||
|
if checksum.casefold() == MODEL_SHA256.casefold():
|
||||||
|
return
|
||||||
|
raise RuntimeError(
|
||||||
|
"downloaded model checksum mismatch "
|
||||||
|
f"(expected {MODEL_SHA256}, got {checksum})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_llama_bindings():
|
def _load_llama_bindings():
|
||||||
|
|
|
||||||
|
|
@ -12,5 +12,7 @@ MODEL_URL = (
|
||||||
"https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/"
|
"https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/"
|
||||||
"Llama-3.2-3B-Instruct-Q4_K_M.gguf"
|
"Llama-3.2-3B-Instruct-Q4_K_M.gguf"
|
||||||
)
|
)
|
||||||
|
MODEL_SHA256 = "6c1a2b41161032677be168d354123594c0e6e67d2b9227c84f296ad037c728ff"
|
||||||
|
MODEL_DOWNLOAD_TIMEOUT_SEC = 60
|
||||||
MODEL_DIR = Path.home() / ".cache" / "aman" / "models"
|
MODEL_DIR = Path.home() / ".cache" / "aman" / "models"
|
||||||
MODEL_PATH = MODEL_DIR / MODEL_NAME
|
MODEL_PATH = MODEL_DIR / MODEL_NAME
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,12 @@ SRC = ROOT / "src"
|
||||||
if str(SRC) not in sys.path:
|
if str(SRC) not in sys.path:
|
||||||
sys.path.insert(0, str(SRC))
|
sys.path.insert(0, str(SRC))
|
||||||
|
|
||||||
from aiprocess import _extract_cleaned_text, _supports_response_format
|
from aiprocess import (
|
||||||
|
_assert_expected_model_checksum,
|
||||||
|
_extract_cleaned_text,
|
||||||
|
_supports_response_format,
|
||||||
|
)
|
||||||
|
from constants import MODEL_SHA256
|
||||||
|
|
||||||
|
|
||||||
class ExtractCleanedTextTests(unittest.TestCase):
|
class ExtractCleanedTextTests(unittest.TestCase):
|
||||||
|
|
@ -84,5 +89,14 @@ class SupportsResponseFormatTests(unittest.TestCase):
|
||||||
self.assertFalse(_supports_response_format(chat_completion))
|
self.assertFalse(_supports_response_format(chat_completion))
|
||||||
|
|
||||||
|
|
||||||
|
class ModelChecksumTests(unittest.TestCase):
|
||||||
|
def test_accepts_expected_checksum_case_insensitive(self):
|
||||||
|
_assert_expected_model_checksum(MODEL_SHA256.upper())
|
||||||
|
|
||||||
|
def test_rejects_unexpected_checksum(self):
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "checksum mismatch"):
|
||||||
|
_assert_expected_model_checksum("0" * 64)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue