80 lines
2.4 KiB
Python
80 lines
2.4 KiB
Python
import sys
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
from stages.asr_whisper import WhisperAsrStage
|
|
|
|
|
|
class _Word:
|
|
def __init__(self, word: str, start: float, end: float, probability: float = 0.9):
|
|
self.word = word
|
|
self.start = start
|
|
self.end = end
|
|
self.probability = probability
|
|
|
|
|
|
class _Segment:
|
|
def __init__(self, text: str, start: float, end: float, words=None):
|
|
self.text = text
|
|
self.start = start
|
|
self.end = end
|
|
self.words = words or []
|
|
|
|
|
|
class _ModelWithWordTimestamps:
|
|
def __init__(self):
|
|
self.kwargs = {}
|
|
|
|
def transcribe(self, _audio, language=None, vad_filter=None, word_timestamps=False):
|
|
self.kwargs = {
|
|
"language": language,
|
|
"vad_filter": vad_filter,
|
|
"word_timestamps": word_timestamps,
|
|
}
|
|
words = [_Word("hello", 0.0, 0.3), _Word("world", 0.31, 0.6)]
|
|
return [_Segment("hello world", 0.0, 0.6, words=words)], {}
|
|
|
|
|
|
class _ModelWithoutWordTimestamps:
|
|
def __init__(self):
|
|
self.kwargs = {}
|
|
|
|
def transcribe(self, _audio, language=None, vad_filter=None):
|
|
self.kwargs = {
|
|
"language": language,
|
|
"vad_filter": vad_filter,
|
|
}
|
|
return [_Segment("hello", 0.0, 0.2, words=[])], {}
|
|
|
|
|
|
class WhisperAsrStageTests(unittest.TestCase):
|
|
def test_transcribe_requests_word_timestamps_when_supported(self):
|
|
model = _ModelWithWordTimestamps()
|
|
stage = WhisperAsrStage(model, configured_language="auto")
|
|
|
|
result = stage.transcribe(object())
|
|
|
|
self.assertTrue(model.kwargs["word_timestamps"])
|
|
self.assertEqual(result.raw_text, "hello world")
|
|
self.assertEqual(len(result.words), 2)
|
|
self.assertEqual(result.words[0].text, "hello")
|
|
self.assertGreaterEqual(result.words[0].start_s, 0.0)
|
|
|
|
def test_transcribe_skips_word_timestamps_when_not_supported(self):
|
|
model = _ModelWithoutWordTimestamps()
|
|
stage = WhisperAsrStage(model, configured_language="auto")
|
|
|
|
result = stage.transcribe(object())
|
|
|
|
self.assertNotIn("word_timestamps", model.kwargs)
|
|
self.assertEqual(result.raw_text, "hello")
|
|
self.assertEqual(result.words, [])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|