102 lines
2.8 KiB
Python
102 lines
2.8 KiB
Python
import sys
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
from aiprocess import (
|
|
_assert_expected_model_checksum,
|
|
_extract_cleaned_text,
|
|
_supports_response_format,
|
|
)
|
|
from constants import MODEL_SHA256
|
|
|
|
|
|
class ExtractCleanedTextTests(unittest.TestCase):
|
|
def test_extracts_cleaned_text_from_json_object(self):
|
|
payload = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": '{"cleaned_text":"Hello <transcript>literal</transcript> world"}'
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
result = _extract_cleaned_text(payload)
|
|
|
|
self.assertEqual(result, "Hello <transcript>literal</transcript> world")
|
|
|
|
def test_extracts_cleaned_text_from_json_string(self):
|
|
payload = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": '"He said \\\"hello\\\""'
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
result = _extract_cleaned_text(payload)
|
|
|
|
self.assertEqual(result, 'He said "hello"')
|
|
|
|
def test_rejects_non_json_output(self):
|
|
payload = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": "<transcript>Hello</transcript>"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
|
|
_extract_cleaned_text(payload)
|
|
|
|
def test_rejects_json_without_required_key(self):
|
|
payload = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": '{"text":"hello"}'
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "missing cleaned_text"):
|
|
_extract_cleaned_text(payload)
|
|
|
|
|
|
class SupportsResponseFormatTests(unittest.TestCase):
|
|
def test_supports_response_format_when_parameter_exists(self):
|
|
def chat_completion(*, messages, temperature, response_format):
|
|
return None
|
|
|
|
self.assertTrue(_supports_response_format(chat_completion))
|
|
|
|
def test_does_not_support_response_format_when_missing(self):
|
|
def chat_completion(*, messages, temperature):
|
|
return None
|
|
|
|
self.assertFalse(_supports_response_format(chat_completion))
|
|
|
|
|
|
class ModelChecksumTests(unittest.TestCase):
|
|
def test_accepts_expected_checksum_case_insensitive(self):
|
|
_assert_expected_model_checksum(MODEL_SHA256.upper())
|
|
|
|
def test_rejects_unexpected_checksum(self):
|
|
with self.assertRaisesRegex(RuntimeError, "checksum mismatch"):
|
|
_assert_expected_model_checksum("0" * 64)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|