Add Vosk keystroke eval tooling and findings
This commit is contained in:
parent
8c1f7c1e13
commit
510d280b74
15 changed files with 2219 additions and 0 deletions
|
|
@ -141,6 +141,64 @@ class AmanCliTests(unittest.TestCase):
|
|||
with self.assertRaises(SystemExit):
|
||||
aman._parse_cli_args(["bench"])
|
||||
|
||||
def test_parse_cli_args_collect_fixed_phrases_command(self):
|
||||
args = aman._parse_cli_args(
|
||||
[
|
||||
"collect-fixed-phrases",
|
||||
"--phrases-file",
|
||||
"exploration/vosk/fixed_phrases/phrases.txt",
|
||||
"--out-dir",
|
||||
"exploration/vosk/fixed_phrases",
|
||||
"--samples-per-phrase",
|
||||
"10",
|
||||
"--samplerate",
|
||||
"16000",
|
||||
"--channels",
|
||||
"1",
|
||||
"--device",
|
||||
"2",
|
||||
"--session-id",
|
||||
"session-123",
|
||||
"--overwrite-session",
|
||||
"--json",
|
||||
]
|
||||
)
|
||||
self.assertEqual(args.command, "collect-fixed-phrases")
|
||||
self.assertEqual(args.phrases_file, "exploration/vosk/fixed_phrases/phrases.txt")
|
||||
self.assertEqual(args.out_dir, "exploration/vosk/fixed_phrases")
|
||||
self.assertEqual(args.samples_per_phrase, 10)
|
||||
self.assertEqual(args.samplerate, 16000)
|
||||
self.assertEqual(args.channels, 1)
|
||||
self.assertEqual(args.device, "2")
|
||||
self.assertEqual(args.session_id, "session-123")
|
||||
self.assertTrue(args.overwrite_session)
|
||||
self.assertTrue(args.json)
|
||||
|
||||
def test_parse_cli_args_eval_vosk_keystrokes_command(self):
|
||||
args = aman._parse_cli_args(
|
||||
[
|
||||
"eval-vosk-keystrokes",
|
||||
"--literal-manifest",
|
||||
"exploration/vosk/keystrokes/literal/manifest.jsonl",
|
||||
"--nato-manifest",
|
||||
"exploration/vosk/keystrokes/nato/manifest.jsonl",
|
||||
"--intents",
|
||||
"exploration/vosk/keystrokes/intents.json",
|
||||
"--output-dir",
|
||||
"exploration/vosk/keystrokes/eval_runs",
|
||||
"--models-file",
|
||||
"exploration/vosk/keystrokes/models.json",
|
||||
"--json",
|
||||
]
|
||||
)
|
||||
self.assertEqual(args.command, "eval-vosk-keystrokes")
|
||||
self.assertEqual(args.literal_manifest, "exploration/vosk/keystrokes/literal/manifest.jsonl")
|
||||
self.assertEqual(args.nato_manifest, "exploration/vosk/keystrokes/nato/manifest.jsonl")
|
||||
self.assertEqual(args.intents, "exploration/vosk/keystrokes/intents.json")
|
||||
self.assertEqual(args.output_dir, "exploration/vosk/keystrokes/eval_runs")
|
||||
self.assertEqual(args.models_file, "exploration/vosk/keystrokes/models.json")
|
||||
self.assertTrue(args.json)
|
||||
|
||||
def test_parse_cli_args_eval_models_command(self):
|
||||
args = aman._parse_cli_args(
|
||||
["eval-models", "--dataset", "benchmarks/cleanup_dataset.jsonl", "--matrix", "benchmarks/model_matrix.small_first.json"]
|
||||
|
|
@ -379,6 +437,83 @@ class AmanCliTests(unittest.TestCase):
|
|||
payload = json.loads(out.getvalue())
|
||||
self.assertEqual(payload["written_rows"], 4)
|
||||
|
||||
def test_collect_fixed_phrases_command_rejects_non_positive_samples_per_phrase(self):
|
||||
args = aman._parse_cli_args(
|
||||
["collect-fixed-phrases", "--samples-per-phrase", "0"]
|
||||
)
|
||||
exit_code = aman._collect_fixed_phrases_command(args)
|
||||
self.assertEqual(exit_code, 1)
|
||||
|
||||
def test_collect_fixed_phrases_command_json_output(self):
|
||||
args = aman._parse_cli_args(
|
||||
[
|
||||
"collect-fixed-phrases",
|
||||
"--phrases-file",
|
||||
"exploration/vosk/fixed_phrases/phrases.txt",
|
||||
"--out-dir",
|
||||
"exploration/vosk/fixed_phrases",
|
||||
"--samples-per-phrase",
|
||||
"2",
|
||||
"--json",
|
||||
]
|
||||
)
|
||||
out = io.StringIO()
|
||||
fake_result = SimpleNamespace(
|
||||
session_id="session-1",
|
||||
phrases=2,
|
||||
samples_per_phrase=2,
|
||||
samples_target=4,
|
||||
samples_written=4,
|
||||
out_dir=Path("/tmp/out"),
|
||||
manifest_path=Path("/tmp/out/manifest.jsonl"),
|
||||
interrupted=False,
|
||||
)
|
||||
with patch("aman.collect_fixed_phrases", return_value=fake_result), patch("sys.stdout", out):
|
||||
exit_code = aman._collect_fixed_phrases_command(args)
|
||||
self.assertEqual(exit_code, 0)
|
||||
payload = json.loads(out.getvalue())
|
||||
self.assertEqual(payload["session_id"], "session-1")
|
||||
self.assertEqual(payload["samples_written"], 4)
|
||||
self.assertFalse(payload["interrupted"])
|
||||
|
||||
def test_eval_vosk_keystrokes_command_json_output(self):
|
||||
args = aman._parse_cli_args(
|
||||
[
|
||||
"eval-vosk-keystrokes",
|
||||
"--literal-manifest",
|
||||
"exploration/vosk/keystrokes/literal/manifest.jsonl",
|
||||
"--nato-manifest",
|
||||
"exploration/vosk/keystrokes/nato/manifest.jsonl",
|
||||
"--intents",
|
||||
"exploration/vosk/keystrokes/intents.json",
|
||||
"--output-dir",
|
||||
"exploration/vosk/keystrokes/eval_runs",
|
||||
"--json",
|
||||
]
|
||||
)
|
||||
out = io.StringIO()
|
||||
fake_summary = {
|
||||
"models": [
|
||||
{
|
||||
"name": "vosk-small-en-us-0.15",
|
||||
"literal": {"intent_accuracy": 1.0, "latency_ms": {"p50": 30.0}},
|
||||
"nato": {"intent_accuracy": 0.9, "latency_ms": {"p50": 35.0}},
|
||||
}
|
||||
],
|
||||
"winners": {
|
||||
"literal": {"name": "vosk-small-en-us-0.15", "intent_accuracy": 1.0, "latency_p50_ms": 30.0},
|
||||
"nato": {"name": "vosk-small-en-us-0.15", "intent_accuracy": 0.9, "latency_p50_ms": 35.0},
|
||||
"overall": {"name": "vosk-small-en-us-0.15", "avg_intent_accuracy": 0.95, "avg_latency_p50_ms": 32.5},
|
||||
},
|
||||
"output_dir": "exploration/vosk/keystrokes/eval_runs/run-1",
|
||||
}
|
||||
with patch("aman.run_vosk_keystroke_eval", return_value=fake_summary), patch("sys.stdout", out):
|
||||
exit_code = aman._eval_vosk_keystrokes_command(args)
|
||||
self.assertEqual(exit_code, 0)
|
||||
payload = json.loads(out.getvalue())
|
||||
self.assertEqual(payload["models"][0]["name"], "vosk-small-en-us-0.15")
|
||||
self.assertEqual(payload["winners"]["overall"]["name"], "vosk-small-en-us-0.15")
|
||||
|
||||
def test_sync_default_model_command_updates_constants(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
report_path = Path(td) / "latest.json"
|
||||
|
|
|
|||
148
tests/test_vosk_collect.py
Normal file
148
tests/test_vosk_collect.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from vosk_collect import CollectOptions, collect_fixed_phrases, float_to_pcm16, load_phrases, slugify_phrase
|
||||
|
||||
|
||||
class VoskCollectTests(unittest.TestCase):
|
||||
def test_load_phrases_ignores_blank_comment_and_deduplicates(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "phrases.txt"
|
||||
path.write_text(
|
||||
(
|
||||
"# heading\n"
|
||||
"\n"
|
||||
"close app\n"
|
||||
"take a screenshot\n"
|
||||
"close app\n"
|
||||
" \n"
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
phrases = load_phrases(path)
|
||||
self.assertEqual(phrases, ["close app", "take a screenshot"])
|
||||
|
||||
def test_load_phrases_empty_after_filtering_raises(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "phrases.txt"
|
||||
path.write_text("# only comments\n\n", encoding="utf-8")
|
||||
with self.assertRaisesRegex(RuntimeError, "no usable labels"):
|
||||
load_phrases(path)
|
||||
|
||||
def test_slugify_phrase_is_deterministic(self):
|
||||
self.assertEqual(slugify_phrase("Take a Screenshot"), "take_a_screenshot")
|
||||
self.assertEqual(slugify_phrase("close-app!!!"), "close_app")
|
||||
|
||||
def test_float_to_pcm16_clamps_audio_bounds(self):
|
||||
values = np.asarray([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0], dtype=np.float32)
|
||||
out = float_to_pcm16(values)
|
||||
self.assertEqual(out.dtype, np.int16)
|
||||
self.assertGreaterEqual(int(out.min()), -32767)
|
||||
self.assertLessEqual(int(out.max()), 32767)
|
||||
self.assertEqual(int(out[0]), -32767)
|
||||
self.assertEqual(int(out[-1]), 32767)
|
||||
|
||||
def test_collect_fixed_phrases_writes_manifest_and_wavs(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
root = Path(td)
|
||||
phrases_path = root / "phrases.txt"
|
||||
out_dir = root / "dataset"
|
||||
phrases_path.write_text("close app\ntake a screenshot\n", encoding="utf-8")
|
||||
options = CollectOptions(
|
||||
phrases_file=phrases_path,
|
||||
out_dir=out_dir,
|
||||
samples_per_phrase=2,
|
||||
samplerate=16000,
|
||||
channels=1,
|
||||
session_id="session-1",
|
||||
)
|
||||
answers = ["", "", "", ""]
|
||||
|
||||
def fake_input(_prompt: str) -> str:
|
||||
return answers.pop(0)
|
||||
|
||||
def fake_record(_options: CollectOptions, _input_func):
|
||||
audio = np.ones((320, 1), dtype=np.float32) * 0.1
|
||||
return audio, 320, 20
|
||||
|
||||
result = collect_fixed_phrases(
|
||||
options,
|
||||
input_func=fake_input,
|
||||
output_func=lambda _line: None,
|
||||
record_sample_fn=fake_record,
|
||||
)
|
||||
|
||||
self.assertFalse(result.interrupted)
|
||||
self.assertEqual(result.samples_written, 4)
|
||||
manifest = out_dir / "manifest.jsonl"
|
||||
rows = [
|
||||
json.loads(line)
|
||||
for line in manifest.read_text(encoding="utf-8").splitlines()
|
||||
if line.strip()
|
||||
]
|
||||
self.assertEqual(len(rows), 4)
|
||||
required = {
|
||||
"session_id",
|
||||
"timestamp_utc",
|
||||
"phrase",
|
||||
"phrase_slug",
|
||||
"sample_index",
|
||||
"wav_path",
|
||||
"samplerate",
|
||||
"channels",
|
||||
"duration_ms",
|
||||
"frames",
|
||||
"device_spec",
|
||||
"collector_version",
|
||||
}
|
||||
self.assertTrue(required.issubset(rows[0].keys()))
|
||||
wav_paths = [root / Path(row["wav_path"]) for row in rows]
|
||||
for wav_path in wav_paths:
|
||||
self.assertTrue(wav_path.exists(), f"missing wav: {wav_path}")
|
||||
|
||||
def test_collect_fixed_phrases_refuses_existing_session_without_overwrite(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
root = Path(td)
|
||||
phrases_path = root / "phrases.txt"
|
||||
out_dir = root / "dataset"
|
||||
phrases_path.write_text("close app\n", encoding="utf-8")
|
||||
options = CollectOptions(
|
||||
phrases_file=phrases_path,
|
||||
out_dir=out_dir,
|
||||
samples_per_phrase=1,
|
||||
samplerate=16000,
|
||||
channels=1,
|
||||
session_id="session-1",
|
||||
)
|
||||
|
||||
def fake_record(_options: CollectOptions, _input_func):
|
||||
audio = np.ones((160, 1), dtype=np.float32) * 0.2
|
||||
return audio, 160, 10
|
||||
|
||||
collect_fixed_phrases(
|
||||
options,
|
||||
input_func=lambda _prompt: "",
|
||||
output_func=lambda _line: None,
|
||||
record_sample_fn=fake_record,
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, "already has samples"):
|
||||
collect_fixed_phrases(
|
||||
options,
|
||||
input_func=lambda _prompt: "",
|
||||
output_func=lambda _line: None,
|
||||
record_sample_fn=fake_record,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
327
tests/test_vosk_eval.py
Normal file
327
tests/test_vosk_eval.py
Normal file
|
|
@ -0,0 +1,327 @@
|
|||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from vosk_eval import (
|
||||
DecodedRow,
|
||||
build_phrase_to_intent_index,
|
||||
load_keystroke_intents,
|
||||
run_vosk_keystroke_eval,
|
||||
summarize_decoded_rows,
|
||||
)
|
||||
|
||||
|
||||
class VoskEvalTests(unittest.TestCase):
|
||||
def test_load_keystroke_intents_parses_valid_payload(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "intents.json"
|
||||
path.write_text(
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"intent_id": "ctrl+d",
|
||||
"literal_phrase": "control d",
|
||||
"nato_phrase": "control delta",
|
||||
"letter": "d",
|
||||
"modifier": "ctrl",
|
||||
}
|
||||
]
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
intents = load_keystroke_intents(path)
|
||||
self.assertEqual(len(intents), 1)
|
||||
self.assertEqual(intents[0].intent_id, "ctrl+d")
|
||||
|
||||
def test_load_keystroke_intents_rejects_duplicate_literal_phrase(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "intents.json"
|
||||
path.write_text(
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"intent_id": "ctrl+d",
|
||||
"literal_phrase": "control d",
|
||||
"nato_phrase": "control delta",
|
||||
"letter": "d",
|
||||
"modifier": "ctrl",
|
||||
},
|
||||
{
|
||||
"intent_id": "ctrl+b",
|
||||
"literal_phrase": "control d",
|
||||
"nato_phrase": "control bravo",
|
||||
"letter": "b",
|
||||
"modifier": "ctrl",
|
||||
},
|
||||
]
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, "duplicate literal_phrase"):
|
||||
load_keystroke_intents(path)
|
||||
|
||||
def test_build_phrase_to_intent_index_uses_grammar_variant(self):
|
||||
intents = [
|
||||
load_keystroke_intents_from_inline(
|
||||
"ctrl+d",
|
||||
"control d",
|
||||
"control delta",
|
||||
"d",
|
||||
"ctrl",
|
||||
)
|
||||
]
|
||||
literal = build_phrase_to_intent_index(intents, grammar="literal")
|
||||
nato = build_phrase_to_intent_index(intents, grammar="nato")
|
||||
self.assertIn("control d", literal)
|
||||
self.assertIn("control delta", nato)
|
||||
|
||||
def test_summarize_decoded_rows_reports_confusions(self):
|
||||
rows = [
|
||||
DecodedRow(
|
||||
wav_path="a.wav",
|
||||
expected_phrase="control d",
|
||||
hypothesis="control d",
|
||||
expected_intent="ctrl+d",
|
||||
predicted_intent="ctrl+d",
|
||||
expected_letter="d",
|
||||
predicted_letter="d",
|
||||
expected_modifier="ctrl",
|
||||
predicted_modifier="ctrl",
|
||||
intent_match=True,
|
||||
audio_ms=1000.0,
|
||||
decode_ms=100.0,
|
||||
rtf=0.1,
|
||||
out_of_grammar=False,
|
||||
),
|
||||
DecodedRow(
|
||||
wav_path="b.wav",
|
||||
expected_phrase="control b",
|
||||
hypothesis="control p",
|
||||
expected_intent="ctrl+b",
|
||||
predicted_intent="ctrl+p",
|
||||
expected_letter="b",
|
||||
predicted_letter="p",
|
||||
expected_modifier="ctrl",
|
||||
predicted_modifier="ctrl",
|
||||
intent_match=False,
|
||||
audio_ms=1000.0,
|
||||
decode_ms=120.0,
|
||||
rtf=0.12,
|
||||
out_of_grammar=False,
|
||||
),
|
||||
DecodedRow(
|
||||
wav_path="c.wav",
|
||||
expected_phrase="control p",
|
||||
hypothesis="",
|
||||
expected_intent="ctrl+p",
|
||||
predicted_intent=None,
|
||||
expected_letter="p",
|
||||
predicted_letter=None,
|
||||
expected_modifier="ctrl",
|
||||
predicted_modifier=None,
|
||||
intent_match=False,
|
||||
audio_ms=1000.0,
|
||||
decode_ms=90.0,
|
||||
rtf=0.09,
|
||||
out_of_grammar=False,
|
||||
),
|
||||
]
|
||||
summary = summarize_decoded_rows(rows)
|
||||
self.assertEqual(summary["samples"], 3)
|
||||
self.assertAlmostEqual(summary["intent_accuracy"], 1 / 3, places=6)
|
||||
self.assertEqual(summary["unknown_count"], 1)
|
||||
self.assertEqual(summary["intent_confusion"]["ctrl+b"]["ctrl+p"], 1)
|
||||
self.assertEqual(summary["letter_confusion"]["p"]["__none__"], 1)
|
||||
self.assertGreaterEqual(len(summary["top_raw_mismatches"]), 1)
|
||||
|
||||
def test_run_vosk_keystroke_eval_hard_fails_model_with_out_of_grammar_output(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
root = Path(td)
|
||||
literal_manifest = root / "literal.jsonl"
|
||||
nato_manifest = root / "nato.jsonl"
|
||||
intents_path = root / "intents.json"
|
||||
output_dir = root / "out"
|
||||
model_dir = root / "model"
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
wav_path = root / "sample.wav"
|
||||
_write_silence_wav(wav_path, samplerate=16000, frames=800)
|
||||
|
||||
intents_path.write_text(
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"intent_id": "ctrl+d",
|
||||
"literal_phrase": "control d",
|
||||
"nato_phrase": "control delta",
|
||||
"letter": "d",
|
||||
"modifier": "ctrl",
|
||||
}
|
||||
]
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
literal_manifest.write_text(
|
||||
json.dumps({"phrase": "control d", "wav_path": str(wav_path)}) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
nato_manifest.write_text(
|
||||
json.dumps({"phrase": "control delta", "wav_path": str(wav_path)}) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
models_file = root / "models.json"
|
||||
models_file.write_text(
|
||||
json.dumps([{"name": "fake", "path": str(model_dir)}]),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
class _FakeModel:
|
||||
def __init__(self, _path: str):
|
||||
return
|
||||
|
||||
class _FakeRecognizer:
|
||||
def __init__(self, _model, _rate, _grammar_json):
|
||||
return
|
||||
|
||||
def SetWords(self, _enabled: bool):
|
||||
return
|
||||
|
||||
def AcceptWaveform(self, _payload: bytes):
|
||||
return True
|
||||
|
||||
def FinalResult(self):
|
||||
return json.dumps({"text": "outside hypothesis"})
|
||||
|
||||
with patch("vosk_eval._load_vosk_bindings", return_value=(_FakeModel, _FakeRecognizer)):
|
||||
with self.assertRaisesRegex(RuntimeError, "out-of-grammar"):
|
||||
run_vosk_keystroke_eval(
|
||||
literal_manifest=literal_manifest,
|
||||
nato_manifest=nato_manifest,
|
||||
intents_path=intents_path,
|
||||
output_dir=output_dir,
|
||||
models_file=models_file,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
def test_run_vosk_keystroke_eval_resolves_manifest_relative_wav_paths(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
root = Path(td)
|
||||
manifests_dir = root / "manifests"
|
||||
samples_dir = manifests_dir / "samples"
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
wav_path = samples_dir / "sample.wav"
|
||||
_write_silence_wav(wav_path, samplerate=16000, frames=800)
|
||||
|
||||
literal_manifest = manifests_dir / "literal.jsonl"
|
||||
nato_manifest = manifests_dir / "nato.jsonl"
|
||||
intents_path = root / "intents.json"
|
||||
output_dir = root / "out"
|
||||
model_dir = root / "model"
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
intents_path.write_text(
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"intent_id": "ctrl+d",
|
||||
"literal_phrase": "control d",
|
||||
"nato_phrase": "control delta",
|
||||
"letter": "d",
|
||||
"modifier": "ctrl",
|
||||
}
|
||||
]
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
relative_wav = "samples/sample.wav"
|
||||
literal_manifest.write_text(
|
||||
json.dumps({"phrase": "control d", "wav_path": relative_wav}) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
nato_manifest.write_text(
|
||||
json.dumps({"phrase": "control delta", "wav_path": relative_wav}) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
models_file = root / "models.json"
|
||||
models_file.write_text(
|
||||
json.dumps([{"name": "fake", "path": str(model_dir)}]),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
class _FakeModel:
|
||||
def __init__(self, _path: str):
|
||||
return
|
||||
|
||||
class _FakeRecognizer:
|
||||
def __init__(self, _model, _rate, grammar_json):
|
||||
phrases = json.loads(grammar_json)
|
||||
self._text = str(phrases[0]) if phrases else ""
|
||||
|
||||
def SetWords(self, _enabled: bool):
|
||||
return
|
||||
|
||||
def AcceptWaveform(self, _payload: bytes):
|
||||
return True
|
||||
|
||||
def FinalResult(self):
|
||||
return json.dumps({"text": self._text})
|
||||
|
||||
with patch("vosk_eval._load_vosk_bindings", return_value=(_FakeModel, _FakeRecognizer)):
|
||||
summary = run_vosk_keystroke_eval(
|
||||
literal_manifest=literal_manifest,
|
||||
nato_manifest=nato_manifest,
|
||||
intents_path=intents_path,
|
||||
output_dir=output_dir,
|
||||
models_file=models_file,
|
||||
verbose=False,
|
||||
)
|
||||
self.assertEqual(summary["models"][0]["literal"]["intent_accuracy"], 1.0)
|
||||
self.assertEqual(summary["models"][0]["nato"]["intent_accuracy"], 1.0)
|
||||
|
||||
|
||||
def load_keystroke_intents_from_inline(
|
||||
intent_id: str,
|
||||
literal_phrase: str,
|
||||
nato_phrase: str,
|
||||
letter: str,
|
||||
modifier: str,
|
||||
):
|
||||
return load_keystroke_intents_from_json(
|
||||
[
|
||||
{
|
||||
"intent_id": intent_id,
|
||||
"literal_phrase": literal_phrase,
|
||||
"nato_phrase": nato_phrase,
|
||||
"letter": letter,
|
||||
"modifier": modifier,
|
||||
}
|
||||
]
|
||||
)[0]
|
||||
|
||||
|
||||
def load_keystroke_intents_from_json(payload):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "intents.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
return load_keystroke_intents(path)
|
||||
|
||||
|
||||
def _write_silence_wav(path: Path, *, samplerate: int, frames: int):
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with wave.open(str(path), "wb") as handle:
|
||||
handle.setnchannels(1)
|
||||
handle.setsampwidth(2)
|
||||
handle.setframerate(samplerate)
|
||||
handle.writeframes(b"\x00\x00" * frames)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue