148 lines
5.3 KiB
Python
148 lines
5.3 KiB
Python
import json
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
from vosk_collect import CollectOptions, collect_fixed_phrases, float_to_pcm16, load_phrases, slugify_phrase
|
|
|
|
|
|
class VoskCollectTests(unittest.TestCase):
|
|
def test_load_phrases_ignores_blank_comment_and_deduplicates(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "phrases.txt"
|
|
path.write_text(
|
|
(
|
|
"# heading\n"
|
|
"\n"
|
|
"close app\n"
|
|
"take a screenshot\n"
|
|
"close app\n"
|
|
" \n"
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
phrases = load_phrases(path)
|
|
self.assertEqual(phrases, ["close app", "take a screenshot"])
|
|
|
|
def test_load_phrases_empty_after_filtering_raises(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "phrases.txt"
|
|
path.write_text("# only comments\n\n", encoding="utf-8")
|
|
with self.assertRaisesRegex(RuntimeError, "no usable labels"):
|
|
load_phrases(path)
|
|
|
|
def test_slugify_phrase_is_deterministic(self):
|
|
self.assertEqual(slugify_phrase("Take a Screenshot"), "take_a_screenshot")
|
|
self.assertEqual(slugify_phrase("close-app!!!"), "close_app")
|
|
|
|
def test_float_to_pcm16_clamps_audio_bounds(self):
|
|
values = np.asarray([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0], dtype=np.float32)
|
|
out = float_to_pcm16(values)
|
|
self.assertEqual(out.dtype, np.int16)
|
|
self.assertGreaterEqual(int(out.min()), -32767)
|
|
self.assertLessEqual(int(out.max()), 32767)
|
|
self.assertEqual(int(out[0]), -32767)
|
|
self.assertEqual(int(out[-1]), 32767)
|
|
|
|
def test_collect_fixed_phrases_writes_manifest_and_wavs(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
root = Path(td)
|
|
phrases_path = root / "phrases.txt"
|
|
out_dir = root / "dataset"
|
|
phrases_path.write_text("close app\ntake a screenshot\n", encoding="utf-8")
|
|
options = CollectOptions(
|
|
phrases_file=phrases_path,
|
|
out_dir=out_dir,
|
|
samples_per_phrase=2,
|
|
samplerate=16000,
|
|
channels=1,
|
|
session_id="session-1",
|
|
)
|
|
answers = ["", "", "", ""]
|
|
|
|
def fake_input(_prompt: str) -> str:
|
|
return answers.pop(0)
|
|
|
|
def fake_record(_options: CollectOptions, _input_func):
|
|
audio = np.ones((320, 1), dtype=np.float32) * 0.1
|
|
return audio, 320, 20
|
|
|
|
result = collect_fixed_phrases(
|
|
options,
|
|
input_func=fake_input,
|
|
output_func=lambda _line: None,
|
|
record_sample_fn=fake_record,
|
|
)
|
|
|
|
self.assertFalse(result.interrupted)
|
|
self.assertEqual(result.samples_written, 4)
|
|
manifest = out_dir / "manifest.jsonl"
|
|
rows = [
|
|
json.loads(line)
|
|
for line in manifest.read_text(encoding="utf-8").splitlines()
|
|
if line.strip()
|
|
]
|
|
self.assertEqual(len(rows), 4)
|
|
required = {
|
|
"session_id",
|
|
"timestamp_utc",
|
|
"phrase",
|
|
"phrase_slug",
|
|
"sample_index",
|
|
"wav_path",
|
|
"samplerate",
|
|
"channels",
|
|
"duration_ms",
|
|
"frames",
|
|
"device_spec",
|
|
"collector_version",
|
|
}
|
|
self.assertTrue(required.issubset(rows[0].keys()))
|
|
wav_paths = [root / Path(row["wav_path"]) for row in rows]
|
|
for wav_path in wav_paths:
|
|
self.assertTrue(wav_path.exists(), f"missing wav: {wav_path}")
|
|
|
|
def test_collect_fixed_phrases_refuses_existing_session_without_overwrite(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
root = Path(td)
|
|
phrases_path = root / "phrases.txt"
|
|
out_dir = root / "dataset"
|
|
phrases_path.write_text("close app\n", encoding="utf-8")
|
|
options = CollectOptions(
|
|
phrases_file=phrases_path,
|
|
out_dir=out_dir,
|
|
samples_per_phrase=1,
|
|
samplerate=16000,
|
|
channels=1,
|
|
session_id="session-1",
|
|
)
|
|
|
|
def fake_record(_options: CollectOptions, _input_func):
|
|
audio = np.ones((160, 1), dtype=np.float32) * 0.2
|
|
return audio, 160, 10
|
|
|
|
collect_fixed_phrases(
|
|
options,
|
|
input_func=lambda _prompt: "",
|
|
output_func=lambda _line: None,
|
|
record_sample_fn=fake_record,
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, "already has samples"):
|
|
collect_fixed_phrases(
|
|
options,
|
|
input_func=lambda _prompt: "",
|
|
output_func=lambda _line: None,
|
|
record_sample_fn=fake_record,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|