aman/tests/test_vosk_collect.py

148 lines
5.3 KiB
Python

import json
import sys
import tempfile
import unittest
from pathlib import Path
import numpy as np
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from vosk_collect import CollectOptions, collect_fixed_phrases, float_to_pcm16, load_phrases, slugify_phrase
class VoskCollectTests(unittest.TestCase):
def test_load_phrases_ignores_blank_comment_and_deduplicates(self):
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "phrases.txt"
path.write_text(
(
"# heading\n"
"\n"
"close app\n"
"take a screenshot\n"
"close app\n"
" \n"
),
encoding="utf-8",
)
phrases = load_phrases(path)
self.assertEqual(phrases, ["close app", "take a screenshot"])
def test_load_phrases_empty_after_filtering_raises(self):
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "phrases.txt"
path.write_text("# only comments\n\n", encoding="utf-8")
with self.assertRaisesRegex(RuntimeError, "no usable labels"):
load_phrases(path)
def test_slugify_phrase_is_deterministic(self):
self.assertEqual(slugify_phrase("Take a Screenshot"), "take_a_screenshot")
self.assertEqual(slugify_phrase("close-app!!!"), "close_app")
def test_float_to_pcm16_clamps_audio_bounds(self):
values = np.asarray([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0], dtype=np.float32)
out = float_to_pcm16(values)
self.assertEqual(out.dtype, np.int16)
self.assertGreaterEqual(int(out.min()), -32767)
self.assertLessEqual(int(out.max()), 32767)
self.assertEqual(int(out[0]), -32767)
self.assertEqual(int(out[-1]), 32767)
def test_collect_fixed_phrases_writes_manifest_and_wavs(self):
with tempfile.TemporaryDirectory() as td:
root = Path(td)
phrases_path = root / "phrases.txt"
out_dir = root / "dataset"
phrases_path.write_text("close app\ntake a screenshot\n", encoding="utf-8")
options = CollectOptions(
phrases_file=phrases_path,
out_dir=out_dir,
samples_per_phrase=2,
samplerate=16000,
channels=1,
session_id="session-1",
)
answers = ["", "", "", ""]
def fake_input(_prompt: str) -> str:
return answers.pop(0)
def fake_record(_options: CollectOptions, _input_func):
audio = np.ones((320, 1), dtype=np.float32) * 0.1
return audio, 320, 20
result = collect_fixed_phrases(
options,
input_func=fake_input,
output_func=lambda _line: None,
record_sample_fn=fake_record,
)
self.assertFalse(result.interrupted)
self.assertEqual(result.samples_written, 4)
manifest = out_dir / "manifest.jsonl"
rows = [
json.loads(line)
for line in manifest.read_text(encoding="utf-8").splitlines()
if line.strip()
]
self.assertEqual(len(rows), 4)
required = {
"session_id",
"timestamp_utc",
"phrase",
"phrase_slug",
"sample_index",
"wav_path",
"samplerate",
"channels",
"duration_ms",
"frames",
"device_spec",
"collector_version",
}
self.assertTrue(required.issubset(rows[0].keys()))
wav_paths = [root / Path(row["wav_path"]) for row in rows]
for wav_path in wav_paths:
self.assertTrue(wav_path.exists(), f"missing wav: {wav_path}")
def test_collect_fixed_phrases_refuses_existing_session_without_overwrite(self):
with tempfile.TemporaryDirectory() as td:
root = Path(td)
phrases_path = root / "phrases.txt"
out_dir = root / "dataset"
phrases_path.write_text("close app\n", encoding="utf-8")
options = CollectOptions(
phrases_file=phrases_path,
out_dir=out_dir,
samples_per_phrase=1,
samplerate=16000,
channels=1,
session_id="session-1",
)
def fake_record(_options: CollectOptions, _input_func):
audio = np.ones((160, 1), dtype=np.float32) * 0.2
return audio, 160, 10
collect_fixed_phrases(
options,
input_func=lambda _prompt: "",
output_func=lambda _line: None,
record_sample_fn=fake_record,
)
with self.assertRaisesRegex(RuntimeError, "already has samples"):
collect_fixed_phrases(
options,
input_func=lambda _prompt: "",
output_func=lambda _line: None,
record_sample_fn=fake_record,
)
if __name__ == "__main__":
unittest.main()