Enforce strict config schema and clean examples

This commit is contained in:
Thales Maciel 2026-02-26 16:31:00 -03:00
parent 0df8c356af
commit 67fc8d1701
4 changed files with 61 additions and 46 deletions

View file

@ -87,6 +87,8 @@ Create `~/.config/aman/config.json` (or let `aman` create it automatically on fi
Recording input can be a device index (preferred) or a substring of the device Recording input can be a device index (preferred) or a substring of the device
name. name.
Config validation is strict: unknown fields are rejected with a startup error.
Hotkey notes: Hotkey notes:
- Use one key plus optional modifiers (for example `Cmd+m`, `Super+m`, `Ctrl+space`). - Use one key plus optional modifiers (for example `Cmd+m`, `Super+m`, `Ctrl+space`).

View file

@ -35,8 +35,5 @@
"Kubernetes", "Kubernetes",
"PostgreSQL" "PostgreSQL"
] ]
},
"domain_inference": {
"enabled": true
} }
} }

View file

@ -119,23 +119,27 @@ def validate(cfg: Config) -> None:
cfg.vocabulary.terms = _validate_terms(cfg.vocabulary.terms) cfg.vocabulary.terms = _validate_terms(cfg.vocabulary.terms)
def _from_dict(data: dict[str, Any], cfg: Config) -> Config: def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
has_sections = any( _reject_unknown_keys(
key in data data,
for key in ( {"daemon", "recording", "stt", "injection", "vocabulary"},
"daemon", parent="",
"recording",
"stt",
"injection",
"vocabulary",
) )
)
if has_sections:
daemon = _ensure_dict(data.get("daemon"), "daemon") daemon = _ensure_dict(data.get("daemon"), "daemon")
recording = _ensure_dict(data.get("recording"), "recording") recording = _ensure_dict(data.get("recording"), "recording")
stt = _ensure_dict(data.get("stt"), "stt") stt = _ensure_dict(data.get("stt"), "stt")
injection = _ensure_dict(data.get("injection"), "injection") injection = _ensure_dict(data.get("injection"), "injection")
vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary") vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary")
_reject_unknown_keys(daemon, {"hotkey"}, parent="daemon")
_reject_unknown_keys(recording, {"input"}, parent="recording")
_reject_unknown_keys(stt, {"model", "device"}, parent="stt")
_reject_unknown_keys(
injection,
{"backend", "remove_transcription_from_clipboard"},
parent="injection",
)
_reject_unknown_keys(vocabulary, {"replacements", "terms"}, parent="vocabulary")
if "hotkey" in daemon: if "hotkey" in daemon:
cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey") cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey")
if "input" in recording: if "input" in recording:
@ -157,7 +161,13 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
cfg.vocabulary.terms = _as_terms(vocabulary["terms"]) cfg.vocabulary.terms = _as_terms(vocabulary["terms"])
return cfg return cfg
return cfg
def _reject_unknown_keys(value: dict[str, Any], allowed: set[str], *, parent: str) -> None:
for key in value.keys():
if key in allowed:
continue
field = f"{parent}.{key}" if parent else key
raise ValueError(f"unknown config field: {field}")
def _ensure_dict(value: Any, field_name: str) -> dict[str, Any]: def _ensure_dict(value: Any, field_name: str) -> dict[str, Any]:

View file

@ -113,7 +113,7 @@ class ConfigTests(unittest.TestCase):
with self.assertRaisesRegex(ValueError, "injection.remove_transcription_from_clipboard"): with self.assertRaisesRegex(ValueError, "injection.remove_transcription_from_clipboard"):
load(str(path)) load(str(path))
def test_unknown_top_level_fields_are_ignored(self): def test_unknown_top_level_fields_raise(self):
payload = { payload = {
"custom_a": {"enabled": True}, "custom_a": {"enabled": True},
"custom_b": {"nested": "value"}, "custom_b": {"nested": "value"},
@ -123,10 +123,8 @@ class ConfigTests(unittest.TestCase):
path = Path(td) / "config.json" path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8") path.write_text(json.dumps(payload), encoding="utf-8")
cfg = load(str(path)) with self.assertRaisesRegex(ValueError, "unknown config field: custom_a"):
load(str(path))
self.assertEqual(cfg.daemon.hotkey, "Cmd+m")
self.assertEqual(cfg.injection.backend, "clipboard")
def test_conflicting_replacements_raise(self): def test_conflicting_replacements_raise(self):
payload = { payload = {
@ -178,15 +176,23 @@ class ConfigTests(unittest.TestCase):
with self.assertRaisesRegex(ValueError, "wildcard"): with self.assertRaisesRegex(ValueError, "wildcard"):
load(str(path)) load(str(path))
def test_unknown_vocabulary_fields_are_ignored(self): def test_unknown_vocabulary_fields_raise(self):
payload = {"vocabulary": {"custom_limit": 100, "custom_extra": 200, "terms": ["Docker"]}} payload = {"vocabulary": {"custom_limit": 100, "custom_extra": 200, "terms": ["Docker"]}}
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json" path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8") path.write_text(json.dumps(payload), encoding="utf-8")
cfg = load(str(path)) with self.assertRaisesRegex(ValueError, "unknown config field: vocabulary.custom_limit"):
load(str(path))
self.assertEqual(cfg.vocabulary.terms, ["Docker"]) def test_unknown_nested_stt_field_raises(self):
payload = {"stt": {"model": "base", "device": "cpu", "language": "en"}}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
with self.assertRaisesRegex(ValueError, "unknown config field: stt.language"):
load(str(path))
if __name__ == "__main__": if __name__ == "__main__":