aman/tests/test_pipeline_engine.py
Thales Maciel 8c1f7c1e13
Some checks failed
ci / test-and-build (push) Has been cancelled
Add benchmark-driven model promotion workflow and pipeline stages
2026-02-28 15:12:33 -03:00

129 lines
4.3 KiB
Python

import sys
import unittest
from pathlib import Path
from types import SimpleNamespace
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from engine.pipeline import PipelineEngine
from stages.alignment_edits import AlignmentHeuristicEngine
from stages.asr_whisper import AsrResult, AsrSegment, AsrWord
from vocabulary import VocabularyEngine
from config import VocabularyConfig
class _FakeEditor:
def __init__(self, *, output_text: str | None = None):
self.calls = []
self.output_text = output_text
def rewrite(self, transcript, *, language, dictionary_context):
self.calls.append(
{
"transcript": transcript,
"language": language,
"dictionary_context": dictionary_context,
}
)
final_text = transcript if self.output_text is None else self.output_text
return SimpleNamespace(
final_text=final_text,
latency_ms=1.0,
pass1_ms=0.5,
pass2_ms=0.5,
)
class _FakeAsr:
def transcribe(self, _audio):
words = [
AsrWord("set", 0.0, 0.1, 0.9),
AsrWord("alarm", 0.2, 0.3, 0.9),
AsrWord("for", 0.4, 0.5, 0.9),
AsrWord("6", 0.6, 0.7, 0.9),
AsrWord("i", 0.8, 0.9, 0.9),
AsrWord("mean", 1.0, 1.1, 0.9),
AsrWord("7", 1.2, 1.3, 0.9),
]
segments = [AsrSegment(text="set alarm for 6 i mean 7", start_s=0.0, end_s=1.3)]
return AsrResult(
raw_text="set alarm for 6 i mean 7",
language="en",
latency_ms=5.0,
words=words,
segments=segments,
)
class PipelineEngineTests(unittest.TestCase):
def test_alignment_draft_is_forwarded_to_editor(self):
editor = _FakeEditor()
pipeline = PipelineEngine(
asr_stage=_FakeAsr(),
editor_stage=editor,
vocabulary=VocabularyEngine(VocabularyConfig()),
alignment_engine=AlignmentHeuristicEngine(),
)
result = pipeline.run_audio(object())
self.assertEqual(editor.calls[0]["transcript"], "set alarm for 7")
self.assertEqual(result.alignment_applied, 1)
self.assertGreaterEqual(result.alignment_ms, 0.0)
self.assertEqual(result.fact_guard_action, "accepted")
self.assertEqual(result.fact_guard_violations, 0)
def test_run_transcript_without_words_keeps_alignment_noop(self):
editor = _FakeEditor()
pipeline = PipelineEngine(
asr_stage=None,
editor_stage=editor,
vocabulary=VocabularyEngine(VocabularyConfig()),
alignment_engine=AlignmentHeuristicEngine(),
)
result = pipeline.run_transcript("hello world", language="en")
self.assertEqual(editor.calls[0]["transcript"], "hello world")
self.assertEqual(result.alignment_applied, 0)
self.assertEqual(result.fact_guard_action, "accepted")
self.assertEqual(result.fact_guard_violations, 0)
def test_fact_guard_fallbacks_when_editor_changes_number(self):
editor = _FakeEditor(output_text="set alarm for 8")
pipeline = PipelineEngine(
asr_stage=None,
editor_stage=editor,
vocabulary=VocabularyEngine(VocabularyConfig()),
alignment_engine=AlignmentHeuristicEngine(),
safety_enabled=True,
safety_strict=False,
)
result = pipeline.run_transcript("set alarm for 7", language="en")
self.assertEqual(result.output_text, "set alarm for 7")
self.assertEqual(result.fact_guard_action, "fallback")
self.assertGreaterEqual(result.fact_guard_violations, 1)
def test_fact_guard_strict_rejects_number_change(self):
editor = _FakeEditor(output_text="set alarm for 8")
pipeline = PipelineEngine(
asr_stage=None,
editor_stage=editor,
vocabulary=VocabularyEngine(VocabularyConfig()),
alignment_engine=AlignmentHeuristicEngine(),
safety_enabled=True,
safety_strict=True,
)
with self.assertRaisesRegex(RuntimeError, "fact guard rejected editor output"):
pipeline.run_transcript("set alarm for 7", language="en")
if __name__ == "__main__":
unittest.main()