aman/tests/test_aiprocess.py
Thales Maciel fa91f313c4
Some checks are pending
ci / test-and-build (push) Waiting to run
Simplify editor cleanup and keep live ASR metadata
Keep the daemon path on the full ASR result so word timings and detected language survive into the editor pipeline instead of falling back to a plain transcript string.

Add PipelineEngine.run_asr_result(), have aman call it when live ASR data is available, and cover the word-aware alignment behavior in the daemon tests.

Collapse the llama cleanup flow to a single JSON-shaped completion while leaving the legacy pass1/pass2 parameters in place as compatibility no-ops.

Validated with PYTHONPATH=src python3 -m unittest tests.test_aiprocess tests.test_aman.
2026-03-12 13:24:36 -03:00

382 lines
13 KiB
Python

import json
import os
import sys
import tempfile
import unittest
from hashlib import sha256
from pathlib import Path
from unittest.mock import patch
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
import aiprocess
from aiprocess import (
ExternalApiProcessor,
LlamaProcessor,
_assert_expected_model_checksum,
_build_request_payload,
_build_user_prompt_xml,
_explicit_generation_kwargs,
_extract_cleaned_text,
_profile_generation_kwargs,
_supports_response_format,
ensure_model,
)
from constants import MODEL_SHA256
class ExtractCleanedTextTests(unittest.TestCase):
def test_extracts_cleaned_text_from_json_object(self):
payload = {
"choices": [
{
"message": {
"content": '{"cleaned_text":"Hello <transcript>literal</transcript> world"}'
}
}
]
}
result = _extract_cleaned_text(payload)
self.assertEqual(result, "Hello <transcript>literal</transcript> world")
def test_extracts_cleaned_text_from_json_string(self):
payload = {
"choices": [
{
"message": {
"content": '"He said \\\"hello\\\""'
}
}
]
}
result = _extract_cleaned_text(payload)
self.assertEqual(result, 'He said "hello"')
def test_rejects_non_json_output(self):
payload = {
"choices": [
{
"message": {
"content": "<transcript>Hello</transcript>"
}
}
]
}
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
_extract_cleaned_text(payload)
def test_rejects_json_without_required_key(self):
payload = {
"choices": [
{
"message": {
"content": '{"text":"hello"}'
}
}
]
}
with self.assertRaisesRegex(RuntimeError, "missing cleaned_text"):
_extract_cleaned_text(payload)
class SupportsResponseFormatTests(unittest.TestCase):
def test_supports_response_format_when_parameter_exists(self):
def chat_completion(*, messages, temperature, response_format):
return None
self.assertTrue(_supports_response_format(chat_completion))
def test_does_not_support_response_format_when_missing(self):
def chat_completion(*, messages, temperature):
return None
self.assertFalse(_supports_response_format(chat_completion))
def test_fast_profile_sets_max_tokens_when_supported(self):
def chat_completion(*, messages, temperature, max_tokens):
return None
kwargs = _profile_generation_kwargs(chat_completion, "fast")
self.assertEqual(kwargs, {"max_tokens": 192})
def test_non_fast_profile_does_not_set_generation_overrides(self):
def chat_completion(*, messages, temperature, max_tokens):
return None
kwargs = _profile_generation_kwargs(chat_completion, "polished")
self.assertEqual(kwargs, {})
def test_explicit_generation_kwargs_honors_supported_params(self):
def chat_completion(*, messages, temperature, top_p, max_tokens):
return None
kwargs = _explicit_generation_kwargs(
chat_completion,
top_p=0.9,
top_k=40,
max_tokens=128,
repeat_penalty=1.1,
min_p=0.05,
)
self.assertEqual(kwargs, {"top_p": 0.9, "max_tokens": 128})
class _WarmupClient:
def __init__(self, response_payload: dict):
self.response_payload = response_payload
self.calls = []
def create_chat_completion(
self,
*,
messages,
temperature,
response_format=None,
max_tokens=None,
):
self.calls.append(
{
"messages": messages,
"temperature": temperature,
"response_format": response_format,
"max_tokens": max_tokens,
}
)
return self.response_payload
class LlamaWarmupTests(unittest.TestCase):
def test_warmup_uses_json_mode_and_low_token_budget(self):
processor = object.__new__(LlamaProcessor)
client = _WarmupClient(
{"choices": [{"message": {"content": '{"cleaned_text":"ok"}'}}]}
)
processor.client = client
processor.warmup(profile="fast")
self.assertEqual(len(client.calls), 1)
call = client.calls[0]
self.assertEqual(call["temperature"], 0.0)
self.assertEqual(call["response_format"], {"type": "json_object"})
self.assertEqual(call["max_tokens"], 32)
user_content = call["messages"][1]["content"]
self.assertIn("<request>", user_content)
self.assertIn("<transcript>warmup</transcript>", user_content)
self.assertIn("<language>auto</language>", user_content)
def test_warmup_raises_on_non_json_response(self):
processor = object.__new__(LlamaProcessor)
client = _WarmupClient(
{"choices": [{"message": {"content": "not-json"}}]}
)
processor.client = client
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
processor.warmup(profile="default")
def test_process_with_metrics_uses_single_completion_timing_shape(self):
processor = object.__new__(LlamaProcessor)
client = _WarmupClient(
{"choices": [{"message": {"content": '{"cleaned_text":"friday"}'}}]}
)
processor.client = client
cleaned_text, timings = processor.process_with_metrics(
"thursday, I mean friday",
lang="en",
dictionary_context="",
profile="default",
)
self.assertEqual(cleaned_text, "friday")
self.assertEqual(len(client.calls), 1)
call = client.calls[0]
self.assertEqual(call["messages"][0]["content"], aiprocess.SYSTEM_PROMPT)
self.assertIn('{"cleaned_text":"..."}', call["messages"][1]["content"])
self.assertEqual(timings.pass1_ms, 0.0)
self.assertGreater(timings.pass2_ms, 0.0)
self.assertEqual(timings.pass2_ms, timings.total_ms)
class ModelChecksumTests(unittest.TestCase):
def test_accepts_expected_checksum_case_insensitive(self):
_assert_expected_model_checksum(MODEL_SHA256.upper())
def test_rejects_unexpected_checksum(self):
with self.assertRaisesRegex(RuntimeError, "checksum mismatch"):
_assert_expected_model_checksum("0" * 64)
class RequestPayloadTests(unittest.TestCase):
def test_build_request_payload_with_dictionary(self):
payload = _build_request_payload("hello", lang="en", dictionary_context="Docker")
self.assertEqual(payload["language"], "en")
self.assertEqual(payload["transcript"], "hello")
self.assertEqual(payload["dictionary"], "Docker")
def test_build_request_payload_omits_empty_dictionary(self):
payload = _build_request_payload("hello", lang="en", dictionary_context=" ")
self.assertEqual(payload["language"], "en")
self.assertEqual(payload["transcript"], "hello")
self.assertNotIn("dictionary", payload)
def test_user_prompt_is_xml_and_escapes_literals(self):
payload = _build_request_payload(
'keep <transcript> and "quotes"',
lang="en",
dictionary_context="Docker & systemd",
)
xml = _build_user_prompt_xml(payload)
self.assertIn("<request>", xml)
self.assertIn("<language>en</language>", xml)
self.assertIn("&lt;transcript&gt;", xml)
self.assertIn("&amp;", xml)
self.assertIn("<output_contract>", xml)
class _Response:
def __init__(self, payload: bytes):
self.payload = payload
self.offset = 0
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def getheader(self, name: str):
if name.lower() == "content-length":
return str(len(self.payload))
return None
def read(self, size: int = -1) -> bytes:
if self.offset >= len(self.payload):
return b""
if size < 0:
chunk = self.payload[self.offset :]
self.offset = len(self.payload)
return chunk
chunk = self.payload[self.offset : self.offset + size]
self.offset += len(chunk)
return chunk
class EnsureModelTests(unittest.TestCase):
def test_existing_valid_model_skips_download(self):
payload = b"valid-model"
checksum = sha256(payload).hexdigest()
with tempfile.TemporaryDirectory() as td:
model_path = Path(td) / "model.gguf"
model_path.write_bytes(payload)
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
aiprocess, "MODEL_DIR", model_path.parent
), patch.object(aiprocess, "MODEL_SHA256", checksum), patch(
"aiprocess.urllib.request.urlopen"
) as urlopen:
out = ensure_model()
self.assertEqual(out, model_path)
urlopen.assert_not_called()
def test_existing_invalid_model_triggers_redownload(self):
cached_payload = b"bad-model"
downloaded_payload = b"good-model"
expected_checksum = sha256(downloaded_payload).hexdigest()
with tempfile.TemporaryDirectory() as td:
model_path = Path(td) / "model.gguf"
model_path.write_bytes(cached_payload)
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
aiprocess, "MODEL_DIR", model_path.parent
), patch.object(aiprocess, "MODEL_SHA256", expected_checksum), patch(
"aiprocess.urllib.request.urlopen",
return_value=_Response(downloaded_payload),
) as urlopen:
out = ensure_model()
self.assertEqual(out, model_path)
self.assertEqual(model_path.read_bytes(), downloaded_payload)
urlopen.assert_called_once()
def test_invalid_cached_model_and_redownload_failure_raises_clear_error(self):
with tempfile.TemporaryDirectory() as td:
model_path = Path(td) / "model.gguf"
model_path.write_bytes(b"bad-model")
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
aiprocess, "MODEL_DIR", model_path.parent
), patch.object(aiprocess, "MODEL_SHA256", "f" * 64), patch(
"aiprocess.urllib.request.urlopen",
side_effect=RuntimeError("network down"),
):
with self.assertRaisesRegex(
RuntimeError, "cached model checksum mismatch and redownload failed"
):
ensure_model()
class ExternalApiProcessorTests(unittest.TestCase):
def test_requires_api_key_env_var(self):
with patch.dict(os.environ, {}, clear=True):
with self.assertRaisesRegex(RuntimeError, "missing external api key"):
ExternalApiProcessor(
provider="openai",
base_url="https://api.openai.com/v1",
model="gpt-4o-mini",
api_key_env_var="AMAN_EXTERNAL_API_KEY",
timeout_ms=1000,
max_retries=0,
)
def test_process_uses_chat_completion_endpoint(self):
response_payload = {
"choices": [{"message": {"content": '{"cleaned_text":"clean"}'}}],
}
response_body = json.dumps(response_payload).encode("utf-8")
with patch.dict(os.environ, {"AMAN_EXTERNAL_API_KEY": "test-key"}, clear=True), patch(
"aiprocess.urllib.request.urlopen",
return_value=_Response(response_body),
) as urlopen:
processor = ExternalApiProcessor(
provider="openai",
base_url="https://api.openai.com/v1",
model="gpt-4o-mini",
api_key_env_var="AMAN_EXTERNAL_API_KEY",
timeout_ms=1000,
max_retries=0,
)
out = processor.process("raw text", dictionary_context="Docker")
self.assertEqual(out, "clean")
request = urlopen.call_args[0][0]
self.assertTrue(request.full_url.endswith("/chat/completions"))
def test_warmup_is_a_noop(self):
with patch.dict(os.environ, {"AMAN_EXTERNAL_API_KEY": "test-key"}, clear=True):
processor = ExternalApiProcessor(
provider="openai",
base_url="https://api.openai.com/v1",
model="gpt-4o-mini",
api_key_env_var="AMAN_EXTERNAL_API_KEY",
timeout_ms=1000,
max_retries=0,
)
with patch("aiprocess.urllib.request.urlopen") as urlopen:
processor.warmup(profile="fast")
urlopen.assert_not_called()
if __name__ == "__main__":
unittest.main()