Verify cached model checksum before use
This commit is contained in:
parent
8b3532f2ca
commit
386ba4af92
2 changed files with 107 additions and 1 deletions
|
|
@ -8,6 +8,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Callable, cast
|
from typing import Any, Callable, cast
|
||||||
|
|
||||||
from constants import (
|
from constants import (
|
||||||
|
|
@ -92,8 +93,19 @@ class LlamaProcessor:
|
||||||
|
|
||||||
|
|
||||||
def ensure_model():
|
def ensure_model():
|
||||||
|
had_invalid_cache = False
|
||||||
if MODEL_PATH.exists():
|
if MODEL_PATH.exists():
|
||||||
return MODEL_PATH
|
existing_checksum = _sha256_file(MODEL_PATH)
|
||||||
|
if existing_checksum.casefold() == MODEL_SHA256.casefold():
|
||||||
|
return MODEL_PATH
|
||||||
|
had_invalid_cache = True
|
||||||
|
logging.warning(
|
||||||
|
"cached model checksum mismatch (%s); expected %s, redownloading",
|
||||||
|
existing_checksum,
|
||||||
|
MODEL_SHA256,
|
||||||
|
)
|
||||||
|
MODEL_PATH.unlink()
|
||||||
|
|
||||||
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp")
|
tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp")
|
||||||
logging.info("downloading model: %s", MODEL_NAME)
|
logging.info("downloading model: %s", MODEL_NAME)
|
||||||
|
|
@ -128,6 +140,8 @@ def ensure_model():
|
||||||
tmp_path.unlink()
|
tmp_path.unlink()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
if had_invalid_cache:
|
||||||
|
raise RuntimeError("cached model checksum mismatch and redownload failed")
|
||||||
raise
|
raise
|
||||||
return MODEL_PATH
|
return MODEL_PATH
|
||||||
|
|
||||||
|
|
@ -141,6 +155,17 @@ def _assert_expected_model_checksum(checksum: str) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sha256_file(path: Path) -> str:
|
||||||
|
digest = hashlib.sha256()
|
||||||
|
with open(path, "rb") as handle:
|
||||||
|
while True:
|
||||||
|
chunk = handle.read(1024 * 1024)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
digest.update(chunk)
|
||||||
|
return digest.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def _load_llama_bindings():
|
def _load_llama_bindings():
|
||||||
try:
|
try:
|
||||||
from llama_cpp import Llama, llama_cpp as llama_cpp_lib # type: ignore[import-not-found]
|
from llama_cpp import Llama, llama_cpp as llama_cpp_lib # type: ignore[import-not-found]
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,21 @@
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
ROOT = Path(__file__).resolve().parents[1]
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
SRC = ROOT / "src"
|
SRC = ROOT / "src"
|
||||||
if str(SRC) not in sys.path:
|
if str(SRC) not in sys.path:
|
||||||
sys.path.insert(0, str(SRC))
|
sys.path.insert(0, str(SRC))
|
||||||
|
|
||||||
|
import aiprocess
|
||||||
from aiprocess import (
|
from aiprocess import (
|
||||||
_assert_expected_model_checksum,
|
_assert_expected_model_checksum,
|
||||||
_extract_cleaned_text,
|
_extract_cleaned_text,
|
||||||
_supports_response_format,
|
_supports_response_format,
|
||||||
|
ensure_model,
|
||||||
)
|
)
|
||||||
from constants import MODEL_SHA256
|
from constants import MODEL_SHA256
|
||||||
|
|
||||||
|
|
@ -98,5 +103,81 @@ class ModelChecksumTests(unittest.TestCase):
|
||||||
_assert_expected_model_checksum("0" * 64)
|
_assert_expected_model_checksum("0" * 64)
|
||||||
|
|
||||||
|
|
||||||
|
class _Response:
|
||||||
|
def __init__(self, payload: bytes):
|
||||||
|
self.payload = payload
|
||||||
|
self.offset = 0
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def getheader(self, name: str):
|
||||||
|
if name.lower() == "content-length":
|
||||||
|
return str(len(self.payload))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def read(self, size: int) -> bytes:
|
||||||
|
if self.offset >= len(self.payload):
|
||||||
|
return b""
|
||||||
|
chunk = self.payload[self.offset : self.offset + size]
|
||||||
|
self.offset += len(chunk)
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
class EnsureModelTests(unittest.TestCase):
|
||||||
|
def test_existing_valid_model_skips_download(self):
|
||||||
|
payload = b"valid-model"
|
||||||
|
checksum = sha256(payload).hexdigest()
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
model_path = Path(td) / "model.gguf"
|
||||||
|
model_path.write_bytes(payload)
|
||||||
|
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
|
||||||
|
aiprocess, "MODEL_DIR", model_path.parent
|
||||||
|
), patch.object(aiprocess, "MODEL_SHA256", checksum), patch(
|
||||||
|
"aiprocess.urllib.request.urlopen"
|
||||||
|
) as urlopen:
|
||||||
|
out = ensure_model()
|
||||||
|
|
||||||
|
self.assertEqual(out, model_path)
|
||||||
|
urlopen.assert_not_called()
|
||||||
|
|
||||||
|
def test_existing_invalid_model_triggers_redownload(self):
|
||||||
|
cached_payload = b"bad-model"
|
||||||
|
downloaded_payload = b"good-model"
|
||||||
|
expected_checksum = sha256(downloaded_payload).hexdigest()
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
model_path = Path(td) / "model.gguf"
|
||||||
|
model_path.write_bytes(cached_payload)
|
||||||
|
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
|
||||||
|
aiprocess, "MODEL_DIR", model_path.parent
|
||||||
|
), patch.object(aiprocess, "MODEL_SHA256", expected_checksum), patch(
|
||||||
|
"aiprocess.urllib.request.urlopen",
|
||||||
|
return_value=_Response(downloaded_payload),
|
||||||
|
) as urlopen:
|
||||||
|
out = ensure_model()
|
||||||
|
|
||||||
|
self.assertEqual(out, model_path)
|
||||||
|
self.assertEqual(model_path.read_bytes(), downloaded_payload)
|
||||||
|
urlopen.assert_called_once()
|
||||||
|
|
||||||
|
def test_invalid_cached_model_and_redownload_failure_raises_clear_error(self):
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
model_path = Path(td) / "model.gguf"
|
||||||
|
model_path.write_bytes(b"bad-model")
|
||||||
|
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
|
||||||
|
aiprocess, "MODEL_DIR", model_path.parent
|
||||||
|
), patch.object(aiprocess, "MODEL_SHA256", "f" * 64), patch(
|
||||||
|
"aiprocess.urllib.request.urlopen",
|
||||||
|
side_effect=RuntimeError("network down"),
|
||||||
|
):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "cached model checksum mismatch and redownload failed"
|
||||||
|
):
|
||||||
|
ensure_model()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue