import sys import tempfile import unittest from hashlib import sha256 from pathlib import Path from unittest.mock import patch ROOT = Path(__file__).resolve().parents[1] SRC = ROOT / "src" if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) import aiprocess from aiprocess import ( _assert_expected_model_checksum, _extract_cleaned_text, _profile_generation_kwargs, _supports_response_format, ensure_model, ) from constants import MODEL_SHA256 class ExtractCleanedTextTests(unittest.TestCase): def test_extracts_cleaned_text_from_json_object(self): payload = { "choices": [ { "message": { "content": '{"cleaned_text":"Hello literal world"}' } } ] } result = _extract_cleaned_text(payload) self.assertEqual(result, "Hello literal world") def test_extracts_cleaned_text_from_json_string(self): payload = { "choices": [ { "message": { "content": '"He said \\\"hello\\\""' } } ] } result = _extract_cleaned_text(payload) self.assertEqual(result, 'He said "hello"') def test_rejects_non_json_output(self): payload = { "choices": [ { "message": { "content": "Hello" } } ] } with self.assertRaisesRegex(RuntimeError, "expected JSON"): _extract_cleaned_text(payload) def test_rejects_json_without_required_key(self): payload = { "choices": [ { "message": { "content": '{"text":"hello"}' } } ] } with self.assertRaisesRegex(RuntimeError, "missing cleaned_text"): _extract_cleaned_text(payload) class SupportsResponseFormatTests(unittest.TestCase): def test_supports_response_format_when_parameter_exists(self): def chat_completion(*, messages, temperature, response_format): return None self.assertTrue(_supports_response_format(chat_completion)) def test_does_not_support_response_format_when_missing(self): def chat_completion(*, messages, temperature): return None self.assertFalse(_supports_response_format(chat_completion)) def test_fast_profile_sets_max_tokens_when_supported(self): def chat_completion(*, messages, temperature, max_tokens): return None kwargs = _profile_generation_kwargs(chat_completion, "fast") self.assertEqual(kwargs, {"max_tokens": 192}) def test_non_fast_profile_does_not_set_generation_overrides(self): def chat_completion(*, messages, temperature, max_tokens): return None kwargs = _profile_generation_kwargs(chat_completion, "polished") self.assertEqual(kwargs, {}) class ModelChecksumTests(unittest.TestCase): def test_accepts_expected_checksum_case_insensitive(self): _assert_expected_model_checksum(MODEL_SHA256.upper()) def test_rejects_unexpected_checksum(self): with self.assertRaisesRegex(RuntimeError, "checksum mismatch"): _assert_expected_model_checksum("0" * 64) class _Response: def __init__(self, payload: bytes): self.payload = payload self.offset = 0 def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def getheader(self, name: str): if name.lower() == "content-length": return str(len(self.payload)) return None def read(self, size: int) -> bytes: if self.offset >= len(self.payload): return b"" chunk = self.payload[self.offset : self.offset + size] self.offset += len(chunk) return chunk class EnsureModelTests(unittest.TestCase): def test_existing_valid_model_skips_download(self): payload = b"valid-model" checksum = sha256(payload).hexdigest() with tempfile.TemporaryDirectory() as td: model_path = Path(td) / "model.gguf" model_path.write_bytes(payload) with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object( aiprocess, "MODEL_DIR", model_path.parent ), patch.object(aiprocess, "MODEL_SHA256", checksum), patch( "aiprocess.urllib.request.urlopen" ) as urlopen: out = ensure_model() self.assertEqual(out, model_path) urlopen.assert_not_called() def test_existing_invalid_model_triggers_redownload(self): cached_payload = b"bad-model" downloaded_payload = b"good-model" expected_checksum = sha256(downloaded_payload).hexdigest() with tempfile.TemporaryDirectory() as td: model_path = Path(td) / "model.gguf" model_path.write_bytes(cached_payload) with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object( aiprocess, "MODEL_DIR", model_path.parent ), patch.object(aiprocess, "MODEL_SHA256", expected_checksum), patch( "aiprocess.urllib.request.urlopen", return_value=_Response(downloaded_payload), ) as urlopen: out = ensure_model() self.assertEqual(out, model_path) self.assertEqual(model_path.read_bytes(), downloaded_payload) urlopen.assert_called_once() def test_invalid_cached_model_and_redownload_failure_raises_clear_error(self): with tempfile.TemporaryDirectory() as td: model_path = Path(td) / "model.gguf" model_path.write_bytes(b"bad-model") with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object( aiprocess, "MODEL_DIR", model_path.parent ), patch.object(aiprocess, "MODEL_SHA256", "f" * 64), patch( "aiprocess.urllib.request.urlopen", side_effect=RuntimeError("network down"), ): with self.assertRaisesRegex( RuntimeError, "cached model checksum mismatch and redownload failed" ): ensure_model() if __name__ == "__main__": unittest.main()