Some checks failed
ci / test-and-build (push) Has been cancelled
Keep transcript-only runs eligible for alignment heuristics instead of bailing out when the ASR stage does not supply word timings. Build fallback AsrWord entries from the transcript so cue-based corrections like "i mean" still apply, while reusing the existing literal guard for verbatim phrases. Cover the new path in alignment and pipeline tests, and validate with python3 -m unittest tests.test_alignment_edits tests.test_pipeline_engine.
146 lines
5.1 KiB
Python
146 lines
5.1 KiB
Python
import sys
|
|
import unittest
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
from engine.pipeline import PipelineEngine
|
|
from stages.alignment_edits import AlignmentHeuristicEngine
|
|
from stages.asr_whisper import AsrResult, AsrSegment, AsrWord
|
|
from vocabulary import VocabularyEngine
|
|
from config import VocabularyConfig
|
|
|
|
|
|
class _FakeEditor:
|
|
def __init__(self, *, output_text: str | None = None):
|
|
self.calls = []
|
|
self.output_text = output_text
|
|
|
|
def rewrite(self, transcript, *, language, dictionary_context):
|
|
self.calls.append(
|
|
{
|
|
"transcript": transcript,
|
|
"language": language,
|
|
"dictionary_context": dictionary_context,
|
|
}
|
|
)
|
|
|
|
final_text = transcript if self.output_text is None else self.output_text
|
|
return SimpleNamespace(
|
|
final_text=final_text,
|
|
latency_ms=1.0,
|
|
pass1_ms=0.5,
|
|
pass2_ms=0.5,
|
|
)
|
|
|
|
|
|
class _FakeAsr:
|
|
def transcribe(self, _audio):
|
|
words = [
|
|
AsrWord("set", 0.0, 0.1, 0.9),
|
|
AsrWord("alarm", 0.2, 0.3, 0.9),
|
|
AsrWord("for", 0.4, 0.5, 0.9),
|
|
AsrWord("6", 0.6, 0.7, 0.9),
|
|
AsrWord("i", 0.8, 0.9, 0.9),
|
|
AsrWord("mean", 1.0, 1.1, 0.9),
|
|
AsrWord("7", 1.2, 1.3, 0.9),
|
|
]
|
|
segments = [AsrSegment(text="set alarm for 6 i mean 7", start_s=0.0, end_s=1.3)]
|
|
return AsrResult(
|
|
raw_text="set alarm for 6 i mean 7",
|
|
language="en",
|
|
latency_ms=5.0,
|
|
words=words,
|
|
segments=segments,
|
|
)
|
|
|
|
|
|
class PipelineEngineTests(unittest.TestCase):
|
|
def test_alignment_draft_is_forwarded_to_editor(self):
|
|
editor = _FakeEditor()
|
|
pipeline = PipelineEngine(
|
|
asr_stage=_FakeAsr(),
|
|
editor_stage=editor,
|
|
vocabulary=VocabularyEngine(VocabularyConfig()),
|
|
alignment_engine=AlignmentHeuristicEngine(),
|
|
)
|
|
|
|
result = pipeline.run_audio(object())
|
|
|
|
self.assertEqual(editor.calls[0]["transcript"], "set alarm for 7")
|
|
self.assertEqual(result.alignment_applied, 1)
|
|
self.assertGreaterEqual(result.alignment_ms, 0.0)
|
|
self.assertEqual(result.fact_guard_action, "accepted")
|
|
self.assertEqual(result.fact_guard_violations, 0)
|
|
|
|
def test_run_transcript_without_words_keeps_alignment_noop(self):
|
|
editor = _FakeEditor()
|
|
pipeline = PipelineEngine(
|
|
asr_stage=None,
|
|
editor_stage=editor,
|
|
vocabulary=VocabularyEngine(VocabularyConfig()),
|
|
alignment_engine=AlignmentHeuristicEngine(),
|
|
)
|
|
|
|
result = pipeline.run_transcript("hello world", language="en")
|
|
|
|
self.assertEqual(editor.calls[0]["transcript"], "hello world")
|
|
self.assertEqual(result.alignment_applied, 0)
|
|
self.assertEqual(result.fact_guard_action, "accepted")
|
|
self.assertEqual(result.fact_guard_violations, 0)
|
|
|
|
def test_run_transcript_without_words_applies_i_mean_correction(self):
|
|
editor = _FakeEditor()
|
|
pipeline = PipelineEngine(
|
|
asr_stage=None,
|
|
editor_stage=editor,
|
|
vocabulary=VocabularyEngine(VocabularyConfig()),
|
|
alignment_engine=AlignmentHeuristicEngine(),
|
|
)
|
|
|
|
result = pipeline.run_transcript("schedule for 5, i mean 6", language="en")
|
|
|
|
self.assertEqual(editor.calls[0]["transcript"], "schedule for 6")
|
|
self.assertEqual(result.output_text, "schedule for 6")
|
|
self.assertEqual(result.alignment_applied, 1)
|
|
self.assertEqual(result.fact_guard_action, "accepted")
|
|
self.assertEqual(result.fact_guard_violations, 0)
|
|
|
|
def test_fact_guard_fallbacks_when_editor_changes_number(self):
|
|
editor = _FakeEditor(output_text="set alarm for 8")
|
|
pipeline = PipelineEngine(
|
|
asr_stage=None,
|
|
editor_stage=editor,
|
|
vocabulary=VocabularyEngine(VocabularyConfig()),
|
|
alignment_engine=AlignmentHeuristicEngine(),
|
|
safety_enabled=True,
|
|
safety_strict=False,
|
|
)
|
|
|
|
result = pipeline.run_transcript("set alarm for 7", language="en")
|
|
|
|
self.assertEqual(result.output_text, "set alarm for 7")
|
|
self.assertEqual(result.fact_guard_action, "fallback")
|
|
self.assertGreaterEqual(result.fact_guard_violations, 1)
|
|
|
|
def test_fact_guard_strict_rejects_number_change(self):
|
|
editor = _FakeEditor(output_text="set alarm for 8")
|
|
pipeline = PipelineEngine(
|
|
asr_stage=None,
|
|
editor_stage=editor,
|
|
vocabulary=VocabularyEngine(VocabularyConfig()),
|
|
alignment_engine=AlignmentHeuristicEngine(),
|
|
safety_enabled=True,
|
|
safety_strict=True,
|
|
)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "fact guard rejected editor output"):
|
|
pipeline.run_transcript("set alarm for 7", language="en")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|