72 lines
2.3 KiB
Python
72 lines
2.3 KiB
Python
import sys
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
from stages.alignment_edits import AlignmentHeuristicEngine
|
|
from stages.asr_whisper import AsrWord
|
|
|
|
|
|
def _words(tokens: list[str], step: float = 0.2) -> list[AsrWord]:
|
|
out: list[AsrWord] = []
|
|
start = 0.0
|
|
for token in tokens:
|
|
out.append(
|
|
AsrWord(
|
|
text=token,
|
|
start_s=start,
|
|
end_s=start + 0.1,
|
|
prob=0.9,
|
|
)
|
|
)
|
|
start += step
|
|
return out
|
|
|
|
|
|
class AlignmentHeuristicEngineTests(unittest.TestCase):
|
|
def test_returns_original_when_no_words_available(self):
|
|
engine = AlignmentHeuristicEngine()
|
|
|
|
result = engine.apply("hello world", [])
|
|
|
|
self.assertEqual(result.draft_text, "hello world")
|
|
self.assertEqual(result.applied_count, 0)
|
|
self.assertEqual(result.decisions, [])
|
|
|
|
def test_applies_i_mean_tail_correction(self):
|
|
engine = AlignmentHeuristicEngine()
|
|
words = _words(["set", "alarm", "for", "6", "i", "mean", "7"])
|
|
|
|
result = engine.apply("set alarm for 6 i mean 7", words)
|
|
|
|
self.assertEqual(result.draft_text, "set alarm for 7")
|
|
self.assertEqual(result.applied_count, 1)
|
|
self.assertTrue(any(item.rule_id == "cue_correction" for item in result.decisions))
|
|
|
|
def test_preserves_literal_i_mean_context(self):
|
|
engine = AlignmentHeuristicEngine()
|
|
words = _words(["write", "exactly", "i", "mean", "this", "sincerely"])
|
|
|
|
result = engine.apply("write exactly i mean this sincerely", words)
|
|
|
|
self.assertEqual(result.draft_text, "write exactly i mean this sincerely")
|
|
self.assertEqual(result.applied_count, 0)
|
|
self.assertGreaterEqual(result.skipped_count, 1)
|
|
|
|
def test_collapses_exact_restart_repetition(self):
|
|
engine = AlignmentHeuristicEngine()
|
|
words = _words(["please", "send", "it", "please", "send", "it"])
|
|
|
|
result = engine.apply("please send it please send it", words)
|
|
|
|
self.assertEqual(result.draft_text, "please send it")
|
|
self.assertEqual(result.applied_count, 1)
|
|
self.assertTrue(any(item.rule_id == "restart_repeat" for item in result.decisions))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|