aman/tests/test_vosk_eval.py

327 lines
12 KiB
Python

import json
import sys
import tempfile
import unittest
import wave
from pathlib import Path
from unittest.mock import patch
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from vosk_eval import (
DecodedRow,
build_phrase_to_intent_index,
load_keystroke_intents,
run_vosk_keystroke_eval,
summarize_decoded_rows,
)
class VoskEvalTests(unittest.TestCase):
def test_load_keystroke_intents_parses_valid_payload(self):
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "intents.json"
path.write_text(
json.dumps(
[
{
"intent_id": "ctrl+d",
"literal_phrase": "control d",
"nato_phrase": "control delta",
"letter": "d",
"modifier": "ctrl",
}
]
),
encoding="utf-8",
)
intents = load_keystroke_intents(path)
self.assertEqual(len(intents), 1)
self.assertEqual(intents[0].intent_id, "ctrl+d")
def test_load_keystroke_intents_rejects_duplicate_literal_phrase(self):
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "intents.json"
path.write_text(
json.dumps(
[
{
"intent_id": "ctrl+d",
"literal_phrase": "control d",
"nato_phrase": "control delta",
"letter": "d",
"modifier": "ctrl",
},
{
"intent_id": "ctrl+b",
"literal_phrase": "control d",
"nato_phrase": "control bravo",
"letter": "b",
"modifier": "ctrl",
},
]
),
encoding="utf-8",
)
with self.assertRaisesRegex(RuntimeError, "duplicate literal_phrase"):
load_keystroke_intents(path)
def test_build_phrase_to_intent_index_uses_grammar_variant(self):
intents = [
load_keystroke_intents_from_inline(
"ctrl+d",
"control d",
"control delta",
"d",
"ctrl",
)
]
literal = build_phrase_to_intent_index(intents, grammar="literal")
nato = build_phrase_to_intent_index(intents, grammar="nato")
self.assertIn("control d", literal)
self.assertIn("control delta", nato)
def test_summarize_decoded_rows_reports_confusions(self):
rows = [
DecodedRow(
wav_path="a.wav",
expected_phrase="control d",
hypothesis="control d",
expected_intent="ctrl+d",
predicted_intent="ctrl+d",
expected_letter="d",
predicted_letter="d",
expected_modifier="ctrl",
predicted_modifier="ctrl",
intent_match=True,
audio_ms=1000.0,
decode_ms=100.0,
rtf=0.1,
out_of_grammar=False,
),
DecodedRow(
wav_path="b.wav",
expected_phrase="control b",
hypothesis="control p",
expected_intent="ctrl+b",
predicted_intent="ctrl+p",
expected_letter="b",
predicted_letter="p",
expected_modifier="ctrl",
predicted_modifier="ctrl",
intent_match=False,
audio_ms=1000.0,
decode_ms=120.0,
rtf=0.12,
out_of_grammar=False,
),
DecodedRow(
wav_path="c.wav",
expected_phrase="control p",
hypothesis="",
expected_intent="ctrl+p",
predicted_intent=None,
expected_letter="p",
predicted_letter=None,
expected_modifier="ctrl",
predicted_modifier=None,
intent_match=False,
audio_ms=1000.0,
decode_ms=90.0,
rtf=0.09,
out_of_grammar=False,
),
]
summary = summarize_decoded_rows(rows)
self.assertEqual(summary["samples"], 3)
self.assertAlmostEqual(summary["intent_accuracy"], 1 / 3, places=6)
self.assertEqual(summary["unknown_count"], 1)
self.assertEqual(summary["intent_confusion"]["ctrl+b"]["ctrl+p"], 1)
self.assertEqual(summary["letter_confusion"]["p"]["__none__"], 1)
self.assertGreaterEqual(len(summary["top_raw_mismatches"]), 1)
def test_run_vosk_keystroke_eval_hard_fails_model_with_out_of_grammar_output(self):
with tempfile.TemporaryDirectory() as td:
root = Path(td)
literal_manifest = root / "literal.jsonl"
nato_manifest = root / "nato.jsonl"
intents_path = root / "intents.json"
output_dir = root / "out"
model_dir = root / "model"
model_dir.mkdir(parents=True, exist_ok=True)
wav_path = root / "sample.wav"
_write_silence_wav(wav_path, samplerate=16000, frames=800)
intents_path.write_text(
json.dumps(
[
{
"intent_id": "ctrl+d",
"literal_phrase": "control d",
"nato_phrase": "control delta",
"letter": "d",
"modifier": "ctrl",
}
]
),
encoding="utf-8",
)
literal_manifest.write_text(
json.dumps({"phrase": "control d", "wav_path": str(wav_path)}) + "\n",
encoding="utf-8",
)
nato_manifest.write_text(
json.dumps({"phrase": "control delta", "wav_path": str(wav_path)}) + "\n",
encoding="utf-8",
)
models_file = root / "models.json"
models_file.write_text(
json.dumps([{"name": "fake", "path": str(model_dir)}]),
encoding="utf-8",
)
class _FakeModel:
def __init__(self, _path: str):
return
class _FakeRecognizer:
def __init__(self, _model, _rate, _grammar_json):
return
def SetWords(self, _enabled: bool):
return
def AcceptWaveform(self, _payload: bytes):
return True
def FinalResult(self):
return json.dumps({"text": "outside hypothesis"})
with patch("vosk_eval._load_vosk_bindings", return_value=(_FakeModel, _FakeRecognizer)):
with self.assertRaisesRegex(RuntimeError, "out-of-grammar"):
run_vosk_keystroke_eval(
literal_manifest=literal_manifest,
nato_manifest=nato_manifest,
intents_path=intents_path,
output_dir=output_dir,
models_file=models_file,
verbose=False,
)
def test_run_vosk_keystroke_eval_resolves_manifest_relative_wav_paths(self):
with tempfile.TemporaryDirectory() as td:
root = Path(td)
manifests_dir = root / "manifests"
samples_dir = manifests_dir / "samples"
samples_dir.mkdir(parents=True, exist_ok=True)
wav_path = samples_dir / "sample.wav"
_write_silence_wav(wav_path, samplerate=16000, frames=800)
literal_manifest = manifests_dir / "literal.jsonl"
nato_manifest = manifests_dir / "nato.jsonl"
intents_path = root / "intents.json"
output_dir = root / "out"
model_dir = root / "model"
model_dir.mkdir(parents=True, exist_ok=True)
intents_path.write_text(
json.dumps(
[
{
"intent_id": "ctrl+d",
"literal_phrase": "control d",
"nato_phrase": "control delta",
"letter": "d",
"modifier": "ctrl",
}
]
),
encoding="utf-8",
)
relative_wav = "samples/sample.wav"
literal_manifest.write_text(
json.dumps({"phrase": "control d", "wav_path": relative_wav}) + "\n",
encoding="utf-8",
)
nato_manifest.write_text(
json.dumps({"phrase": "control delta", "wav_path": relative_wav}) + "\n",
encoding="utf-8",
)
models_file = root / "models.json"
models_file.write_text(
json.dumps([{"name": "fake", "path": str(model_dir)}]),
encoding="utf-8",
)
class _FakeModel:
def __init__(self, _path: str):
return
class _FakeRecognizer:
def __init__(self, _model, _rate, grammar_json):
phrases = json.loads(grammar_json)
self._text = str(phrases[0]) if phrases else ""
def SetWords(self, _enabled: bool):
return
def AcceptWaveform(self, _payload: bytes):
return True
def FinalResult(self):
return json.dumps({"text": self._text})
with patch("vosk_eval._load_vosk_bindings", return_value=(_FakeModel, _FakeRecognizer)):
summary = run_vosk_keystroke_eval(
literal_manifest=literal_manifest,
nato_manifest=nato_manifest,
intents_path=intents_path,
output_dir=output_dir,
models_file=models_file,
verbose=False,
)
self.assertEqual(summary["models"][0]["literal"]["intent_accuracy"], 1.0)
self.assertEqual(summary["models"][0]["nato"]["intent_accuracy"], 1.0)
def load_keystroke_intents_from_inline(
intent_id: str,
literal_phrase: str,
nato_phrase: str,
letter: str,
modifier: str,
):
return load_keystroke_intents_from_json(
[
{
"intent_id": intent_id,
"literal_phrase": literal_phrase,
"nato_phrase": nato_phrase,
"letter": letter,
"modifier": modifier,
}
]
)[0]
def load_keystroke_intents_from_json(payload):
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "intents.json"
path.write_text(json.dumps(payload), encoding="utf-8")
return load_keystroke_intents(path)
def _write_silence_wav(path: Path, *, samplerate: int, frames: int):
path.parent.mkdir(parents=True, exist_ok=True)
with wave.open(str(path), "wb") as handle:
handle.setnchannels(1)
handle.setsampwidth(2)
handle.setframerate(samplerate)
handle.writeframes(b"\x00\x00" * frames)
if __name__ == "__main__":
unittest.main()