Stop shipping code that implied Aman supported a two-pass editor, external API cleanup, or a Wayland scaffold when the runtime only exercises single-pass local cleanup on X11.\n\nCollapse aiprocess to the active single-pass Llama contract, delete desktop_wayland and the empty wayland extra, and make model_eval reject pass1_/pass2_ tuning keys while keeping pass1_ms/pass2_ms as report compatibility fields.\n\nRemove the unused pillow dependency, switch to SPDX-style license metadata, and clean setuptools build state before packaging so deleted modules do not leak into wheels. Update the methodology and repo guidance docs, and add focused tests for desktop adapter selection, stale param rejection, and portable wheel contents.\n\nValidate with uv lock, python3 -m unittest discover -s tests -p 'test_*.py', python3 -m py_compile src/*.py tests/*.py, and python3 -m build --wheel --sdist --no-isolation.
364 lines
12 KiB
Python
364 lines
12 KiB
Python
import sys
|
|
import tempfile
|
|
import unittest
|
|
from hashlib import sha256
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
import aiprocess
|
|
from aiprocess import (
|
|
LlamaProcessor,
|
|
_assert_expected_model_checksum,
|
|
_build_request_payload,
|
|
_build_user_prompt_xml,
|
|
_explicit_generation_kwargs,
|
|
_extract_cleaned_text,
|
|
_profile_generation_kwargs,
|
|
_supports_response_format,
|
|
ensure_model,
|
|
probe_managed_model,
|
|
)
|
|
from constants import MODEL_SHA256
|
|
|
|
|
|
class ExtractCleanedTextTests(unittest.TestCase):
|
|
def test_extracts_cleaned_text_from_json_object(self):
|
|
payload = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": '{"cleaned_text":"Hello <transcript>literal</transcript> world"}'
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
result = _extract_cleaned_text(payload)
|
|
|
|
self.assertEqual(result, "Hello <transcript>literal</transcript> world")
|
|
|
|
def test_extracts_cleaned_text_from_json_string(self):
|
|
payload = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": '"He said \\\"hello\\\""'
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
result = _extract_cleaned_text(payload)
|
|
|
|
self.assertEqual(result, 'He said "hello"')
|
|
|
|
def test_rejects_non_json_output(self):
|
|
payload = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": "<transcript>Hello</transcript>"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
|
|
_extract_cleaned_text(payload)
|
|
|
|
def test_rejects_json_without_required_key(self):
|
|
payload = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": '{"text":"hello"}'
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "missing cleaned_text"):
|
|
_extract_cleaned_text(payload)
|
|
|
|
|
|
class SupportsResponseFormatTests(unittest.TestCase):
|
|
def test_supports_response_format_when_parameter_exists(self):
|
|
def chat_completion(*, messages, temperature, response_format):
|
|
return None
|
|
|
|
self.assertTrue(_supports_response_format(chat_completion))
|
|
|
|
def test_does_not_support_response_format_when_missing(self):
|
|
def chat_completion(*, messages, temperature):
|
|
return None
|
|
|
|
self.assertFalse(_supports_response_format(chat_completion))
|
|
|
|
def test_fast_profile_sets_max_tokens_when_supported(self):
|
|
def chat_completion(*, messages, temperature, max_tokens):
|
|
return None
|
|
|
|
kwargs = _profile_generation_kwargs(chat_completion, "fast")
|
|
|
|
self.assertEqual(kwargs, {"max_tokens": 192})
|
|
|
|
def test_non_fast_profile_does_not_set_generation_overrides(self):
|
|
def chat_completion(*, messages, temperature, max_tokens):
|
|
return None
|
|
|
|
kwargs = _profile_generation_kwargs(chat_completion, "polished")
|
|
|
|
self.assertEqual(kwargs, {})
|
|
|
|
def test_explicit_generation_kwargs_honors_supported_params(self):
|
|
def chat_completion(*, messages, temperature, top_p, max_tokens):
|
|
return None
|
|
|
|
kwargs = _explicit_generation_kwargs(
|
|
chat_completion,
|
|
top_p=0.9,
|
|
top_k=40,
|
|
max_tokens=128,
|
|
repeat_penalty=1.1,
|
|
min_p=0.05,
|
|
)
|
|
self.assertEqual(kwargs, {"top_p": 0.9, "max_tokens": 128})
|
|
|
|
|
|
class _WarmupClient:
|
|
def __init__(self, response_payload: dict):
|
|
self.response_payload = response_payload
|
|
self.calls = []
|
|
|
|
def create_chat_completion(
|
|
self,
|
|
*,
|
|
messages,
|
|
temperature,
|
|
response_format=None,
|
|
max_tokens=None,
|
|
):
|
|
self.calls.append(
|
|
{
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"response_format": response_format,
|
|
"max_tokens": max_tokens,
|
|
}
|
|
)
|
|
return self.response_payload
|
|
|
|
|
|
class LlamaWarmupTests(unittest.TestCase):
|
|
def test_warmup_uses_json_mode_and_low_token_budget(self):
|
|
processor = object.__new__(LlamaProcessor)
|
|
client = _WarmupClient(
|
|
{"choices": [{"message": {"content": '{"cleaned_text":"ok"}'}}]}
|
|
)
|
|
processor.client = client
|
|
|
|
processor.warmup(profile="fast")
|
|
|
|
self.assertEqual(len(client.calls), 1)
|
|
call = client.calls[0]
|
|
self.assertEqual(call["temperature"], 0.0)
|
|
self.assertEqual(call["response_format"], {"type": "json_object"})
|
|
self.assertEqual(call["max_tokens"], 32)
|
|
user_content = call["messages"][1]["content"]
|
|
self.assertIn("<request>", user_content)
|
|
self.assertIn("<transcript>warmup</transcript>", user_content)
|
|
self.assertIn("<language>auto</language>", user_content)
|
|
|
|
def test_warmup_raises_on_non_json_response(self):
|
|
processor = object.__new__(LlamaProcessor)
|
|
client = _WarmupClient(
|
|
{"choices": [{"message": {"content": "not-json"}}]}
|
|
)
|
|
processor.client = client
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
|
|
processor.warmup(profile="default")
|
|
|
|
def test_process_with_metrics_uses_single_completion_timing_shape(self):
|
|
processor = object.__new__(LlamaProcessor)
|
|
client = _WarmupClient(
|
|
{"choices": [{"message": {"content": '{"cleaned_text":"friday"}'}}]}
|
|
)
|
|
processor.client = client
|
|
|
|
cleaned_text, timings = processor.process_with_metrics(
|
|
"thursday, I mean friday",
|
|
lang="en",
|
|
dictionary_context="",
|
|
profile="default",
|
|
)
|
|
|
|
self.assertEqual(cleaned_text, "friday")
|
|
self.assertEqual(len(client.calls), 1)
|
|
call = client.calls[0]
|
|
self.assertEqual(call["messages"][0]["content"], aiprocess.SYSTEM_PROMPT)
|
|
self.assertIn('{"cleaned_text":"..."}', call["messages"][1]["content"])
|
|
self.assertEqual(timings.pass1_ms, 0.0)
|
|
self.assertGreater(timings.pass2_ms, 0.0)
|
|
self.assertEqual(timings.pass2_ms, timings.total_ms)
|
|
|
|
|
|
class ModelChecksumTests(unittest.TestCase):
|
|
def test_accepts_expected_checksum_case_insensitive(self):
|
|
_assert_expected_model_checksum(MODEL_SHA256.upper())
|
|
|
|
def test_rejects_unexpected_checksum(self):
|
|
with self.assertRaisesRegex(RuntimeError, "checksum mismatch"):
|
|
_assert_expected_model_checksum("0" * 64)
|
|
|
|
|
|
class RequestPayloadTests(unittest.TestCase):
|
|
def test_build_request_payload_with_dictionary(self):
|
|
payload = _build_request_payload("hello", lang="en", dictionary_context="Docker")
|
|
self.assertEqual(payload["language"], "en")
|
|
self.assertEqual(payload["transcript"], "hello")
|
|
self.assertEqual(payload["dictionary"], "Docker")
|
|
|
|
def test_build_request_payload_omits_empty_dictionary(self):
|
|
payload = _build_request_payload("hello", lang="en", dictionary_context=" ")
|
|
self.assertEqual(payload["language"], "en")
|
|
self.assertEqual(payload["transcript"], "hello")
|
|
self.assertNotIn("dictionary", payload)
|
|
|
|
def test_user_prompt_is_xml_and_escapes_literals(self):
|
|
payload = _build_request_payload(
|
|
'keep <transcript> and "quotes"',
|
|
lang="en",
|
|
dictionary_context="Docker & systemd",
|
|
)
|
|
xml = _build_user_prompt_xml(payload)
|
|
self.assertIn("<request>", xml)
|
|
self.assertIn("<language>en</language>", xml)
|
|
self.assertIn("<transcript>", xml)
|
|
self.assertIn("&", xml)
|
|
self.assertIn("<output_contract>", xml)
|
|
|
|
|
|
class _Response:
|
|
def __init__(self, payload: bytes):
|
|
self.payload = payload
|
|
self.offset = 0
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def getheader(self, name: str):
|
|
if name.lower() == "content-length":
|
|
return str(len(self.payload))
|
|
return None
|
|
|
|
def read(self, size: int = -1) -> bytes:
|
|
if self.offset >= len(self.payload):
|
|
return b""
|
|
if size < 0:
|
|
chunk = self.payload[self.offset :]
|
|
self.offset = len(self.payload)
|
|
return chunk
|
|
chunk = self.payload[self.offset : self.offset + size]
|
|
self.offset += len(chunk)
|
|
return chunk
|
|
|
|
|
|
class EnsureModelTests(unittest.TestCase):
|
|
def test_existing_valid_model_skips_download(self):
|
|
payload = b"valid-model"
|
|
checksum = sha256(payload).hexdigest()
|
|
with tempfile.TemporaryDirectory() as td:
|
|
model_path = Path(td) / "model.gguf"
|
|
model_path.write_bytes(payload)
|
|
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
|
|
aiprocess, "MODEL_DIR", model_path.parent
|
|
), patch.object(aiprocess, "MODEL_SHA256", checksum), patch(
|
|
"aiprocess.urllib.request.urlopen"
|
|
) as urlopen:
|
|
out = ensure_model()
|
|
|
|
self.assertEqual(out, model_path)
|
|
urlopen.assert_not_called()
|
|
|
|
def test_existing_invalid_model_triggers_redownload(self):
|
|
cached_payload = b"bad-model"
|
|
downloaded_payload = b"good-model"
|
|
expected_checksum = sha256(downloaded_payload).hexdigest()
|
|
with tempfile.TemporaryDirectory() as td:
|
|
model_path = Path(td) / "model.gguf"
|
|
model_path.write_bytes(cached_payload)
|
|
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
|
|
aiprocess, "MODEL_DIR", model_path.parent
|
|
), patch.object(aiprocess, "MODEL_SHA256", expected_checksum), patch(
|
|
"aiprocess.urllib.request.urlopen",
|
|
return_value=_Response(downloaded_payload),
|
|
) as urlopen:
|
|
out = ensure_model()
|
|
|
|
self.assertEqual(out, model_path)
|
|
self.assertEqual(model_path.read_bytes(), downloaded_payload)
|
|
urlopen.assert_called_once()
|
|
|
|
def test_invalid_cached_model_and_redownload_failure_raises_clear_error(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
model_path = Path(td) / "model.gguf"
|
|
model_path.write_bytes(b"bad-model")
|
|
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
|
|
aiprocess, "MODEL_DIR", model_path.parent
|
|
), patch.object(aiprocess, "MODEL_SHA256", "f" * 64), patch(
|
|
"aiprocess.urllib.request.urlopen",
|
|
side_effect=RuntimeError("network down"),
|
|
):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "cached model checksum mismatch and redownload failed"
|
|
):
|
|
ensure_model()
|
|
|
|
def test_probe_managed_model_is_read_only_for_valid_cache(self):
|
|
payload = b"valid-model"
|
|
checksum = sha256(payload).hexdigest()
|
|
with tempfile.TemporaryDirectory() as td:
|
|
model_path = Path(td) / "model.gguf"
|
|
model_path.write_bytes(payload)
|
|
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
|
|
aiprocess, "MODEL_SHA256", checksum
|
|
), patch("aiprocess.urllib.request.urlopen") as urlopen:
|
|
result = probe_managed_model()
|
|
|
|
self.assertEqual(result.status, "ready")
|
|
self.assertIn("ready", result.message)
|
|
urlopen.assert_not_called()
|
|
|
|
def test_probe_managed_model_reports_missing_cache(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
model_path = Path(td) / "model.gguf"
|
|
with patch.object(aiprocess, "MODEL_PATH", model_path):
|
|
result = probe_managed_model()
|
|
|
|
self.assertEqual(result.status, "missing")
|
|
self.assertIn(str(model_path), result.message)
|
|
|
|
def test_probe_managed_model_reports_invalid_checksum(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
model_path = Path(td) / "model.gguf"
|
|
model_path.write_bytes(b"bad-model")
|
|
with patch.object(aiprocess, "MODEL_PATH", model_path), patch.object(
|
|
aiprocess, "MODEL_SHA256", "f" * 64
|
|
):
|
|
result = probe_managed_model()
|
|
|
|
self.assertEqual(result.status, "invalid")
|
|
self.assertIn("checksum mismatch", result.message)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|