Remove log_transcript config and enforce JSON AI output
This commit is contained in:
parent
c3503fbbde
commit
1423e44008
8 changed files with 198 additions and 62 deletions
88
tests/test_aiprocess.py
Normal file
88
tests/test_aiprocess.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from aiprocess import _extract_cleaned_text, _supports_response_format
|
||||
|
||||
|
||||
class ExtractCleanedTextTests(unittest.TestCase):
|
||||
def test_extracts_cleaned_text_from_json_object(self):
|
||||
payload = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": '{"cleaned_text":"Hello <transcript>literal</transcript> world"}'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = _extract_cleaned_text(payload)
|
||||
|
||||
self.assertEqual(result, "Hello <transcript>literal</transcript> world")
|
||||
|
||||
def test_extracts_cleaned_text_from_json_string(self):
|
||||
payload = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": '"He said \\\"hello\\\""'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = _extract_cleaned_text(payload)
|
||||
|
||||
self.assertEqual(result, 'He said "hello"')
|
||||
|
||||
def test_rejects_non_json_output(self):
|
||||
payload = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "<transcript>Hello</transcript>"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
|
||||
_extract_cleaned_text(payload)
|
||||
|
||||
def test_rejects_json_without_required_key(self):
|
||||
payload = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": '{"text":"hello"}'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "missing cleaned_text"):
|
||||
_extract_cleaned_text(payload)
|
||||
|
||||
|
||||
class SupportsResponseFormatTests(unittest.TestCase):
|
||||
def test_supports_response_format_when_parameter_exists(self):
|
||||
def chat_completion(*, messages, temperature, response_format):
|
||||
return None
|
||||
|
||||
self.assertTrue(_supports_response_format(chat_completion))
|
||||
|
||||
def test_does_not_support_response_format_when_missing(self):
|
||||
def chat_completion(*, messages, temperature):
|
||||
return None
|
||||
|
||||
self.assertFalse(_supports_response_format(chat_completion))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue