549 lines
19 KiB
Python
549 lines
19 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import asdict, dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from constants import DEFAULT_CONFIG_PATH
|
|
from hotkey import split_hotkey
|
|
from languages import DEFAULT_STT_LANGUAGE, normalize_stt_language
|
|
|
|
|
|
CURRENT_CONFIG_VERSION = 1
|
|
DEFAULT_HOTKEY = "Cmd+m"
|
|
DEFAULT_STT_PROVIDER = "local_whisper"
|
|
DEFAULT_STT_MODEL = "base"
|
|
DEFAULT_STT_DEVICE = "cpu"
|
|
DEFAULT_INJECTION_BACKEND = "clipboard"
|
|
DEFAULT_UX_PROFILE = "default"
|
|
ALLOWED_STT_PROVIDERS = {"local_whisper"}
|
|
ALLOWED_INJECTION_BACKENDS = {"clipboard", "injection"}
|
|
ALLOWED_UX_PROFILES = {"default", "fast", "polished"}
|
|
WILDCARD_CHARS = set("*?[]{}")
|
|
|
|
|
|
@dataclass
|
|
class ConfigValidationError(ValueError):
|
|
field: str
|
|
reason: str
|
|
example_fix: str = ""
|
|
|
|
def __str__(self) -> str:
|
|
if self.example_fix:
|
|
return f"{self.field}: {self.reason}. example: {self.example_fix}"
|
|
return f"{self.field}: {self.reason}"
|
|
|
|
|
|
def _raise_cfg_error(field: str, reason: str, example_fix: str = "") -> None:
|
|
raise ConfigValidationError(field=field, reason=reason, example_fix=example_fix)
|
|
|
|
|
|
@dataclass
|
|
class DaemonConfig:
|
|
hotkey: str = DEFAULT_HOTKEY
|
|
|
|
|
|
@dataclass
|
|
class RecordingConfig:
|
|
input: str | int | None = ""
|
|
|
|
|
|
@dataclass
|
|
class SttConfig:
|
|
provider: str = DEFAULT_STT_PROVIDER
|
|
model: str = DEFAULT_STT_MODEL
|
|
device: str = DEFAULT_STT_DEVICE
|
|
language: str = DEFAULT_STT_LANGUAGE
|
|
|
|
|
|
@dataclass
|
|
class ModelsConfig:
|
|
allow_custom_models: bool = False
|
|
whisper_model_path: str = ""
|
|
|
|
|
|
@dataclass
|
|
class InjectionConfig:
|
|
backend: str = DEFAULT_INJECTION_BACKEND
|
|
remove_transcription_from_clipboard: bool = False
|
|
|
|
|
|
@dataclass
|
|
class SafetyConfig:
|
|
enabled: bool = True
|
|
strict: bool = False
|
|
|
|
|
|
@dataclass
|
|
class UxConfig:
|
|
profile: str = DEFAULT_UX_PROFILE
|
|
show_notifications: bool = True
|
|
|
|
|
|
@dataclass
|
|
class AdvancedConfig:
|
|
strict_startup: bool = True
|
|
|
|
|
|
@dataclass
|
|
class VocabularyReplacement:
|
|
source: str
|
|
target: str
|
|
|
|
|
|
@dataclass
|
|
class VocabularyConfig:
|
|
replacements: list[VocabularyReplacement] = field(default_factory=list)
|
|
terms: list[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
config_version: int = CURRENT_CONFIG_VERSION
|
|
daemon: DaemonConfig = field(default_factory=DaemonConfig)
|
|
recording: RecordingConfig = field(default_factory=RecordingConfig)
|
|
stt: SttConfig = field(default_factory=SttConfig)
|
|
models: ModelsConfig = field(default_factory=ModelsConfig)
|
|
injection: InjectionConfig = field(default_factory=InjectionConfig)
|
|
safety: SafetyConfig = field(default_factory=SafetyConfig)
|
|
ux: UxConfig = field(default_factory=UxConfig)
|
|
advanced: AdvancedConfig = field(default_factory=AdvancedConfig)
|
|
vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig)
|
|
|
|
|
|
def load(path: str | None) -> Config:
|
|
cfg = Config()
|
|
p = Path(path) if path else DEFAULT_CONFIG_PATH
|
|
if p.exists():
|
|
data = json.loads(p.read_text(encoding="utf-8"))
|
|
if not isinstance(data, dict):
|
|
_raise_cfg_error(
|
|
"config",
|
|
"must be a JSON object",
|
|
'{"daemon":{"hotkey":"Super+m"}}',
|
|
)
|
|
data = _migrate_dict(data)
|
|
cfg = _from_dict(data, cfg)
|
|
validate(cfg)
|
|
return cfg
|
|
|
|
validate(cfg)
|
|
_write_default_config(p, cfg)
|
|
return cfg
|
|
|
|
|
|
def save(path: str | Path | None, cfg: Config) -> Path:
|
|
validate(cfg)
|
|
target = Path(path) if path else DEFAULT_CONFIG_PATH
|
|
_write_default_config(target, cfg)
|
|
return target
|
|
|
|
|
|
def redacted_dict(cfg: Config) -> dict[str, Any]:
|
|
return asdict(cfg)
|
|
|
|
|
|
def _write_default_config(path: Path, cfg: Config) -> None:
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
path.write_text(f"{json.dumps(redacted_dict(cfg), indent=2)}\n", encoding="utf-8")
|
|
|
|
|
|
def validate(cfg: Config) -> None:
|
|
if not isinstance(cfg.config_version, int):
|
|
_raise_cfg_error("config_version", "must be integer", '{"config_version":1}')
|
|
if cfg.config_version != CURRENT_CONFIG_VERSION:
|
|
_raise_cfg_error(
|
|
"config_version",
|
|
f"must be {CURRENT_CONFIG_VERSION}",
|
|
f'{{"config_version":{CURRENT_CONFIG_VERSION}}}',
|
|
)
|
|
|
|
hotkey = cfg.daemon.hotkey.strip()
|
|
if not hotkey:
|
|
_raise_cfg_error("daemon.hotkey", "cannot be empty", '{"daemon":{"hotkey":"Super+m"}}')
|
|
try:
|
|
split_hotkey(hotkey)
|
|
except ValueError as exc: # pragma: no cover - behavior exercised in tests
|
|
_raise_cfg_error("daemon.hotkey", f"is invalid: {exc}", '{"daemon":{"hotkey":"Super+m"}}')
|
|
|
|
if isinstance(cfg.recording.input, bool):
|
|
_raise_cfg_error("recording.input", "cannot be boolean", '{"recording":{"input":1}}')
|
|
if not isinstance(cfg.recording.input, (str, int)) and cfg.recording.input is not None:
|
|
_raise_cfg_error(
|
|
"recording.input",
|
|
"must be string, integer, or null",
|
|
'{"recording":{"input":"USB"}}',
|
|
)
|
|
|
|
stt_provider = cfg.stt.provider.strip().lower()
|
|
if stt_provider not in ALLOWED_STT_PROVIDERS:
|
|
allowed = ", ".join(sorted(ALLOWED_STT_PROVIDERS))
|
|
_raise_cfg_error(
|
|
"stt.provider",
|
|
f"must be one of: {allowed}",
|
|
'{"stt":{"provider":"local_whisper"}}',
|
|
)
|
|
cfg.stt.provider = stt_provider
|
|
|
|
model = cfg.stt.model.strip()
|
|
if not model:
|
|
_raise_cfg_error("stt.model", "cannot be empty", '{"stt":{"model":"base"}}')
|
|
|
|
device = cfg.stt.device.strip()
|
|
if not device:
|
|
_raise_cfg_error("stt.device", "cannot be empty", '{"stt":{"device":"cpu"}}')
|
|
if not isinstance(cfg.stt.language, str):
|
|
_raise_cfg_error("stt.language", "must be a string", '{"stt":{"language":"auto"}}')
|
|
try:
|
|
cfg.stt.language = normalize_stt_language(cfg.stt.language)
|
|
except ValueError as exc:
|
|
_raise_cfg_error(
|
|
"stt.language",
|
|
str(exc),
|
|
'{"stt":{"language":"auto"}}',
|
|
)
|
|
|
|
if not isinstance(cfg.models.allow_custom_models, bool):
|
|
_raise_cfg_error(
|
|
"models.allow_custom_models",
|
|
"must be boolean",
|
|
'{"models":{"allow_custom_models":false}}',
|
|
)
|
|
if not isinstance(cfg.models.whisper_model_path, str):
|
|
_raise_cfg_error(
|
|
"models.whisper_model_path",
|
|
"must be string",
|
|
'{"models":{"whisper_model_path":""}}',
|
|
)
|
|
cfg.models.whisper_model_path = cfg.models.whisper_model_path.strip()
|
|
if not cfg.models.allow_custom_models:
|
|
if cfg.models.whisper_model_path:
|
|
_raise_cfg_error(
|
|
"models.whisper_model_path",
|
|
"requires models.allow_custom_models=true",
|
|
'{"models":{"allow_custom_models":true,"whisper_model_path":"/path/model.bin"}}',
|
|
)
|
|
|
|
backend = cfg.injection.backend.strip().lower()
|
|
if backend not in ALLOWED_INJECTION_BACKENDS:
|
|
allowed = ", ".join(sorted(ALLOWED_INJECTION_BACKENDS))
|
|
_raise_cfg_error(
|
|
"injection.backend",
|
|
f"must be one of: {allowed}",
|
|
'{"injection":{"backend":"clipboard"}}',
|
|
)
|
|
cfg.injection.backend = backend
|
|
if not isinstance(cfg.injection.remove_transcription_from_clipboard, bool):
|
|
_raise_cfg_error(
|
|
"injection.remove_transcription_from_clipboard",
|
|
"must be boolean",
|
|
'{"injection":{"remove_transcription_from_clipboard":false}}',
|
|
)
|
|
if not isinstance(cfg.safety.enabled, bool):
|
|
_raise_cfg_error(
|
|
"safety.enabled",
|
|
"must be boolean",
|
|
'{"safety":{"enabled":true}}',
|
|
)
|
|
if not isinstance(cfg.safety.strict, bool):
|
|
_raise_cfg_error(
|
|
"safety.strict",
|
|
"must be boolean",
|
|
'{"safety":{"strict":false}}',
|
|
)
|
|
|
|
profile = cfg.ux.profile.strip().lower()
|
|
if profile not in ALLOWED_UX_PROFILES:
|
|
allowed = ", ".join(sorted(ALLOWED_UX_PROFILES))
|
|
_raise_cfg_error(
|
|
"ux.profile",
|
|
f"must be one of: {allowed}",
|
|
'{"ux":{"profile":"default"}}',
|
|
)
|
|
cfg.ux.profile = profile
|
|
if not isinstance(cfg.ux.show_notifications, bool):
|
|
_raise_cfg_error(
|
|
"ux.show_notifications",
|
|
"must be boolean",
|
|
'{"ux":{"show_notifications":true}}',
|
|
)
|
|
if not isinstance(cfg.advanced.strict_startup, bool):
|
|
_raise_cfg_error(
|
|
"advanced.strict_startup",
|
|
"must be boolean",
|
|
'{"advanced":{"strict_startup":true}}',
|
|
)
|
|
|
|
cfg.vocabulary.replacements = _validate_replacements(cfg.vocabulary.replacements)
|
|
cfg.vocabulary.terms = _validate_terms(cfg.vocabulary.terms)
|
|
|
|
def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
|
_reject_unknown_keys(
|
|
data,
|
|
{
|
|
"config_version",
|
|
"daemon",
|
|
"recording",
|
|
"stt",
|
|
"models",
|
|
"injection",
|
|
"safety",
|
|
"vocabulary",
|
|
"ux",
|
|
"advanced",
|
|
},
|
|
parent="",
|
|
)
|
|
daemon = _ensure_dict(data.get("daemon"), "daemon")
|
|
recording = _ensure_dict(data.get("recording"), "recording")
|
|
stt = _ensure_dict(data.get("stt"), "stt")
|
|
models = _ensure_dict(data.get("models"), "models")
|
|
injection = _ensure_dict(data.get("injection"), "injection")
|
|
safety = _ensure_dict(data.get("safety"), "safety")
|
|
vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary")
|
|
ux = _ensure_dict(data.get("ux"), "ux")
|
|
advanced = _ensure_dict(data.get("advanced"), "advanced")
|
|
|
|
_reject_unknown_keys(daemon, {"hotkey"}, parent="daemon")
|
|
_reject_unknown_keys(recording, {"input"}, parent="recording")
|
|
_reject_unknown_keys(stt, {"provider", "model", "device", "language"}, parent="stt")
|
|
_reject_unknown_keys(
|
|
models,
|
|
{"allow_custom_models", "whisper_model_path"},
|
|
parent="models",
|
|
)
|
|
_reject_unknown_keys(
|
|
injection,
|
|
{"backend", "remove_transcription_from_clipboard"},
|
|
parent="injection",
|
|
)
|
|
_reject_unknown_keys(safety, {"enabled", "strict"}, parent="safety")
|
|
_reject_unknown_keys(vocabulary, {"replacements", "terms"}, parent="vocabulary")
|
|
_reject_unknown_keys(ux, {"profile", "show_notifications"}, parent="ux")
|
|
_reject_unknown_keys(advanced, {"strict_startup"}, parent="advanced")
|
|
|
|
if "config_version" in data:
|
|
cfg.config_version = _as_int(data["config_version"], "config_version")
|
|
if "hotkey" in daemon:
|
|
cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey")
|
|
if "input" in recording:
|
|
cfg.recording.input = _as_recording_input(recording["input"])
|
|
if "provider" in stt:
|
|
cfg.stt.provider = _as_nonempty_str(stt["provider"], "stt.provider")
|
|
if "model" in stt:
|
|
cfg.stt.model = _as_nonempty_str(stt["model"], "stt.model")
|
|
if "device" in stt:
|
|
cfg.stt.device = _as_nonempty_str(stt["device"], "stt.device")
|
|
if "language" in stt:
|
|
cfg.stt.language = _as_nonempty_str(stt["language"], "stt.language")
|
|
if "allow_custom_models" in models:
|
|
cfg.models.allow_custom_models = _as_bool(models["allow_custom_models"], "models.allow_custom_models")
|
|
if "whisper_model_path" in models:
|
|
cfg.models.whisper_model_path = _as_str(models["whisper_model_path"], "models.whisper_model_path")
|
|
if "backend" in injection:
|
|
cfg.injection.backend = _as_nonempty_str(injection["backend"], "injection.backend")
|
|
if "remove_transcription_from_clipboard" in injection:
|
|
cfg.injection.remove_transcription_from_clipboard = _as_bool(
|
|
injection["remove_transcription_from_clipboard"],
|
|
"injection.remove_transcription_from_clipboard",
|
|
)
|
|
if "enabled" in safety:
|
|
cfg.safety.enabled = _as_bool(safety["enabled"], "safety.enabled")
|
|
if "strict" in safety:
|
|
cfg.safety.strict = _as_bool(safety["strict"], "safety.strict")
|
|
if "replacements" in vocabulary:
|
|
cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"])
|
|
if "terms" in vocabulary:
|
|
cfg.vocabulary.terms = _as_terms(vocabulary["terms"])
|
|
if "profile" in ux:
|
|
cfg.ux.profile = _as_nonempty_str(ux["profile"], "ux.profile")
|
|
if "show_notifications" in ux:
|
|
cfg.ux.show_notifications = _as_bool(ux["show_notifications"], "ux.show_notifications")
|
|
if "strict_startup" in advanced:
|
|
cfg.advanced.strict_startup = _as_bool(
|
|
advanced["strict_startup"],
|
|
"advanced.strict_startup",
|
|
)
|
|
return cfg
|
|
|
|
|
|
def _migrate_dict(data: dict[str, Any]) -> dict[str, Any]:
|
|
migrated = dict(data)
|
|
version = migrated.get("config_version")
|
|
if version is None:
|
|
migrated["config_version"] = CURRENT_CONFIG_VERSION
|
|
return migrated
|
|
if not isinstance(version, int):
|
|
_raise_cfg_error("config_version", "must be integer", '{"config_version":1}')
|
|
if version > CURRENT_CONFIG_VERSION:
|
|
_raise_cfg_error(
|
|
"config_version",
|
|
f"unsupported future version {version}; expected <= {CURRENT_CONFIG_VERSION}",
|
|
f'{{"config_version":{CURRENT_CONFIG_VERSION}}}',
|
|
)
|
|
if version <= 0:
|
|
_raise_cfg_error(
|
|
"config_version",
|
|
"must be positive",
|
|
f'{{"config_version":{CURRENT_CONFIG_VERSION}}}',
|
|
)
|
|
if version != CURRENT_CONFIG_VERSION:
|
|
migrated["config_version"] = CURRENT_CONFIG_VERSION
|
|
return migrated
|
|
|
|
|
|
def _reject_unknown_keys(value: dict[str, Any], allowed: set[str], *, parent: str) -> None:
|
|
for key in value.keys():
|
|
if key in allowed:
|
|
continue
|
|
field = f"{parent}.{key}" if parent else key
|
|
_raise_cfg_error(field, "unknown config field", "remove this key from the config")
|
|
|
|
|
|
def _ensure_dict(value: Any, field_name: str) -> dict[str, Any]:
|
|
if value is None:
|
|
return {}
|
|
if not isinstance(value, dict):
|
|
_raise_cfg_error(field_name, "must be an object", f'{{"{field_name}":{{...}}}}')
|
|
return value
|
|
|
|
|
|
def _as_nonempty_str(value: Any, field_name: str) -> str:
|
|
if not isinstance(value, str):
|
|
_raise_cfg_error(field_name, "must be a string", f'{{"{field_name}":"value"}}')
|
|
if not value.strip():
|
|
_raise_cfg_error(field_name, "cannot be empty", f'{{"{field_name}":"value"}}')
|
|
return value
|
|
|
|
|
|
def _as_str(value: Any, field_name: str) -> str:
|
|
if not isinstance(value, str):
|
|
_raise_cfg_error(field_name, "must be a string", f'{{"{field_name}":"value"}}')
|
|
return value
|
|
|
|
|
|
def _as_int(value: Any, field_name: str) -> int:
|
|
if isinstance(value, bool) or not isinstance(value, int):
|
|
_raise_cfg_error(field_name, "must be integer", f'{{"{field_name}":1}}')
|
|
return value
|
|
|
|
|
|
def _as_bool(value: Any, field_name: str) -> bool:
|
|
if not isinstance(value, bool):
|
|
_raise_cfg_error(field_name, "must be boolean", f'{{"{field_name}":true}}')
|
|
return value
|
|
|
|
|
|
def _as_recording_input(value: Any) -> str | int | None:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, bool):
|
|
_raise_cfg_error("recording.input", "cannot be boolean", '{"recording":{"input":1}}')
|
|
if isinstance(value, (str, int)):
|
|
return value
|
|
_raise_cfg_error(
|
|
"recording.input",
|
|
"must be string, integer, or null",
|
|
'{"recording":{"input":"USB"}}',
|
|
)
|
|
|
|
|
|
def _as_replacements(value: Any) -> list[VocabularyReplacement]:
|
|
if not isinstance(value, list):
|
|
_raise_cfg_error(
|
|
"vocabulary.replacements",
|
|
"must be a list",
|
|
'{"vocabulary":{"replacements":[{"from":"A","to":"B"}]}}',
|
|
)
|
|
replacements: list[VocabularyReplacement] = []
|
|
for i, item in enumerate(value):
|
|
if not isinstance(item, dict):
|
|
_raise_cfg_error(
|
|
f"vocabulary.replacements[{i}]",
|
|
"must be an object",
|
|
'{"vocabulary":{"replacements":[{"from":"A","to":"B"}]}}',
|
|
)
|
|
if "from" not in item:
|
|
_raise_cfg_error(
|
|
f"vocabulary.replacements[{i}].from",
|
|
"is required",
|
|
'{"vocabulary":{"replacements":[{"from":"A","to":"B"}]}}',
|
|
)
|
|
if "to" not in item:
|
|
_raise_cfg_error(
|
|
f"vocabulary.replacements[{i}].to",
|
|
"is required",
|
|
'{"vocabulary":{"replacements":[{"from":"A","to":"B"}]}}',
|
|
)
|
|
source = _as_nonempty_str(item["from"], f"vocabulary.replacements[{i}].from")
|
|
target = _as_nonempty_str(item["to"], f"vocabulary.replacements[{i}].to")
|
|
replacements.append(VocabularyReplacement(source=source, target=target))
|
|
return replacements
|
|
|
|
|
|
def _as_terms(value: Any) -> list[str]:
|
|
if not isinstance(value, list):
|
|
_raise_cfg_error("vocabulary.terms", "must be a list", '{"vocabulary":{"terms":["Docker"]}}')
|
|
terms: list[str] = []
|
|
for i, item in enumerate(value):
|
|
terms.append(_as_nonempty_str(item, f"vocabulary.terms[{i}]"))
|
|
return terms
|
|
|
|
|
|
def _validate_replacements(value: list[VocabularyReplacement]) -> list[VocabularyReplacement]:
|
|
deduped: list[VocabularyReplacement] = []
|
|
seen: dict[str, str] = {}
|
|
for i, item in enumerate(value):
|
|
source = item.source.strip()
|
|
target = item.target.strip()
|
|
if not source:
|
|
_raise_cfg_error(f"vocabulary.replacements[{i}].from", "cannot be empty")
|
|
if not target:
|
|
_raise_cfg_error(f"vocabulary.replacements[{i}].to", "cannot be empty")
|
|
if source == target:
|
|
_raise_cfg_error(f"vocabulary.replacements[{i}]", "cannot map a term to itself")
|
|
if "\n" in source or "\n" in target:
|
|
_raise_cfg_error(f"vocabulary.replacements[{i}]", "cannot contain newlines")
|
|
if any(ch in source for ch in WILDCARD_CHARS):
|
|
_raise_cfg_error(
|
|
f"vocabulary.replacements[{i}].from",
|
|
"cannot contain wildcard characters",
|
|
)
|
|
|
|
source_key = _normalize_key(source)
|
|
target_key = _normalize_key(target)
|
|
prev_target = seen.get(source_key)
|
|
if prev_target is None:
|
|
seen[source_key] = target
|
|
deduped.append(VocabularyReplacement(source=source, target=target))
|
|
continue
|
|
if _normalize_key(prev_target) != target_key:
|
|
_raise_cfg_error(
|
|
"vocabulary.replacements",
|
|
f"has conflicting entries for '{source}'",
|
|
)
|
|
return deduped
|
|
|
|
|
|
def _validate_terms(value: list[str]) -> list[str]:
|
|
deduped: list[str] = []
|
|
seen: set[str] = set()
|
|
for i, term in enumerate(value):
|
|
cleaned = term.strip()
|
|
if not cleaned:
|
|
_raise_cfg_error(f"vocabulary.terms[{i}]", "cannot be empty")
|
|
if "\n" in cleaned:
|
|
_raise_cfg_error(f"vocabulary.terms[{i}]", "cannot contain newlines")
|
|
if any(ch in cleaned for ch in WILDCARD_CHARS):
|
|
_raise_cfg_error(f"vocabulary.terms[{i}]", "cannot contain wildcard characters")
|
|
key = _normalize_key(cleaned)
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
deduped.append(cleaned)
|
|
return deduped
|
|
|
|
|
|
def _normalize_key(value: str) -> str:
|
|
return " ".join(value.casefold().split())
|