Add benchmark-driven model promotion workflow and pipeline stages
Some checks failed
ci / test-and-build (push) Has been cancelled
Some checks failed
ci / test-and-build (push) Has been cancelled
This commit is contained in:
parent
98b13d1069
commit
8c1f7c1e13
38 changed files with 5300 additions and 503 deletions
129
tests/test_pipeline_engine.py
Normal file
129
tests/test_pipeline_engine.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from engine.pipeline import PipelineEngine
|
||||
from stages.alignment_edits import AlignmentHeuristicEngine
|
||||
from stages.asr_whisper import AsrResult, AsrSegment, AsrWord
|
||||
from vocabulary import VocabularyEngine
|
||||
from config import VocabularyConfig
|
||||
|
||||
|
||||
class _FakeEditor:
|
||||
def __init__(self, *, output_text: str | None = None):
|
||||
self.calls = []
|
||||
self.output_text = output_text
|
||||
|
||||
def rewrite(self, transcript, *, language, dictionary_context):
|
||||
self.calls.append(
|
||||
{
|
||||
"transcript": transcript,
|
||||
"language": language,
|
||||
"dictionary_context": dictionary_context,
|
||||
}
|
||||
)
|
||||
|
||||
final_text = transcript if self.output_text is None else self.output_text
|
||||
return SimpleNamespace(
|
||||
final_text=final_text,
|
||||
latency_ms=1.0,
|
||||
pass1_ms=0.5,
|
||||
pass2_ms=0.5,
|
||||
)
|
||||
|
||||
|
||||
class _FakeAsr:
|
||||
def transcribe(self, _audio):
|
||||
words = [
|
||||
AsrWord("set", 0.0, 0.1, 0.9),
|
||||
AsrWord("alarm", 0.2, 0.3, 0.9),
|
||||
AsrWord("for", 0.4, 0.5, 0.9),
|
||||
AsrWord("6", 0.6, 0.7, 0.9),
|
||||
AsrWord("i", 0.8, 0.9, 0.9),
|
||||
AsrWord("mean", 1.0, 1.1, 0.9),
|
||||
AsrWord("7", 1.2, 1.3, 0.9),
|
||||
]
|
||||
segments = [AsrSegment(text="set alarm for 6 i mean 7", start_s=0.0, end_s=1.3)]
|
||||
return AsrResult(
|
||||
raw_text="set alarm for 6 i mean 7",
|
||||
language="en",
|
||||
latency_ms=5.0,
|
||||
words=words,
|
||||
segments=segments,
|
||||
)
|
||||
|
||||
|
||||
class PipelineEngineTests(unittest.TestCase):
|
||||
def test_alignment_draft_is_forwarded_to_editor(self):
|
||||
editor = _FakeEditor()
|
||||
pipeline = PipelineEngine(
|
||||
asr_stage=_FakeAsr(),
|
||||
editor_stage=editor,
|
||||
vocabulary=VocabularyEngine(VocabularyConfig()),
|
||||
alignment_engine=AlignmentHeuristicEngine(),
|
||||
)
|
||||
|
||||
result = pipeline.run_audio(object())
|
||||
|
||||
self.assertEqual(editor.calls[0]["transcript"], "set alarm for 7")
|
||||
self.assertEqual(result.alignment_applied, 1)
|
||||
self.assertGreaterEqual(result.alignment_ms, 0.0)
|
||||
self.assertEqual(result.fact_guard_action, "accepted")
|
||||
self.assertEqual(result.fact_guard_violations, 0)
|
||||
|
||||
def test_run_transcript_without_words_keeps_alignment_noop(self):
|
||||
editor = _FakeEditor()
|
||||
pipeline = PipelineEngine(
|
||||
asr_stage=None,
|
||||
editor_stage=editor,
|
||||
vocabulary=VocabularyEngine(VocabularyConfig()),
|
||||
alignment_engine=AlignmentHeuristicEngine(),
|
||||
)
|
||||
|
||||
result = pipeline.run_transcript("hello world", language="en")
|
||||
|
||||
self.assertEqual(editor.calls[0]["transcript"], "hello world")
|
||||
self.assertEqual(result.alignment_applied, 0)
|
||||
self.assertEqual(result.fact_guard_action, "accepted")
|
||||
self.assertEqual(result.fact_guard_violations, 0)
|
||||
|
||||
def test_fact_guard_fallbacks_when_editor_changes_number(self):
|
||||
editor = _FakeEditor(output_text="set alarm for 8")
|
||||
pipeline = PipelineEngine(
|
||||
asr_stage=None,
|
||||
editor_stage=editor,
|
||||
vocabulary=VocabularyEngine(VocabularyConfig()),
|
||||
alignment_engine=AlignmentHeuristicEngine(),
|
||||
safety_enabled=True,
|
||||
safety_strict=False,
|
||||
)
|
||||
|
||||
result = pipeline.run_transcript("set alarm for 7", language="en")
|
||||
|
||||
self.assertEqual(result.output_text, "set alarm for 7")
|
||||
self.assertEqual(result.fact_guard_action, "fallback")
|
||||
self.assertGreaterEqual(result.fact_guard_violations, 1)
|
||||
|
||||
def test_fact_guard_strict_rejects_number_change(self):
|
||||
editor = _FakeEditor(output_text="set alarm for 8")
|
||||
pipeline = PipelineEngine(
|
||||
asr_stage=None,
|
||||
editor_stage=editor,
|
||||
vocabulary=VocabularyEngine(VocabularyConfig()),
|
||||
alignment_engine=AlignmentHeuristicEngine(),
|
||||
safety_enabled=True,
|
||||
safety_strict=True,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "fact guard rejected editor output"):
|
||||
pipeline.run_transcript("set alarm for 7", language="en")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue