Add benchmark-driven model promotion workflow and pipeline stages
Some checks failed
ci / test-and-build (push) Has been cancelled
Some checks failed
ci / test-and-build (push) Has been cancelled
This commit is contained in:
parent
98b13d1069
commit
8c1f7c1e13
38 changed files with 5300 additions and 503 deletions
|
|
@ -15,8 +15,11 @@ if str(SRC) not in sys.path:
|
|||
import aiprocess
|
||||
from aiprocess import (
|
||||
ExternalApiProcessor,
|
||||
LlamaProcessor,
|
||||
_assert_expected_model_checksum,
|
||||
_build_request_payload,
|
||||
_build_user_prompt_xml,
|
||||
_explicit_generation_kwargs,
|
||||
_extract_cleaned_text,
|
||||
_profile_generation_kwargs,
|
||||
_supports_response_format,
|
||||
|
|
@ -114,6 +117,75 @@ class SupportsResponseFormatTests(unittest.TestCase):
|
|||
|
||||
self.assertEqual(kwargs, {})
|
||||
|
||||
def test_explicit_generation_kwargs_honors_supported_params(self):
|
||||
def chat_completion(*, messages, temperature, top_p, max_tokens):
|
||||
return None
|
||||
|
||||
kwargs = _explicit_generation_kwargs(
|
||||
chat_completion,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
max_tokens=128,
|
||||
repeat_penalty=1.1,
|
||||
min_p=0.05,
|
||||
)
|
||||
self.assertEqual(kwargs, {"top_p": 0.9, "max_tokens": 128})
|
||||
|
||||
|
||||
class _WarmupClient:
|
||||
def __init__(self, response_payload: dict):
|
||||
self.response_payload = response_payload
|
||||
self.calls = []
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
*,
|
||||
messages,
|
||||
temperature,
|
||||
response_format=None,
|
||||
max_tokens=None,
|
||||
):
|
||||
self.calls.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"response_format": response_format,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
)
|
||||
return self.response_payload
|
||||
|
||||
|
||||
class LlamaWarmupTests(unittest.TestCase):
|
||||
def test_warmup_uses_json_mode_and_low_token_budget(self):
|
||||
processor = object.__new__(LlamaProcessor)
|
||||
client = _WarmupClient(
|
||||
{"choices": [{"message": {"content": '{"cleaned_text":"ok"}'}}]}
|
||||
)
|
||||
processor.client = client
|
||||
|
||||
processor.warmup(profile="fast")
|
||||
|
||||
self.assertEqual(len(client.calls), 1)
|
||||
call = client.calls[0]
|
||||
self.assertEqual(call["temperature"], 0.0)
|
||||
self.assertEqual(call["response_format"], {"type": "json_object"})
|
||||
self.assertEqual(call["max_tokens"], 32)
|
||||
user_content = call["messages"][1]["content"]
|
||||
self.assertIn("<request>", user_content)
|
||||
self.assertIn("<transcript>warmup</transcript>", user_content)
|
||||
self.assertIn("<language>auto</language>", user_content)
|
||||
|
||||
def test_warmup_raises_on_non_json_response(self):
|
||||
processor = object.__new__(LlamaProcessor)
|
||||
client = _WarmupClient(
|
||||
{"choices": [{"message": {"content": "not-json"}}]}
|
||||
)
|
||||
processor.client = client
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
|
||||
processor.warmup(profile="default")
|
||||
|
||||
|
||||
class ModelChecksumTests(unittest.TestCase):
|
||||
def test_accepts_expected_checksum_case_insensitive(self):
|
||||
|
|
@ -137,6 +209,19 @@ class RequestPayloadTests(unittest.TestCase):
|
|||
self.assertEqual(payload["transcript"], "hello")
|
||||
self.assertNotIn("dictionary", payload)
|
||||
|
||||
def test_user_prompt_is_xml_and_escapes_literals(self):
|
||||
payload = _build_request_payload(
|
||||
'keep <transcript> and "quotes"',
|
||||
lang="en",
|
||||
dictionary_context="Docker & systemd",
|
||||
)
|
||||
xml = _build_user_prompt_xml(payload)
|
||||
self.assertIn("<request>", xml)
|
||||
self.assertIn("<language>en</language>", xml)
|
||||
self.assertIn("<transcript>", xml)
|
||||
self.assertIn("&", xml)
|
||||
self.assertIn("<output_contract>", xml)
|
||||
|
||||
|
||||
class _Response:
|
||||
def __init__(self, payload: bytes):
|
||||
|
|
@ -254,6 +339,21 @@ class ExternalApiProcessorTests(unittest.TestCase):
|
|||
request = urlopen.call_args[0][0]
|
||||
self.assertTrue(request.full_url.endswith("/chat/completions"))
|
||||
|
||||
def test_warmup_is_a_noop(self):
|
||||
with patch.dict(os.environ, {"AMAN_EXTERNAL_API_KEY": "test-key"}, clear=True):
|
||||
processor = ExternalApiProcessor(
|
||||
provider="openai",
|
||||
base_url="https://api.openai.com/v1",
|
||||
model="gpt-4o-mini",
|
||||
api_key_env_var="AMAN_EXTERNAL_API_KEY",
|
||||
timeout_ms=1000,
|
||||
max_retries=0,
|
||||
)
|
||||
with patch("aiprocess.urllib.request.urlopen") as urlopen:
|
||||
processor.warmup(profile="fast")
|
||||
|
||||
urlopen.assert_not_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
72
tests/test_alignment_edits.py
Normal file
72
tests/test_alignment_edits.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from stages.alignment_edits import AlignmentHeuristicEngine
|
||||
from stages.asr_whisper import AsrWord
|
||||
|
||||
|
||||
def _words(tokens: list[str], step: float = 0.2) -> list[AsrWord]:
|
||||
out: list[AsrWord] = []
|
||||
start = 0.0
|
||||
for token in tokens:
|
||||
out.append(
|
||||
AsrWord(
|
||||
text=token,
|
||||
start_s=start,
|
||||
end_s=start + 0.1,
|
||||
prob=0.9,
|
||||
)
|
||||
)
|
||||
start += step
|
||||
return out
|
||||
|
||||
|
||||
class AlignmentHeuristicEngineTests(unittest.TestCase):
|
||||
def test_returns_original_when_no_words_available(self):
|
||||
engine = AlignmentHeuristicEngine()
|
||||
|
||||
result = engine.apply("hello world", [])
|
||||
|
||||
self.assertEqual(result.draft_text, "hello world")
|
||||
self.assertEqual(result.applied_count, 0)
|
||||
self.assertEqual(result.decisions, [])
|
||||
|
||||
def test_applies_i_mean_tail_correction(self):
|
||||
engine = AlignmentHeuristicEngine()
|
||||
words = _words(["set", "alarm", "for", "6", "i", "mean", "7"])
|
||||
|
||||
result = engine.apply("set alarm for 6 i mean 7", words)
|
||||
|
||||
self.assertEqual(result.draft_text, "set alarm for 7")
|
||||
self.assertEqual(result.applied_count, 1)
|
||||
self.assertTrue(any(item.rule_id == "cue_correction" for item in result.decisions))
|
||||
|
||||
def test_preserves_literal_i_mean_context(self):
|
||||
engine = AlignmentHeuristicEngine()
|
||||
words = _words(["write", "exactly", "i", "mean", "this", "sincerely"])
|
||||
|
||||
result = engine.apply("write exactly i mean this sincerely", words)
|
||||
|
||||
self.assertEqual(result.draft_text, "write exactly i mean this sincerely")
|
||||
self.assertEqual(result.applied_count, 0)
|
||||
self.assertGreaterEqual(result.skipped_count, 1)
|
||||
|
||||
def test_collapses_exact_restart_repetition(self):
|
||||
engine = AlignmentHeuristicEngine()
|
||||
words = _words(["please", "send", "it", "please", "send", "it"])
|
||||
|
||||
result = engine.apply("please send it please send it", words)
|
||||
|
||||
self.assertEqual(result.draft_text, "please send it")
|
||||
self.assertEqual(result.applied_count, 1)
|
||||
self.assertTrue(any(item.rule_id == "restart_repeat" for item in result.decisions))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -111,11 +111,21 @@ class FakeUnsupportedLanguageModel:
|
|||
class FakeAIProcessor:
|
||||
def __init__(self):
|
||||
self.last_kwargs = {}
|
||||
self.warmup_calls = []
|
||||
self.warmup_error = None
|
||||
self.process_error = None
|
||||
|
||||
def process(self, text, lang="auto", **_kwargs):
|
||||
if self.process_error is not None:
|
||||
raise self.process_error
|
||||
self.last_kwargs = {"lang": lang, **_kwargs}
|
||||
return text
|
||||
|
||||
def warmup(self, profile="default"):
|
||||
self.warmup_calls.append(profile)
|
||||
if self.warmup_error:
|
||||
raise self.warmup_error
|
||||
|
||||
|
||||
class FakeAudio:
|
||||
def __init__(self, size: int):
|
||||
|
|
@ -212,6 +222,32 @@ class DaemonTests(unittest.TestCase):
|
|||
|
||||
self.assertEqual(desktop.inject_calls, [("good morning Marta", "clipboard", False)])
|
||||
|
||||
@patch("aman.stop_audio_recording", return_value=FakeAudio(8))
|
||||
@patch("aman.start_audio_recording", return_value=(object(), object()))
|
||||
def test_editor_failure_aborts_output_injection(self, _start_mock, _stop_mock):
|
||||
desktop = FakeDesktop()
|
||||
model = FakeModel(text="hello world")
|
||||
ai_processor = FakeAIProcessor()
|
||||
ai_processor.process_error = RuntimeError("editor boom")
|
||||
|
||||
daemon = self._build_daemon(
|
||||
desktop,
|
||||
model,
|
||||
verbose=False,
|
||||
ai_processor=ai_processor,
|
||||
)
|
||||
daemon._start_stop_worker = (
|
||||
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
||||
stream, record, trigger, process_audio
|
||||
)
|
||||
)
|
||||
|
||||
daemon.toggle()
|
||||
daemon.toggle()
|
||||
|
||||
self.assertEqual(desktop.inject_calls, [])
|
||||
self.assertEqual(daemon.get_state(), aman.State.IDLE)
|
||||
|
||||
def test_transcribe_skips_hints_when_model_does_not_support_them(self):
|
||||
desktop = FakeDesktop()
|
||||
model = FakeModel(text="hello")
|
||||
|
|
@ -242,7 +278,7 @@ class DaemonTests(unittest.TestCase):
|
|||
self.assertEqual(used_lang, "auto")
|
||||
self.assertIn("Docker", model.last_kwargs["hotwords"])
|
||||
self.assertIn("Systemd", model.last_kwargs["hotwords"])
|
||||
self.assertIn("Preferred vocabulary", model.last_kwargs["initial_prompt"])
|
||||
self.assertIsNone(model.last_kwargs["initial_prompt"])
|
||||
|
||||
def test_transcribe_uses_configured_language_hint(self):
|
||||
desktop = FakeDesktop()
|
||||
|
|
@ -300,7 +336,7 @@ class DaemonTests(unittest.TestCase):
|
|||
daemon_verbose = self._build_daemon(desktop, FakeModel(), cfg=cfg, verbose=True)
|
||||
self.assertTrue(daemon_verbose.log_transcript)
|
||||
|
||||
def test_ai_processor_is_initialized_during_daemon_init(self):
|
||||
def test_editor_stage_is_initialized_during_daemon_init(self):
|
||||
desktop = FakeDesktop()
|
||||
with patch("aman._build_whisper_model", return_value=FakeModel()), patch(
|
||||
"aman.LlamaProcessor", return_value=FakeAIProcessor()
|
||||
|
|
@ -308,7 +344,47 @@ class DaemonTests(unittest.TestCase):
|
|||
daemon = aman.Daemon(self._config(), desktop, verbose=True)
|
||||
|
||||
processor_cls.assert_called_once_with(verbose=True, model_path=None)
|
||||
self.assertIsNotNone(daemon.ai_processor)
|
||||
self.assertIsNotNone(daemon.editor_stage)
|
||||
|
||||
def test_editor_stage_is_warmed_up_during_daemon_init(self):
|
||||
desktop = FakeDesktop()
|
||||
ai_processor = FakeAIProcessor()
|
||||
with patch("aman._build_whisper_model", return_value=FakeModel()), patch(
|
||||
"aman.LlamaProcessor", return_value=ai_processor
|
||||
):
|
||||
daemon = aman.Daemon(self._config(), desktop, verbose=False)
|
||||
|
||||
self.assertIs(daemon.editor_stage._processor, ai_processor)
|
||||
self.assertEqual(ai_processor.warmup_calls, ["default"])
|
||||
|
||||
def test_editor_stage_warmup_failure_is_fatal_with_strict_startup(self):
|
||||
desktop = FakeDesktop()
|
||||
cfg = self._config()
|
||||
cfg.advanced.strict_startup = True
|
||||
ai_processor = FakeAIProcessor()
|
||||
ai_processor.warmup_error = RuntimeError("warmup boom")
|
||||
with patch("aman._build_whisper_model", return_value=FakeModel()), patch(
|
||||
"aman.LlamaProcessor", return_value=ai_processor
|
||||
):
|
||||
with self.assertRaisesRegex(RuntimeError, "editor stage warmup failed"):
|
||||
aman.Daemon(cfg, desktop, verbose=False)
|
||||
|
||||
def test_editor_stage_warmup_failure_is_non_fatal_without_strict_startup(self):
|
||||
desktop = FakeDesktop()
|
||||
cfg = self._config()
|
||||
cfg.advanced.strict_startup = False
|
||||
ai_processor = FakeAIProcessor()
|
||||
ai_processor.warmup_error = RuntimeError("warmup boom")
|
||||
with patch("aman._build_whisper_model", return_value=FakeModel()), patch(
|
||||
"aman.LlamaProcessor", return_value=ai_processor
|
||||
):
|
||||
with self.assertLogs(level="WARNING") as logs:
|
||||
daemon = aman.Daemon(cfg, desktop, verbose=False)
|
||||
|
||||
self.assertIs(daemon.editor_stage._processor, ai_processor)
|
||||
self.assertTrue(
|
||||
any("continuing because advanced.strict_startup=false" in line for line in logs.output)
|
||||
)
|
||||
|
||||
@patch("aman.stop_audio_recording", return_value=FakeAudio(8))
|
||||
@patch("aman.start_audio_recording", return_value=(object(), object()))
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import sys
|
|||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
|
|
@ -92,6 +93,20 @@ class _RetrySetupDesktop(_FakeDesktop):
|
|||
on_quit()
|
||||
|
||||
|
||||
class _FakeBenchEditorStage:
|
||||
def warmup(self):
|
||||
return
|
||||
|
||||
def rewrite(self, transcript, *, language, dictionary_context):
|
||||
_ = dictionary_context
|
||||
return SimpleNamespace(
|
||||
final_text=f"[{language}] {transcript.strip()}",
|
||||
latency_ms=1.0,
|
||||
pass1_ms=0.5,
|
||||
pass2_ms=0.5,
|
||||
)
|
||||
|
||||
|
||||
class AmanCliTests(unittest.TestCase):
|
||||
def test_parse_cli_args_defaults_to_run_command(self):
|
||||
args = aman._parse_cli_args(["--dry-run"])
|
||||
|
|
@ -111,6 +126,85 @@ class AmanCliTests(unittest.TestCase):
|
|||
self.assertEqual(args.command, "self-check")
|
||||
self.assertTrue(args.json)
|
||||
|
||||
def test_parse_cli_args_bench_command(self):
|
||||
args = aman._parse_cli_args(
|
||||
["bench", "--text", "hello", "--repeat", "2", "--warmup", "0", "--json"]
|
||||
)
|
||||
|
||||
self.assertEqual(args.command, "bench")
|
||||
self.assertEqual(args.text, "hello")
|
||||
self.assertEqual(args.repeat, 2)
|
||||
self.assertEqual(args.warmup, 0)
|
||||
self.assertTrue(args.json)
|
||||
|
||||
def test_parse_cli_args_bench_requires_input(self):
|
||||
with self.assertRaises(SystemExit):
|
||||
aman._parse_cli_args(["bench"])
|
||||
|
||||
def test_parse_cli_args_eval_models_command(self):
|
||||
args = aman._parse_cli_args(
|
||||
["eval-models", "--dataset", "benchmarks/cleanup_dataset.jsonl", "--matrix", "benchmarks/model_matrix.small_first.json"]
|
||||
)
|
||||
self.assertEqual(args.command, "eval-models")
|
||||
self.assertEqual(args.dataset, "benchmarks/cleanup_dataset.jsonl")
|
||||
self.assertEqual(args.matrix, "benchmarks/model_matrix.small_first.json")
|
||||
self.assertEqual(args.heuristic_dataset, "")
|
||||
self.assertEqual(args.heuristic_weight, 0.25)
|
||||
self.assertEqual(args.report_version, 2)
|
||||
|
||||
def test_parse_cli_args_eval_models_with_heuristic_options(self):
|
||||
args = aman._parse_cli_args(
|
||||
[
|
||||
"eval-models",
|
||||
"--dataset",
|
||||
"benchmarks/cleanup_dataset.jsonl",
|
||||
"--matrix",
|
||||
"benchmarks/model_matrix.small_first.json",
|
||||
"--heuristic-dataset",
|
||||
"benchmarks/heuristics_dataset.jsonl",
|
||||
"--heuristic-weight",
|
||||
"0.4",
|
||||
"--report-version",
|
||||
"2",
|
||||
]
|
||||
)
|
||||
self.assertEqual(args.heuristic_dataset, "benchmarks/heuristics_dataset.jsonl")
|
||||
self.assertEqual(args.heuristic_weight, 0.4)
|
||||
self.assertEqual(args.report_version, 2)
|
||||
|
||||
def test_parse_cli_args_build_heuristic_dataset_command(self):
|
||||
args = aman._parse_cli_args(
|
||||
[
|
||||
"build-heuristic-dataset",
|
||||
"--input",
|
||||
"benchmarks/heuristics_dataset.raw.jsonl",
|
||||
"--output",
|
||||
"benchmarks/heuristics_dataset.jsonl",
|
||||
]
|
||||
)
|
||||
self.assertEqual(args.command, "build-heuristic-dataset")
|
||||
self.assertEqual(args.input, "benchmarks/heuristics_dataset.raw.jsonl")
|
||||
self.assertEqual(args.output, "benchmarks/heuristics_dataset.jsonl")
|
||||
|
||||
def test_parse_cli_args_sync_default_model_command(self):
|
||||
args = aman._parse_cli_args(
|
||||
[
|
||||
"sync-default-model",
|
||||
"--report",
|
||||
"benchmarks/results/latest.json",
|
||||
"--artifacts",
|
||||
"benchmarks/model_artifacts.json",
|
||||
"--constants",
|
||||
"src/constants.py",
|
||||
"--check",
|
||||
]
|
||||
)
|
||||
self.assertEqual(args.command, "sync-default-model")
|
||||
self.assertEqual(args.report, "benchmarks/results/latest.json")
|
||||
self.assertEqual(args.artifacts, "benchmarks/model_artifacts.json")
|
||||
self.assertEqual(args.constants, "src/constants.py")
|
||||
self.assertTrue(args.check)
|
||||
|
||||
def test_version_command_prints_version(self):
|
||||
out = io.StringIO()
|
||||
args = aman._parse_cli_args(["version"])
|
||||
|
|
@ -145,6 +239,259 @@ class AmanCliTests(unittest.TestCase):
|
|||
self.assertEqual(exit_code, 2)
|
||||
self.assertIn("[FAIL] config.load", out.getvalue())
|
||||
|
||||
def test_bench_command_json_output(self):
|
||||
args = aman._parse_cli_args(["bench", "--text", "hello", "--repeat", "2", "--warmup", "0", "--json"])
|
||||
out = io.StringIO()
|
||||
with patch("aman.load", return_value=Config()), patch(
|
||||
"aman._build_editor_stage", return_value=_FakeBenchEditorStage()
|
||||
), patch("sys.stdout", out):
|
||||
exit_code = aman._bench_command(args)
|
||||
|
||||
self.assertEqual(exit_code, 0)
|
||||
payload = json.loads(out.getvalue())
|
||||
self.assertEqual(payload["measured_runs"], 2)
|
||||
self.assertEqual(payload["summary"]["runs"], 2)
|
||||
self.assertEqual(len(payload["runs"]), 2)
|
||||
self.assertEqual(payload["editor_backend"], "local_llama_builtin")
|
||||
self.assertIn("avg_alignment_ms", payload["summary"])
|
||||
self.assertIn("avg_fact_guard_ms", payload["summary"])
|
||||
self.assertIn("alignment_applied", payload["runs"][0])
|
||||
self.assertIn("fact_guard_action", payload["runs"][0])
|
||||
|
||||
def test_bench_command_supports_text_file_input(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
text_file = Path(td) / "input.txt"
|
||||
text_file.write_text("hello from file", encoding="utf-8")
|
||||
args = aman._parse_cli_args(
|
||||
["bench", "--text-file", str(text_file), "--repeat", "1", "--warmup", "0", "--print-output"]
|
||||
)
|
||||
out = io.StringIO()
|
||||
with patch("aman.load", return_value=Config()), patch(
|
||||
"aman._build_editor_stage", return_value=_FakeBenchEditorStage()
|
||||
), patch("sys.stdout", out):
|
||||
exit_code = aman._bench_command(args)
|
||||
|
||||
self.assertEqual(exit_code, 0)
|
||||
self.assertIn("[auto] hello from file", out.getvalue())
|
||||
|
||||
def test_bench_command_rejects_empty_input(self):
|
||||
args = aman._parse_cli_args(["bench", "--text", " "])
|
||||
with patch("aman.load", return_value=Config()), patch(
|
||||
"aman._build_editor_stage", return_value=_FakeBenchEditorStage()
|
||||
):
|
||||
exit_code = aman._bench_command(args)
|
||||
|
||||
self.assertEqual(exit_code, 1)
|
||||
|
||||
def test_bench_command_rejects_non_positive_repeat(self):
|
||||
args = aman._parse_cli_args(["bench", "--text", "hello", "--repeat", "0"])
|
||||
with patch("aman.load", return_value=Config()), patch(
|
||||
"aman._build_editor_stage", return_value=_FakeBenchEditorStage()
|
||||
):
|
||||
exit_code = aman._bench_command(args)
|
||||
|
||||
self.assertEqual(exit_code, 1)
|
||||
|
||||
def test_eval_models_command_writes_report(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
output_path = Path(td) / "report.json"
|
||||
args = aman._parse_cli_args(
|
||||
[
|
||||
"eval-models",
|
||||
"--dataset",
|
||||
"benchmarks/cleanup_dataset.jsonl",
|
||||
"--matrix",
|
||||
"benchmarks/model_matrix.small_first.json",
|
||||
"--output",
|
||||
str(output_path),
|
||||
"--json",
|
||||
]
|
||||
)
|
||||
out = io.StringIO()
|
||||
fake_report = {
|
||||
"models": [{"name": "base", "best_param_set": {"latency_ms": {"p50": 1000.0}, "quality": {"hybrid_score_avg": 0.8, "parse_valid_rate": 1.0}}}],
|
||||
"winner_recommendation": {"name": "base", "reason": "test"},
|
||||
}
|
||||
with patch("aman.run_model_eval", return_value=fake_report), patch("sys.stdout", out):
|
||||
exit_code = aman._eval_models_command(args)
|
||||
self.assertEqual(exit_code, 0)
|
||||
self.assertTrue(output_path.exists())
|
||||
payload = json.loads(output_path.read_text(encoding="utf-8"))
|
||||
self.assertEqual(payload["winner_recommendation"]["name"], "base")
|
||||
|
||||
def test_eval_models_command_forwards_heuristic_arguments(self):
|
||||
args = aman._parse_cli_args(
|
||||
[
|
||||
"eval-models",
|
||||
"--dataset",
|
||||
"benchmarks/cleanup_dataset.jsonl",
|
||||
"--matrix",
|
||||
"benchmarks/model_matrix.small_first.json",
|
||||
"--heuristic-dataset",
|
||||
"benchmarks/heuristics_dataset.jsonl",
|
||||
"--heuristic-weight",
|
||||
"0.35",
|
||||
"--report-version",
|
||||
"2",
|
||||
"--json",
|
||||
]
|
||||
)
|
||||
out = io.StringIO()
|
||||
fake_report = {
|
||||
"models": [{"name": "base", "best_param_set": {}}],
|
||||
"winner_recommendation": {"name": "base", "reason": "ok"},
|
||||
}
|
||||
with patch("aman.run_model_eval", return_value=fake_report) as run_eval_mock, patch(
|
||||
"sys.stdout", out
|
||||
):
|
||||
exit_code = aman._eval_models_command(args)
|
||||
self.assertEqual(exit_code, 0)
|
||||
run_eval_mock.assert_called_once_with(
|
||||
"benchmarks/cleanup_dataset.jsonl",
|
||||
"benchmarks/model_matrix.small_first.json",
|
||||
heuristic_dataset_path="benchmarks/heuristics_dataset.jsonl",
|
||||
heuristic_weight=0.35,
|
||||
report_version=2,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
def test_build_heuristic_dataset_command_json_output(self):
|
||||
args = aman._parse_cli_args(
|
||||
[
|
||||
"build-heuristic-dataset",
|
||||
"--input",
|
||||
"benchmarks/heuristics_dataset.raw.jsonl",
|
||||
"--output",
|
||||
"benchmarks/heuristics_dataset.jsonl",
|
||||
"--json",
|
||||
]
|
||||
)
|
||||
out = io.StringIO()
|
||||
summary = {
|
||||
"raw_rows": 4,
|
||||
"written_rows": 4,
|
||||
"generated_word_rows": 2,
|
||||
"output_path": "benchmarks/heuristics_dataset.jsonl",
|
||||
}
|
||||
with patch("aman.build_heuristic_dataset", return_value=summary), patch("sys.stdout", out):
|
||||
exit_code = aman._build_heuristic_dataset_command(args)
|
||||
self.assertEqual(exit_code, 0)
|
||||
payload = json.loads(out.getvalue())
|
||||
self.assertEqual(payload["written_rows"], 4)
|
||||
|
||||
def test_sync_default_model_command_updates_constants(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
report_path = Path(td) / "latest.json"
|
||||
artifacts_path = Path(td) / "artifacts.json"
|
||||
constants_path = Path(td) / "constants.py"
|
||||
report_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"winner_recommendation": {
|
||||
"name": "test-model",
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
artifacts_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"name": "test-model",
|
||||
"filename": "winner.gguf",
|
||||
"url": "https://example.invalid/winner.gguf",
|
||||
"sha256": "a" * 64,
|
||||
}
|
||||
]
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
constants_path.write_text(
|
||||
(
|
||||
'MODEL_NAME = "old.gguf"\n'
|
||||
'MODEL_URL = "https://example.invalid/old.gguf"\n'
|
||||
'MODEL_SHA256 = "' + ("b" * 64) + '"\n'
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
args = aman._parse_cli_args(
|
||||
[
|
||||
"sync-default-model",
|
||||
"--report",
|
||||
str(report_path),
|
||||
"--artifacts",
|
||||
str(artifacts_path),
|
||||
"--constants",
|
||||
str(constants_path),
|
||||
]
|
||||
)
|
||||
exit_code = aman._sync_default_model_command(args)
|
||||
self.assertEqual(exit_code, 0)
|
||||
updated = constants_path.read_text(encoding="utf-8")
|
||||
self.assertIn('MODEL_NAME = "winner.gguf"', updated)
|
||||
self.assertIn('MODEL_URL = "https://example.invalid/winner.gguf"', updated)
|
||||
self.assertIn('MODEL_SHA256 = "' + ("a" * 64) + '"', updated)
|
||||
|
||||
def test_sync_default_model_command_check_mode_returns_2_on_drift(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
report_path = Path(td) / "latest.json"
|
||||
artifacts_path = Path(td) / "artifacts.json"
|
||||
constants_path = Path(td) / "constants.py"
|
||||
report_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"winner_recommendation": {
|
||||
"name": "test-model",
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
artifacts_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"name": "test-model",
|
||||
"filename": "winner.gguf",
|
||||
"url": "https://example.invalid/winner.gguf",
|
||||
"sha256": "a" * 64,
|
||||
}
|
||||
]
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
constants_path.write_text(
|
||||
(
|
||||
'MODEL_NAME = "old.gguf"\n'
|
||||
'MODEL_URL = "https://example.invalid/old.gguf"\n'
|
||||
'MODEL_SHA256 = "' + ("b" * 64) + '"\n'
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
args = aman._parse_cli_args(
|
||||
[
|
||||
"sync-default-model",
|
||||
"--report",
|
||||
str(report_path),
|
||||
"--artifacts",
|
||||
str(artifacts_path),
|
||||
"--constants",
|
||||
str(constants_path),
|
||||
"--check",
|
||||
]
|
||||
)
|
||||
exit_code = aman._sync_default_model_command(args)
|
||||
self.assertEqual(exit_code, 2)
|
||||
updated = constants_path.read_text(encoding="utf-8")
|
||||
self.assertIn('MODEL_NAME = "old.gguf"', updated)
|
||||
|
||||
def test_init_command_creates_default_config(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
|
|
|
|||
80
tests/test_asr_whisper.py
Normal file
80
tests/test_asr_whisper.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from stages.asr_whisper import WhisperAsrStage
|
||||
|
||||
|
||||
class _Word:
|
||||
def __init__(self, word: str, start: float, end: float, probability: float = 0.9):
|
||||
self.word = word
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.probability = probability
|
||||
|
||||
|
||||
class _Segment:
|
||||
def __init__(self, text: str, start: float, end: float, words=None):
|
||||
self.text = text
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.words = words or []
|
||||
|
||||
|
||||
class _ModelWithWordTimestamps:
|
||||
def __init__(self):
|
||||
self.kwargs = {}
|
||||
|
||||
def transcribe(self, _audio, language=None, vad_filter=None, word_timestamps=False):
|
||||
self.kwargs = {
|
||||
"language": language,
|
||||
"vad_filter": vad_filter,
|
||||
"word_timestamps": word_timestamps,
|
||||
}
|
||||
words = [_Word("hello", 0.0, 0.3), _Word("world", 0.31, 0.6)]
|
||||
return [_Segment("hello world", 0.0, 0.6, words=words)], {}
|
||||
|
||||
|
||||
class _ModelWithoutWordTimestamps:
|
||||
def __init__(self):
|
||||
self.kwargs = {}
|
||||
|
||||
def transcribe(self, _audio, language=None, vad_filter=None):
|
||||
self.kwargs = {
|
||||
"language": language,
|
||||
"vad_filter": vad_filter,
|
||||
}
|
||||
return [_Segment("hello", 0.0, 0.2, words=[])], {}
|
||||
|
||||
|
||||
class WhisperAsrStageTests(unittest.TestCase):
|
||||
def test_transcribe_requests_word_timestamps_when_supported(self):
|
||||
model = _ModelWithWordTimestamps()
|
||||
stage = WhisperAsrStage(model, configured_language="auto")
|
||||
|
||||
result = stage.transcribe(object())
|
||||
|
||||
self.assertTrue(model.kwargs["word_timestamps"])
|
||||
self.assertEqual(result.raw_text, "hello world")
|
||||
self.assertEqual(len(result.words), 2)
|
||||
self.assertEqual(result.words[0].text, "hello")
|
||||
self.assertGreaterEqual(result.words[0].start_s, 0.0)
|
||||
|
||||
def test_transcribe_skips_word_timestamps_when_not_supported(self):
|
||||
model = _ModelWithoutWordTimestamps()
|
||||
stage = WhisperAsrStage(model, configured_language="auto")
|
||||
|
||||
result = stage.transcribe(object())
|
||||
|
||||
self.assertNotIn("word_timestamps", model.kwargs)
|
||||
self.assertEqual(result.raw_text, "hello")
|
||||
self.assertEqual(result.words, [])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -25,14 +25,12 @@ class ConfigTests(unittest.TestCase):
|
|||
self.assertEqual(cfg.stt.model, "base")
|
||||
self.assertEqual(cfg.stt.device, "cpu")
|
||||
self.assertEqual(cfg.stt.language, "auto")
|
||||
self.assertEqual(cfg.llm.provider, "local_llama")
|
||||
self.assertFalse(cfg.models.allow_custom_models)
|
||||
self.assertEqual(cfg.models.whisper_model_path, "")
|
||||
self.assertEqual(cfg.models.llm_model_path, "")
|
||||
self.assertFalse(cfg.external_api.enabled)
|
||||
self.assertEqual(cfg.external_api.provider, "openai")
|
||||
self.assertEqual(cfg.injection.backend, "clipboard")
|
||||
self.assertFalse(cfg.injection.remove_transcription_from_clipboard)
|
||||
self.assertTrue(cfg.safety.enabled)
|
||||
self.assertFalse(cfg.safety.strict)
|
||||
self.assertEqual(cfg.ux.profile, "default")
|
||||
self.assertTrue(cfg.ux.show_notifications)
|
||||
self.assertTrue(cfg.advanced.strict_startup)
|
||||
|
|
@ -54,13 +52,15 @@ class ConfigTests(unittest.TestCase):
|
|||
"device": "cuda",
|
||||
"language": "English",
|
||||
},
|
||||
"llm": {"provider": "local_llama"},
|
||||
"models": {"allow_custom_models": False},
|
||||
"external_api": {"enabled": False},
|
||||
"injection": {
|
||||
"backend": "injection",
|
||||
"remove_transcription_from_clipboard": True,
|
||||
},
|
||||
"safety": {
|
||||
"enabled": True,
|
||||
"strict": True,
|
||||
},
|
||||
"vocabulary": {
|
||||
"replacements": [
|
||||
{"from": "Martha", "to": "Marta"},
|
||||
|
|
@ -82,9 +82,10 @@ class ConfigTests(unittest.TestCase):
|
|||
self.assertEqual(cfg.stt.model, "small")
|
||||
self.assertEqual(cfg.stt.device, "cuda")
|
||||
self.assertEqual(cfg.stt.language, "en")
|
||||
self.assertEqual(cfg.llm.provider, "local_llama")
|
||||
self.assertEqual(cfg.injection.backend, "injection")
|
||||
self.assertTrue(cfg.injection.remove_transcription_from_clipboard)
|
||||
self.assertTrue(cfg.safety.enabled)
|
||||
self.assertTrue(cfg.safety.strict)
|
||||
self.assertEqual(len(cfg.vocabulary.replacements), 2)
|
||||
self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha")
|
||||
self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta")
|
||||
|
|
@ -138,6 +139,33 @@ class ConfigTests(unittest.TestCase):
|
|||
with self.assertRaisesRegex(ValueError, "injection.remove_transcription_from_clipboard"):
|
||||
load(str(path))
|
||||
|
||||
def test_invalid_safety_enabled_option_raises(self):
|
||||
payload = {"safety": {"enabled": "yes"}}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "safety.enabled"):
|
||||
load(str(path))
|
||||
|
||||
def test_invalid_safety_strict_option_raises(self):
|
||||
payload = {"safety": {"strict": "yes"}}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "safety.strict"):
|
||||
load(str(path))
|
||||
|
||||
def test_unknown_safety_fields_raise(self):
|
||||
payload = {"safety": {"enabled": True, "mode": "strict"}}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "safety.mode: unknown config field"):
|
||||
load(str(path))
|
||||
|
||||
def test_unknown_top_level_fields_raise(self):
|
||||
payload = {
|
||||
"custom_a": {"enabled": True},
|
||||
|
|
@ -269,10 +297,9 @@ class ConfigTests(unittest.TestCase):
|
|||
|
||||
self.assertEqual(cfg.config_version, CURRENT_CONFIG_VERSION)
|
||||
|
||||
def test_external_llm_requires_external_api_enabled(self):
|
||||
def test_legacy_llm_config_fields_raise(self):
|
||||
payload = {
|
||||
"llm": {"provider": "external_api"},
|
||||
"external_api": {"enabled": False},
|
||||
"llm": {"provider": "local_llama"},
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
|
|
@ -280,7 +307,7 @@ class ConfigTests(unittest.TestCase):
|
|||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"llm.provider: external_api provider requires external_api.enabled=true",
|
||||
"llm: unknown config field",
|
||||
):
|
||||
load(str(path))
|
||||
|
||||
|
|
|
|||
|
|
@ -23,37 +23,20 @@ class ConfigUiRuntimeModeTests(unittest.TestCase):
|
|||
|
||||
def test_infer_runtime_mode_detects_expert_overrides(self):
|
||||
cfg = Config()
|
||||
cfg.llm.provider = "external_api"
|
||||
cfg.external_api.enabled = True
|
||||
cfg.models.allow_custom_models = True
|
||||
self.assertEqual(infer_runtime_mode(cfg), RUNTIME_MODE_EXPERT)
|
||||
|
||||
def test_apply_canonical_runtime_defaults_resets_expert_fields(self):
|
||||
cfg = Config()
|
||||
cfg.stt.provider = "local_whisper"
|
||||
cfg.llm.provider = "external_api"
|
||||
cfg.external_api.enabled = True
|
||||
cfg.external_api.base_url = "https://example.local/v1"
|
||||
cfg.external_api.model = "custom-model"
|
||||
cfg.external_api.api_key_env_var = "CUSTOM_KEY"
|
||||
cfg.external_api.timeout_ms = 321
|
||||
cfg.external_api.max_retries = 8
|
||||
cfg.models.allow_custom_models = True
|
||||
cfg.models.whisper_model_path = "/tmp/custom-whisper.bin"
|
||||
cfg.models.llm_model_path = "/tmp/custom-model.gguf"
|
||||
|
||||
apply_canonical_runtime_defaults(cfg)
|
||||
|
||||
self.assertEqual(cfg.stt.provider, "local_whisper")
|
||||
self.assertEqual(cfg.llm.provider, "local_llama")
|
||||
self.assertFalse(cfg.external_api.enabled)
|
||||
self.assertEqual(cfg.external_api.base_url, "https://api.openai.com/v1")
|
||||
self.assertEqual(cfg.external_api.model, "gpt-4o-mini")
|
||||
self.assertEqual(cfg.external_api.api_key_env_var, "AMAN_EXTERNAL_API_KEY")
|
||||
self.assertEqual(cfg.external_api.timeout_ms, 15000)
|
||||
self.assertEqual(cfg.external_api.max_retries, 2)
|
||||
self.assertFalse(cfg.models.allow_custom_models)
|
||||
self.assertEqual(cfg.models.whisper_model_path, "")
|
||||
self.assertEqual(cfg.models.llm_model_path, "")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
86
tests/test_fact_guard.py
Normal file
86
tests/test_fact_guard.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from stages.fact_guard import FactGuardEngine
|
||||
|
||||
|
||||
class FactGuardEngineTests(unittest.TestCase):
|
||||
def test_disabled_guard_accepts_candidate(self):
|
||||
guard = FactGuardEngine()
|
||||
|
||||
result = guard.apply(
|
||||
"set alarm for 7",
|
||||
"set alarm for 8",
|
||||
enabled=False,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
self.assertEqual(result.action, "accepted")
|
||||
self.assertEqual(result.final_text, "set alarm for 8")
|
||||
self.assertEqual(result.violations_count, 0)
|
||||
|
||||
def test_fallbacks_on_number_change(self):
|
||||
guard = FactGuardEngine()
|
||||
|
||||
result = guard.apply(
|
||||
"set alarm for 7",
|
||||
"set alarm for 8",
|
||||
enabled=True,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
self.assertEqual(result.action, "fallback")
|
||||
self.assertEqual(result.final_text, "set alarm for 7")
|
||||
self.assertGreaterEqual(result.violations_count, 1)
|
||||
|
||||
def test_fallbacks_on_name_change(self):
|
||||
guard = FactGuardEngine()
|
||||
|
||||
result = guard.apply(
|
||||
"invite Marta tomorrow",
|
||||
"invite Martha tomorrow",
|
||||
enabled=True,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
self.assertEqual(result.action, "fallback")
|
||||
self.assertEqual(result.final_text, "invite Marta tomorrow")
|
||||
self.assertGreaterEqual(result.violations_count, 1)
|
||||
|
||||
def test_accepts_style_only_rewrite(self):
|
||||
guard = FactGuardEngine()
|
||||
|
||||
result = guard.apply(
|
||||
"please send the report",
|
||||
"Please send the report.",
|
||||
enabled=True,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
self.assertEqual(result.action, "accepted")
|
||||
self.assertEqual(result.final_text, "Please send the report.")
|
||||
self.assertEqual(result.violations_count, 0)
|
||||
|
||||
def test_strict_mode_rejects_large_lexical_additions(self):
|
||||
guard = FactGuardEngine()
|
||||
|
||||
result = guard.apply(
|
||||
"send the report",
|
||||
"send the report and include two extra paragraphs with assumptions",
|
||||
enabled=True,
|
||||
strict=True,
|
||||
)
|
||||
|
||||
self.assertEqual(result.action, "rejected")
|
||||
self.assertEqual(result.final_text, "send the report")
|
||||
self.assertGreaterEqual(result.violations_count, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
137
tests/test_model_eval.py
Normal file
137
tests/test_model_eval.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
import model_eval
|
||||
|
||||
|
||||
class _FakeProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
_ = (args, kwargs)
|
||||
|
||||
def warmup(self, **kwargs):
|
||||
_ = kwargs
|
||||
return
|
||||
|
||||
def process(self, text, **kwargs):
|
||||
_ = kwargs
|
||||
return text.strip()
|
||||
|
||||
|
||||
class ModelEvalTests(unittest.TestCase):
|
||||
def test_load_eval_dataset_validates_required_fields(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
dataset = Path(td) / "dataset.jsonl"
|
||||
dataset.write_text(
|
||||
'{"id":"c1","input_text":"hello","expected_output":"hello"}\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
cases = model_eval.load_eval_dataset(dataset)
|
||||
self.assertEqual(len(cases), 1)
|
||||
self.assertEqual(cases[0].case_id, "c1")
|
||||
|
||||
def test_run_model_eval_produces_report(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
model_file = Path(td) / "fake.gguf"
|
||||
model_file.write_text("fake", encoding="utf-8")
|
||||
dataset = Path(td) / "dataset.jsonl"
|
||||
dataset.write_text(
|
||||
(
|
||||
'{"id":"c1","input_text":"hello world","expected_output":"hello world"}\n'
|
||||
'{"id":"c2","input_text":"hello","expected_output":"hello"}\n'
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
matrix = Path(td) / "matrix.json"
|
||||
heuristic_dataset = Path(td) / "heuristics.jsonl"
|
||||
matrix.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"warmup_runs": 0,
|
||||
"measured_runs": 1,
|
||||
"timeout_sec": 30,
|
||||
"baseline_model": {
|
||||
"name": "base",
|
||||
"provider": "local_llama",
|
||||
"model_path": str(model_file),
|
||||
"profile": "default",
|
||||
"param_grid": {"temperature": [0.0]},
|
||||
},
|
||||
"candidate_models": [
|
||||
{
|
||||
"name": "small",
|
||||
"provider": "local_llama",
|
||||
"model_path": str(model_file),
|
||||
"profile": "fast",
|
||||
"param_grid": {"temperature": [0.0, 0.1], "max_tokens": [96]},
|
||||
}
|
||||
],
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
heuristic_dataset.write_text(
|
||||
(
|
||||
'{"id":"h1","transcript":"set alarm for 6 i mean 7","words":[{"text":"set","start_s":0.0,"end_s":0.1,"prob":0.9},{"text":"alarm","start_s":0.2,"end_s":0.3,"prob":0.9},{"text":"for","start_s":0.4,"end_s":0.5,"prob":0.9},{"text":"6","start_s":0.6,"end_s":0.7,"prob":0.9},{"text":"i","start_s":0.8,"end_s":0.9,"prob":0.9},{"text":"mean","start_s":1.0,"end_s":1.1,"prob":0.9},{"text":"7","start_s":1.2,"end_s":1.3,"prob":0.9}],"expected_aligned_text":"set alarm for 7","expected":{"applied_min":1,"required_rule_ids":["cue_correction"]},"tags":["i_mean_correction"]}\n'
|
||||
'{"id":"h2","transcript":"write exactly i mean this sincerely","words":[{"text":"write","start_s":0.0,"end_s":0.1,"prob":0.9},{"text":"exactly","start_s":0.2,"end_s":0.3,"prob":0.9},{"text":"i","start_s":0.4,"end_s":0.5,"prob":0.9},{"text":"mean","start_s":0.6,"end_s":0.7,"prob":0.9},{"text":"this","start_s":0.8,"end_s":0.9,"prob":0.9},{"text":"sincerely","start_s":1.0,"end_s":1.1,"prob":0.9}],"expected_aligned_text":"write exactly i mean this sincerely","expected":{"required_rule_ids":[],"forbidden_rule_ids":["cue_correction"]},"tags":["i_mean_literal"]}\n'
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
with patch("model_eval.LlamaProcessor", _FakeProcessor):
|
||||
report = model_eval.run_model_eval(
|
||||
dataset,
|
||||
matrix,
|
||||
heuristic_dataset_path=heuristic_dataset,
|
||||
heuristic_weight=0.3,
|
||||
report_version=2,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
self.assertEqual(report["report_version"], 2)
|
||||
self.assertIn("models", report)
|
||||
self.assertEqual(len(report["models"]), 2)
|
||||
self.assertIn("winner_recommendation", report)
|
||||
self.assertIn("heuristic_eval", report)
|
||||
self.assertEqual(report["heuristic_eval"]["cases"], 2)
|
||||
self.assertIn("combined_score", report["models"][0]["best_param_set"])
|
||||
summary = model_eval.format_model_eval_summary(report)
|
||||
self.assertIn("model eval summary", summary)
|
||||
|
||||
def test_load_heuristic_dataset_validates_required_fields(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
dataset = Path(td) / "heuristics.jsonl"
|
||||
dataset.write_text(
|
||||
'{"id":"h1","transcript":"hello world","words":[{"text":"hello","start_s":0.0,"end_s":0.1},{"text":"world","start_s":0.2,"end_s":0.3}],"expected_aligned_text":"hello world"}\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
cases = model_eval.load_heuristic_dataset(dataset)
|
||||
self.assertEqual(len(cases), 1)
|
||||
self.assertEqual(cases[0].case_id, "h1")
|
||||
self.assertEqual(cases[0].expected.applied_min, 0)
|
||||
|
||||
def test_build_heuristic_dataset_generates_words_when_missing(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
source = Path(td) / "heuristics.raw.jsonl"
|
||||
output = Path(td) / "heuristics.jsonl"
|
||||
source.write_text(
|
||||
'{"id":"h1","transcript":"please send it","expected_aligned_text":"please send it","expected":{"applied_min":0}}\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
summary = model_eval.build_heuristic_dataset(source, output)
|
||||
self.assertEqual(summary["written_rows"], 1)
|
||||
self.assertEqual(summary["generated_word_rows"], 1)
|
||||
loaded = model_eval.load_heuristic_dataset(output)
|
||||
self.assertEqual(len(loaded), 1)
|
||||
self.assertGreaterEqual(len(loaded[0].words), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
129
tests/test_pipeline_engine.py
Normal file
129
tests/test_pipeline_engine.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from engine.pipeline import PipelineEngine
|
||||
from stages.alignment_edits import AlignmentHeuristicEngine
|
||||
from stages.asr_whisper import AsrResult, AsrSegment, AsrWord
|
||||
from vocabulary import VocabularyEngine
|
||||
from config import VocabularyConfig
|
||||
|
||||
|
||||
class _FakeEditor:
|
||||
def __init__(self, *, output_text: str | None = None):
|
||||
self.calls = []
|
||||
self.output_text = output_text
|
||||
|
||||
def rewrite(self, transcript, *, language, dictionary_context):
|
||||
self.calls.append(
|
||||
{
|
||||
"transcript": transcript,
|
||||
"language": language,
|
||||
"dictionary_context": dictionary_context,
|
||||
}
|
||||
)
|
||||
|
||||
final_text = transcript if self.output_text is None else self.output_text
|
||||
return SimpleNamespace(
|
||||
final_text=final_text,
|
||||
latency_ms=1.0,
|
||||
pass1_ms=0.5,
|
||||
pass2_ms=0.5,
|
||||
)
|
||||
|
||||
|
||||
class _FakeAsr:
|
||||
def transcribe(self, _audio):
|
||||
words = [
|
||||
AsrWord("set", 0.0, 0.1, 0.9),
|
||||
AsrWord("alarm", 0.2, 0.3, 0.9),
|
||||
AsrWord("for", 0.4, 0.5, 0.9),
|
||||
AsrWord("6", 0.6, 0.7, 0.9),
|
||||
AsrWord("i", 0.8, 0.9, 0.9),
|
||||
AsrWord("mean", 1.0, 1.1, 0.9),
|
||||
AsrWord("7", 1.2, 1.3, 0.9),
|
||||
]
|
||||
segments = [AsrSegment(text="set alarm for 6 i mean 7", start_s=0.0, end_s=1.3)]
|
||||
return AsrResult(
|
||||
raw_text="set alarm for 6 i mean 7",
|
||||
language="en",
|
||||
latency_ms=5.0,
|
||||
words=words,
|
||||
segments=segments,
|
||||
)
|
||||
|
||||
|
||||
class PipelineEngineTests(unittest.TestCase):
|
||||
def test_alignment_draft_is_forwarded_to_editor(self):
|
||||
editor = _FakeEditor()
|
||||
pipeline = PipelineEngine(
|
||||
asr_stage=_FakeAsr(),
|
||||
editor_stage=editor,
|
||||
vocabulary=VocabularyEngine(VocabularyConfig()),
|
||||
alignment_engine=AlignmentHeuristicEngine(),
|
||||
)
|
||||
|
||||
result = pipeline.run_audio(object())
|
||||
|
||||
self.assertEqual(editor.calls[0]["transcript"], "set alarm for 7")
|
||||
self.assertEqual(result.alignment_applied, 1)
|
||||
self.assertGreaterEqual(result.alignment_ms, 0.0)
|
||||
self.assertEqual(result.fact_guard_action, "accepted")
|
||||
self.assertEqual(result.fact_guard_violations, 0)
|
||||
|
||||
def test_run_transcript_without_words_keeps_alignment_noop(self):
|
||||
editor = _FakeEditor()
|
||||
pipeline = PipelineEngine(
|
||||
asr_stage=None,
|
||||
editor_stage=editor,
|
||||
vocabulary=VocabularyEngine(VocabularyConfig()),
|
||||
alignment_engine=AlignmentHeuristicEngine(),
|
||||
)
|
||||
|
||||
result = pipeline.run_transcript("hello world", language="en")
|
||||
|
||||
self.assertEqual(editor.calls[0]["transcript"], "hello world")
|
||||
self.assertEqual(result.alignment_applied, 0)
|
||||
self.assertEqual(result.fact_guard_action, "accepted")
|
||||
self.assertEqual(result.fact_guard_violations, 0)
|
||||
|
||||
def test_fact_guard_fallbacks_when_editor_changes_number(self):
|
||||
editor = _FakeEditor(output_text="set alarm for 8")
|
||||
pipeline = PipelineEngine(
|
||||
asr_stage=None,
|
||||
editor_stage=editor,
|
||||
vocabulary=VocabularyEngine(VocabularyConfig()),
|
||||
alignment_engine=AlignmentHeuristicEngine(),
|
||||
safety_enabled=True,
|
||||
safety_strict=False,
|
||||
)
|
||||
|
||||
result = pipeline.run_transcript("set alarm for 7", language="en")
|
||||
|
||||
self.assertEqual(result.output_text, "set alarm for 7")
|
||||
self.assertEqual(result.fact_guard_action, "fallback")
|
||||
self.assertGreaterEqual(result.fact_guard_violations, 1)
|
||||
|
||||
def test_fact_guard_strict_rejects_number_change(self):
|
||||
editor = _FakeEditor(output_text="set alarm for 8")
|
||||
pipeline = PipelineEngine(
|
||||
asr_stage=None,
|
||||
editor_stage=editor,
|
||||
vocabulary=VocabularyEngine(VocabularyConfig()),
|
||||
alignment_engine=AlignmentHeuristicEngine(),
|
||||
safety_enabled=True,
|
||||
safety_strict=True,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "fact guard rejected editor output"):
|
||||
pipeline.run_transcript("set alarm for 7", language="en")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue