aman/tests/test_alignment_edits.py
Thales Maciel 8c1f7c1e13
Some checks failed
ci / test-and-build (push) Has been cancelled
Add benchmark-driven model promotion workflow and pipeline stages
2026-02-28 15:12:33 -03:00

72 lines
2.3 KiB
Python

import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from stages.alignment_edits import AlignmentHeuristicEngine
from stages.asr_whisper import AsrWord
def _words(tokens: list[str], step: float = 0.2) -> list[AsrWord]:
out: list[AsrWord] = []
start = 0.0
for token in tokens:
out.append(
AsrWord(
text=token,
start_s=start,
end_s=start + 0.1,
prob=0.9,
)
)
start += step
return out
class AlignmentHeuristicEngineTests(unittest.TestCase):
def test_returns_original_when_no_words_available(self):
engine = AlignmentHeuristicEngine()
result = engine.apply("hello world", [])
self.assertEqual(result.draft_text, "hello world")
self.assertEqual(result.applied_count, 0)
self.assertEqual(result.decisions, [])
def test_applies_i_mean_tail_correction(self):
engine = AlignmentHeuristicEngine()
words = _words(["set", "alarm", "for", "6", "i", "mean", "7"])
result = engine.apply("set alarm for 6 i mean 7", words)
self.assertEqual(result.draft_text, "set alarm for 7")
self.assertEqual(result.applied_count, 1)
self.assertTrue(any(item.rule_id == "cue_correction" for item in result.decisions))
def test_preserves_literal_i_mean_context(self):
engine = AlignmentHeuristicEngine()
words = _words(["write", "exactly", "i", "mean", "this", "sincerely"])
result = engine.apply("write exactly i mean this sincerely", words)
self.assertEqual(result.draft_text, "write exactly i mean this sincerely")
self.assertEqual(result.applied_count, 0)
self.assertGreaterEqual(result.skipped_count, 1)
def test_collapses_exact_restart_repetition(self):
engine = AlignmentHeuristicEngine()
words = _words(["please", "send", "it", "please", "send", "it"])
result = engine.apply("please send it please send it", words)
self.assertEqual(result.draft_text, "please send it")
self.assertEqual(result.applied_count, 1)
self.assertTrue(any(item.rule_id == "restart_repeat" for item in result.decisions))
if __name__ == "__main__":
unittest.main()