aman/tests/test_aiprocess.py

200 lines
6.4 KiB
Python

import sys
import tempfile
import unittest
from hashlib import sha256
from pathlib import Path
from unittest.mock import patch
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
import aiprocess
from aiprocess import (
_assert_expected_model_checksum,
_extract_cleaned_text,
_profile_generation_kwargs,
_supports_response_format,
ensure_model,
)
from constants import MODEL_SHA256
class ExtractCleanedTextTests(unittest.TestCase):
def test_extracts_cleaned_text_from_json_object(self):
payload = {
"choices": [
{
"message": {
"content": '{"cleaned_text":"Hello <transcript>literal</transcript> world"}'
}
}
]
}
result = _extract_cleaned_text(payload)
self.assertEqual(result, "Hello <transcript>literal</transcript> world")
def test_extracts_cleaned_text_from_json_string(self):
payload = {
"choices": [
{
"message": {
"content": '"He said \\\"hello\\\""'
}
}
]
}
result = _extract_cleaned_text(payload)
self.assertEqual(result, 'He said "hello"')
def test_rejects_non_json_output(self):
payload = {
"choices": [
{
"message": {
"content": "<transcript>Hello</transcript>"
}
}
]
}
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
_extract_cleaned_text(payload)
def test_rejects_json_without_required_key(self):
payload = {
"choices": [
{
"message": {
"content": '{"text":"hello"}'
}
}
]
}
with self.assertRaisesRegex(RuntimeError, "missing cleaned_text"):
_extract_cleaned_text(payload)
class SupportsResponseFormatTests(unittest.TestCase):
def test_supports_response_format_when_parameter_exists(self):
def chat_completion(*, messages, temperature, response_format):
return None
self.assertTrue(_supports_response_format(chat_completion))
def test_does_not_support_response_format_when_missing(self):
def chat_completion(*, messages, temperature):
return None
self.assertFalse(_supports_response_format(chat_completion))
def test_fast_profile_sets_max_tokens_when_supported(self):
def chat_completion(*, messages, temperature, max_tokens):
return None
kwargs = _profile_generation_kwargs(chat_completion, "fast")
self.assertEqual(kwargs, {"max_tokens": 192})
def test_non_fast_profile_does_not_set_generation_overrides(self):
def chat_completion(*, messages, temperature, max_tokens):
return None
kwargs = _profile_generation_kwargs(chat_completion, "polished")
self.assertEqual(kwargs, {})
class ModelChecksumTests(unittest.TestCase):
def test_accepts_expected_checksum_case_insensitive(self):
_assert_expected_model_checksum(MODEL_SHA256.upper())
def test_rejects_unexpected_checksum(self):
with self.assertRaisesRegex(RuntimeError, "checksum mismatch"):
_assert_expected_model_checksum("0" * 64)
class _Response:
def __init__(self, payload: bytes):
self.payload = payload
self.offset = 0
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def getheader(self, name: str):
if name.lower() == "content-length":
return str(len(self.payload))
return None
def read(self, size: int) -> bytes:
if self.offset >= len(self.payload):
return b""
chunk = self.payload[self.offset : self.offset + size]
self.offset += len(chunk)
return chunk
class EnsureModelTests(unittest.TestCase):
def test_existing_valid_model_skips_download(self):
payload = b"valid-model"
checksum = sha256(payload).hexdigest()
with tempfile.TemporaryDirectory() as td:
model_path = Path(td) / "model.gguf"
model_path.write_bytes(payload)
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
aiprocess, "MODEL_DIR", model_path.parent
), patch.object(aiprocess, "MODEL_SHA256", checksum), patch(
"aiprocess.urllib.request.urlopen"
) as urlopen:
out = ensure_model()
self.assertEqual(out, model_path)
urlopen.assert_not_called()
def test_existing_invalid_model_triggers_redownload(self):
cached_payload = b"bad-model"
downloaded_payload = b"good-model"
expected_checksum = sha256(downloaded_payload).hexdigest()
with tempfile.TemporaryDirectory() as td:
model_path = Path(td) / "model.gguf"
model_path.write_bytes(cached_payload)
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
aiprocess, "MODEL_DIR", model_path.parent
), patch.object(aiprocess, "MODEL_SHA256", expected_checksum), patch(
"aiprocess.urllib.request.urlopen",
return_value=_Response(downloaded_payload),
) as urlopen:
out = ensure_model()
self.assertEqual(out, model_path)
self.assertEqual(model_path.read_bytes(), downloaded_payload)
urlopen.assert_called_once()
def test_invalid_cached_model_and_redownload_failure_raises_clear_error(self):
with tempfile.TemporaryDirectory() as td:
model_path = Path(td) / "model.gguf"
model_path.write_bytes(b"bad-model")
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
aiprocess, "MODEL_DIR", model_path.parent
), patch.object(aiprocess, "MODEL_SHA256", "f" * 64), patch(
"aiprocess.urllib.request.urlopen",
side_effect=RuntimeError("network down"),
):
with self.assertRaisesRegex(
RuntimeError, "cached model checksum mismatch and redownload failed"
):
ensure_model()
if __name__ == "__main__":
unittest.main()