Add benchmark-driven model promotion workflow and pipeline stages
Some checks failed
ci / test-and-build (push) Has been cancelled

This commit is contained in:
Thales Maciel 2026-02-28 15:12:33 -03:00
parent 98b13d1069
commit 8c1f7c1e13
38 changed files with 5300 additions and 503 deletions

View file

@ -15,8 +15,11 @@ if str(SRC) not in sys.path:
import aiprocess
from aiprocess import (
ExternalApiProcessor,
LlamaProcessor,
_assert_expected_model_checksum,
_build_request_payload,
_build_user_prompt_xml,
_explicit_generation_kwargs,
_extract_cleaned_text,
_profile_generation_kwargs,
_supports_response_format,
@ -114,6 +117,75 @@ class SupportsResponseFormatTests(unittest.TestCase):
self.assertEqual(kwargs, {})
def test_explicit_generation_kwargs_honors_supported_params(self):
def chat_completion(*, messages, temperature, top_p, max_tokens):
return None
kwargs = _explicit_generation_kwargs(
chat_completion,
top_p=0.9,
top_k=40,
max_tokens=128,
repeat_penalty=1.1,
min_p=0.05,
)
self.assertEqual(kwargs, {"top_p": 0.9, "max_tokens": 128})
class _WarmupClient:
def __init__(self, response_payload: dict):
self.response_payload = response_payload
self.calls = []
def create_chat_completion(
self,
*,
messages,
temperature,
response_format=None,
max_tokens=None,
):
self.calls.append(
{
"messages": messages,
"temperature": temperature,
"response_format": response_format,
"max_tokens": max_tokens,
}
)
return self.response_payload
class LlamaWarmupTests(unittest.TestCase):
def test_warmup_uses_json_mode_and_low_token_budget(self):
processor = object.__new__(LlamaProcessor)
client = _WarmupClient(
{"choices": [{"message": {"content": '{"cleaned_text":"ok"}'}}]}
)
processor.client = client
processor.warmup(profile="fast")
self.assertEqual(len(client.calls), 1)
call = client.calls[0]
self.assertEqual(call["temperature"], 0.0)
self.assertEqual(call["response_format"], {"type": "json_object"})
self.assertEqual(call["max_tokens"], 32)
user_content = call["messages"][1]["content"]
self.assertIn("<request>", user_content)
self.assertIn("<transcript>warmup</transcript>", user_content)
self.assertIn("<language>auto</language>", user_content)
def test_warmup_raises_on_non_json_response(self):
processor = object.__new__(LlamaProcessor)
client = _WarmupClient(
{"choices": [{"message": {"content": "not-json"}}]}
)
processor.client = client
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
processor.warmup(profile="default")
class ModelChecksumTests(unittest.TestCase):
def test_accepts_expected_checksum_case_insensitive(self):
@ -137,6 +209,19 @@ class RequestPayloadTests(unittest.TestCase):
self.assertEqual(payload["transcript"], "hello")
self.assertNotIn("dictionary", payload)
def test_user_prompt_is_xml_and_escapes_literals(self):
payload = _build_request_payload(
'keep <transcript> and "quotes"',
lang="en",
dictionary_context="Docker & systemd",
)
xml = _build_user_prompt_xml(payload)
self.assertIn("<request>", xml)
self.assertIn("<language>en</language>", xml)
self.assertIn("&lt;transcript&gt;", xml)
self.assertIn("&amp;", xml)
self.assertIn("<output_contract>", xml)
class _Response:
def __init__(self, payload: bytes):
@ -254,6 +339,21 @@ class ExternalApiProcessorTests(unittest.TestCase):
request = urlopen.call_args[0][0]
self.assertTrue(request.full_url.endswith("/chat/completions"))
def test_warmup_is_a_noop(self):
with patch.dict(os.environ, {"AMAN_EXTERNAL_API_KEY": "test-key"}, clear=True):
processor = ExternalApiProcessor(
provider="openai",
base_url="https://api.openai.com/v1",
model="gpt-4o-mini",
api_key_env_var="AMAN_EXTERNAL_API_KEY",
timeout_ms=1000,
max_retries=0,
)
with patch("aiprocess.urllib.request.urlopen") as urlopen:
processor.warmup(profile="fast")
urlopen.assert_not_called()
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,72 @@
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from stages.alignment_edits import AlignmentHeuristicEngine
from stages.asr_whisper import AsrWord
def _words(tokens: list[str], step: float = 0.2) -> list[AsrWord]:
out: list[AsrWord] = []
start = 0.0
for token in tokens:
out.append(
AsrWord(
text=token,
start_s=start,
end_s=start + 0.1,
prob=0.9,
)
)
start += step
return out
class AlignmentHeuristicEngineTests(unittest.TestCase):
def test_returns_original_when_no_words_available(self):
engine = AlignmentHeuristicEngine()
result = engine.apply("hello world", [])
self.assertEqual(result.draft_text, "hello world")
self.assertEqual(result.applied_count, 0)
self.assertEqual(result.decisions, [])
def test_applies_i_mean_tail_correction(self):
engine = AlignmentHeuristicEngine()
words = _words(["set", "alarm", "for", "6", "i", "mean", "7"])
result = engine.apply("set alarm for 6 i mean 7", words)
self.assertEqual(result.draft_text, "set alarm for 7")
self.assertEqual(result.applied_count, 1)
self.assertTrue(any(item.rule_id == "cue_correction" for item in result.decisions))
def test_preserves_literal_i_mean_context(self):
engine = AlignmentHeuristicEngine()
words = _words(["write", "exactly", "i", "mean", "this", "sincerely"])
result = engine.apply("write exactly i mean this sincerely", words)
self.assertEqual(result.draft_text, "write exactly i mean this sincerely")
self.assertEqual(result.applied_count, 0)
self.assertGreaterEqual(result.skipped_count, 1)
def test_collapses_exact_restart_repetition(self):
engine = AlignmentHeuristicEngine()
words = _words(["please", "send", "it", "please", "send", "it"])
result = engine.apply("please send it please send it", words)
self.assertEqual(result.draft_text, "please send it")
self.assertEqual(result.applied_count, 1)
self.assertTrue(any(item.rule_id == "restart_repeat" for item in result.decisions))
if __name__ == "__main__":
unittest.main()

View file

@ -111,11 +111,21 @@ class FakeUnsupportedLanguageModel:
class FakeAIProcessor:
def __init__(self):
self.last_kwargs = {}
self.warmup_calls = []
self.warmup_error = None
self.process_error = None
def process(self, text, lang="auto", **_kwargs):
if self.process_error is not None:
raise self.process_error
self.last_kwargs = {"lang": lang, **_kwargs}
return text
def warmup(self, profile="default"):
self.warmup_calls.append(profile)
if self.warmup_error:
raise self.warmup_error
class FakeAudio:
def __init__(self, size: int):
@ -212,6 +222,32 @@ class DaemonTests(unittest.TestCase):
self.assertEqual(desktop.inject_calls, [("good morning Marta", "clipboard", False)])
@patch("aman.stop_audio_recording", return_value=FakeAudio(8))
@patch("aman.start_audio_recording", return_value=(object(), object()))
def test_editor_failure_aborts_output_injection(self, _start_mock, _stop_mock):
desktop = FakeDesktop()
model = FakeModel(text="hello world")
ai_processor = FakeAIProcessor()
ai_processor.process_error = RuntimeError("editor boom")
daemon = self._build_daemon(
desktop,
model,
verbose=False,
ai_processor=ai_processor,
)
daemon._start_stop_worker = (
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
stream, record, trigger, process_audio
)
)
daemon.toggle()
daemon.toggle()
self.assertEqual(desktop.inject_calls, [])
self.assertEqual(daemon.get_state(), aman.State.IDLE)
def test_transcribe_skips_hints_when_model_does_not_support_them(self):
desktop = FakeDesktop()
model = FakeModel(text="hello")
@ -242,7 +278,7 @@ class DaemonTests(unittest.TestCase):
self.assertEqual(used_lang, "auto")
self.assertIn("Docker", model.last_kwargs["hotwords"])
self.assertIn("Systemd", model.last_kwargs["hotwords"])
self.assertIn("Preferred vocabulary", model.last_kwargs["initial_prompt"])
self.assertIsNone(model.last_kwargs["initial_prompt"])
def test_transcribe_uses_configured_language_hint(self):
desktop = FakeDesktop()
@ -300,7 +336,7 @@ class DaemonTests(unittest.TestCase):
daemon_verbose = self._build_daemon(desktop, FakeModel(), cfg=cfg, verbose=True)
self.assertTrue(daemon_verbose.log_transcript)
def test_ai_processor_is_initialized_during_daemon_init(self):
def test_editor_stage_is_initialized_during_daemon_init(self):
desktop = FakeDesktop()
with patch("aman._build_whisper_model", return_value=FakeModel()), patch(
"aman.LlamaProcessor", return_value=FakeAIProcessor()
@ -308,7 +344,47 @@ class DaemonTests(unittest.TestCase):
daemon = aman.Daemon(self._config(), desktop, verbose=True)
processor_cls.assert_called_once_with(verbose=True, model_path=None)
self.assertIsNotNone(daemon.ai_processor)
self.assertIsNotNone(daemon.editor_stage)
def test_editor_stage_is_warmed_up_during_daemon_init(self):
desktop = FakeDesktop()
ai_processor = FakeAIProcessor()
with patch("aman._build_whisper_model", return_value=FakeModel()), patch(
"aman.LlamaProcessor", return_value=ai_processor
):
daemon = aman.Daemon(self._config(), desktop, verbose=False)
self.assertIs(daemon.editor_stage._processor, ai_processor)
self.assertEqual(ai_processor.warmup_calls, ["default"])
def test_editor_stage_warmup_failure_is_fatal_with_strict_startup(self):
desktop = FakeDesktop()
cfg = self._config()
cfg.advanced.strict_startup = True
ai_processor = FakeAIProcessor()
ai_processor.warmup_error = RuntimeError("warmup boom")
with patch("aman._build_whisper_model", return_value=FakeModel()), patch(
"aman.LlamaProcessor", return_value=ai_processor
):
with self.assertRaisesRegex(RuntimeError, "editor stage warmup failed"):
aman.Daemon(cfg, desktop, verbose=False)
def test_editor_stage_warmup_failure_is_non_fatal_without_strict_startup(self):
desktop = FakeDesktop()
cfg = self._config()
cfg.advanced.strict_startup = False
ai_processor = FakeAIProcessor()
ai_processor.warmup_error = RuntimeError("warmup boom")
with patch("aman._build_whisper_model", return_value=FakeModel()), patch(
"aman.LlamaProcessor", return_value=ai_processor
):
with self.assertLogs(level="WARNING") as logs:
daemon = aman.Daemon(cfg, desktop, verbose=False)
self.assertIs(daemon.editor_stage._processor, ai_processor)
self.assertTrue(
any("continuing because advanced.strict_startup=false" in line for line in logs.output)
)
@patch("aman.stop_audio_recording", return_value=FakeAudio(8))
@patch("aman.start_audio_recording", return_value=(object(), object()))

View file

@ -4,6 +4,7 @@ import sys
import tempfile
import unittest
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch
ROOT = Path(__file__).resolve().parents[1]
@ -92,6 +93,20 @@ class _RetrySetupDesktop(_FakeDesktop):
on_quit()
class _FakeBenchEditorStage:
def warmup(self):
return
def rewrite(self, transcript, *, language, dictionary_context):
_ = dictionary_context
return SimpleNamespace(
final_text=f"[{language}] {transcript.strip()}",
latency_ms=1.0,
pass1_ms=0.5,
pass2_ms=0.5,
)
class AmanCliTests(unittest.TestCase):
def test_parse_cli_args_defaults_to_run_command(self):
args = aman._parse_cli_args(["--dry-run"])
@ -111,6 +126,85 @@ class AmanCliTests(unittest.TestCase):
self.assertEqual(args.command, "self-check")
self.assertTrue(args.json)
def test_parse_cli_args_bench_command(self):
args = aman._parse_cli_args(
["bench", "--text", "hello", "--repeat", "2", "--warmup", "0", "--json"]
)
self.assertEqual(args.command, "bench")
self.assertEqual(args.text, "hello")
self.assertEqual(args.repeat, 2)
self.assertEqual(args.warmup, 0)
self.assertTrue(args.json)
def test_parse_cli_args_bench_requires_input(self):
with self.assertRaises(SystemExit):
aman._parse_cli_args(["bench"])
def test_parse_cli_args_eval_models_command(self):
args = aman._parse_cli_args(
["eval-models", "--dataset", "benchmarks/cleanup_dataset.jsonl", "--matrix", "benchmarks/model_matrix.small_first.json"]
)
self.assertEqual(args.command, "eval-models")
self.assertEqual(args.dataset, "benchmarks/cleanup_dataset.jsonl")
self.assertEqual(args.matrix, "benchmarks/model_matrix.small_first.json")
self.assertEqual(args.heuristic_dataset, "")
self.assertEqual(args.heuristic_weight, 0.25)
self.assertEqual(args.report_version, 2)
def test_parse_cli_args_eval_models_with_heuristic_options(self):
args = aman._parse_cli_args(
[
"eval-models",
"--dataset",
"benchmarks/cleanup_dataset.jsonl",
"--matrix",
"benchmarks/model_matrix.small_first.json",
"--heuristic-dataset",
"benchmarks/heuristics_dataset.jsonl",
"--heuristic-weight",
"0.4",
"--report-version",
"2",
]
)
self.assertEqual(args.heuristic_dataset, "benchmarks/heuristics_dataset.jsonl")
self.assertEqual(args.heuristic_weight, 0.4)
self.assertEqual(args.report_version, 2)
def test_parse_cli_args_build_heuristic_dataset_command(self):
args = aman._parse_cli_args(
[
"build-heuristic-dataset",
"--input",
"benchmarks/heuristics_dataset.raw.jsonl",
"--output",
"benchmarks/heuristics_dataset.jsonl",
]
)
self.assertEqual(args.command, "build-heuristic-dataset")
self.assertEqual(args.input, "benchmarks/heuristics_dataset.raw.jsonl")
self.assertEqual(args.output, "benchmarks/heuristics_dataset.jsonl")
def test_parse_cli_args_sync_default_model_command(self):
args = aman._parse_cli_args(
[
"sync-default-model",
"--report",
"benchmarks/results/latest.json",
"--artifacts",
"benchmarks/model_artifacts.json",
"--constants",
"src/constants.py",
"--check",
]
)
self.assertEqual(args.command, "sync-default-model")
self.assertEqual(args.report, "benchmarks/results/latest.json")
self.assertEqual(args.artifacts, "benchmarks/model_artifacts.json")
self.assertEqual(args.constants, "src/constants.py")
self.assertTrue(args.check)
def test_version_command_prints_version(self):
out = io.StringIO()
args = aman._parse_cli_args(["version"])
@ -145,6 +239,259 @@ class AmanCliTests(unittest.TestCase):
self.assertEqual(exit_code, 2)
self.assertIn("[FAIL] config.load", out.getvalue())
def test_bench_command_json_output(self):
args = aman._parse_cli_args(["bench", "--text", "hello", "--repeat", "2", "--warmup", "0", "--json"])
out = io.StringIO()
with patch("aman.load", return_value=Config()), patch(
"aman._build_editor_stage", return_value=_FakeBenchEditorStage()
), patch("sys.stdout", out):
exit_code = aman._bench_command(args)
self.assertEqual(exit_code, 0)
payload = json.loads(out.getvalue())
self.assertEqual(payload["measured_runs"], 2)
self.assertEqual(payload["summary"]["runs"], 2)
self.assertEqual(len(payload["runs"]), 2)
self.assertEqual(payload["editor_backend"], "local_llama_builtin")
self.assertIn("avg_alignment_ms", payload["summary"])
self.assertIn("avg_fact_guard_ms", payload["summary"])
self.assertIn("alignment_applied", payload["runs"][0])
self.assertIn("fact_guard_action", payload["runs"][0])
def test_bench_command_supports_text_file_input(self):
with tempfile.TemporaryDirectory() as td:
text_file = Path(td) / "input.txt"
text_file.write_text("hello from file", encoding="utf-8")
args = aman._parse_cli_args(
["bench", "--text-file", str(text_file), "--repeat", "1", "--warmup", "0", "--print-output"]
)
out = io.StringIO()
with patch("aman.load", return_value=Config()), patch(
"aman._build_editor_stage", return_value=_FakeBenchEditorStage()
), patch("sys.stdout", out):
exit_code = aman._bench_command(args)
self.assertEqual(exit_code, 0)
self.assertIn("[auto] hello from file", out.getvalue())
def test_bench_command_rejects_empty_input(self):
args = aman._parse_cli_args(["bench", "--text", " "])
with patch("aman.load", return_value=Config()), patch(
"aman._build_editor_stage", return_value=_FakeBenchEditorStage()
):
exit_code = aman._bench_command(args)
self.assertEqual(exit_code, 1)
def test_bench_command_rejects_non_positive_repeat(self):
args = aman._parse_cli_args(["bench", "--text", "hello", "--repeat", "0"])
with patch("aman.load", return_value=Config()), patch(
"aman._build_editor_stage", return_value=_FakeBenchEditorStage()
):
exit_code = aman._bench_command(args)
self.assertEqual(exit_code, 1)
def test_eval_models_command_writes_report(self):
with tempfile.TemporaryDirectory() as td:
output_path = Path(td) / "report.json"
args = aman._parse_cli_args(
[
"eval-models",
"--dataset",
"benchmarks/cleanup_dataset.jsonl",
"--matrix",
"benchmarks/model_matrix.small_first.json",
"--output",
str(output_path),
"--json",
]
)
out = io.StringIO()
fake_report = {
"models": [{"name": "base", "best_param_set": {"latency_ms": {"p50": 1000.0}, "quality": {"hybrid_score_avg": 0.8, "parse_valid_rate": 1.0}}}],
"winner_recommendation": {"name": "base", "reason": "test"},
}
with patch("aman.run_model_eval", return_value=fake_report), patch("sys.stdout", out):
exit_code = aman._eval_models_command(args)
self.assertEqual(exit_code, 0)
self.assertTrue(output_path.exists())
payload = json.loads(output_path.read_text(encoding="utf-8"))
self.assertEqual(payload["winner_recommendation"]["name"], "base")
def test_eval_models_command_forwards_heuristic_arguments(self):
args = aman._parse_cli_args(
[
"eval-models",
"--dataset",
"benchmarks/cleanup_dataset.jsonl",
"--matrix",
"benchmarks/model_matrix.small_first.json",
"--heuristic-dataset",
"benchmarks/heuristics_dataset.jsonl",
"--heuristic-weight",
"0.35",
"--report-version",
"2",
"--json",
]
)
out = io.StringIO()
fake_report = {
"models": [{"name": "base", "best_param_set": {}}],
"winner_recommendation": {"name": "base", "reason": "ok"},
}
with patch("aman.run_model_eval", return_value=fake_report) as run_eval_mock, patch(
"sys.stdout", out
):
exit_code = aman._eval_models_command(args)
self.assertEqual(exit_code, 0)
run_eval_mock.assert_called_once_with(
"benchmarks/cleanup_dataset.jsonl",
"benchmarks/model_matrix.small_first.json",
heuristic_dataset_path="benchmarks/heuristics_dataset.jsonl",
heuristic_weight=0.35,
report_version=2,
verbose=False,
)
def test_build_heuristic_dataset_command_json_output(self):
args = aman._parse_cli_args(
[
"build-heuristic-dataset",
"--input",
"benchmarks/heuristics_dataset.raw.jsonl",
"--output",
"benchmarks/heuristics_dataset.jsonl",
"--json",
]
)
out = io.StringIO()
summary = {
"raw_rows": 4,
"written_rows": 4,
"generated_word_rows": 2,
"output_path": "benchmarks/heuristics_dataset.jsonl",
}
with patch("aman.build_heuristic_dataset", return_value=summary), patch("sys.stdout", out):
exit_code = aman._build_heuristic_dataset_command(args)
self.assertEqual(exit_code, 0)
payload = json.loads(out.getvalue())
self.assertEqual(payload["written_rows"], 4)
def test_sync_default_model_command_updates_constants(self):
with tempfile.TemporaryDirectory() as td:
report_path = Path(td) / "latest.json"
artifacts_path = Path(td) / "artifacts.json"
constants_path = Path(td) / "constants.py"
report_path.write_text(
json.dumps(
{
"winner_recommendation": {
"name": "test-model",
}
}
),
encoding="utf-8",
)
artifacts_path.write_text(
json.dumps(
{
"models": [
{
"name": "test-model",
"filename": "winner.gguf",
"url": "https://example.invalid/winner.gguf",
"sha256": "a" * 64,
}
]
}
),
encoding="utf-8",
)
constants_path.write_text(
(
'MODEL_NAME = "old.gguf"\n'
'MODEL_URL = "https://example.invalid/old.gguf"\n'
'MODEL_SHA256 = "' + ("b" * 64) + '"\n'
),
encoding="utf-8",
)
args = aman._parse_cli_args(
[
"sync-default-model",
"--report",
str(report_path),
"--artifacts",
str(artifacts_path),
"--constants",
str(constants_path),
]
)
exit_code = aman._sync_default_model_command(args)
self.assertEqual(exit_code, 0)
updated = constants_path.read_text(encoding="utf-8")
self.assertIn('MODEL_NAME = "winner.gguf"', updated)
self.assertIn('MODEL_URL = "https://example.invalid/winner.gguf"', updated)
self.assertIn('MODEL_SHA256 = "' + ("a" * 64) + '"', updated)
def test_sync_default_model_command_check_mode_returns_2_on_drift(self):
with tempfile.TemporaryDirectory() as td:
report_path = Path(td) / "latest.json"
artifacts_path = Path(td) / "artifacts.json"
constants_path = Path(td) / "constants.py"
report_path.write_text(
json.dumps(
{
"winner_recommendation": {
"name": "test-model",
}
}
),
encoding="utf-8",
)
artifacts_path.write_text(
json.dumps(
{
"models": [
{
"name": "test-model",
"filename": "winner.gguf",
"url": "https://example.invalid/winner.gguf",
"sha256": "a" * 64,
}
]
}
),
encoding="utf-8",
)
constants_path.write_text(
(
'MODEL_NAME = "old.gguf"\n'
'MODEL_URL = "https://example.invalid/old.gguf"\n'
'MODEL_SHA256 = "' + ("b" * 64) + '"\n'
),
encoding="utf-8",
)
args = aman._parse_cli_args(
[
"sync-default-model",
"--report",
str(report_path),
"--artifacts",
str(artifacts_path),
"--constants",
str(constants_path),
"--check",
]
)
exit_code = aman._sync_default_model_command(args)
self.assertEqual(exit_code, 2)
updated = constants_path.read_text(encoding="utf-8")
self.assertIn('MODEL_NAME = "old.gguf"', updated)
def test_init_command_creates_default_config(self):
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"

80
tests/test_asr_whisper.py Normal file
View file

@ -0,0 +1,80 @@
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from stages.asr_whisper import WhisperAsrStage
class _Word:
def __init__(self, word: str, start: float, end: float, probability: float = 0.9):
self.word = word
self.start = start
self.end = end
self.probability = probability
class _Segment:
def __init__(self, text: str, start: float, end: float, words=None):
self.text = text
self.start = start
self.end = end
self.words = words or []
class _ModelWithWordTimestamps:
def __init__(self):
self.kwargs = {}
def transcribe(self, _audio, language=None, vad_filter=None, word_timestamps=False):
self.kwargs = {
"language": language,
"vad_filter": vad_filter,
"word_timestamps": word_timestamps,
}
words = [_Word("hello", 0.0, 0.3), _Word("world", 0.31, 0.6)]
return [_Segment("hello world", 0.0, 0.6, words=words)], {}
class _ModelWithoutWordTimestamps:
def __init__(self):
self.kwargs = {}
def transcribe(self, _audio, language=None, vad_filter=None):
self.kwargs = {
"language": language,
"vad_filter": vad_filter,
}
return [_Segment("hello", 0.0, 0.2, words=[])], {}
class WhisperAsrStageTests(unittest.TestCase):
def test_transcribe_requests_word_timestamps_when_supported(self):
model = _ModelWithWordTimestamps()
stage = WhisperAsrStage(model, configured_language="auto")
result = stage.transcribe(object())
self.assertTrue(model.kwargs["word_timestamps"])
self.assertEqual(result.raw_text, "hello world")
self.assertEqual(len(result.words), 2)
self.assertEqual(result.words[0].text, "hello")
self.assertGreaterEqual(result.words[0].start_s, 0.0)
def test_transcribe_skips_word_timestamps_when_not_supported(self):
model = _ModelWithoutWordTimestamps()
stage = WhisperAsrStage(model, configured_language="auto")
result = stage.transcribe(object())
self.assertNotIn("word_timestamps", model.kwargs)
self.assertEqual(result.raw_text, "hello")
self.assertEqual(result.words, [])
if __name__ == "__main__":
unittest.main()

View file

@ -25,14 +25,12 @@ class ConfigTests(unittest.TestCase):
self.assertEqual(cfg.stt.model, "base")
self.assertEqual(cfg.stt.device, "cpu")
self.assertEqual(cfg.stt.language, "auto")
self.assertEqual(cfg.llm.provider, "local_llama")
self.assertFalse(cfg.models.allow_custom_models)
self.assertEqual(cfg.models.whisper_model_path, "")
self.assertEqual(cfg.models.llm_model_path, "")
self.assertFalse(cfg.external_api.enabled)
self.assertEqual(cfg.external_api.provider, "openai")
self.assertEqual(cfg.injection.backend, "clipboard")
self.assertFalse(cfg.injection.remove_transcription_from_clipboard)
self.assertTrue(cfg.safety.enabled)
self.assertFalse(cfg.safety.strict)
self.assertEqual(cfg.ux.profile, "default")
self.assertTrue(cfg.ux.show_notifications)
self.assertTrue(cfg.advanced.strict_startup)
@ -54,13 +52,15 @@ class ConfigTests(unittest.TestCase):
"device": "cuda",
"language": "English",
},
"llm": {"provider": "local_llama"},
"models": {"allow_custom_models": False},
"external_api": {"enabled": False},
"injection": {
"backend": "injection",
"remove_transcription_from_clipboard": True,
},
"safety": {
"enabled": True,
"strict": True,
},
"vocabulary": {
"replacements": [
{"from": "Martha", "to": "Marta"},
@ -82,9 +82,10 @@ class ConfigTests(unittest.TestCase):
self.assertEqual(cfg.stt.model, "small")
self.assertEqual(cfg.stt.device, "cuda")
self.assertEqual(cfg.stt.language, "en")
self.assertEqual(cfg.llm.provider, "local_llama")
self.assertEqual(cfg.injection.backend, "injection")
self.assertTrue(cfg.injection.remove_transcription_from_clipboard)
self.assertTrue(cfg.safety.enabled)
self.assertTrue(cfg.safety.strict)
self.assertEqual(len(cfg.vocabulary.replacements), 2)
self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha")
self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta")
@ -138,6 +139,33 @@ class ConfigTests(unittest.TestCase):
with self.assertRaisesRegex(ValueError, "injection.remove_transcription_from_clipboard"):
load(str(path))
def test_invalid_safety_enabled_option_raises(self):
payload = {"safety": {"enabled": "yes"}}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
with self.assertRaisesRegex(ValueError, "safety.enabled"):
load(str(path))
def test_invalid_safety_strict_option_raises(self):
payload = {"safety": {"strict": "yes"}}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
with self.assertRaisesRegex(ValueError, "safety.strict"):
load(str(path))
def test_unknown_safety_fields_raise(self):
payload = {"safety": {"enabled": True, "mode": "strict"}}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
with self.assertRaisesRegex(ValueError, "safety.mode: unknown config field"):
load(str(path))
def test_unknown_top_level_fields_raise(self):
payload = {
"custom_a": {"enabled": True},
@ -269,10 +297,9 @@ class ConfigTests(unittest.TestCase):
self.assertEqual(cfg.config_version, CURRENT_CONFIG_VERSION)
def test_external_llm_requires_external_api_enabled(self):
def test_legacy_llm_config_fields_raise(self):
payload = {
"llm": {"provider": "external_api"},
"external_api": {"enabled": False},
"llm": {"provider": "local_llama"},
}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
@ -280,7 +307,7 @@ class ConfigTests(unittest.TestCase):
with self.assertRaisesRegex(
ValueError,
"llm.provider: external_api provider requires external_api.enabled=true",
"llm: unknown config field",
):
load(str(path))

View file

@ -23,37 +23,20 @@ class ConfigUiRuntimeModeTests(unittest.TestCase):
def test_infer_runtime_mode_detects_expert_overrides(self):
cfg = Config()
cfg.llm.provider = "external_api"
cfg.external_api.enabled = True
cfg.models.allow_custom_models = True
self.assertEqual(infer_runtime_mode(cfg), RUNTIME_MODE_EXPERT)
def test_apply_canonical_runtime_defaults_resets_expert_fields(self):
cfg = Config()
cfg.stt.provider = "local_whisper"
cfg.llm.provider = "external_api"
cfg.external_api.enabled = True
cfg.external_api.base_url = "https://example.local/v1"
cfg.external_api.model = "custom-model"
cfg.external_api.api_key_env_var = "CUSTOM_KEY"
cfg.external_api.timeout_ms = 321
cfg.external_api.max_retries = 8
cfg.models.allow_custom_models = True
cfg.models.whisper_model_path = "/tmp/custom-whisper.bin"
cfg.models.llm_model_path = "/tmp/custom-model.gguf"
apply_canonical_runtime_defaults(cfg)
self.assertEqual(cfg.stt.provider, "local_whisper")
self.assertEqual(cfg.llm.provider, "local_llama")
self.assertFalse(cfg.external_api.enabled)
self.assertEqual(cfg.external_api.base_url, "https://api.openai.com/v1")
self.assertEqual(cfg.external_api.model, "gpt-4o-mini")
self.assertEqual(cfg.external_api.api_key_env_var, "AMAN_EXTERNAL_API_KEY")
self.assertEqual(cfg.external_api.timeout_ms, 15000)
self.assertEqual(cfg.external_api.max_retries, 2)
self.assertFalse(cfg.models.allow_custom_models)
self.assertEqual(cfg.models.whisper_model_path, "")
self.assertEqual(cfg.models.llm_model_path, "")
if __name__ == "__main__":

86
tests/test_fact_guard.py Normal file
View file

@ -0,0 +1,86 @@
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from stages.fact_guard import FactGuardEngine
class FactGuardEngineTests(unittest.TestCase):
def test_disabled_guard_accepts_candidate(self):
guard = FactGuardEngine()
result = guard.apply(
"set alarm for 7",
"set alarm for 8",
enabled=False,
strict=False,
)
self.assertEqual(result.action, "accepted")
self.assertEqual(result.final_text, "set alarm for 8")
self.assertEqual(result.violations_count, 0)
def test_fallbacks_on_number_change(self):
guard = FactGuardEngine()
result = guard.apply(
"set alarm for 7",
"set alarm for 8",
enabled=True,
strict=False,
)
self.assertEqual(result.action, "fallback")
self.assertEqual(result.final_text, "set alarm for 7")
self.assertGreaterEqual(result.violations_count, 1)
def test_fallbacks_on_name_change(self):
guard = FactGuardEngine()
result = guard.apply(
"invite Marta tomorrow",
"invite Martha tomorrow",
enabled=True,
strict=False,
)
self.assertEqual(result.action, "fallback")
self.assertEqual(result.final_text, "invite Marta tomorrow")
self.assertGreaterEqual(result.violations_count, 1)
def test_accepts_style_only_rewrite(self):
guard = FactGuardEngine()
result = guard.apply(
"please send the report",
"Please send the report.",
enabled=True,
strict=False,
)
self.assertEqual(result.action, "accepted")
self.assertEqual(result.final_text, "Please send the report.")
self.assertEqual(result.violations_count, 0)
def test_strict_mode_rejects_large_lexical_additions(self):
guard = FactGuardEngine()
result = guard.apply(
"send the report",
"send the report and include two extra paragraphs with assumptions",
enabled=True,
strict=True,
)
self.assertEqual(result.action, "rejected")
self.assertEqual(result.final_text, "send the report")
self.assertGreaterEqual(result.violations_count, 1)
if __name__ == "__main__":
unittest.main()

137
tests/test_model_eval.py Normal file
View file

@ -0,0 +1,137 @@
import json
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
import model_eval
class _FakeProcessor:
def __init__(self, *args, **kwargs):
_ = (args, kwargs)
def warmup(self, **kwargs):
_ = kwargs
return
def process(self, text, **kwargs):
_ = kwargs
return text.strip()
class ModelEvalTests(unittest.TestCase):
def test_load_eval_dataset_validates_required_fields(self):
with tempfile.TemporaryDirectory() as td:
dataset = Path(td) / "dataset.jsonl"
dataset.write_text(
'{"id":"c1","input_text":"hello","expected_output":"hello"}\n',
encoding="utf-8",
)
cases = model_eval.load_eval_dataset(dataset)
self.assertEqual(len(cases), 1)
self.assertEqual(cases[0].case_id, "c1")
def test_run_model_eval_produces_report(self):
with tempfile.TemporaryDirectory() as td:
model_file = Path(td) / "fake.gguf"
model_file.write_text("fake", encoding="utf-8")
dataset = Path(td) / "dataset.jsonl"
dataset.write_text(
(
'{"id":"c1","input_text":"hello world","expected_output":"hello world"}\n'
'{"id":"c2","input_text":"hello","expected_output":"hello"}\n'
),
encoding="utf-8",
)
matrix = Path(td) / "matrix.json"
heuristic_dataset = Path(td) / "heuristics.jsonl"
matrix.write_text(
json.dumps(
{
"warmup_runs": 0,
"measured_runs": 1,
"timeout_sec": 30,
"baseline_model": {
"name": "base",
"provider": "local_llama",
"model_path": str(model_file),
"profile": "default",
"param_grid": {"temperature": [0.0]},
},
"candidate_models": [
{
"name": "small",
"provider": "local_llama",
"model_path": str(model_file),
"profile": "fast",
"param_grid": {"temperature": [0.0, 0.1], "max_tokens": [96]},
}
],
}
),
encoding="utf-8",
)
heuristic_dataset.write_text(
(
'{"id":"h1","transcript":"set alarm for 6 i mean 7","words":[{"text":"set","start_s":0.0,"end_s":0.1,"prob":0.9},{"text":"alarm","start_s":0.2,"end_s":0.3,"prob":0.9},{"text":"for","start_s":0.4,"end_s":0.5,"prob":0.9},{"text":"6","start_s":0.6,"end_s":0.7,"prob":0.9},{"text":"i","start_s":0.8,"end_s":0.9,"prob":0.9},{"text":"mean","start_s":1.0,"end_s":1.1,"prob":0.9},{"text":"7","start_s":1.2,"end_s":1.3,"prob":0.9}],"expected_aligned_text":"set alarm for 7","expected":{"applied_min":1,"required_rule_ids":["cue_correction"]},"tags":["i_mean_correction"]}\n'
'{"id":"h2","transcript":"write exactly i mean this sincerely","words":[{"text":"write","start_s":0.0,"end_s":0.1,"prob":0.9},{"text":"exactly","start_s":0.2,"end_s":0.3,"prob":0.9},{"text":"i","start_s":0.4,"end_s":0.5,"prob":0.9},{"text":"mean","start_s":0.6,"end_s":0.7,"prob":0.9},{"text":"this","start_s":0.8,"end_s":0.9,"prob":0.9},{"text":"sincerely","start_s":1.0,"end_s":1.1,"prob":0.9}],"expected_aligned_text":"write exactly i mean this sincerely","expected":{"required_rule_ids":[],"forbidden_rule_ids":["cue_correction"]},"tags":["i_mean_literal"]}\n'
),
encoding="utf-8",
)
with patch("model_eval.LlamaProcessor", _FakeProcessor):
report = model_eval.run_model_eval(
dataset,
matrix,
heuristic_dataset_path=heuristic_dataset,
heuristic_weight=0.3,
report_version=2,
verbose=False,
)
self.assertEqual(report["report_version"], 2)
self.assertIn("models", report)
self.assertEqual(len(report["models"]), 2)
self.assertIn("winner_recommendation", report)
self.assertIn("heuristic_eval", report)
self.assertEqual(report["heuristic_eval"]["cases"], 2)
self.assertIn("combined_score", report["models"][0]["best_param_set"])
summary = model_eval.format_model_eval_summary(report)
self.assertIn("model eval summary", summary)
def test_load_heuristic_dataset_validates_required_fields(self):
with tempfile.TemporaryDirectory() as td:
dataset = Path(td) / "heuristics.jsonl"
dataset.write_text(
'{"id":"h1","transcript":"hello world","words":[{"text":"hello","start_s":0.0,"end_s":0.1},{"text":"world","start_s":0.2,"end_s":0.3}],"expected_aligned_text":"hello world"}\n',
encoding="utf-8",
)
cases = model_eval.load_heuristic_dataset(dataset)
self.assertEqual(len(cases), 1)
self.assertEqual(cases[0].case_id, "h1")
self.assertEqual(cases[0].expected.applied_min, 0)
def test_build_heuristic_dataset_generates_words_when_missing(self):
with tempfile.TemporaryDirectory() as td:
source = Path(td) / "heuristics.raw.jsonl"
output = Path(td) / "heuristics.jsonl"
source.write_text(
'{"id":"h1","transcript":"please send it","expected_aligned_text":"please send it","expected":{"applied_min":0}}\n',
encoding="utf-8",
)
summary = model_eval.build_heuristic_dataset(source, output)
self.assertEqual(summary["written_rows"], 1)
self.assertEqual(summary["generated_word_rows"], 1)
loaded = model_eval.load_heuristic_dataset(output)
self.assertEqual(len(loaded), 1)
self.assertGreaterEqual(len(loaded[0].words), 1)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,129 @@
import sys
import unittest
from pathlib import Path
from types import SimpleNamespace
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from engine.pipeline import PipelineEngine
from stages.alignment_edits import AlignmentHeuristicEngine
from stages.asr_whisper import AsrResult, AsrSegment, AsrWord
from vocabulary import VocabularyEngine
from config import VocabularyConfig
class _FakeEditor:
def __init__(self, *, output_text: str | None = None):
self.calls = []
self.output_text = output_text
def rewrite(self, transcript, *, language, dictionary_context):
self.calls.append(
{
"transcript": transcript,
"language": language,
"dictionary_context": dictionary_context,
}
)
final_text = transcript if self.output_text is None else self.output_text
return SimpleNamespace(
final_text=final_text,
latency_ms=1.0,
pass1_ms=0.5,
pass2_ms=0.5,
)
class _FakeAsr:
def transcribe(self, _audio):
words = [
AsrWord("set", 0.0, 0.1, 0.9),
AsrWord("alarm", 0.2, 0.3, 0.9),
AsrWord("for", 0.4, 0.5, 0.9),
AsrWord("6", 0.6, 0.7, 0.9),
AsrWord("i", 0.8, 0.9, 0.9),
AsrWord("mean", 1.0, 1.1, 0.9),
AsrWord("7", 1.2, 1.3, 0.9),
]
segments = [AsrSegment(text="set alarm for 6 i mean 7", start_s=0.0, end_s=1.3)]
return AsrResult(
raw_text="set alarm for 6 i mean 7",
language="en",
latency_ms=5.0,
words=words,
segments=segments,
)
class PipelineEngineTests(unittest.TestCase):
def test_alignment_draft_is_forwarded_to_editor(self):
editor = _FakeEditor()
pipeline = PipelineEngine(
asr_stage=_FakeAsr(),
editor_stage=editor,
vocabulary=VocabularyEngine(VocabularyConfig()),
alignment_engine=AlignmentHeuristicEngine(),
)
result = pipeline.run_audio(object())
self.assertEqual(editor.calls[0]["transcript"], "set alarm for 7")
self.assertEqual(result.alignment_applied, 1)
self.assertGreaterEqual(result.alignment_ms, 0.0)
self.assertEqual(result.fact_guard_action, "accepted")
self.assertEqual(result.fact_guard_violations, 0)
def test_run_transcript_without_words_keeps_alignment_noop(self):
editor = _FakeEditor()
pipeline = PipelineEngine(
asr_stage=None,
editor_stage=editor,
vocabulary=VocabularyEngine(VocabularyConfig()),
alignment_engine=AlignmentHeuristicEngine(),
)
result = pipeline.run_transcript("hello world", language="en")
self.assertEqual(editor.calls[0]["transcript"], "hello world")
self.assertEqual(result.alignment_applied, 0)
self.assertEqual(result.fact_guard_action, "accepted")
self.assertEqual(result.fact_guard_violations, 0)
def test_fact_guard_fallbacks_when_editor_changes_number(self):
editor = _FakeEditor(output_text="set alarm for 8")
pipeline = PipelineEngine(
asr_stage=None,
editor_stage=editor,
vocabulary=VocabularyEngine(VocabularyConfig()),
alignment_engine=AlignmentHeuristicEngine(),
safety_enabled=True,
safety_strict=False,
)
result = pipeline.run_transcript("set alarm for 7", language="en")
self.assertEqual(result.output_text, "set alarm for 7")
self.assertEqual(result.fact_guard_action, "fallback")
self.assertGreaterEqual(result.fact_guard_violations, 1)
def test_fact_guard_strict_rejects_number_change(self):
editor = _FakeEditor(output_text="set alarm for 8")
pipeline = PipelineEngine(
asr_stage=None,
editor_stage=editor,
vocabulary=VocabularyEngine(VocabularyConfig()),
alignment_engine=AlignmentHeuristicEngine(),
safety_enabled=True,
safety_strict=True,
)
with self.assertRaisesRegex(RuntimeError, "fact guard rejected editor output"):
pipeline.run_transcript("set alarm for 7", language="en")
if __name__ == "__main__":
unittest.main()