diff --git a/README.md b/README.md index 6884ea5..021cce8 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,8 @@ Create `~/.config/aman/config.json` (or let `aman` create it automatically on fi Recording input can be a device index (preferred) or a substring of the device name. +Config validation is strict: unknown fields are rejected with a startup error. + Hotkey notes: - Use one key plus optional modifiers (for example `Cmd+m`, `Super+m`, `Ctrl+space`). diff --git a/config.example.json b/config.example.json index b72cfc9..9326250 100644 --- a/config.example.json +++ b/config.example.json @@ -35,8 +35,5 @@ "Kubernetes", "PostgreSQL" ] - }, - "domain_inference": { - "enabled": true } } diff --git a/src/config.py b/src/config.py index ab4d30a..0e79791 100644 --- a/src/config.py +++ b/src/config.py @@ -119,47 +119,57 @@ def validate(cfg: Config) -> None: cfg.vocabulary.terms = _validate_terms(cfg.vocabulary.terms) def _from_dict(data: dict[str, Any], cfg: Config) -> Config: - has_sections = any( - key in data - for key in ( - "daemon", - "recording", - "stt", - "injection", - "vocabulary", - ) + _reject_unknown_keys( + data, + {"daemon", "recording", "stt", "injection", "vocabulary"}, + parent="", ) - if has_sections: - daemon = _ensure_dict(data.get("daemon"), "daemon") - recording = _ensure_dict(data.get("recording"), "recording") - stt = _ensure_dict(data.get("stt"), "stt") - injection = _ensure_dict(data.get("injection"), "injection") - vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary") + daemon = _ensure_dict(data.get("daemon"), "daemon") + recording = _ensure_dict(data.get("recording"), "recording") + stt = _ensure_dict(data.get("stt"), "stt") + injection = _ensure_dict(data.get("injection"), "injection") + vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary") - if "hotkey" in daemon: - cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey") - if "input" in recording: - cfg.recording.input = _as_recording_input(recording["input"]) - if "model" in stt: - cfg.stt.model = _as_nonempty_str(stt["model"], "stt.model") - if "device" in stt: - cfg.stt.device = _as_nonempty_str(stt["device"], "stt.device") - if "backend" in injection: - cfg.injection.backend = _as_nonempty_str(injection["backend"], "injection.backend") - if "remove_transcription_from_clipboard" in injection: - cfg.injection.remove_transcription_from_clipboard = _as_bool( - injection["remove_transcription_from_clipboard"], - "injection.remove_transcription_from_clipboard", - ) - if "replacements" in vocabulary: - cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"]) - if "terms" in vocabulary: - cfg.vocabulary.terms = _as_terms(vocabulary["terms"]) - return cfg + _reject_unknown_keys(daemon, {"hotkey"}, parent="daemon") + _reject_unknown_keys(recording, {"input"}, parent="recording") + _reject_unknown_keys(stt, {"model", "device"}, parent="stt") + _reject_unknown_keys( + injection, + {"backend", "remove_transcription_from_clipboard"}, + parent="injection", + ) + _reject_unknown_keys(vocabulary, {"replacements", "terms"}, parent="vocabulary") + if "hotkey" in daemon: + cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey") + if "input" in recording: + cfg.recording.input = _as_recording_input(recording["input"]) + if "model" in stt: + cfg.stt.model = _as_nonempty_str(stt["model"], "stt.model") + if "device" in stt: + cfg.stt.device = _as_nonempty_str(stt["device"], "stt.device") + if "backend" in injection: + cfg.injection.backend = _as_nonempty_str(injection["backend"], "injection.backend") + if "remove_transcription_from_clipboard" in injection: + cfg.injection.remove_transcription_from_clipboard = _as_bool( + injection["remove_transcription_from_clipboard"], + "injection.remove_transcription_from_clipboard", + ) + if "replacements" in vocabulary: + cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"]) + if "terms" in vocabulary: + cfg.vocabulary.terms = _as_terms(vocabulary["terms"]) return cfg +def _reject_unknown_keys(value: dict[str, Any], allowed: set[str], *, parent: str) -> None: + for key in value.keys(): + if key in allowed: + continue + field = f"{parent}.{key}" if parent else key + raise ValueError(f"unknown config field: {field}") + + def _ensure_dict(value: Any, field_name: str) -> dict[str, Any]: if value is None: return {} diff --git a/tests/test_config.py b/tests/test_config.py index 1e6dd13..abff50e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -113,7 +113,7 @@ class ConfigTests(unittest.TestCase): with self.assertRaisesRegex(ValueError, "injection.remove_transcription_from_clipboard"): load(str(path)) - def test_unknown_top_level_fields_are_ignored(self): + def test_unknown_top_level_fields_raise(self): payload = { "custom_a": {"enabled": True}, "custom_b": {"nested": "value"}, @@ -123,10 +123,8 @@ class ConfigTests(unittest.TestCase): path = Path(td) / "config.json" path.write_text(json.dumps(payload), encoding="utf-8") - cfg = load(str(path)) - - self.assertEqual(cfg.daemon.hotkey, "Cmd+m") - self.assertEqual(cfg.injection.backend, "clipboard") + with self.assertRaisesRegex(ValueError, "unknown config field: custom_a"): + load(str(path)) def test_conflicting_replacements_raise(self): payload = { @@ -178,15 +176,23 @@ class ConfigTests(unittest.TestCase): with self.assertRaisesRegex(ValueError, "wildcard"): load(str(path)) - def test_unknown_vocabulary_fields_are_ignored(self): + def test_unknown_vocabulary_fields_raise(self): payload = {"vocabulary": {"custom_limit": 100, "custom_extra": 200, "terms": ["Docker"]}} with tempfile.TemporaryDirectory() as td: path = Path(td) / "config.json" path.write_text(json.dumps(payload), encoding="utf-8") - cfg = load(str(path)) + with self.assertRaisesRegex(ValueError, "unknown config field: vocabulary.custom_limit"): + load(str(path)) - self.assertEqual(cfg.vocabulary.terms, ["Docker"]) + def test_unknown_nested_stt_field_raises(self): + payload = {"stt": {"model": "base", "device": "cpu", "language": "en"}} + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + path.write_text(json.dumps(payload), encoding="utf-8") + + with self.assertRaisesRegex(ValueError, "unknown config field: stt.language"): + load(str(path)) if __name__ == "__main__":