import json import sys import tempfile import unittest from pathlib import Path import numpy as np ROOT = Path(__file__).resolve().parents[1] SRC = ROOT / "src" if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) from vosk_collect import CollectOptions, collect_fixed_phrases, float_to_pcm16, load_phrases, slugify_phrase class VoskCollectTests(unittest.TestCase): def test_load_phrases_ignores_blank_comment_and_deduplicates(self): with tempfile.TemporaryDirectory() as td: path = Path(td) / "phrases.txt" path.write_text( ( "# heading\n" "\n" "close app\n" "take a screenshot\n" "close app\n" " \n" ), encoding="utf-8", ) phrases = load_phrases(path) self.assertEqual(phrases, ["close app", "take a screenshot"]) def test_load_phrases_empty_after_filtering_raises(self): with tempfile.TemporaryDirectory() as td: path = Path(td) / "phrases.txt" path.write_text("# only comments\n\n", encoding="utf-8") with self.assertRaisesRegex(RuntimeError, "no usable labels"): load_phrases(path) def test_slugify_phrase_is_deterministic(self): self.assertEqual(slugify_phrase("Take a Screenshot"), "take_a_screenshot") self.assertEqual(slugify_phrase("close-app!!!"), "close_app") def test_float_to_pcm16_clamps_audio_bounds(self): values = np.asarray([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0], dtype=np.float32) out = float_to_pcm16(values) self.assertEqual(out.dtype, np.int16) self.assertGreaterEqual(int(out.min()), -32767) self.assertLessEqual(int(out.max()), 32767) self.assertEqual(int(out[0]), -32767) self.assertEqual(int(out[-1]), 32767) def test_collect_fixed_phrases_writes_manifest_and_wavs(self): with tempfile.TemporaryDirectory() as td: root = Path(td) phrases_path = root / "phrases.txt" out_dir = root / "dataset" phrases_path.write_text("close app\ntake a screenshot\n", encoding="utf-8") options = CollectOptions( phrases_file=phrases_path, out_dir=out_dir, samples_per_phrase=2, samplerate=16000, channels=1, session_id="session-1", ) answers = ["", "", "", ""] def fake_input(_prompt: str) -> str: return answers.pop(0) def fake_record(_options: CollectOptions, _input_func): audio = np.ones((320, 1), dtype=np.float32) * 0.1 return audio, 320, 20 result = collect_fixed_phrases( options, input_func=fake_input, output_func=lambda _line: None, record_sample_fn=fake_record, ) self.assertFalse(result.interrupted) self.assertEqual(result.samples_written, 4) manifest = out_dir / "manifest.jsonl" rows = [ json.loads(line) for line in manifest.read_text(encoding="utf-8").splitlines() if line.strip() ] self.assertEqual(len(rows), 4) required = { "session_id", "timestamp_utc", "phrase", "phrase_slug", "sample_index", "wav_path", "samplerate", "channels", "duration_ms", "frames", "device_spec", "collector_version", } self.assertTrue(required.issubset(rows[0].keys())) wav_paths = [root / Path(row["wav_path"]) for row in rows] for wav_path in wav_paths: self.assertTrue(wav_path.exists(), f"missing wav: {wav_path}") def test_collect_fixed_phrases_refuses_existing_session_without_overwrite(self): with tempfile.TemporaryDirectory() as td: root = Path(td) phrases_path = root / "phrases.txt" out_dir = root / "dataset" phrases_path.write_text("close app\n", encoding="utf-8") options = CollectOptions( phrases_file=phrases_path, out_dir=out_dir, samples_per_phrase=1, samplerate=16000, channels=1, session_id="session-1", ) def fake_record(_options: CollectOptions, _input_func): audio = np.ones((160, 1), dtype=np.float32) * 0.2 return audio, 160, 10 collect_fixed_phrases( options, input_func=lambda _prompt: "", output_func=lambda _line: None, record_sample_fn=fake_record, ) with self.assertRaisesRegex(RuntimeError, "already has samples"): collect_fixed_phrases( options, input_func=lambda _prompt: "", output_func=lambda _line: None, record_sample_fn=fake_record, ) if __name__ == "__main__": unittest.main()