316 lines
12 KiB
Python
316 lines
12 KiB
Python
import json
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
from config import CURRENT_CONFIG_VERSION, load, redacted_dict
|
|
|
|
|
|
class ConfigTests(unittest.TestCase):
|
|
def test_defaults_when_file_missing(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
missing = Path(td) / "nested" / "config.json"
|
|
cfg = load(str(missing))
|
|
|
|
self.assertEqual(cfg.daemon.hotkey, "Cmd+m")
|
|
self.assertEqual(cfg.config_version, CURRENT_CONFIG_VERSION)
|
|
self.assertEqual(cfg.recording.input, "")
|
|
self.assertEqual(cfg.stt.provider, "local_whisper")
|
|
self.assertEqual(cfg.stt.model, "base")
|
|
self.assertEqual(cfg.stt.device, "cpu")
|
|
self.assertEqual(cfg.stt.language, "auto")
|
|
self.assertFalse(cfg.models.allow_custom_models)
|
|
self.assertEqual(cfg.models.whisper_model_path, "")
|
|
self.assertEqual(cfg.injection.backend, "clipboard")
|
|
self.assertFalse(cfg.injection.remove_transcription_from_clipboard)
|
|
self.assertTrue(cfg.safety.enabled)
|
|
self.assertFalse(cfg.safety.strict)
|
|
self.assertEqual(cfg.ux.profile, "default")
|
|
self.assertTrue(cfg.ux.show_notifications)
|
|
self.assertTrue(cfg.advanced.strict_startup)
|
|
self.assertEqual(cfg.vocabulary.replacements, [])
|
|
self.assertEqual(cfg.vocabulary.terms, [])
|
|
|
|
self.assertTrue(missing.exists())
|
|
written = json.loads(missing.read_text(encoding="utf-8"))
|
|
self.assertEqual(written, redacted_dict(cfg))
|
|
|
|
def test_loads_nested_config(self):
|
|
payload = {
|
|
"config_version": CURRENT_CONFIG_VERSION,
|
|
"daemon": {"hotkey": "Ctrl+space"},
|
|
"recording": {"input": 3},
|
|
"stt": {
|
|
"provider": "local_whisper",
|
|
"model": "small",
|
|
"device": "cuda",
|
|
"language": "English",
|
|
},
|
|
"models": {"allow_custom_models": False},
|
|
"injection": {
|
|
"backend": "injection",
|
|
"remove_transcription_from_clipboard": True,
|
|
},
|
|
"safety": {
|
|
"enabled": True,
|
|
"strict": True,
|
|
},
|
|
"vocabulary": {
|
|
"replacements": [
|
|
{"from": "Martha", "to": "Marta"},
|
|
{"from": "docker", "to": "Docker"},
|
|
],
|
|
"terms": ["Systemd", "Kubernetes"],
|
|
},
|
|
}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
cfg = load(str(path))
|
|
|
|
self.assertEqual(cfg.config_version, CURRENT_CONFIG_VERSION)
|
|
self.assertEqual(cfg.daemon.hotkey, "Ctrl+space")
|
|
self.assertEqual(cfg.recording.input, 3)
|
|
self.assertEqual(cfg.stt.provider, "local_whisper")
|
|
self.assertEqual(cfg.stt.model, "small")
|
|
self.assertEqual(cfg.stt.device, "cuda")
|
|
self.assertEqual(cfg.stt.language, "en")
|
|
self.assertEqual(cfg.injection.backend, "injection")
|
|
self.assertTrue(cfg.injection.remove_transcription_from_clipboard)
|
|
self.assertTrue(cfg.safety.enabled)
|
|
self.assertTrue(cfg.safety.strict)
|
|
self.assertEqual(len(cfg.vocabulary.replacements), 2)
|
|
self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha")
|
|
self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta")
|
|
self.assertEqual(cfg.vocabulary.terms, ["Systemd", "Kubernetes"])
|
|
|
|
def test_super_modifier_hotkey_is_valid(self):
|
|
payload = {"daemon": {"hotkey": "Super+m"}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
cfg = load(str(path))
|
|
|
|
self.assertEqual(cfg.daemon.hotkey, "Super+m")
|
|
|
|
def test_invalid_hotkey_missing_key_raises(self):
|
|
payload = {"daemon": {"hotkey": "Ctrl+Alt"}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "daemon.hotkey: is invalid: missing key"):
|
|
load(str(path))
|
|
|
|
def test_invalid_hotkey_multiple_keys_raises(self):
|
|
payload = {"daemon": {"hotkey": "Ctrl+a+b"}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "daemon.hotkey: is invalid: must include exactly one non-modifier key"
|
|
):
|
|
load(str(path))
|
|
|
|
def test_invalid_injection_backend_raises(self):
|
|
payload = {"injection": {"backend": "invalid"}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "injection.backend"):
|
|
load(str(path))
|
|
|
|
def test_invalid_clipboard_remove_option_raises(self):
|
|
payload = {"injection": {"remove_transcription_from_clipboard": "yes"}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "injection.remove_transcription_from_clipboard"):
|
|
load(str(path))
|
|
|
|
def test_invalid_safety_enabled_option_raises(self):
|
|
payload = {"safety": {"enabled": "yes"}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "safety.enabled"):
|
|
load(str(path))
|
|
|
|
def test_invalid_safety_strict_option_raises(self):
|
|
payload = {"safety": {"strict": "yes"}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "safety.strict"):
|
|
load(str(path))
|
|
|
|
def test_unknown_safety_fields_raise(self):
|
|
payload = {"safety": {"enabled": True, "mode": "strict"}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "safety.mode: unknown config field"):
|
|
load(str(path))
|
|
|
|
def test_unknown_top_level_fields_raise(self):
|
|
payload = {
|
|
"custom_a": {"enabled": True},
|
|
"custom_b": {"nested": "value"},
|
|
"custom_c": 123,
|
|
}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "custom_a: unknown config field"):
|
|
load(str(path))
|
|
|
|
def test_conflicting_replacements_raise(self):
|
|
payload = {
|
|
"vocabulary": {
|
|
"replacements": [
|
|
{"from": "Martha", "to": "Marta"},
|
|
{"from": "martha", "to": "Martha"},
|
|
]
|
|
}
|
|
}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "conflicting"):
|
|
load(str(path))
|
|
|
|
def test_duplicate_rules_and_terms_are_deduplicated(self):
|
|
payload = {
|
|
"vocabulary": {
|
|
"replacements": [
|
|
{"from": "docker", "to": "Docker"},
|
|
{"from": "DOCKER", "to": "Docker"},
|
|
],
|
|
"terms": ["Systemd", "systemd"],
|
|
}
|
|
}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
cfg = load(str(path))
|
|
|
|
self.assertEqual(len(cfg.vocabulary.replacements), 1)
|
|
self.assertEqual(cfg.vocabulary.replacements[0].source, "docker")
|
|
self.assertEqual(cfg.vocabulary.replacements[0].target, "Docker")
|
|
self.assertEqual(cfg.vocabulary.terms, ["Systemd"])
|
|
|
|
def test_wildcard_term_raises(self):
|
|
payload = {
|
|
"vocabulary": {
|
|
"terms": ["Dock*"],
|
|
}
|
|
}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "wildcard"):
|
|
load(str(path))
|
|
|
|
def test_unknown_vocabulary_fields_raise(self):
|
|
payload = {"vocabulary": {"custom_limit": 100, "custom_extra": 200, "terms": ["Docker"]}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "vocabulary.custom_limit: unknown config field"):
|
|
load(str(path))
|
|
|
|
def test_stt_language_accepts_auto(self):
|
|
payload = {"stt": {"model": "base", "device": "cpu", "language": "auto"}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
cfg = load(str(path))
|
|
|
|
self.assertEqual(cfg.stt.language, "auto")
|
|
|
|
def test_invalid_stt_language_raises(self):
|
|
payload = {"stt": {"model": "base", "device": "cpu", "language": "klingon"}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "stt.language: unsupported language"):
|
|
load(str(path))
|
|
|
|
def test_non_string_stt_language_raises(self):
|
|
payload = {"stt": {"model": "base", "device": "cpu", "language": 123}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "stt.language: must be a string"):
|
|
load(str(path))
|
|
|
|
def test_unknown_nested_stt_field_raises(self):
|
|
payload = {"stt": {"model": "base", "device": "cpu", "custom": "value"}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "stt.custom: unknown config field"):
|
|
load(str(path))
|
|
|
|
def test_invalid_ux_profile_raises(self):
|
|
payload = {"ux": {"profile": "unknown"}}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(ValueError, "ux.profile: must be one of"):
|
|
load(str(path))
|
|
|
|
def test_missing_config_version_is_migrated_to_current(self):
|
|
payload = {
|
|
"daemon": {"hotkey": "Super+m"},
|
|
"stt": {"model": "base", "device": "cpu"},
|
|
}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
cfg = load(str(path))
|
|
|
|
self.assertEqual(cfg.config_version, CURRENT_CONFIG_VERSION)
|
|
|
|
def test_legacy_llm_config_fields_raise(self):
|
|
payload = {
|
|
"llm": {"provider": "local_llama"},
|
|
}
|
|
with tempfile.TemporaryDirectory() as td:
|
|
path = Path(td) / "config.json"
|
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"llm: unknown config field",
|
|
):
|
|
load(str(path))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|