import sys import unittest from pathlib import Path ROOT = Path(__file__).resolve().parents[1] SRC = ROOT / "src" if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) from stages.asr_whisper import WhisperAsrStage class _Word: def __init__(self, word: str, start: float, end: float, probability: float = 0.9): self.word = word self.start = start self.end = end self.probability = probability class _Segment: def __init__(self, text: str, start: float, end: float, words=None): self.text = text self.start = start self.end = end self.words = words or [] class _ModelWithWordTimestamps: def __init__(self): self.kwargs = {} def transcribe(self, _audio, language=None, vad_filter=None, word_timestamps=False): self.kwargs = { "language": language, "vad_filter": vad_filter, "word_timestamps": word_timestamps, } words = [_Word("hello", 0.0, 0.3), _Word("world", 0.31, 0.6)] return [_Segment("hello world", 0.0, 0.6, words=words)], {} class _ModelWithoutWordTimestamps: def __init__(self): self.kwargs = {} def transcribe(self, _audio, language=None, vad_filter=None): self.kwargs = { "language": language, "vad_filter": vad_filter, } return [_Segment("hello", 0.0, 0.2, words=[])], {} class WhisperAsrStageTests(unittest.TestCase): def test_transcribe_requests_word_timestamps_when_supported(self): model = _ModelWithWordTimestamps() stage = WhisperAsrStage(model, configured_language="auto") result = stage.transcribe(object()) self.assertTrue(model.kwargs["word_timestamps"]) self.assertEqual(result.raw_text, "hello world") self.assertEqual(len(result.words), 2) self.assertEqual(result.words[0].text, "hello") self.assertGreaterEqual(result.words[0].start_s, 0.0) def test_transcribe_skips_word_timestamps_when_not_supported(self): model = _ModelWithoutWordTimestamps() stage = WhisperAsrStage(model, configured_language="auto") result = stage.transcribe(object()) self.assertNotIn("word_timestamps", model.kwargs) self.assertEqual(result.raw_text, "hello") self.assertEqual(result.words, []) if __name__ == "__main__": unittest.main()