aman/tests/test_asr_whisper.py
Thales Maciel 8c1f7c1e13
Some checks failed
ci / test-and-build (push) Has been cancelled
Add benchmark-driven model promotion workflow and pipeline stages
2026-02-28 15:12:33 -03:00

80 lines
2.4 KiB
Python

import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from stages.asr_whisper import WhisperAsrStage
class _Word:
def __init__(self, word: str, start: float, end: float, probability: float = 0.9):
self.word = word
self.start = start
self.end = end
self.probability = probability
class _Segment:
def __init__(self, text: str, start: float, end: float, words=None):
self.text = text
self.start = start
self.end = end
self.words = words or []
class _ModelWithWordTimestamps:
def __init__(self):
self.kwargs = {}
def transcribe(self, _audio, language=None, vad_filter=None, word_timestamps=False):
self.kwargs = {
"language": language,
"vad_filter": vad_filter,
"word_timestamps": word_timestamps,
}
words = [_Word("hello", 0.0, 0.3), _Word("world", 0.31, 0.6)]
return [_Segment("hello world", 0.0, 0.6, words=words)], {}
class _ModelWithoutWordTimestamps:
def __init__(self):
self.kwargs = {}
def transcribe(self, _audio, language=None, vad_filter=None):
self.kwargs = {
"language": language,
"vad_filter": vad_filter,
}
return [_Segment("hello", 0.0, 0.2, words=[])], {}
class WhisperAsrStageTests(unittest.TestCase):
def test_transcribe_requests_word_timestamps_when_supported(self):
model = _ModelWithWordTimestamps()
stage = WhisperAsrStage(model, configured_language="auto")
result = stage.transcribe(object())
self.assertTrue(model.kwargs["word_timestamps"])
self.assertEqual(result.raw_text, "hello world")
self.assertEqual(len(result.words), 2)
self.assertEqual(result.words[0].text, "hello")
self.assertGreaterEqual(result.words[0].start_s, 0.0)
def test_transcribe_skips_word_timestamps_when_not_supported(self):
model = _ModelWithoutWordTimestamps()
stage = WhisperAsrStage(model, configured_language="auto")
result = stage.transcribe(object())
self.assertNotIn("word_timestamps", model.kwargs)
self.assertEqual(result.raw_text, "hello")
self.assertEqual(result.words, [])
if __name__ == "__main__":
unittest.main()