Make the milestone 3 runtime story predictable instead of treating doctor, self-check, and startup failures as loosely related surfaces. Split doctor and self-check into distinct read-only flows, add tri-state diagnostic status with stable IDs and next steps, and reuse that wording in CLI output, service logs, and tray-triggered diagnostics. Add non-mutating config/model probes, a make runtime-check gate, and public recovery/validation docs for the X11 GA roadmap. Validation: make runtime-check; PYTHONPATH=src python3 -m unittest discover -s tests -p 'test_*.py'; python3 -m py_compile src/*.py tests/*.py; PYTHONPATH=src python3 -m aman doctor --help; PYTHONPATH=src python3 -m aman self-check --help. Leave milestone 3 open in the roadmap until the manual X11 validation rows are filled.
419 lines
15 KiB
Python
419 lines
15 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from hashlib import sha256
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
import aiprocess
|
|
from aiprocess import (
|
|
ExternalApiProcessor,
|
|
LlamaProcessor,
|
|
_assert_expected_model_checksum,
|
|
_build_request_payload,
|
|
_build_user_prompt_xml,
|
|
_explicit_generation_kwargs,
|
|
_extract_cleaned_text,
|
|
_profile_generation_kwargs,
|
|
_supports_response_format,
|
|
ensure_model,
|
|
probe_managed_model,
|
|
)
|
|
from constants import MODEL_SHA256
|
|
|
|
|
|
class ExtractCleanedTextTests(unittest.TestCase):
|
|
def test_extracts_cleaned_text_from_json_object(self):
|
|
payload = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": '{"cleaned_text":"Hello <transcript>literal</transcript> world"}'
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
result = _extract_cleaned_text(payload)
|
|
|
|
self.assertEqual(result, "Hello <transcript>literal</transcript> world")
|
|
|
|
def test_extracts_cleaned_text_from_json_string(self):
|
|
payload = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": '"He said \\\"hello\\\""'
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
result = _extract_cleaned_text(payload)
|
|
|
|
self.assertEqual(result, 'He said "hello"')
|
|
|
|
def test_rejects_non_json_output(self):
|
|
payload = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": "<transcript>Hello</transcript>"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
|
|
_extract_cleaned_text(payload)
|
|
|
|
def test_rejects_json_without_required_key(self):
|
|
payload = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": '{"text":"hello"}'
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "missing cleaned_text"):
|
|
_extract_cleaned_text(payload)
|
|
|
|
|
|
class SupportsResponseFormatTests(unittest.TestCase):
|
|
def test_supports_response_format_when_parameter_exists(self):
|
|
def chat_completion(*, messages, temperature, response_format):
|
|
return None
|
|
|
|
self.assertTrue(_supports_response_format(chat_completion))
|
|
|
|
def test_does_not_support_response_format_when_missing(self):
|
|
def chat_completion(*, messages, temperature):
|
|
return None
|
|
|
|
self.assertFalse(_supports_response_format(chat_completion))
|
|
|
|
def test_fast_profile_sets_max_tokens_when_supported(self):
|
|
def chat_completion(*, messages, temperature, max_tokens):
|
|
return None
|
|
|
|
kwargs = _profile_generation_kwargs(chat_completion, "fast")
|
|
|
|
self.assertEqual(kwargs, {"max_tokens": 192})
|
|
|
|
def test_non_fast_profile_does_not_set_generation_overrides(self):
|
|
def chat_completion(*, messages, temperature, max_tokens):
|
|
return None
|
|
|
|
kwargs = _profile_generation_kwargs(chat_completion, "polished")
|
|
|
|
self.assertEqual(kwargs, {})
|
|
|
|
def test_explicit_generation_kwargs_honors_supported_params(self):
|
|
def chat_completion(*, messages, temperature, top_p, max_tokens):
|
|
return None
|
|
|
|
kwargs = _explicit_generation_kwargs(
|
|
chat_completion,
|
|
top_p=0.9,
|
|
top_k=40,
|
|
max_tokens=128,
|
|
repeat_penalty=1.1,
|
|
min_p=0.05,
|
|
)
|
|
self.assertEqual(kwargs, {"top_p": 0.9, "max_tokens": 128})
|
|
|
|
|
|
class _WarmupClient:
|
|
def __init__(self, response_payload: dict):
|
|
self.response_payload = response_payload
|
|
self.calls = []
|
|
|
|
def create_chat_completion(
|
|
self,
|
|
*,
|
|
messages,
|
|
temperature,
|
|
response_format=None,
|
|
max_tokens=None,
|
|
):
|
|
self.calls.append(
|
|
{
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"response_format": response_format,
|
|
"max_tokens": max_tokens,
|
|
}
|
|
)
|
|
return self.response_payload
|
|
|
|
|
|
class LlamaWarmupTests(unittest.TestCase):
|
|
def test_warmup_uses_json_mode_and_low_token_budget(self):
|
|
processor = object.__new__(LlamaProcessor)
|
|
client = _WarmupClient(
|
|
{"choices": [{"message": {"content": '{"cleaned_text":"ok"}'}}]}
|
|
)
|
|
processor.client = client
|
|
|
|
processor.warmup(profile="fast")
|
|
|
|
self.assertEqual(len(client.calls), 1)
|
|
call = client.calls[0]
|
|
self.assertEqual(call["temperature"], 0.0)
|
|
self.assertEqual(call["response_format"], {"type": "json_object"})
|
|
self.assertEqual(call["max_tokens"], 32)
|
|
user_content = call["messages"][1]["content"]
|
|
self.assertIn("<request>", user_content)
|
|
self.assertIn("<transcript>warmup</transcript>", user_content)
|
|
self.assertIn("<language>auto</language>", user_content)
|
|
|
|
def test_warmup_raises_on_non_json_response(self):
|
|
processor = object.__new__(LlamaProcessor)
|
|
client = _WarmupClient(
|
|
{"choices": [{"message": {"content": "not-json"}}]}
|
|
)
|
|
processor.client = client
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
|
|
processor.warmup(profile="default")
|
|
|
|
def test_process_with_metrics_uses_single_completion_timing_shape(self):
|
|
processor = object.__new__(LlamaProcessor)
|
|
client = _WarmupClient(
|
|
{"choices": [{"message": {"content": '{"cleaned_text":"friday"}'}}]}
|
|
)
|
|
processor.client = client
|
|
|
|
cleaned_text, timings = processor.process_with_metrics(
|
|
"thursday, I mean friday",
|
|
lang="en",
|
|
dictionary_context="",
|
|
profile="default",
|
|
)
|
|
|
|
self.assertEqual(cleaned_text, "friday")
|
|
self.assertEqual(len(client.calls), 1)
|
|
call = client.calls[0]
|
|
self.assertEqual(call["messages"][0]["content"], aiprocess.SYSTEM_PROMPT)
|
|
self.assertIn('{"cleaned_text":"..."}', call["messages"][1]["content"])
|
|
self.assertEqual(timings.pass1_ms, 0.0)
|
|
self.assertGreater(timings.pass2_ms, 0.0)
|
|
self.assertEqual(timings.pass2_ms, timings.total_ms)
|
|
|
|
|
|
class ModelChecksumTests(unittest.TestCase):
|
|
def test_accepts_expected_checksum_case_insensitive(self):
|
|
_assert_expected_model_checksum(MODEL_SHA256.upper())
|
|
|
|
def test_rejects_unexpected_checksum(self):
|
|
with self.assertRaisesRegex(RuntimeError, "checksum mismatch"):
|
|
_assert_expected_model_checksum("0" * 64)
|
|
|
|
|
|
class RequestPayloadTests(unittest.TestCase):
|
|
def test_build_request_payload_with_dictionary(self):
|
|
payload = _build_request_payload("hello", lang="en", dictionary_context="Docker")
|
|
self.assertEqual(payload["language"], "en")
|
|
self.assertEqual(payload["transcript"], "hello")
|
|
self.assertEqual(payload["dictionary"], "Docker")
|
|
|
|
def test_build_request_payload_omits_empty_dictionary(self):
|
|
payload = _build_request_payload("hello", lang="en", dictionary_context=" ")
|
|
self.assertEqual(payload["language"], "en")
|
|
self.assertEqual(payload["transcript"], "hello")
|
|
self.assertNotIn("dictionary", payload)
|
|
|
|
def test_user_prompt_is_xml_and_escapes_literals(self):
|
|
payload = _build_request_payload(
|
|
'keep <transcript> and "quotes"',
|
|
lang="en",
|
|
dictionary_context="Docker & systemd",
|
|
)
|
|
xml = _build_user_prompt_xml(payload)
|
|
self.assertIn("<request>", xml)
|
|
self.assertIn("<language>en</language>", xml)
|
|
self.assertIn("<transcript>", xml)
|
|
self.assertIn("&", xml)
|
|
self.assertIn("<output_contract>", xml)
|
|
|
|
|
|
class _Response:
|
|
def __init__(self, payload: bytes):
|
|
self.payload = payload
|
|
self.offset = 0
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def getheader(self, name: str):
|
|
if name.lower() == "content-length":
|
|
return str(len(self.payload))
|
|
return None
|
|
|
|
def read(self, size: int = -1) -> bytes:
|
|
if self.offset >= len(self.payload):
|
|
return b""
|
|
if size < 0:
|
|
chunk = self.payload[self.offset :]
|
|
self.offset = len(self.payload)
|
|
return chunk
|
|
chunk = self.payload[self.offset : self.offset + size]
|
|
self.offset += len(chunk)
|
|
return chunk
|
|
|
|
|
|
class EnsureModelTests(unittest.TestCase):
|
|
def test_existing_valid_model_skips_download(self):
|
|
payload = b"valid-model"
|
|
checksum = sha256(payload).hexdigest()
|
|
with tempfile.TemporaryDirectory() as td:
|
|
model_path = Path(td) / "model.gguf"
|
|
model_path.write_bytes(payload)
|
|
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
|
|
aiprocess, "MODEL_DIR", model_path.parent
|
|
), patch.object(aiprocess, "MODEL_SHA256", checksum), patch(
|
|
"aiprocess.urllib.request.urlopen"
|
|
) as urlopen:
|
|
out = ensure_model()
|
|
|
|
self.assertEqual(out, model_path)
|
|
urlopen.assert_not_called()
|
|
|
|
def test_existing_invalid_model_triggers_redownload(self):
|
|
cached_payload = b"bad-model"
|
|
downloaded_payload = b"good-model"
|
|
expected_checksum = sha256(downloaded_payload).hexdigest()
|
|
with tempfile.TemporaryDirectory() as td:
|
|
model_path = Path(td) / "model.gguf"
|
|
model_path.write_bytes(cached_payload)
|
|
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
|
|
aiprocess, "MODEL_DIR", model_path.parent
|
|
), patch.object(aiprocess, "MODEL_SHA256", expected_checksum), patch(
|
|
"aiprocess.urllib.request.urlopen",
|
|
return_value=_Response(downloaded_payload),
|
|
) as urlopen:
|
|
out = ensure_model()
|
|
|
|
self.assertEqual(out, model_path)
|
|
self.assertEqual(model_path.read_bytes(), downloaded_payload)
|
|
urlopen.assert_called_once()
|
|
|
|
def test_invalid_cached_model_and_redownload_failure_raises_clear_error(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
model_path = Path(td) / "model.gguf"
|
|
model_path.write_bytes(b"bad-model")
|
|
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
|
|
aiprocess, "MODEL_DIR", model_path.parent
|
|
), patch.object(aiprocess, "MODEL_SHA256", "f" * 64), patch(
|
|
"aiprocess.urllib.request.urlopen",
|
|
side_effect=RuntimeError("network down"),
|
|
):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "cached model checksum mismatch and redownload failed"
|
|
):
|
|
ensure_model()
|
|
|
|
def test_probe_managed_model_is_read_only_for_valid_cache(self):
|
|
payload = b"valid-model"
|
|
checksum = sha256(payload).hexdigest()
|
|
with tempfile.TemporaryDirectory() as td:
|
|
model_path = Path(td) / "model.gguf"
|
|
model_path.write_bytes(payload)
|
|
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
|
|
aiprocess, "MODEL_SHA256", checksum
|
|
), patch("aiprocess.urllib.request.urlopen") as urlopen:
|
|
result = probe_managed_model()
|
|
|
|
self.assertEqual(result.status, "ready")
|
|
self.assertIn("ready", result.message)
|
|
urlopen.assert_not_called()
|
|
|
|
def test_probe_managed_model_reports_missing_cache(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
model_path = Path(td) / "model.gguf"
|
|
with patch.object(aiprocess, "MODEL_PATH", model_path):
|
|
result = probe_managed_model()
|
|
|
|
self.assertEqual(result.status, "missing")
|
|
self.assertIn(str(model_path), result.message)
|
|
|
|
def test_probe_managed_model_reports_invalid_checksum(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
model_path = Path(td) / "model.gguf"
|
|
model_path.write_bytes(b"bad-model")
|
|
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
|
|
aiprocess, "MODEL_SHA256", "f" * 64
|
|
):
|
|
result = probe_managed_model()
|
|
|
|
self.assertEqual(result.status, "invalid")
|
|
self.assertIn("checksum mismatch", result.message)
|
|
|
|
|
|
class ExternalApiProcessorTests(unittest.TestCase):
|
|
def test_requires_api_key_env_var(self):
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
with self.assertRaisesRegex(RuntimeError, "missing external api key"):
|
|
ExternalApiProcessor(
|
|
provider="openai",
|
|
base_url="https://api.openai.com/v1",
|
|
model="gpt-4o-mini",
|
|
api_key_env_var="AMAN_EXTERNAL_API_KEY",
|
|
timeout_ms=1000,
|
|
max_retries=0,
|
|
)
|
|
|
|
def test_process_uses_chat_completion_endpoint(self):
|
|
response_payload = {
|
|
"choices": [{"message": {"content": '{"cleaned_text":"clean"}'}}],
|
|
}
|
|
response_body = json.dumps(response_payload).encode("utf-8")
|
|
with patch.dict(os.environ, {"AMAN_EXTERNAL_API_KEY": "test-key"}, clear=True), patch(
|
|
"aiprocess.urllib.request.urlopen",
|
|
return_value=_Response(response_body),
|
|
) as urlopen:
|
|
processor = ExternalApiProcessor(
|
|
provider="openai",
|
|
base_url="https://api.openai.com/v1",
|
|
model="gpt-4o-mini",
|
|
api_key_env_var="AMAN_EXTERNAL_API_KEY",
|
|
timeout_ms=1000,
|
|
max_retries=0,
|
|
)
|
|
out = processor.process("raw text", dictionary_context="Docker")
|
|
|
|
self.assertEqual(out, "clean")
|
|
request = urlopen.call_args[0][0]
|
|
self.assertTrue(request.full_url.endswith("/chat/completions"))
|
|
|
|
def test_warmup_is_a_noop(self):
|
|
with patch.dict(os.environ, {"AMAN_EXTERNAL_API_KEY": "test-key"}, clear=True):
|
|
processor = ExternalApiProcessor(
|
|
provider="openai",
|
|
base_url="https://api.openai.com/v1",
|
|
model="gpt-4o-mini",
|
|
api_key_env_var="AMAN_EXTERNAL_API_KEY",
|
|
timeout_ms=1000,
|
|
max_retries=0,
|
|
)
|
|
with patch("aiprocess.urllib.request.urlopen") as urlopen:
|
|
processor.warmup(profile="fast")
|
|
|
|
urlopen.assert_not_called()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|