327 lines
12 KiB
Python
327 lines
12 KiB
Python
import json
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
import wave
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
from vosk_eval import (
|
|
DecodedRow,
|
|
build_phrase_to_intent_index,
|
|
load_keystroke_intents,
|
|
run_vosk_keystroke_eval,
|
|
summarize_decoded_rows,
|
|
)
|
|
|
|
|
|
class VoskEvalTests(unittest.TestCase):
|
|
def test_load_keystroke_intents_parses_valid_payload(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "intents.json"
|
|
path.write_text(
|
|
json.dumps(
|
|
[
|
|
{
|
|
"intent_id": "ctrl+d",
|
|
"literal_phrase": "control d",
|
|
"nato_phrase": "control delta",
|
|
"letter": "d",
|
|
"modifier": "ctrl",
|
|
}
|
|
]
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
intents = load_keystroke_intents(path)
|
|
self.assertEqual(len(intents), 1)
|
|
self.assertEqual(intents[0].intent_id, "ctrl+d")
|
|
|
|
def test_load_keystroke_intents_rejects_duplicate_literal_phrase(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "intents.json"
|
|
path.write_text(
|
|
json.dumps(
|
|
[
|
|
{
|
|
"intent_id": "ctrl+d",
|
|
"literal_phrase": "control d",
|
|
"nato_phrase": "control delta",
|
|
"letter": "d",
|
|
"modifier": "ctrl",
|
|
},
|
|
{
|
|
"intent_id": "ctrl+b",
|
|
"literal_phrase": "control d",
|
|
"nato_phrase": "control bravo",
|
|
"letter": "b",
|
|
"modifier": "ctrl",
|
|
},
|
|
]
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, "duplicate literal_phrase"):
|
|
load_keystroke_intents(path)
|
|
|
|
def test_build_phrase_to_intent_index_uses_grammar_variant(self):
|
|
intents = [
|
|
load_keystroke_intents_from_inline(
|
|
"ctrl+d",
|
|
"control d",
|
|
"control delta",
|
|
"d",
|
|
"ctrl",
|
|
)
|
|
]
|
|
literal = build_phrase_to_intent_index(intents, grammar="literal")
|
|
nato = build_phrase_to_intent_index(intents, grammar="nato")
|
|
self.assertIn("control d", literal)
|
|
self.assertIn("control delta", nato)
|
|
|
|
def test_summarize_decoded_rows_reports_confusions(self):
|
|
rows = [
|
|
DecodedRow(
|
|
wav_path="a.wav",
|
|
expected_phrase="control d",
|
|
hypothesis="control d",
|
|
expected_intent="ctrl+d",
|
|
predicted_intent="ctrl+d",
|
|
expected_letter="d",
|
|
predicted_letter="d",
|
|
expected_modifier="ctrl",
|
|
predicted_modifier="ctrl",
|
|
intent_match=True,
|
|
audio_ms=1000.0,
|
|
decode_ms=100.0,
|
|
rtf=0.1,
|
|
out_of_grammar=False,
|
|
),
|
|
DecodedRow(
|
|
wav_path="b.wav",
|
|
expected_phrase="control b",
|
|
hypothesis="control p",
|
|
expected_intent="ctrl+b",
|
|
predicted_intent="ctrl+p",
|
|
expected_letter="b",
|
|
predicted_letter="p",
|
|
expected_modifier="ctrl",
|
|
predicted_modifier="ctrl",
|
|
intent_match=False,
|
|
audio_ms=1000.0,
|
|
decode_ms=120.0,
|
|
rtf=0.12,
|
|
out_of_grammar=False,
|
|
),
|
|
DecodedRow(
|
|
wav_path="c.wav",
|
|
expected_phrase="control p",
|
|
hypothesis="",
|
|
expected_intent="ctrl+p",
|
|
predicted_intent=None,
|
|
expected_letter="p",
|
|
predicted_letter=None,
|
|
expected_modifier="ctrl",
|
|
predicted_modifier=None,
|
|
intent_match=False,
|
|
audio_ms=1000.0,
|
|
decode_ms=90.0,
|
|
rtf=0.09,
|
|
out_of_grammar=False,
|
|
),
|
|
]
|
|
summary = summarize_decoded_rows(rows)
|
|
self.assertEqual(summary["samples"], 3)
|
|
self.assertAlmostEqual(summary["intent_accuracy"], 1 / 3, places=6)
|
|
self.assertEqual(summary["unknown_count"], 1)
|
|
self.assertEqual(summary["intent_confusion"]["ctrl+b"]["ctrl+p"], 1)
|
|
self.assertEqual(summary["letter_confusion"]["p"]["__none__"], 1)
|
|
self.assertGreaterEqual(len(summary["top_raw_mismatches"]), 1)
|
|
|
|
def test_run_vosk_keystroke_eval_hard_fails_model_with_out_of_grammar_output(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
root = Path(td)
|
|
literal_manifest = root / "literal.jsonl"
|
|
nato_manifest = root / "nato.jsonl"
|
|
intents_path = root / "intents.json"
|
|
output_dir = root / "out"
|
|
model_dir = root / "model"
|
|
model_dir.mkdir(parents=True, exist_ok=True)
|
|
wav_path = root / "sample.wav"
|
|
_write_silence_wav(wav_path, samplerate=16000, frames=800)
|
|
|
|
intents_path.write_text(
|
|
json.dumps(
|
|
[
|
|
{
|
|
"intent_id": "ctrl+d",
|
|
"literal_phrase": "control d",
|
|
"nato_phrase": "control delta",
|
|
"letter": "d",
|
|
"modifier": "ctrl",
|
|
}
|
|
]
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
literal_manifest.write_text(
|
|
json.dumps({"phrase": "control d", "wav_path": str(wav_path)}) + "\n",
|
|
encoding="utf-8",
|
|
)
|
|
nato_manifest.write_text(
|
|
json.dumps({"phrase": "control delta", "wav_path": str(wav_path)}) + "\n",
|
|
encoding="utf-8",
|
|
)
|
|
models_file = root / "models.json"
|
|
models_file.write_text(
|
|
json.dumps([{"name": "fake", "path": str(model_dir)}]),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
class _FakeModel:
|
|
def __init__(self, _path: str):
|
|
return
|
|
|
|
class _FakeRecognizer:
|
|
def __init__(self, _model, _rate, _grammar_json):
|
|
return
|
|
|
|
def SetWords(self, _enabled: bool):
|
|
return
|
|
|
|
def AcceptWaveform(self, _payload: bytes):
|
|
return True
|
|
|
|
def FinalResult(self):
|
|
return json.dumps({"text": "outside hypothesis"})
|
|
|
|
with patch("vosk_eval._load_vosk_bindings", return_value=(_FakeModel, _FakeRecognizer)):
|
|
with self.assertRaisesRegex(RuntimeError, "out-of-grammar"):
|
|
run_vosk_keystroke_eval(
|
|
literal_manifest=literal_manifest,
|
|
nato_manifest=nato_manifest,
|
|
intents_path=intents_path,
|
|
output_dir=output_dir,
|
|
models_file=models_file,
|
|
verbose=False,
|
|
)
|
|
|
|
def test_run_vosk_keystroke_eval_resolves_manifest_relative_wav_paths(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
root = Path(td)
|
|
manifests_dir = root / "manifests"
|
|
samples_dir = manifests_dir / "samples"
|
|
samples_dir.mkdir(parents=True, exist_ok=True)
|
|
wav_path = samples_dir / "sample.wav"
|
|
_write_silence_wav(wav_path, samplerate=16000, frames=800)
|
|
|
|
literal_manifest = manifests_dir / "literal.jsonl"
|
|
nato_manifest = manifests_dir / "nato.jsonl"
|
|
intents_path = root / "intents.json"
|
|
output_dir = root / "out"
|
|
model_dir = root / "model"
|
|
model_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
intents_path.write_text(
|
|
json.dumps(
|
|
[
|
|
{
|
|
"intent_id": "ctrl+d",
|
|
"literal_phrase": "control d",
|
|
"nato_phrase": "control delta",
|
|
"letter": "d",
|
|
"modifier": "ctrl",
|
|
}
|
|
]
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
relative_wav = "samples/sample.wav"
|
|
literal_manifest.write_text(
|
|
json.dumps({"phrase": "control d", "wav_path": relative_wav}) + "\n",
|
|
encoding="utf-8",
|
|
)
|
|
nato_manifest.write_text(
|
|
json.dumps({"phrase": "control delta", "wav_path": relative_wav}) + "\n",
|
|
encoding="utf-8",
|
|
)
|
|
models_file = root / "models.json"
|
|
models_file.write_text(
|
|
json.dumps([{"name": "fake", "path": str(model_dir)}]),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
class _FakeModel:
|
|
def __init__(self, _path: str):
|
|
return
|
|
|
|
class _FakeRecognizer:
|
|
def __init__(self, _model, _rate, grammar_json):
|
|
phrases = json.loads(grammar_json)
|
|
self._text = str(phrases[0]) if phrases else ""
|
|
|
|
def SetWords(self, _enabled: bool):
|
|
return
|
|
|
|
def AcceptWaveform(self, _payload: bytes):
|
|
return True
|
|
|
|
def FinalResult(self):
|
|
return json.dumps({"text": self._text})
|
|
|
|
with patch("vosk_eval._load_vosk_bindings", return_value=(_FakeModel, _FakeRecognizer)):
|
|
summary = run_vosk_keystroke_eval(
|
|
literal_manifest=literal_manifest,
|
|
nato_manifest=nato_manifest,
|
|
intents_path=intents_path,
|
|
output_dir=output_dir,
|
|
models_file=models_file,
|
|
verbose=False,
|
|
)
|
|
self.assertEqual(summary["models"][0]["literal"]["intent_accuracy"], 1.0)
|
|
self.assertEqual(summary["models"][0]["nato"]["intent_accuracy"], 1.0)
|
|
|
|
|
|
def load_keystroke_intents_from_inline(
|
|
intent_id: str,
|
|
literal_phrase: str,
|
|
nato_phrase: str,
|
|
letter: str,
|
|
modifier: str,
|
|
):
|
|
return load_keystroke_intents_from_json(
|
|
[
|
|
{
|
|
"intent_id": intent_id,
|
|
"literal_phrase": literal_phrase,
|
|
"nato_phrase": nato_phrase,
|
|
"letter": letter,
|
|
"modifier": modifier,
|
|
}
|
|
]
|
|
)[0]
|
|
|
|
|
|
def load_keystroke_intents_from_json(payload):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "intents.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
return load_keystroke_intents(path)
|
|
|
|
|
|
def _write_silence_wav(path: Path, *, samplerate: int, frames: int):
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with wave.open(str(path), "wb") as handle:
|
|
handle.setnchannels(1)
|
|
handle.setsampwidth(2)
|
|
handle.setframerate(samplerate)
|
|
handle.writeframes(b"\x00\x00" * frames)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|