Add benchmark-driven model promotion workflow and pipeline stages
Some checks failed
ci / test-and-build (push) Has been cancelled
Some checks failed
ci / test-and-build (push) Has been cancelled
This commit is contained in:
parent
98b13d1069
commit
8c1f7c1e13
38 changed files with 5300 additions and 503 deletions
80
tests/test_asr_whisper.py
Normal file
80
tests/test_asr_whisper.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from stages.asr_whisper import WhisperAsrStage
|
||||
|
||||
|
||||
class _Word:
|
||||
def __init__(self, word: str, start: float, end: float, probability: float = 0.9):
|
||||
self.word = word
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.probability = probability
|
||||
|
||||
|
||||
class _Segment:
|
||||
def __init__(self, text: str, start: float, end: float, words=None):
|
||||
self.text = text
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.words = words or []
|
||||
|
||||
|
||||
class _ModelWithWordTimestamps:
|
||||
def __init__(self):
|
||||
self.kwargs = {}
|
||||
|
||||
def transcribe(self, _audio, language=None, vad_filter=None, word_timestamps=False):
|
||||
self.kwargs = {
|
||||
"language": language,
|
||||
"vad_filter": vad_filter,
|
||||
"word_timestamps": word_timestamps,
|
||||
}
|
||||
words = [_Word("hello", 0.0, 0.3), _Word("world", 0.31, 0.6)]
|
||||
return [_Segment("hello world", 0.0, 0.6, words=words)], {}
|
||||
|
||||
|
||||
class _ModelWithoutWordTimestamps:
|
||||
def __init__(self):
|
||||
self.kwargs = {}
|
||||
|
||||
def transcribe(self, _audio, language=None, vad_filter=None):
|
||||
self.kwargs = {
|
||||
"language": language,
|
||||
"vad_filter": vad_filter,
|
||||
}
|
||||
return [_Segment("hello", 0.0, 0.2, words=[])], {}
|
||||
|
||||
|
||||
class WhisperAsrStageTests(unittest.TestCase):
|
||||
def test_transcribe_requests_word_timestamps_when_supported(self):
|
||||
model = _ModelWithWordTimestamps()
|
||||
stage = WhisperAsrStage(model, configured_language="auto")
|
||||
|
||||
result = stage.transcribe(object())
|
||||
|
||||
self.assertTrue(model.kwargs["word_timestamps"])
|
||||
self.assertEqual(result.raw_text, "hello world")
|
||||
self.assertEqual(len(result.words), 2)
|
||||
self.assertEqual(result.words[0].text, "hello")
|
||||
self.assertGreaterEqual(result.words[0].start_s, 0.0)
|
||||
|
||||
def test_transcribe_skips_word_timestamps_when_not_supported(self):
|
||||
model = _ModelWithoutWordTimestamps()
|
||||
stage = WhisperAsrStage(model, configured_language="auto")
|
||||
|
||||
result = stage.transcribe(object())
|
||||
|
||||
self.assertNotIn("word_timestamps", model.kwargs)
|
||||
self.assertEqual(result.raw_text, "hello")
|
||||
self.assertEqual(result.words, [])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue