Add multilingual STT support and config UI/runtime updates
This commit is contained in:
parent
ed950cb7c4
commit
4a69c3d333
26 changed files with 2207 additions and 465 deletions
266
src/config.py
266
src/config.py
|
|
@ -7,13 +7,26 @@ from typing import Any
|
|||
|
||||
from constants import DEFAULT_CONFIG_PATH
|
||||
from hotkey import split_hotkey
|
||||
from languages import DEFAULT_STT_LANGUAGE, normalize_stt_language
|
||||
|
||||
|
||||
CURRENT_CONFIG_VERSION = 1
|
||||
DEFAULT_HOTKEY = "Cmd+m"
|
||||
DEFAULT_STT_PROVIDER = "local_whisper"
|
||||
DEFAULT_STT_MODEL = "base"
|
||||
DEFAULT_STT_DEVICE = "cpu"
|
||||
DEFAULT_LLM_PROVIDER = "local_llama"
|
||||
DEFAULT_EXTERNAL_API_PROVIDER = "openai"
|
||||
DEFAULT_EXTERNAL_API_BASE_URL = "https://api.openai.com/v1"
|
||||
DEFAULT_EXTERNAL_API_MODEL = "gpt-4o-mini"
|
||||
DEFAULT_EXTERNAL_API_TIMEOUT_MS = 15000
|
||||
DEFAULT_EXTERNAL_API_MAX_RETRIES = 2
|
||||
DEFAULT_EXTERNAL_API_KEY_ENV_VAR = "AMAN_EXTERNAL_API_KEY"
|
||||
DEFAULT_INJECTION_BACKEND = "clipboard"
|
||||
DEFAULT_UX_PROFILE = "default"
|
||||
ALLOWED_STT_PROVIDERS = {"local_whisper"}
|
||||
ALLOWED_LLM_PROVIDERS = {"local_llama", "external_api"}
|
||||
ALLOWED_EXTERNAL_API_PROVIDERS = {"openai"}
|
||||
ALLOWED_INJECTION_BACKENDS = {"clipboard", "injection"}
|
||||
ALLOWED_UX_PROFILES = {"default", "fast", "polished"}
|
||||
WILDCARD_CHARS = set("*?[]{}")
|
||||
|
|
@ -47,8 +60,33 @@ class RecordingConfig:
|
|||
|
||||
@dataclass
|
||||
class SttConfig:
|
||||
provider: str = DEFAULT_STT_PROVIDER
|
||||
model: str = DEFAULT_STT_MODEL
|
||||
device: str = DEFAULT_STT_DEVICE
|
||||
language: str = DEFAULT_STT_LANGUAGE
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlmConfig:
|
||||
provider: str = DEFAULT_LLM_PROVIDER
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelsConfig:
|
||||
allow_custom_models: bool = False
|
||||
whisper_model_path: str = ""
|
||||
llm_model_path: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExternalApiConfig:
|
||||
enabled: bool = False
|
||||
provider: str = DEFAULT_EXTERNAL_API_PROVIDER
|
||||
base_url: str = DEFAULT_EXTERNAL_API_BASE_URL
|
||||
model: str = DEFAULT_EXTERNAL_API_MODEL
|
||||
timeout_ms: int = DEFAULT_EXTERNAL_API_TIMEOUT_MS
|
||||
max_retries: int = DEFAULT_EXTERNAL_API_MAX_RETRIES
|
||||
api_key_env_var: str = DEFAULT_EXTERNAL_API_KEY_ENV_VAR
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -82,9 +120,13 @@ class VocabularyConfig:
|
|||
|
||||
@dataclass
|
||||
class Config:
|
||||
config_version: int = CURRENT_CONFIG_VERSION
|
||||
daemon: DaemonConfig = field(default_factory=DaemonConfig)
|
||||
recording: RecordingConfig = field(default_factory=RecordingConfig)
|
||||
stt: SttConfig = field(default_factory=SttConfig)
|
||||
llm: LlmConfig = field(default_factory=LlmConfig)
|
||||
models: ModelsConfig = field(default_factory=ModelsConfig)
|
||||
external_api: ExternalApiConfig = field(default_factory=ExternalApiConfig)
|
||||
injection: InjectionConfig = field(default_factory=InjectionConfig)
|
||||
ux: UxConfig = field(default_factory=UxConfig)
|
||||
advanced: AdvancedConfig = field(default_factory=AdvancedConfig)
|
||||
|
|
@ -102,6 +144,7 @@ def load(path: str | None) -> Config:
|
|||
"must be a JSON object",
|
||||
'{"daemon":{"hotkey":"Super+m"}}',
|
||||
)
|
||||
data = _migrate_dict(data)
|
||||
cfg = _from_dict(data, cfg)
|
||||
validate(cfg)
|
||||
return cfg
|
||||
|
|
@ -128,6 +171,15 @@ def _write_default_config(path: Path, cfg: Config) -> None:
|
|||
|
||||
|
||||
def validate(cfg: Config) -> None:
|
||||
if not isinstance(cfg.config_version, int):
|
||||
_raise_cfg_error("config_version", "must be integer", '{"config_version":1}')
|
||||
if cfg.config_version != CURRENT_CONFIG_VERSION:
|
||||
_raise_cfg_error(
|
||||
"config_version",
|
||||
f"must be {CURRENT_CONFIG_VERSION}",
|
||||
f'{{"config_version":{CURRENT_CONFIG_VERSION}}}',
|
||||
)
|
||||
|
||||
hotkey = cfg.daemon.hotkey.strip()
|
||||
if not hotkey:
|
||||
_raise_cfg_error("daemon.hotkey", "cannot be empty", '{"daemon":{"hotkey":"Super+m"}}')
|
||||
|
|
@ -145,6 +197,16 @@ def validate(cfg: Config) -> None:
|
|||
'{"recording":{"input":"USB"}}',
|
||||
)
|
||||
|
||||
stt_provider = cfg.stt.provider.strip().lower()
|
||||
if stt_provider not in ALLOWED_STT_PROVIDERS:
|
||||
allowed = ", ".join(sorted(ALLOWED_STT_PROVIDERS))
|
||||
_raise_cfg_error(
|
||||
"stt.provider",
|
||||
f"must be one of: {allowed}",
|
||||
'{"stt":{"provider":"local_whisper"}}',
|
||||
)
|
||||
cfg.stt.provider = stt_provider
|
||||
|
||||
model = cfg.stt.model.strip()
|
||||
if not model:
|
||||
_raise_cfg_error("stt.model", "cannot be empty", '{"stt":{"model":"base"}}')
|
||||
|
|
@ -152,6 +214,113 @@ def validate(cfg: Config) -> None:
|
|||
device = cfg.stt.device.strip()
|
||||
if not device:
|
||||
_raise_cfg_error("stt.device", "cannot be empty", '{"stt":{"device":"cpu"}}')
|
||||
if not isinstance(cfg.stt.language, str):
|
||||
_raise_cfg_error("stt.language", "must be a string", '{"stt":{"language":"auto"}}')
|
||||
try:
|
||||
cfg.stt.language = normalize_stt_language(cfg.stt.language)
|
||||
except ValueError as exc:
|
||||
_raise_cfg_error(
|
||||
"stt.language",
|
||||
str(exc),
|
||||
'{"stt":{"language":"auto"}}',
|
||||
)
|
||||
|
||||
llm_provider = cfg.llm.provider.strip().lower()
|
||||
if llm_provider not in ALLOWED_LLM_PROVIDERS:
|
||||
allowed = ", ".join(sorted(ALLOWED_LLM_PROVIDERS))
|
||||
_raise_cfg_error(
|
||||
"llm.provider",
|
||||
f"must be one of: {allowed}",
|
||||
'{"llm":{"provider":"local_llama"}}',
|
||||
)
|
||||
cfg.llm.provider = llm_provider
|
||||
|
||||
if not isinstance(cfg.models.allow_custom_models, bool):
|
||||
_raise_cfg_error(
|
||||
"models.allow_custom_models",
|
||||
"must be boolean",
|
||||
'{"models":{"allow_custom_models":false}}',
|
||||
)
|
||||
if not isinstance(cfg.models.whisper_model_path, str):
|
||||
_raise_cfg_error(
|
||||
"models.whisper_model_path",
|
||||
"must be string",
|
||||
'{"models":{"whisper_model_path":""}}',
|
||||
)
|
||||
if not isinstance(cfg.models.llm_model_path, str):
|
||||
_raise_cfg_error(
|
||||
"models.llm_model_path",
|
||||
"must be string",
|
||||
'{"models":{"llm_model_path":""}}',
|
||||
)
|
||||
cfg.models.whisper_model_path = cfg.models.whisper_model_path.strip()
|
||||
cfg.models.llm_model_path = cfg.models.llm_model_path.strip()
|
||||
if not cfg.models.allow_custom_models:
|
||||
if cfg.models.whisper_model_path:
|
||||
_raise_cfg_error(
|
||||
"models.whisper_model_path",
|
||||
"requires models.allow_custom_models=true",
|
||||
'{"models":{"allow_custom_models":true,"whisper_model_path":"/path/model.bin"}}',
|
||||
)
|
||||
if cfg.models.llm_model_path:
|
||||
_raise_cfg_error(
|
||||
"models.llm_model_path",
|
||||
"requires models.allow_custom_models=true",
|
||||
'{"models":{"allow_custom_models":true,"llm_model_path":"/path/model.gguf"}}',
|
||||
)
|
||||
|
||||
if not isinstance(cfg.external_api.enabled, bool):
|
||||
_raise_cfg_error(
|
||||
"external_api.enabled",
|
||||
"must be boolean",
|
||||
'{"external_api":{"enabled":false}}',
|
||||
)
|
||||
external_provider = cfg.external_api.provider.strip().lower()
|
||||
if external_provider not in ALLOWED_EXTERNAL_API_PROVIDERS:
|
||||
allowed = ", ".join(sorted(ALLOWED_EXTERNAL_API_PROVIDERS))
|
||||
_raise_cfg_error(
|
||||
"external_api.provider",
|
||||
f"must be one of: {allowed}",
|
||||
'{"external_api":{"provider":"openai"}}',
|
||||
)
|
||||
cfg.external_api.provider = external_provider
|
||||
if not cfg.external_api.base_url.strip():
|
||||
_raise_cfg_error(
|
||||
"external_api.base_url",
|
||||
"cannot be empty",
|
||||
'{"external_api":{"base_url":"https://api.openai.com/v1"}}',
|
||||
)
|
||||
if not cfg.external_api.model.strip():
|
||||
_raise_cfg_error(
|
||||
"external_api.model",
|
||||
"cannot be empty",
|
||||
'{"external_api":{"model":"gpt-4o-mini"}}',
|
||||
)
|
||||
if not isinstance(cfg.external_api.timeout_ms, int) or cfg.external_api.timeout_ms <= 0:
|
||||
_raise_cfg_error(
|
||||
"external_api.timeout_ms",
|
||||
"must be a positive integer",
|
||||
'{"external_api":{"timeout_ms":15000}}',
|
||||
)
|
||||
if not isinstance(cfg.external_api.max_retries, int) or cfg.external_api.max_retries < 0:
|
||||
_raise_cfg_error(
|
||||
"external_api.max_retries",
|
||||
"must be a non-negative integer",
|
||||
'{"external_api":{"max_retries":2}}',
|
||||
)
|
||||
if not cfg.external_api.api_key_env_var.strip():
|
||||
_raise_cfg_error(
|
||||
"external_api.api_key_env_var",
|
||||
"cannot be empty",
|
||||
'{"external_api":{"api_key_env_var":"AMAN_EXTERNAL_API_KEY"}}',
|
||||
)
|
||||
|
||||
if cfg.llm.provider == "external_api" and not cfg.external_api.enabled:
|
||||
_raise_cfg_error(
|
||||
"llm.provider",
|
||||
"external_api provider requires external_api.enabled=true",
|
||||
'{"llm":{"provider":"external_api"},"external_api":{"enabled":true}}',
|
||||
)
|
||||
|
||||
backend = cfg.injection.backend.strip().lower()
|
||||
if backend not in ALLOWED_INJECTION_BACKENDS:
|
||||
|
|
@ -197,12 +366,27 @@ def validate(cfg: Config) -> None:
|
|||
def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||
_reject_unknown_keys(
|
||||
data,
|
||||
{"daemon", "recording", "stt", "injection", "vocabulary", "ux", "advanced"},
|
||||
{
|
||||
"config_version",
|
||||
"daemon",
|
||||
"recording",
|
||||
"stt",
|
||||
"llm",
|
||||
"models",
|
||||
"external_api",
|
||||
"injection",
|
||||
"vocabulary",
|
||||
"ux",
|
||||
"advanced",
|
||||
},
|
||||
parent="",
|
||||
)
|
||||
daemon = _ensure_dict(data.get("daemon"), "daemon")
|
||||
recording = _ensure_dict(data.get("recording"), "recording")
|
||||
stt = _ensure_dict(data.get("stt"), "stt")
|
||||
llm = _ensure_dict(data.get("llm"), "llm")
|
||||
models = _ensure_dict(data.get("models"), "models")
|
||||
external_api = _ensure_dict(data.get("external_api"), "external_api")
|
||||
injection = _ensure_dict(data.get("injection"), "injection")
|
||||
vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary")
|
||||
ux = _ensure_dict(data.get("ux"), "ux")
|
||||
|
|
@ -210,7 +394,18 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
|||
|
||||
_reject_unknown_keys(daemon, {"hotkey"}, parent="daemon")
|
||||
_reject_unknown_keys(recording, {"input"}, parent="recording")
|
||||
_reject_unknown_keys(stt, {"model", "device"}, parent="stt")
|
||||
_reject_unknown_keys(stt, {"provider", "model", "device", "language"}, parent="stt")
|
||||
_reject_unknown_keys(llm, {"provider"}, parent="llm")
|
||||
_reject_unknown_keys(
|
||||
models,
|
||||
{"allow_custom_models", "whisper_model_path", "llm_model_path"},
|
||||
parent="models",
|
||||
)
|
||||
_reject_unknown_keys(
|
||||
external_api,
|
||||
{"enabled", "provider", "base_url", "model", "timeout_ms", "max_retries", "api_key_env_var"},
|
||||
parent="external_api",
|
||||
)
|
||||
_reject_unknown_keys(
|
||||
injection,
|
||||
{"backend", "remove_transcription_from_clipboard"},
|
||||
|
|
@ -220,14 +415,44 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
|||
_reject_unknown_keys(ux, {"profile", "show_notifications"}, parent="ux")
|
||||
_reject_unknown_keys(advanced, {"strict_startup"}, parent="advanced")
|
||||
|
||||
if "config_version" in data:
|
||||
cfg.config_version = _as_int(data["config_version"], "config_version")
|
||||
if "hotkey" in daemon:
|
||||
cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey")
|
||||
if "input" in recording:
|
||||
cfg.recording.input = _as_recording_input(recording["input"])
|
||||
if "provider" in stt:
|
||||
cfg.stt.provider = _as_nonempty_str(stt["provider"], "stt.provider")
|
||||
if "model" in stt:
|
||||
cfg.stt.model = _as_nonempty_str(stt["model"], "stt.model")
|
||||
if "device" in stt:
|
||||
cfg.stt.device = _as_nonempty_str(stt["device"], "stt.device")
|
||||
if "language" in stt:
|
||||
cfg.stt.language = _as_nonempty_str(stt["language"], "stt.language")
|
||||
if "provider" in llm:
|
||||
cfg.llm.provider = _as_nonempty_str(llm["provider"], "llm.provider")
|
||||
if "allow_custom_models" in models:
|
||||
cfg.models.allow_custom_models = _as_bool(models["allow_custom_models"], "models.allow_custom_models")
|
||||
if "whisper_model_path" in models:
|
||||
cfg.models.whisper_model_path = _as_str(models["whisper_model_path"], "models.whisper_model_path")
|
||||
if "llm_model_path" in models:
|
||||
cfg.models.llm_model_path = _as_str(models["llm_model_path"], "models.llm_model_path")
|
||||
if "enabled" in external_api:
|
||||
cfg.external_api.enabled = _as_bool(external_api["enabled"], "external_api.enabled")
|
||||
if "provider" in external_api:
|
||||
cfg.external_api.provider = _as_nonempty_str(external_api["provider"], "external_api.provider")
|
||||
if "base_url" in external_api:
|
||||
cfg.external_api.base_url = _as_nonempty_str(external_api["base_url"], "external_api.base_url")
|
||||
if "model" in external_api:
|
||||
cfg.external_api.model = _as_nonempty_str(external_api["model"], "external_api.model")
|
||||
if "timeout_ms" in external_api:
|
||||
cfg.external_api.timeout_ms = _as_int(external_api["timeout_ms"], "external_api.timeout_ms")
|
||||
if "max_retries" in external_api:
|
||||
cfg.external_api.max_retries = _as_int(external_api["max_retries"], "external_api.max_retries")
|
||||
if "api_key_env_var" in external_api:
|
||||
cfg.external_api.api_key_env_var = _as_nonempty_str(
|
||||
external_api["api_key_env_var"], "external_api.api_key_env_var"
|
||||
)
|
||||
if "backend" in injection:
|
||||
cfg.injection.backend = _as_nonempty_str(injection["backend"], "injection.backend")
|
||||
if "remove_transcription_from_clipboard" in injection:
|
||||
|
|
@ -251,6 +476,31 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
|||
return cfg
|
||||
|
||||
|
||||
def _migrate_dict(data: dict[str, Any]) -> dict[str, Any]:
|
||||
migrated = dict(data)
|
||||
version = migrated.get("config_version")
|
||||
if version is None:
|
||||
migrated["config_version"] = CURRENT_CONFIG_VERSION
|
||||
return migrated
|
||||
if not isinstance(version, int):
|
||||
_raise_cfg_error("config_version", "must be integer", '{"config_version":1}')
|
||||
if version > CURRENT_CONFIG_VERSION:
|
||||
_raise_cfg_error(
|
||||
"config_version",
|
||||
f"unsupported future version {version}; expected <= {CURRENT_CONFIG_VERSION}",
|
||||
f'{{"config_version":{CURRENT_CONFIG_VERSION}}}',
|
||||
)
|
||||
if version <= 0:
|
||||
_raise_cfg_error(
|
||||
"config_version",
|
||||
"must be positive",
|
||||
f'{{"config_version":{CURRENT_CONFIG_VERSION}}}',
|
||||
)
|
||||
if version != CURRENT_CONFIG_VERSION:
|
||||
migrated["config_version"] = CURRENT_CONFIG_VERSION
|
||||
return migrated
|
||||
|
||||
|
||||
def _reject_unknown_keys(value: dict[str, Any], allowed: set[str], *, parent: str) -> None:
|
||||
for key in value.keys():
|
||||
if key in allowed:
|
||||
|
|
@ -275,6 +525,18 @@ def _as_nonempty_str(value: Any, field_name: str) -> str:
|
|||
return value
|
||||
|
||||
|
||||
def _as_str(value: Any, field_name: str) -> str:
|
||||
if not isinstance(value, str):
|
||||
_raise_cfg_error(field_name, "must be a string", f'{{"{field_name}":"value"}}')
|
||||
return value
|
||||
|
||||
|
||||
def _as_int(value: Any, field_name: str) -> int:
|
||||
if isinstance(value, bool) or not isinstance(value, int):
|
||||
_raise_cfg_error(field_name, "must be integer", f'{{"{field_name}":1}}')
|
||||
return value
|
||||
|
||||
|
||||
def _as_bool(value: Any, field_name: str) -> bool:
|
||||
if not isinstance(value, bool):
|
||||
_raise_cfg_error(field_name, "must be boolean", f'{{"{field_name}":true}}')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue