Add Vosk keystroke eval tooling and findings
This commit is contained in:
parent
8c1f7c1e13
commit
510d280b74
15 changed files with 2219 additions and 0 deletions
220
src/aman.py
220
src/aman.py
|
|
@ -36,6 +36,22 @@ from recorder import stop_recording as stop_audio_recording
|
|||
from stages.asr_whisper import AsrResult, WhisperAsrStage
|
||||
from stages.editor_llama import LlamaEditorStage
|
||||
from vocabulary import VocabularyEngine
|
||||
from vosk_collect import (
|
||||
DEFAULT_CHANNELS,
|
||||
DEFAULT_FIXED_PHRASES_OUT_DIR,
|
||||
DEFAULT_FIXED_PHRASES_PATH,
|
||||
DEFAULT_SAMPLE_RATE,
|
||||
DEFAULT_SAMPLES_PER_PHRASE,
|
||||
CollectOptions,
|
||||
collect_fixed_phrases,
|
||||
)
|
||||
from vosk_eval import (
|
||||
DEFAULT_KEYSTROKE_EVAL_OUTPUT_DIR,
|
||||
DEFAULT_KEYSTROKE_INTENTS_PATH,
|
||||
DEFAULT_KEYSTROKE_LITERAL_MANIFEST_PATH,
|
||||
DEFAULT_KEYSTROKE_NATO_MANIFEST_PATH,
|
||||
run_vosk_keystroke_eval,
|
||||
)
|
||||
|
||||
|
||||
class State:
|
||||
|
|
@ -981,6 +997,88 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
)
|
||||
bench_parser.add_argument("-v", "--verbose", action="store_true", help="enable verbose logs")
|
||||
|
||||
collect_parser = subparsers.add_parser(
|
||||
"collect-fixed-phrases",
|
||||
help="internal: collect labeled fixed-phrase wav samples for command-stt exploration",
|
||||
)
|
||||
collect_parser.add_argument(
|
||||
"--phrases-file",
|
||||
default=str(DEFAULT_FIXED_PHRASES_PATH),
|
||||
help="path to fixed-phrase labels file (one phrase per line)",
|
||||
)
|
||||
collect_parser.add_argument(
|
||||
"--out-dir",
|
||||
default=str(DEFAULT_FIXED_PHRASES_OUT_DIR),
|
||||
help="output directory for samples/ and manifest.jsonl",
|
||||
)
|
||||
collect_parser.add_argument(
|
||||
"--samples-per-phrase",
|
||||
type=int,
|
||||
default=DEFAULT_SAMPLES_PER_PHRASE,
|
||||
help="number of recordings to capture per phrase",
|
||||
)
|
||||
collect_parser.add_argument(
|
||||
"--samplerate",
|
||||
type=int,
|
||||
default=DEFAULT_SAMPLE_RATE,
|
||||
help="sample rate for captured wav files",
|
||||
)
|
||||
collect_parser.add_argument(
|
||||
"--channels",
|
||||
type=int,
|
||||
default=DEFAULT_CHANNELS,
|
||||
help="number of input channels to capture",
|
||||
)
|
||||
collect_parser.add_argument(
|
||||
"--device",
|
||||
default="",
|
||||
help="optional recording device index or name substring",
|
||||
)
|
||||
collect_parser.add_argument(
|
||||
"--session-id",
|
||||
default="",
|
||||
help="optional session id; autogenerated when omitted",
|
||||
)
|
||||
collect_parser.add_argument(
|
||||
"--overwrite-session",
|
||||
action="store_true",
|
||||
help="allow writing samples for an existing session id",
|
||||
)
|
||||
collect_parser.add_argument("--json", action="store_true", help="print JSON summary output")
|
||||
collect_parser.add_argument("-v", "--verbose", action="store_true", help="enable verbose logs")
|
||||
|
||||
keystroke_eval_parser = subparsers.add_parser(
|
||||
"eval-vosk-keystrokes",
|
||||
help="internal: evaluate keystroke dictation datasets with literal and nato grammars",
|
||||
)
|
||||
keystroke_eval_parser.add_argument(
|
||||
"--literal-manifest",
|
||||
default=str(DEFAULT_KEYSTROKE_LITERAL_MANIFEST_PATH),
|
||||
help="path to literal keystroke manifest.jsonl",
|
||||
)
|
||||
keystroke_eval_parser.add_argument(
|
||||
"--nato-manifest",
|
||||
default=str(DEFAULT_KEYSTROKE_NATO_MANIFEST_PATH),
|
||||
help="path to nato keystroke manifest.jsonl",
|
||||
)
|
||||
keystroke_eval_parser.add_argument(
|
||||
"--intents",
|
||||
default=str(DEFAULT_KEYSTROKE_INTENTS_PATH),
|
||||
help="path to keystroke intents definition json",
|
||||
)
|
||||
keystroke_eval_parser.add_argument(
|
||||
"--output-dir",
|
||||
default=str(DEFAULT_KEYSTROKE_EVAL_OUTPUT_DIR),
|
||||
help="directory for run reports",
|
||||
)
|
||||
keystroke_eval_parser.add_argument(
|
||||
"--models-file",
|
||||
default="",
|
||||
help="optional json array of model specs [{name,path}]",
|
||||
)
|
||||
keystroke_eval_parser.add_argument("--json", action="store_true", help="print JSON summary output")
|
||||
keystroke_eval_parser.add_argument("-v", "--verbose", action="store_true", help="enable verbose logs")
|
||||
|
||||
eval_parser = subparsers.add_parser(
|
||||
"eval-models",
|
||||
help="evaluate model/parameter matrices against expected outputs",
|
||||
|
|
@ -1059,6 +1157,8 @@ def _parse_cli_args(argv: list[str]) -> argparse.Namespace:
|
|||
"doctor",
|
||||
"self-check",
|
||||
"bench",
|
||||
"collect-fixed-phrases",
|
||||
"eval-vosk-keystrokes",
|
||||
"eval-models",
|
||||
"build-heuristic-dataset",
|
||||
"sync-default-model",
|
||||
|
|
@ -1255,6 +1355,120 @@ def _bench_command(args: argparse.Namespace) -> int:
|
|||
return 0
|
||||
|
||||
|
||||
def _collect_fixed_phrases_command(args: argparse.Namespace) -> int:
|
||||
if args.samples_per_phrase < 1:
|
||||
logging.error("collect-fixed-phrases failed: --samples-per-phrase must be >= 1")
|
||||
return 1
|
||||
if args.samplerate < 1:
|
||||
logging.error("collect-fixed-phrases failed: --samplerate must be >= 1")
|
||||
return 1
|
||||
if args.channels < 1:
|
||||
logging.error("collect-fixed-phrases failed: --channels must be >= 1")
|
||||
return 1
|
||||
|
||||
options = CollectOptions(
|
||||
phrases_file=Path(args.phrases_file),
|
||||
out_dir=Path(args.out_dir),
|
||||
samples_per_phrase=args.samples_per_phrase,
|
||||
samplerate=args.samplerate,
|
||||
channels=args.channels,
|
||||
device_spec=(args.device.strip() if args.device.strip() else None),
|
||||
session_id=(args.session_id.strip() if args.session_id.strip() else None),
|
||||
overwrite_session=bool(args.overwrite_session),
|
||||
)
|
||||
try:
|
||||
result = collect_fixed_phrases(options)
|
||||
except Exception as exc:
|
||||
logging.error("collect-fixed-phrases failed: %s", exc)
|
||||
return 1
|
||||
|
||||
summary = {
|
||||
"session_id": result.session_id,
|
||||
"phrases": result.phrases,
|
||||
"samples_per_phrase": result.samples_per_phrase,
|
||||
"samples_target": result.samples_target,
|
||||
"samples_written": result.samples_written,
|
||||
"out_dir": str(result.out_dir),
|
||||
"manifest_path": str(result.manifest_path),
|
||||
"interrupted": result.interrupted,
|
||||
}
|
||||
if args.json:
|
||||
print(json.dumps(summary, indent=2, ensure_ascii=False))
|
||||
else:
|
||||
print(
|
||||
"collect-fixed-phrases summary: "
|
||||
f"session={result.session_id} "
|
||||
f"phrases={result.phrases} "
|
||||
f"samples_per_phrase={result.samples_per_phrase} "
|
||||
f"written={result.samples_written}/{result.samples_target} "
|
||||
f"interrupted={result.interrupted} "
|
||||
f"manifest={result.manifest_path}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def _eval_vosk_keystrokes_command(args: argparse.Namespace) -> int:
|
||||
try:
|
||||
summary = run_vosk_keystroke_eval(
|
||||
literal_manifest=args.literal_manifest,
|
||||
nato_manifest=args.nato_manifest,
|
||||
intents_path=args.intents,
|
||||
output_dir=args.output_dir,
|
||||
models_file=(args.models_file.strip() or None),
|
||||
verbose=args.verbose,
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.error("eval-vosk-keystrokes failed: %s", exc)
|
||||
return 1
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(summary, indent=2, ensure_ascii=False))
|
||||
return 0
|
||||
|
||||
print(
|
||||
"eval-vosk-keystrokes summary: "
|
||||
f"models={len(summary.get('models', []))} "
|
||||
f"output_dir={summary.get('output_dir', '')}"
|
||||
)
|
||||
winners = summary.get("winners", {})
|
||||
literal_winner = winners.get("literal", {})
|
||||
nato_winner = winners.get("nato", {})
|
||||
overall_winner = winners.get("overall", {})
|
||||
if literal_winner:
|
||||
print(
|
||||
"winner[literal]: "
|
||||
f"{literal_winner.get('name', '')} "
|
||||
f"acc={float(literal_winner.get('intent_accuracy', 0.0)):.3f} "
|
||||
f"p50={float(literal_winner.get('latency_p50_ms', 0.0)):.1f}ms"
|
||||
)
|
||||
if nato_winner:
|
||||
print(
|
||||
"winner[nato]: "
|
||||
f"{nato_winner.get('name', '')} "
|
||||
f"acc={float(nato_winner.get('intent_accuracy', 0.0)):.3f} "
|
||||
f"p50={float(nato_winner.get('latency_p50_ms', 0.0)):.1f}ms"
|
||||
)
|
||||
if overall_winner:
|
||||
print(
|
||||
"winner[overall]: "
|
||||
f"{overall_winner.get('name', '')} "
|
||||
f"acc={float(overall_winner.get('avg_intent_accuracy', 0.0)):.3f} "
|
||||
f"p50={float(overall_winner.get('avg_latency_p50_ms', 0.0)):.1f}ms"
|
||||
)
|
||||
|
||||
for model in summary.get("models", []):
|
||||
literal = model.get("literal", {})
|
||||
nato = model.get("nato", {})
|
||||
print(
|
||||
f"{model.get('name', '')}: "
|
||||
f"literal_acc={float(literal.get('intent_accuracy', 0.0)):.3f} "
|
||||
f"literal_p50={float(literal.get('latency_ms', {}).get('p50', 0.0)):.1f}ms "
|
||||
f"nato_acc={float(nato.get('intent_accuracy', 0.0)):.3f} "
|
||||
f"nato_p50={float(nato.get('latency_ms', {}).get('p50', 0.0)):.1f}ms"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def _eval_models_command(args: argparse.Namespace) -> int:
|
||||
try:
|
||||
report = run_model_eval(
|
||||
|
|
@ -1597,6 +1811,12 @@ def main(argv: list[str] | None = None) -> int:
|
|||
if args.command == "bench":
|
||||
_configure_logging(args.verbose)
|
||||
return _bench_command(args)
|
||||
if args.command == "collect-fixed-phrases":
|
||||
_configure_logging(args.verbose)
|
||||
return _collect_fixed_phrases_command(args)
|
||||
if args.command == "eval-vosk-keystrokes":
|
||||
_configure_logging(args.verbose)
|
||||
return _eval_vosk_keystrokes_command(args)
|
||||
if args.command == "eval-models":
|
||||
_configure_logging(args.verbose)
|
||||
return _eval_models_command(args)
|
||||
|
|
|
|||
329
src/vosk_collect.py
Normal file
329
src/vosk_collect.py
Normal file
|
|
@ -0,0 +1,329 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import wave
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from recorder import list_input_devices, resolve_input_device
|
||||
|
||||
|
||||
DEFAULT_FIXED_PHRASES_PATH = Path("exploration/vosk/fixed_phrases/phrases.txt")
|
||||
DEFAULT_FIXED_PHRASES_OUT_DIR = Path("exploration/vosk/fixed_phrases")
|
||||
DEFAULT_SAMPLES_PER_PHRASE = 10
|
||||
DEFAULT_SAMPLE_RATE = 16000
|
||||
DEFAULT_CHANNELS = 1
|
||||
COLLECTOR_VERSION = "fixed-phrases-v1"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CollectOptions:
|
||||
phrases_file: Path = DEFAULT_FIXED_PHRASES_PATH
|
||||
out_dir: Path = DEFAULT_FIXED_PHRASES_OUT_DIR
|
||||
samples_per_phrase: int = DEFAULT_SAMPLES_PER_PHRASE
|
||||
samplerate: int = DEFAULT_SAMPLE_RATE
|
||||
channels: int = DEFAULT_CHANNELS
|
||||
device_spec: str | int | None = None
|
||||
session_id: str | None = None
|
||||
overwrite_session: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CollectResult:
|
||||
session_id: str
|
||||
phrases: int
|
||||
samples_per_phrase: int
|
||||
samples_target: int
|
||||
samples_written: int
|
||||
out_dir: Path
|
||||
manifest_path: Path
|
||||
interrupted: bool
|
||||
|
||||
|
||||
def load_phrases(path: Path | str) -> list[str]:
|
||||
phrases_path = Path(path)
|
||||
if not phrases_path.exists():
|
||||
raise RuntimeError(f"phrases file does not exist: {phrases_path}")
|
||||
rows = phrases_path.read_text(encoding="utf-8").splitlines()
|
||||
phrases: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for raw in rows:
|
||||
text = raw.strip()
|
||||
if not text or text.startswith("#"):
|
||||
continue
|
||||
if text in seen:
|
||||
continue
|
||||
seen.add(text)
|
||||
phrases.append(text)
|
||||
if not phrases:
|
||||
raise RuntimeError(f"phrases file has no usable labels: {phrases_path}")
|
||||
return phrases
|
||||
|
||||
|
||||
def slugify_phrase(value: str) -> str:
|
||||
slug = re.sub(r"[^a-z0-9]+", "_", value.casefold()).strip("_")
|
||||
if not slug:
|
||||
return "phrase"
|
||||
return slug[:64]
|
||||
|
||||
|
||||
def float_to_pcm16(audio: np.ndarray) -> np.ndarray:
|
||||
if audio.size <= 0:
|
||||
return np.zeros((0,), dtype=np.int16)
|
||||
clipped = np.clip(np.asarray(audio, dtype=np.float32), -1.0, 1.0)
|
||||
return np.rint(clipped * 32767.0).astype(np.int16)
|
||||
|
||||
|
||||
def collect_fixed_phrases(
|
||||
options: CollectOptions,
|
||||
*,
|
||||
input_func: Callable[[str], str] = input,
|
||||
output_func: Callable[[str], None] = print,
|
||||
record_sample_fn: Callable[[CollectOptions, Callable[[str], str]], tuple[np.ndarray, int, int]]
|
||||
| None = None,
|
||||
) -> CollectResult:
|
||||
_validate_options(options)
|
||||
phrases = load_phrases(options.phrases_file)
|
||||
slug_map = _build_slug_map(phrases)
|
||||
session_id = _resolve_session_id(options.session_id)
|
||||
out_dir = options.out_dir.expanduser().resolve()
|
||||
samples_root = out_dir / "samples"
|
||||
manifest_path = out_dir / "manifest.jsonl"
|
||||
if not options.overwrite_session and _session_has_samples(samples_root, session_id):
|
||||
raise RuntimeError(
|
||||
f"session '{session_id}' already has samples in {samples_root}; use --overwrite-session"
|
||||
)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
recorder = record_sample_fn or _record_sample_manual_stop
|
||||
target = len(phrases) * options.samples_per_phrase
|
||||
written = 0
|
||||
|
||||
output_func(
|
||||
"collecting fixed-phrase samples: "
|
||||
f"session={session_id} phrases={len(phrases)} samples_per_phrase={options.samples_per_phrase}"
|
||||
)
|
||||
for phrase in phrases:
|
||||
slug = slug_map[phrase]
|
||||
phrase_dir = samples_root / slug
|
||||
phrase_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_func(f'phrase: "{phrase}"')
|
||||
sample_index = 1
|
||||
while sample_index <= options.samples_per_phrase:
|
||||
choice = input_func(
|
||||
f"sample {sample_index}/{options.samples_per_phrase} - press Enter to start "
|
||||
"(or 'q' to stop this session): "
|
||||
).strip()
|
||||
if choice.casefold() in {"q", "quit", "exit"}:
|
||||
output_func("collection interrupted by user")
|
||||
return CollectResult(
|
||||
session_id=session_id,
|
||||
phrases=len(phrases),
|
||||
samples_per_phrase=options.samples_per_phrase,
|
||||
samples_target=target,
|
||||
samples_written=written,
|
||||
out_dir=out_dir,
|
||||
manifest_path=manifest_path,
|
||||
interrupted=True,
|
||||
)
|
||||
audio, frame_count, duration_ms = recorder(options, input_func)
|
||||
if frame_count <= 0:
|
||||
output_func("captured empty sample; retrying the same index")
|
||||
continue
|
||||
wav_path = phrase_dir / f"{session_id}__{sample_index:03d}.wav"
|
||||
_write_wav_file(wav_path, audio, samplerate=options.samplerate, channels=options.channels)
|
||||
row = {
|
||||
"session_id": session_id,
|
||||
"timestamp_utc": _utc_now_iso(),
|
||||
"phrase": phrase,
|
||||
"phrase_slug": slug,
|
||||
"sample_index": sample_index,
|
||||
"wav_path": _path_for_manifest(wav_path),
|
||||
"samplerate": options.samplerate,
|
||||
"channels": options.channels,
|
||||
"duration_ms": duration_ms,
|
||||
"frames": frame_count,
|
||||
"device_spec": options.device_spec,
|
||||
"collector_version": COLLECTOR_VERSION,
|
||||
}
|
||||
_append_manifest_row(manifest_path, row)
|
||||
written += 1
|
||||
output_func(
|
||||
f"saved sample {written}/{target}: {row['wav_path']} "
|
||||
f"(duration_ms={duration_ms}, frames={frame_count})"
|
||||
)
|
||||
sample_index += 1
|
||||
return CollectResult(
|
||||
session_id=session_id,
|
||||
phrases=len(phrases),
|
||||
samples_per_phrase=options.samples_per_phrase,
|
||||
samples_target=target,
|
||||
samples_written=written,
|
||||
out_dir=out_dir,
|
||||
manifest_path=manifest_path,
|
||||
interrupted=False,
|
||||
)
|
||||
|
||||
|
||||
def _record_sample_manual_stop(
|
||||
options: CollectOptions,
|
||||
input_func: Callable[[str], str],
|
||||
) -> tuple[np.ndarray, int, int]:
|
||||
sd = _sounddevice()
|
||||
frames: list[np.ndarray] = []
|
||||
device = _resolve_device_or_raise(options.device_spec)
|
||||
|
||||
def callback(indata, _frames, _time, _status):
|
||||
frames.append(indata.copy())
|
||||
|
||||
stream = sd.InputStream(
|
||||
samplerate=options.samplerate,
|
||||
channels=options.channels,
|
||||
dtype="float32",
|
||||
device=device,
|
||||
callback=callback,
|
||||
)
|
||||
stream.start()
|
||||
try:
|
||||
input_func("recording... press Enter to stop: ")
|
||||
finally:
|
||||
stop_error = None
|
||||
try:
|
||||
stream.stop()
|
||||
except Exception as exc: # pragma: no cover - exercised via recorder tests, hard to force here
|
||||
stop_error = exc
|
||||
try:
|
||||
stream.close()
|
||||
except Exception as exc: # pragma: no cover - exercised via recorder tests, hard to force here
|
||||
if stop_error is None:
|
||||
raise
|
||||
raise RuntimeError(f"recording stop failed ({stop_error}) and close also failed ({exc})") from exc
|
||||
if stop_error is not None:
|
||||
raise stop_error
|
||||
|
||||
audio = _flatten_frames(frames, channels=options.channels)
|
||||
frame_count = int(audio.shape[0]) if audio.ndim == 2 else int(audio.shape[0])
|
||||
duration_ms = int(round((frame_count / float(options.samplerate)) * 1000.0))
|
||||
return audio, frame_count, duration_ms
|
||||
|
||||
|
||||
def _validate_options(options: CollectOptions) -> None:
|
||||
if options.samples_per_phrase < 1:
|
||||
raise RuntimeError("samples_per_phrase must be >= 1")
|
||||
if options.samplerate < 1:
|
||||
raise RuntimeError("samplerate must be >= 1")
|
||||
if options.channels < 1:
|
||||
raise RuntimeError("channels must be >= 1")
|
||||
|
||||
|
||||
def _resolve_session_id(value: str | None) -> str:
|
||||
text = (value or "").strip()
|
||||
if text:
|
||||
return text
|
||||
return datetime.now(timezone.utc).strftime("session-%Y%m%dT%H%M%SZ")
|
||||
|
||||
|
||||
def _build_slug_map(phrases: list[str]) -> dict[str, str]:
|
||||
out: dict[str, str] = {}
|
||||
used: dict[str, str] = {}
|
||||
for phrase in phrases:
|
||||
slug = slugify_phrase(phrase)
|
||||
previous = used.get(slug)
|
||||
if previous is not None and previous != phrase:
|
||||
raise RuntimeError(
|
||||
f'phrases "{previous}" and "{phrase}" map to the same slug "{slug}"'
|
||||
)
|
||||
used[slug] = phrase
|
||||
out[phrase] = slug
|
||||
return out
|
||||
|
||||
|
||||
def _session_has_samples(samples_root: Path, session_id: str) -> bool:
|
||||
if not samples_root.exists():
|
||||
return False
|
||||
pattern = f"{session_id}__*.wav"
|
||||
return any(samples_root.rglob(pattern))
|
||||
|
||||
|
||||
def _flatten_frames(frames: list[np.ndarray], *, channels: int) -> np.ndarray:
|
||||
if not frames:
|
||||
return np.zeros((0, channels), dtype=np.float32)
|
||||
data = np.concatenate(frames, axis=0)
|
||||
if data.ndim == 1:
|
||||
data = data.reshape(-1, 1)
|
||||
if data.ndim != 2:
|
||||
raise RuntimeError(f"unexpected recorded frame shape: {data.shape}")
|
||||
return np.asarray(data, dtype=np.float32)
|
||||
|
||||
|
||||
def _write_wav_file(path: Path, audio: np.ndarray, *, samplerate: int, channels: int) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pcm = float_to_pcm16(audio)
|
||||
with wave.open(str(path), "wb") as handle:
|
||||
handle.setnchannels(channels)
|
||||
handle.setsampwidth(2)
|
||||
handle.setframerate(samplerate)
|
||||
handle.writeframes(pcm.tobytes())
|
||||
|
||||
|
||||
def _append_manifest_row(manifest_path: Path, row: dict[str, object]) -> None:
|
||||
manifest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with manifest_path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(f"{json.dumps(row, ensure_ascii=False)}\n")
|
||||
handle.flush()
|
||||
|
||||
|
||||
def _path_for_manifest(path: Path) -> str:
|
||||
try:
|
||||
rel = path.resolve().relative_to(Path.cwd().resolve())
|
||||
return rel.as_posix()
|
||||
except Exception:
|
||||
return path.as_posix()
|
||||
|
||||
|
||||
def _utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat(timespec="milliseconds").replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _resolve_device_or_raise(spec: str | int | None) -> int | None:
|
||||
device = resolve_input_device(spec)
|
||||
if not _is_explicit_device_spec(spec):
|
||||
return device
|
||||
if device is not None:
|
||||
return device
|
||||
raise RuntimeError(
|
||||
f"input device '{spec}' did not match any input device; available: {_available_inputs_summary()}"
|
||||
)
|
||||
|
||||
|
||||
def _is_explicit_device_spec(spec: str | int | None) -> bool:
|
||||
if spec is None:
|
||||
return False
|
||||
if isinstance(spec, int):
|
||||
return True
|
||||
return bool(str(spec).strip())
|
||||
|
||||
|
||||
def _available_inputs_summary(limit: int = 8) -> str:
|
||||
devices = list_input_devices()
|
||||
if not devices:
|
||||
return "<none>"
|
||||
items = [f"{d['index']}:{d['name']}" for d in devices[:limit]]
|
||||
if len(devices) > limit:
|
||||
items.append("...")
|
||||
return ", ".join(items)
|
||||
|
||||
|
||||
def _sounddevice():
|
||||
try:
|
||||
import sounddevice as sd # type: ignore[import-not-found]
|
||||
except ModuleNotFoundError as exc:
|
||||
raise RuntimeError(
|
||||
"sounddevice is not installed; install dependencies with `uv sync --extra x11`"
|
||||
) from exc
|
||||
return sd
|
||||
670
src/vosk_eval.py
Normal file
670
src/vosk_eval.py
Normal file
|
|
@ -0,0 +1,670 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import statistics
|
||||
import time
|
||||
import wave
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
|
||||
DEFAULT_KEYSTROKE_INTENTS_PATH = Path("exploration/vosk/keystrokes/intents.json")
|
||||
DEFAULT_KEYSTROKE_LITERAL_MANIFEST_PATH = Path("exploration/vosk/keystrokes/literal/manifest.jsonl")
|
||||
DEFAULT_KEYSTROKE_NATO_MANIFEST_PATH = Path("exploration/vosk/keystrokes/nato/manifest.jsonl")
|
||||
DEFAULT_KEYSTROKE_EVAL_OUTPUT_DIR = Path("exploration/vosk/keystrokes/eval_runs")
|
||||
DEFAULT_KEYSTROKE_MODELS = [
|
||||
{
|
||||
"name": "vosk-small-en-us-0.15",
|
||||
"path": "/tmp/vosk-models/vosk-model-small-en-us-0.15",
|
||||
},
|
||||
{
|
||||
"name": "vosk-en-us-0.22-lgraph",
|
||||
"path": "/tmp/vosk-models/vosk-model-en-us-0.22-lgraph",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IntentSpec:
|
||||
intent_id: str
|
||||
literal_phrase: str
|
||||
nato_phrase: str
|
||||
letter: str
|
||||
modifier: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelSpec:
|
||||
name: str
|
||||
path: Path
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ManifestSample:
|
||||
wav_path: Path
|
||||
expected_phrase: str
|
||||
expected_intent: str
|
||||
expected_letter: str
|
||||
expected_modifier: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodedRow:
|
||||
wav_path: str
|
||||
expected_phrase: str
|
||||
hypothesis: str
|
||||
expected_intent: str
|
||||
predicted_intent: str | None
|
||||
expected_letter: str
|
||||
predicted_letter: str | None
|
||||
expected_modifier: str
|
||||
predicted_modifier: str | None
|
||||
intent_match: bool
|
||||
audio_ms: float
|
||||
decode_ms: float
|
||||
rtf: float | None
|
||||
out_of_grammar: bool
|
||||
|
||||
|
||||
def run_vosk_keystroke_eval(
|
||||
*,
|
||||
literal_manifest: str | Path,
|
||||
nato_manifest: str | Path,
|
||||
intents_path: str | Path,
|
||||
output_dir: str | Path,
|
||||
models_file: str | Path | None = None,
|
||||
verbose: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
intents = load_keystroke_intents(intents_path)
|
||||
literal_index = build_phrase_to_intent_index(intents, grammar="literal")
|
||||
nato_index = build_phrase_to_intent_index(intents, grammar="nato")
|
||||
|
||||
literal_samples = load_manifest_samples(literal_manifest, literal_index)
|
||||
nato_samples = load_manifest_samples(nato_manifest, nato_index)
|
||||
model_specs = load_model_specs(models_file)
|
||||
|
||||
if not model_specs:
|
||||
raise RuntimeError("no model specs provided")
|
||||
|
||||
run_id = datetime.now(timezone.utc).strftime("run-%Y%m%dT%H%M%SZ")
|
||||
base_output_dir = Path(output_dir)
|
||||
run_output_dir = (base_output_dir / run_id).resolve()
|
||||
run_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
summary: dict[str, Any] = {
|
||||
"report_version": 1,
|
||||
"run_id": run_id,
|
||||
"literal_manifest": str(Path(literal_manifest)),
|
||||
"nato_manifest": str(Path(nato_manifest)),
|
||||
"intents_path": str(Path(intents_path)),
|
||||
"models_file": str(models_file) if models_file else "",
|
||||
"models": [],
|
||||
"skipped_models": [],
|
||||
"winners": {},
|
||||
"cross_grammar_delta": [],
|
||||
"output_dir": str(run_output_dir),
|
||||
}
|
||||
|
||||
for model in model_specs:
|
||||
if not model.path.exists():
|
||||
summary["skipped_models"].append(
|
||||
{
|
||||
"name": model.name,
|
||||
"path": str(model.path),
|
||||
"reason": "model path does not exist",
|
||||
}
|
||||
)
|
||||
continue
|
||||
model_report = _evaluate_model(
|
||||
model,
|
||||
literal_samples=literal_samples,
|
||||
nato_samples=nato_samples,
|
||||
literal_index=literal_index,
|
||||
nato_index=nato_index,
|
||||
output_dir=run_output_dir,
|
||||
verbose=verbose,
|
||||
)
|
||||
summary["models"].append(model_report)
|
||||
|
||||
if not summary["models"]:
|
||||
raise RuntimeError("no models were successfully evaluated")
|
||||
|
||||
summary["winners"] = _pick_winners(summary["models"])
|
||||
summary["cross_grammar_delta"] = _cross_grammar_delta(summary["models"])
|
||||
summary_path = run_output_dir / "summary.json"
|
||||
summary["summary_path"] = str(summary_path)
|
||||
summary_path.write_text(f"{json.dumps(summary, indent=2, ensure_ascii=False)}\n", encoding="utf-8")
|
||||
return summary
|
||||
|
||||
|
||||
def load_keystroke_intents(path: str | Path) -> list[IntentSpec]:
|
||||
payload = _load_json(path, description="intents")
|
||||
if not isinstance(payload, list):
|
||||
raise RuntimeError("intents file must be a JSON array")
|
||||
|
||||
intents: list[IntentSpec] = []
|
||||
seen_ids: set[str] = set()
|
||||
seen_literal: set[str] = set()
|
||||
seen_nato: set[str] = set()
|
||||
for idx, item in enumerate(payload):
|
||||
if not isinstance(item, dict):
|
||||
raise RuntimeError(f"intents[{idx}] must be an object")
|
||||
intent_id = str(item.get("intent_id", "")).strip()
|
||||
literal_phrase = str(item.get("literal_phrase", "")).strip()
|
||||
nato_phrase = str(item.get("nato_phrase", "")).strip()
|
||||
letter = str(item.get("letter", "")).strip().casefold()
|
||||
modifier = str(item.get("modifier", "")).strip().casefold()
|
||||
if not intent_id:
|
||||
raise RuntimeError(f"intents[{idx}].intent_id is required")
|
||||
if not literal_phrase:
|
||||
raise RuntimeError(f"intents[{idx}].literal_phrase is required")
|
||||
if not nato_phrase:
|
||||
raise RuntimeError(f"intents[{idx}].nato_phrase is required")
|
||||
if letter not in {"d", "b", "p"}:
|
||||
raise RuntimeError(f"intents[{idx}].letter must be one of d/b/p")
|
||||
if modifier not in {"ctrl", "shift", "ctrl+shift"}:
|
||||
raise RuntimeError(f"intents[{idx}].modifier must be ctrl/shift/ctrl+shift")
|
||||
|
||||
norm_id = _norm(intent_id)
|
||||
norm_literal = _norm(literal_phrase)
|
||||
norm_nato = _norm(nato_phrase)
|
||||
if norm_id in seen_ids:
|
||||
raise RuntimeError(f"duplicate intent_id '{intent_id}'")
|
||||
if norm_literal in seen_literal:
|
||||
raise RuntimeError(f"duplicate literal_phrase '{literal_phrase}'")
|
||||
if norm_nato in seen_nato:
|
||||
raise RuntimeError(f"duplicate nato_phrase '{nato_phrase}'")
|
||||
seen_ids.add(norm_id)
|
||||
seen_literal.add(norm_literal)
|
||||
seen_nato.add(norm_nato)
|
||||
intents.append(
|
||||
IntentSpec(
|
||||
intent_id=intent_id,
|
||||
literal_phrase=literal_phrase,
|
||||
nato_phrase=nato_phrase,
|
||||
letter=letter,
|
||||
modifier=modifier,
|
||||
)
|
||||
)
|
||||
|
||||
if not intents:
|
||||
raise RuntimeError("intents file is empty")
|
||||
return intents
|
||||
|
||||
|
||||
def build_phrase_to_intent_index(
|
||||
intents: list[IntentSpec],
|
||||
*,
|
||||
grammar: str,
|
||||
) -> dict[str, IntentSpec]:
|
||||
if grammar not in {"literal", "nato"}:
|
||||
raise RuntimeError(f"unsupported grammar type '{grammar}'")
|
||||
out: dict[str, IntentSpec] = {}
|
||||
for spec in intents:
|
||||
phrase = spec.literal_phrase if grammar == "literal" else spec.nato_phrase
|
||||
key = _norm(phrase)
|
||||
if key in out:
|
||||
raise RuntimeError(f"duplicate phrase mapping for grammar {grammar}: '{phrase}'")
|
||||
out[key] = spec
|
||||
return out
|
||||
|
||||
|
||||
def load_manifest_samples(
|
||||
path: str | Path,
|
||||
phrase_index: dict[str, IntentSpec],
|
||||
) -> list[ManifestSample]:
|
||||
manifest_path = Path(path)
|
||||
if not manifest_path.exists():
|
||||
raise RuntimeError(f"manifest file does not exist: {manifest_path}")
|
||||
rows = manifest_path.read_text(encoding="utf-8").splitlines()
|
||||
samples: list[ManifestSample] = []
|
||||
for idx, raw in enumerate(rows, start=1):
|
||||
text = raw.strip()
|
||||
if not text:
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"invalid manifest json at line {idx}: {exc}") from exc
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError(f"manifest line {idx} must be an object")
|
||||
phrase = str(payload.get("phrase", "")).strip()
|
||||
wav_path_raw = str(payload.get("wav_path", "")).strip()
|
||||
if not phrase:
|
||||
raise RuntimeError(f"manifest line {idx} missing phrase")
|
||||
if not wav_path_raw:
|
||||
raise RuntimeError(f"manifest line {idx} missing wav_path")
|
||||
|
||||
spec = phrase_index.get(_norm(phrase))
|
||||
if spec is None:
|
||||
raise RuntimeError(
|
||||
f"manifest line {idx} phrase '{phrase}' does not exist in grammar index"
|
||||
)
|
||||
wav_path = _resolve_manifest_wav_path(
|
||||
wav_path_raw,
|
||||
manifest_dir=manifest_path.parent,
|
||||
)
|
||||
if not wav_path.exists():
|
||||
raise RuntimeError(f"manifest line {idx} wav_path does not exist: {wav_path}")
|
||||
samples.append(
|
||||
ManifestSample(
|
||||
wav_path=wav_path,
|
||||
expected_phrase=phrase,
|
||||
expected_intent=spec.intent_id,
|
||||
expected_letter=spec.letter,
|
||||
expected_modifier=spec.modifier,
|
||||
)
|
||||
)
|
||||
if not samples:
|
||||
raise RuntimeError(f"manifest has no samples: {manifest_path}")
|
||||
return samples
|
||||
|
||||
|
||||
def load_model_specs(path: str | Path | None) -> list[ModelSpec]:
|
||||
if path is None:
|
||||
return [
|
||||
ModelSpec(
|
||||
name=str(row["name"]),
|
||||
path=Path(str(row["path"])).expanduser().resolve(),
|
||||
)
|
||||
for row in DEFAULT_KEYSTROKE_MODELS
|
||||
]
|
||||
models_path = Path(path)
|
||||
payload = _load_json(models_path, description="model specs")
|
||||
if not isinstance(payload, list):
|
||||
raise RuntimeError("models file must be a JSON array")
|
||||
specs: list[ModelSpec] = []
|
||||
seen: set[str] = set()
|
||||
for idx, item in enumerate(payload):
|
||||
if not isinstance(item, dict):
|
||||
raise RuntimeError(f"models[{idx}] must be an object")
|
||||
name = str(item.get("name", "")).strip()
|
||||
path_raw = str(item.get("path", "")).strip()
|
||||
if not name:
|
||||
raise RuntimeError(f"models[{idx}].name is required")
|
||||
if not path_raw:
|
||||
raise RuntimeError(f"models[{idx}].path is required")
|
||||
key = _norm(name)
|
||||
if key in seen:
|
||||
raise RuntimeError(f"duplicate model name '{name}' in models file")
|
||||
seen.add(key)
|
||||
model_path = Path(path_raw).expanduser()
|
||||
if not model_path.is_absolute():
|
||||
model_path = (models_path.parent / model_path).resolve()
|
||||
else:
|
||||
model_path = model_path.resolve()
|
||||
specs.append(ModelSpec(name=name, path=model_path))
|
||||
return specs
|
||||
|
||||
|
||||
def summarize_decoded_rows(rows: list[DecodedRow]) -> dict[str, Any]:
|
||||
if not rows:
|
||||
return {
|
||||
"samples": 0,
|
||||
"intent_match_count": 0,
|
||||
"intent_accuracy": 0.0,
|
||||
"unknown_count": 0,
|
||||
"unknown_rate": 0.0,
|
||||
"out_of_grammar_count": 0,
|
||||
"latency_ms": {"avg": 0.0, "p50": 0.0, "p95": 0.0},
|
||||
"rtf_avg": 0.0,
|
||||
"intent_breakdown": {},
|
||||
"modifier_breakdown": {},
|
||||
"letter_breakdown": {},
|
||||
"intent_confusion": {},
|
||||
"letter_confusion": {},
|
||||
"top_raw_mismatches": [],
|
||||
}
|
||||
sample_count = len(rows)
|
||||
intent_match_count = sum(1 for row in rows if row.intent_match)
|
||||
unknown_count = sum(1 for row in rows if row.predicted_intent is None)
|
||||
out_of_grammar_count = sum(1 for row in rows if row.out_of_grammar)
|
||||
|
||||
decode_values = sorted(row.decode_ms for row in rows)
|
||||
p50 = statistics.median(decode_values)
|
||||
p95 = decode_values[int(round((len(decode_values) - 1) * 0.95))]
|
||||
|
||||
rtf_values = [row.rtf for row in rows if row.rtf is not None]
|
||||
rtf_avg = float(sum(rtf_values) / len(rtf_values)) if rtf_values else 0.0
|
||||
|
||||
intent_breakdown: dict[str, dict[str, float | int]] = {}
|
||||
modifier_breakdown: dict[str, dict[str, float | int]] = {}
|
||||
letter_breakdown: dict[str, dict[str, float | int]] = {}
|
||||
intent_confusion: dict[str, dict[str, int]] = {}
|
||||
letter_confusion: dict[str, dict[str, int]] = {}
|
||||
raw_mismatch_counts: dict[tuple[str, str], int] = {}
|
||||
|
||||
for row in rows:
|
||||
_inc_metric_bucket(intent_breakdown, row.expected_intent, row.intent_match)
|
||||
_inc_metric_bucket(modifier_breakdown, row.expected_modifier, row.intent_match)
|
||||
_inc_metric_bucket(letter_breakdown, row.expected_letter, row.intent_match)
|
||||
predicted_intent = row.predicted_intent if row.predicted_intent else "__none__"
|
||||
predicted_letter = row.predicted_letter if row.predicted_letter else "__none__"
|
||||
_inc_confusion(intent_confusion, row.expected_intent, predicted_intent)
|
||||
_inc_confusion(letter_confusion, row.expected_letter, predicted_letter)
|
||||
if not row.intent_match:
|
||||
key = (row.expected_phrase, row.hypothesis)
|
||||
raw_mismatch_counts[key] = raw_mismatch_counts.get(key, 0) + 1
|
||||
|
||||
_finalize_metric_buckets(intent_breakdown)
|
||||
_finalize_metric_buckets(modifier_breakdown)
|
||||
_finalize_metric_buckets(letter_breakdown)
|
||||
|
||||
top_raw_mismatches = [
|
||||
{
|
||||
"expected_phrase": expected_phrase,
|
||||
"hypothesis": hypothesis,
|
||||
"count": count,
|
||||
}
|
||||
for (expected_phrase, hypothesis), count in sorted(
|
||||
raw_mismatch_counts.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)[:20]
|
||||
]
|
||||
|
||||
return {
|
||||
"samples": sample_count,
|
||||
"intent_match_count": intent_match_count,
|
||||
"intent_accuracy": intent_match_count / sample_count,
|
||||
"unknown_count": unknown_count,
|
||||
"unknown_rate": unknown_count / sample_count,
|
||||
"out_of_grammar_count": out_of_grammar_count,
|
||||
"latency_ms": {
|
||||
"avg": sum(decode_values) / sample_count,
|
||||
"p50": p50,
|
||||
"p95": p95,
|
||||
},
|
||||
"rtf_avg": rtf_avg,
|
||||
"intent_breakdown": intent_breakdown,
|
||||
"modifier_breakdown": modifier_breakdown,
|
||||
"letter_breakdown": letter_breakdown,
|
||||
"intent_confusion": intent_confusion,
|
||||
"letter_confusion": letter_confusion,
|
||||
"top_raw_mismatches": top_raw_mismatches,
|
||||
}
|
||||
|
||||
|
||||
def _evaluate_model(
|
||||
model: ModelSpec,
|
||||
*,
|
||||
literal_samples: list[ManifestSample],
|
||||
nato_samples: list[ManifestSample],
|
||||
literal_index: dict[str, IntentSpec],
|
||||
nato_index: dict[str, IntentSpec],
|
||||
output_dir: Path,
|
||||
verbose: bool,
|
||||
) -> dict[str, Any]:
|
||||
_ModelClass, recognizer_factory = _load_vosk_bindings()
|
||||
started = time.perf_counter()
|
||||
vosk_model = _ModelClass(str(model.path))
|
||||
model_load_ms = (time.perf_counter() - started) * 1000.0
|
||||
|
||||
grammar_reports: dict[str, Any] = {}
|
||||
for grammar, samples, index in (
|
||||
("literal", literal_samples, literal_index),
|
||||
("nato", nato_samples, nato_index),
|
||||
):
|
||||
phrases = _phrases_for_grammar(index.values(), grammar=grammar)
|
||||
norm_allowed = {_norm(item) for item in phrases}
|
||||
decoded: list[DecodedRow] = []
|
||||
for sample in samples:
|
||||
hypothesis, audio_ms, decode_ms = _decode_sample_with_grammar(
|
||||
recognizer_factory,
|
||||
vosk_model,
|
||||
sample.wav_path,
|
||||
phrases,
|
||||
)
|
||||
hyp_norm = _norm(hypothesis)
|
||||
spec = index.get(hyp_norm)
|
||||
predicted_intent = spec.intent_id if spec is not None else None
|
||||
predicted_letter = spec.letter if spec is not None else None
|
||||
predicted_modifier = spec.modifier if spec is not None else None
|
||||
out_of_grammar = bool(hyp_norm) and hyp_norm not in norm_allowed
|
||||
decoded.append(
|
||||
DecodedRow(
|
||||
wav_path=str(sample.wav_path),
|
||||
expected_phrase=sample.expected_phrase,
|
||||
hypothesis=hypothesis,
|
||||
expected_intent=sample.expected_intent,
|
||||
predicted_intent=predicted_intent,
|
||||
expected_letter=sample.expected_letter,
|
||||
predicted_letter=predicted_letter,
|
||||
expected_modifier=sample.expected_modifier,
|
||||
predicted_modifier=predicted_modifier,
|
||||
intent_match=sample.expected_intent == predicted_intent,
|
||||
audio_ms=audio_ms,
|
||||
decode_ms=decode_ms,
|
||||
rtf=(decode_ms / audio_ms) if audio_ms > 0 else None,
|
||||
out_of_grammar=out_of_grammar,
|
||||
)
|
||||
)
|
||||
|
||||
report = summarize_decoded_rows(decoded)
|
||||
if report["out_of_grammar_count"] > 0:
|
||||
raise RuntimeError(
|
||||
f"model '{model.name}' produced {report['out_of_grammar_count']} out-of-grammar "
|
||||
f"hypotheses for grammar '{grammar}'"
|
||||
)
|
||||
|
||||
sample_path = output_dir / f"{grammar}__{_safe_filename(model.name)}__samples.jsonl"
|
||||
_write_samples_report(sample_path, decoded)
|
||||
report["samples_report"] = str(sample_path)
|
||||
if verbose:
|
||||
print(
|
||||
f"vosk-eval[{model.name}][{grammar}]: "
|
||||
f"acc={report['intent_accuracy']:.3f} "
|
||||
f"p50={report['latency_ms']['p50']:.1f}ms "
|
||||
f"p95={report['latency_ms']['p95']:.1f}ms"
|
||||
)
|
||||
grammar_reports[grammar] = report
|
||||
|
||||
literal_acc = float(grammar_reports["literal"]["intent_accuracy"])
|
||||
nato_acc = float(grammar_reports["nato"]["intent_accuracy"])
|
||||
literal_p50 = float(grammar_reports["literal"]["latency_ms"]["p50"])
|
||||
nato_p50 = float(grammar_reports["nato"]["latency_ms"]["p50"])
|
||||
overall_accuracy = (literal_acc + nato_acc) / 2.0
|
||||
overall_latency_p50 = (literal_p50 + nato_p50) / 2.0
|
||||
return {
|
||||
"name": model.name,
|
||||
"path": str(model.path),
|
||||
"model_load_ms": model_load_ms,
|
||||
"literal": grammar_reports["literal"],
|
||||
"nato": grammar_reports["nato"],
|
||||
"overall": {
|
||||
"avg_intent_accuracy": overall_accuracy,
|
||||
"avg_latency_p50_ms": overall_latency_p50,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _decode_sample_with_grammar(
|
||||
recognizer_factory: Callable[[Any, float, str], Any],
|
||||
vosk_model: Any,
|
||||
wav_path: Path,
|
||||
phrases: list[str],
|
||||
) -> tuple[str, float, float]:
|
||||
with wave.open(str(wav_path), "rb") as handle:
|
||||
channels = handle.getnchannels()
|
||||
sample_width = handle.getsampwidth()
|
||||
sample_rate = float(handle.getframerate())
|
||||
frame_count = handle.getnframes()
|
||||
payload = handle.readframes(frame_count)
|
||||
if channels != 1 or sample_width != 2:
|
||||
raise RuntimeError(
|
||||
f"unsupported wav format for {wav_path}: channels={channels} sample_width={sample_width}"
|
||||
)
|
||||
recognizer = recognizer_factory(vosk_model, sample_rate, json.dumps(phrases))
|
||||
if hasattr(recognizer, "SetWords"):
|
||||
recognizer.SetWords(False)
|
||||
started = time.perf_counter()
|
||||
recognizer.AcceptWaveform(payload)
|
||||
result = recognizer.FinalResult()
|
||||
decode_ms = (time.perf_counter() - started) * 1000.0
|
||||
audio_ms = (frame_count / sample_rate) * 1000.0
|
||||
try:
|
||||
text = str(json.loads(result).get("text", "")).strip()
|
||||
except Exception:
|
||||
text = ""
|
||||
return text, audio_ms, decode_ms
|
||||
|
||||
|
||||
def _pick_winners(models: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
winners: dict[str, Any] = {}
|
||||
for grammar in ("literal", "nato"):
|
||||
ranked = sorted(
|
||||
models,
|
||||
key=lambda item: (
|
||||
float(item[grammar]["intent_accuracy"]),
|
||||
-float(item[grammar]["latency_ms"]["p50"]),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
best = ranked[0]
|
||||
winners[grammar] = {
|
||||
"name": best["name"],
|
||||
"intent_accuracy": best[grammar]["intent_accuracy"],
|
||||
"latency_p50_ms": best[grammar]["latency_ms"]["p50"],
|
||||
}
|
||||
ranked_overall = sorted(
|
||||
models,
|
||||
key=lambda item: (
|
||||
float(item["overall"]["avg_intent_accuracy"]),
|
||||
-float(item["overall"]["avg_latency_p50_ms"]),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
winners["overall"] = {
|
||||
"name": ranked_overall[0]["name"],
|
||||
"avg_intent_accuracy": ranked_overall[0]["overall"]["avg_intent_accuracy"],
|
||||
"avg_latency_p50_ms": ranked_overall[0]["overall"]["avg_latency_p50_ms"],
|
||||
}
|
||||
return winners
|
||||
|
||||
|
||||
def _cross_grammar_delta(models: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
rows: list[dict[str, Any]] = []
|
||||
for model in models:
|
||||
literal_acc = float(model["literal"]["intent_accuracy"])
|
||||
nato_acc = float(model["nato"]["intent_accuracy"])
|
||||
rows.append(
|
||||
{
|
||||
"name": model["name"],
|
||||
"intent_accuracy_delta_nato_minus_literal": nato_acc - literal_acc,
|
||||
"literal_intent_accuracy": literal_acc,
|
||||
"nato_intent_accuracy": nato_acc,
|
||||
}
|
||||
)
|
||||
rows.sort(key=lambda item: item["intent_accuracy_delta_nato_minus_literal"], reverse=True)
|
||||
return rows
|
||||
|
||||
|
||||
def _write_samples_report(path: Path, rows: list[DecodedRow]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as handle:
|
||||
for row in rows:
|
||||
payload = {
|
||||
"wav_path": row.wav_path,
|
||||
"expected_phrase": row.expected_phrase,
|
||||
"hypothesis": row.hypothesis,
|
||||
"expected_intent": row.expected_intent,
|
||||
"predicted_intent": row.predicted_intent,
|
||||
"expected_letter": row.expected_letter,
|
||||
"predicted_letter": row.predicted_letter,
|
||||
"expected_modifier": row.expected_modifier,
|
||||
"predicted_modifier": row.predicted_modifier,
|
||||
"intent_match": row.intent_match,
|
||||
"audio_ms": row.audio_ms,
|
||||
"decode_ms": row.decode_ms,
|
||||
"rtf": row.rtf,
|
||||
"out_of_grammar": row.out_of_grammar,
|
||||
}
|
||||
handle.write(f"{json.dumps(payload, ensure_ascii=False)}\n")
|
||||
|
||||
|
||||
def _load_vosk_bindings() -> tuple[Any, Callable[[Any, float, str], Any]]:
|
||||
try:
|
||||
from vosk import KaldiRecognizer, Model, SetLogLevel # type: ignore[import-not-found]
|
||||
except ModuleNotFoundError as exc:
|
||||
raise RuntimeError(
|
||||
"vosk is not installed; run with `uv run --with vosk aman eval-vosk-keystrokes ...`"
|
||||
) from exc
|
||||
SetLogLevel(-1)
|
||||
return Model, KaldiRecognizer
|
||||
|
||||
|
||||
def _phrases_for_grammar(
|
||||
specs: Iterable[IntentSpec],
|
||||
*,
|
||||
grammar: str,
|
||||
) -> list[str]:
|
||||
if grammar not in {"literal", "nato"}:
|
||||
raise RuntimeError(f"unsupported grammar type '{grammar}'")
|
||||
out: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for spec in specs:
|
||||
phrase = spec.literal_phrase if grammar == "literal" else spec.nato_phrase
|
||||
key = _norm(phrase)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(phrase)
|
||||
return sorted(out)
|
||||
|
||||
|
||||
def _inc_metric_bucket(table: dict[str, dict[str, float | int]], key: str, matched: bool) -> None:
|
||||
bucket = table.setdefault(key, {"total": 0, "matches": 0, "accuracy": 0.0})
|
||||
bucket["total"] = int(bucket["total"]) + 1
|
||||
if matched:
|
||||
bucket["matches"] = int(bucket["matches"]) + 1
|
||||
|
||||
|
||||
def _finalize_metric_buckets(table: dict[str, dict[str, float | int]]) -> None:
|
||||
for bucket in table.values():
|
||||
total = int(bucket["total"])
|
||||
matches = int(bucket["matches"])
|
||||
bucket["accuracy"] = (matches / total) if total else 0.0
|
||||
|
||||
|
||||
def _inc_confusion(table: dict[str, dict[str, int]], expected: str, predicted: str) -> None:
|
||||
row = table.setdefault(expected, {})
|
||||
row[predicted] = int(row.get(predicted, 0)) + 1
|
||||
|
||||
|
||||
def _safe_filename(value: str) -> str:
|
||||
out = []
|
||||
for ch in value:
|
||||
if ch.isalnum() or ch in {"-", "_", "."}:
|
||||
out.append(ch)
|
||||
else:
|
||||
out.append("_")
|
||||
return "".join(out).strip("_") or "model"
|
||||
|
||||
|
||||
def _load_json(path: str | Path, *, description: str) -> Any:
|
||||
data_path = Path(path)
|
||||
if not data_path.exists():
|
||||
raise RuntimeError(f"{description} file does not exist: {data_path}")
|
||||
try:
|
||||
return json.loads(data_path.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"invalid {description} json '{data_path}': {exc}") from exc
|
||||
|
||||
|
||||
def _resolve_manifest_wav_path(raw_value: str, *, manifest_dir: Path) -> Path:
|
||||
candidate = Path(raw_value).expanduser()
|
||||
if candidate.is_absolute():
|
||||
return candidate.resolve()
|
||||
cwd_candidate = (Path.cwd() / candidate).resolve()
|
||||
if cwd_candidate.exists():
|
||||
return cwd_candidate
|
||||
manifest_candidate = (manifest_dir / candidate).resolve()
|
||||
if manifest_candidate.exists():
|
||||
return manifest_candidate
|
||||
return cwd_candidate
|
||||
|
||||
|
||||
def _norm(value: str) -> str:
|
||||
return " ".join((value or "").strip().casefold().split())
|
||||
Loading…
Add table
Add a link
Reference in a new issue