import sys import tempfile import unittest from hashlib import sha256 from pathlib import Path from unittest.mock import patch ROOT = Path(__file__).resolve().parents[1] SRC = ROOT / "src" if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) import aiprocess from aiprocess import ( LlamaProcessor, _assert_expected_model_checksum, _build_request_payload, _build_user_prompt_xml, _explicit_generation_kwargs, _extract_cleaned_text, _profile_generation_kwargs, _supports_response_format, ensure_model, probe_managed_model, ) from constants import MODEL_SHA256 class ExtractCleanedTextTests(unittest.TestCase): def test_extracts_cleaned_text_from_json_object(self): payload = { "choices": [ { "message": { "content": '{"cleaned_text":"Hello literal world"}' } } ] } result = _extract_cleaned_text(payload) self.assertEqual(result, "Hello literal world") def test_extracts_cleaned_text_from_json_string(self): payload = { "choices": [ { "message": { "content": '"He said \\\"hello\\\""' } } ] } result = _extract_cleaned_text(payload) self.assertEqual(result, 'He said "hello"') def test_rejects_non_json_output(self): payload = { "choices": [ { "message": { "content": "Hello" } } ] } with self.assertRaisesRegex(RuntimeError, "expected JSON"): _extract_cleaned_text(payload) def test_rejects_json_without_required_key(self): payload = { "choices": [ { "message": { "content": '{"text":"hello"}' } } ] } with self.assertRaisesRegex(RuntimeError, "missing cleaned_text"): _extract_cleaned_text(payload) class SupportsResponseFormatTests(unittest.TestCase): def test_supports_response_format_when_parameter_exists(self): def chat_completion(*, messages, temperature, response_format): return None self.assertTrue(_supports_response_format(chat_completion)) def test_does_not_support_response_format_when_missing(self): def chat_completion(*, messages, temperature): return None self.assertFalse(_supports_response_format(chat_completion)) def test_fast_profile_sets_max_tokens_when_supported(self): def chat_completion(*, messages, temperature, max_tokens): return None kwargs = _profile_generation_kwargs(chat_completion, "fast") self.assertEqual(kwargs, {"max_tokens": 192}) def test_non_fast_profile_does_not_set_generation_overrides(self): def chat_completion(*, messages, temperature, max_tokens): return None kwargs = _profile_generation_kwargs(chat_completion, "polished") self.assertEqual(kwargs, {}) def test_explicit_generation_kwargs_honors_supported_params(self): def chat_completion(*, messages, temperature, top_p, max_tokens): return None kwargs = _explicit_generation_kwargs( chat_completion, top_p=0.9, top_k=40, max_tokens=128, repeat_penalty=1.1, min_p=0.05, ) self.assertEqual(kwargs, {"top_p": 0.9, "max_tokens": 128}) class _WarmupClient: def __init__(self, response_payload: dict): self.response_payload = response_payload self.calls = [] def create_chat_completion( self, *, messages, temperature, response_format=None, max_tokens=None, ): self.calls.append( { "messages": messages, "temperature": temperature, "response_format": response_format, "max_tokens": max_tokens, } ) return self.response_payload class LlamaWarmupTests(unittest.TestCase): def test_warmup_uses_json_mode_and_low_token_budget(self): processor = object.__new__(LlamaProcessor) client = _WarmupClient( {"choices": [{"message": {"content": '{"cleaned_text":"ok"}'}}]} ) processor.client = client processor.warmup(profile="fast") self.assertEqual(len(client.calls), 1) call = client.calls[0] self.assertEqual(call["temperature"], 0.0) self.assertEqual(call["response_format"], {"type": "json_object"}) self.assertEqual(call["max_tokens"], 32) user_content = call["messages"][1]["content"] self.assertIn("", user_content) self.assertIn("warmup", user_content) self.assertIn("auto", user_content) def test_warmup_raises_on_non_json_response(self): processor = object.__new__(LlamaProcessor) client = _WarmupClient( {"choices": [{"message": {"content": "not-json"}}]} ) processor.client = client with self.assertRaisesRegex(RuntimeError, "expected JSON"): processor.warmup(profile="default") def test_process_with_metrics_uses_single_completion_timing_shape(self): processor = object.__new__(LlamaProcessor) client = _WarmupClient( {"choices": [{"message": {"content": '{"cleaned_text":"friday"}'}}]} ) processor.client = client cleaned_text, timings = processor.process_with_metrics( "thursday, I mean friday", lang="en", dictionary_context="", profile="default", ) self.assertEqual(cleaned_text, "friday") self.assertEqual(len(client.calls), 1) call = client.calls[0] self.assertEqual(call["messages"][0]["content"], aiprocess.SYSTEM_PROMPT) self.assertIn('{"cleaned_text":"..."}', call["messages"][1]["content"]) self.assertEqual(timings.pass1_ms, 0.0) self.assertGreater(timings.pass2_ms, 0.0) self.assertEqual(timings.pass2_ms, timings.total_ms) class ModelChecksumTests(unittest.TestCase): def test_accepts_expected_checksum_case_insensitive(self): _assert_expected_model_checksum(MODEL_SHA256.upper()) def test_rejects_unexpected_checksum(self): with self.assertRaisesRegex(RuntimeError, "checksum mismatch"): _assert_expected_model_checksum("0" * 64) class RequestPayloadTests(unittest.TestCase): def test_build_request_payload_with_dictionary(self): payload = _build_request_payload("hello", lang="en", dictionary_context="Docker") self.assertEqual(payload["language"], "en") self.assertEqual(payload["transcript"], "hello") self.assertEqual(payload["dictionary"], "Docker") def test_build_request_payload_omits_empty_dictionary(self): payload = _build_request_payload("hello", lang="en", dictionary_context=" ") self.assertEqual(payload["language"], "en") self.assertEqual(payload["transcript"], "hello") self.assertNotIn("dictionary", payload) def test_user_prompt_is_xml_and_escapes_literals(self): payload = _build_request_payload( 'keep and "quotes"', lang="en", dictionary_context="Docker & systemd", ) xml = _build_user_prompt_xml(payload) self.assertIn("", xml) self.assertIn("en", xml) self.assertIn("<transcript>", xml) self.assertIn("&", xml) self.assertIn("", xml) class _Response: def __init__(self, payload: bytes): self.payload = payload self.offset = 0 def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def getheader(self, name: str): if name.lower() == "content-length": return str(len(self.payload)) return None def read(self, size: int = -1) -> bytes: if self.offset >= len(self.payload): return b"" if size < 0: chunk = self.payload[self.offset :] self.offset = len(self.payload) return chunk chunk = self.payload[self.offset : self.offset + size] self.offset += len(chunk) return chunk class EnsureModelTests(unittest.TestCase): def test_existing_valid_model_skips_download(self): payload = b"valid-model" checksum = sha256(payload).hexdigest() with tempfile.TemporaryDirectory() as td: model_path = Path(td) / "model.gguf" model_path.write_bytes(payload) with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object( aiprocess, "MODEL_DIR", model_path.parent ), patch.object(aiprocess, "MODEL_SHA256", checksum), patch( "aiprocess.urllib.request.urlopen" ) as urlopen: out = ensure_model() self.assertEqual(out, model_path) urlopen.assert_not_called() def test_existing_invalid_model_triggers_redownload(self): cached_payload = b"bad-model" downloaded_payload = b"good-model" expected_checksum = sha256(downloaded_payload).hexdigest() with tempfile.TemporaryDirectory() as td: model_path = Path(td) / "model.gguf" model_path.write_bytes(cached_payload) with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object( aiprocess, "MODEL_DIR", model_path.parent ), patch.object(aiprocess, "MODEL_SHA256", expected_checksum), patch( "aiprocess.urllib.request.urlopen", return_value=_Response(downloaded_payload), ) as urlopen: out = ensure_model() self.assertEqual(out, model_path) self.assertEqual(model_path.read_bytes(), downloaded_payload) urlopen.assert_called_once() def test_invalid_cached_model_and_redownload_failure_raises_clear_error(self): with tempfile.TemporaryDirectory() as td: model_path = Path(td) / "model.gguf" model_path.write_bytes(b"bad-model") with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object( aiprocess, "MODEL_DIR", model_path.parent ), patch.object(aiprocess, "MODEL_SHA256", "f" * 64), patch( "aiprocess.urllib.request.urlopen", side_effect=RuntimeError("network down"), ): with self.assertRaisesRegex( RuntimeError, "cached model checksum mismatch and redownload failed" ): ensure_model() def test_probe_managed_model_is_read_only_for_valid_cache(self): payload = b"valid-model" checksum = sha256(payload).hexdigest() with tempfile.TemporaryDirectory() as td: model_path = Path(td) / "model.gguf" model_path.write_bytes(payload) with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object( aiprocess, "MODEL_SHA256", checksum ), patch("aiprocess.urllib.request.urlopen") as urlopen: result = probe_managed_model() self.assertEqual(result.status, "ready") self.assertIn("ready", result.message) urlopen.assert_not_called() def test_probe_managed_model_reports_missing_cache(self): with tempfile.TemporaryDirectory() as td: model_path = Path(td) / "model.gguf" with patch.object(aiprocess, "MODEL_PATH", model_path): result = probe_managed_model() self.assertEqual(result.status, "missing") self.assertIn(str(model_path), result.message) def test_probe_managed_model_reports_invalid_checksum(self): with tempfile.TemporaryDirectory() as td: model_path = Path(td) / "model.gguf" model_path.write_bytes(b"bad-model") with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object( aiprocess, "MODEL_SHA256", "f" * 64 ): result = probe_managed_model() self.assertEqual(result.status, "invalid") self.assertIn("checksum mismatch", result.message) if __name__ == "__main__": unittest.main()