Add vocabulary correction pipeline and example config

This commit is contained in:
Thales Maciel 2026-02-25 10:03:32 -03:00
parent f9224621fa
commit c3503fbbde
9 changed files with 865 additions and 23 deletions

View file

@ -12,7 +12,11 @@ DEFAULT_HOTKEY = "Cmd+m"
DEFAULT_STT_MODEL = "base"
DEFAULT_STT_DEVICE = "cpu"
DEFAULT_INJECTION_BACKEND = "clipboard"
DEFAULT_VOCAB_LIMIT = 500
DEFAULT_DOMAIN_INFERENCE_MODE = "auto"
ALLOWED_INJECTION_BACKENDS = {"clipboard", "injection"}
ALLOWED_DOMAIN_INFERENCE_MODES = {"auto"}
WILDCARD_CHARS = set("*?[]{}")
@dataclass
@ -46,6 +50,26 @@ class LoggingConfig:
log_transcript: bool = False
@dataclass
class VocabularyReplacement:
source: str
target: str
@dataclass
class VocabularyConfig:
replacements: list[VocabularyReplacement] = field(default_factory=list)
terms: list[str] = field(default_factory=list)
max_rules: int = DEFAULT_VOCAB_LIMIT
max_terms: int = DEFAULT_VOCAB_LIMIT
@dataclass
class DomainInferenceConfig:
enabled: bool = True
mode: str = DEFAULT_DOMAIN_INFERENCE_MODE
@dataclass
class Config:
daemon: DaemonConfig = field(default_factory=DaemonConfig)
@ -54,6 +78,8 @@ class Config:
injection: InjectionConfig = field(default_factory=InjectionConfig)
ai: AiConfig = field(default_factory=AiConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig)
domain_inference: DomainInferenceConfig = field(default_factory=DomainInferenceConfig)
def load(path: str | None) -> Config:
@ -102,10 +128,43 @@ def validate(cfg: Config) -> None:
if not isinstance(cfg.logging.log_transcript, bool):
raise ValueError("logging.log_transcript must be boolean")
cfg.vocabulary.max_rules = _validated_limit(cfg.vocabulary.max_rules, "vocabulary.max_rules")
cfg.vocabulary.max_terms = _validated_limit(cfg.vocabulary.max_terms, "vocabulary.max_terms")
if len(cfg.vocabulary.replacements) > cfg.vocabulary.max_rules:
raise ValueError(
f"vocabulary.replacements cannot exceed vocabulary.max_rules ({cfg.vocabulary.max_rules})"
)
if len(cfg.vocabulary.terms) > cfg.vocabulary.max_terms:
raise ValueError(
f"vocabulary.terms cannot exceed vocabulary.max_terms ({cfg.vocabulary.max_terms})"
)
cfg.vocabulary.replacements = _validate_replacements(cfg.vocabulary.replacements)
cfg.vocabulary.terms = _validate_terms(cfg.vocabulary.terms)
if not isinstance(cfg.domain_inference.enabled, bool):
raise ValueError("domain_inference.enabled must be boolean")
mode = cfg.domain_inference.mode.strip().lower()
if mode not in ALLOWED_DOMAIN_INFERENCE_MODES:
allowed = ", ".join(sorted(ALLOWED_DOMAIN_INFERENCE_MODES))
raise ValueError(f"domain_inference.mode must be one of: {allowed}")
cfg.domain_inference.mode = mode
def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
has_sections = any(
key in data for key in ("daemon", "recording", "stt", "injection", "ai", "logging")
key in data
for key in (
"daemon",
"recording",
"stt",
"injection",
"ai",
"logging",
"vocabulary",
"domain_inference",
)
)
if has_sections:
daemon = _ensure_dict(data.get("daemon"), "daemon")
@ -114,6 +173,8 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
injection = _ensure_dict(data.get("injection"), "injection")
ai = _ensure_dict(data.get("ai"), "ai")
logging_cfg = _ensure_dict(data.get("logging"), "logging")
vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary")
domain_inference = _ensure_dict(data.get("domain_inference"), "domain_inference")
if "hotkey" in daemon:
cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey")
@ -129,6 +190,22 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
cfg.ai.enabled = _as_bool(ai["enabled"], "ai.enabled")
if "log_transcript" in logging_cfg:
cfg.logging.log_transcript = _as_bool(logging_cfg["log_transcript"], "logging.log_transcript")
if "replacements" in vocabulary:
cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"])
if "terms" in vocabulary:
cfg.vocabulary.terms = _as_terms(vocabulary["terms"])
if "max_rules" in vocabulary:
cfg.vocabulary.max_rules = _as_int(vocabulary["max_rules"], "vocabulary.max_rules")
if "max_terms" in vocabulary:
cfg.vocabulary.max_terms = _as_int(vocabulary["max_terms"], "vocabulary.max_terms")
if "enabled" in domain_inference:
cfg.domain_inference.enabled = _as_bool(
domain_inference["enabled"], "domain_inference.enabled"
)
if "mode" in domain_inference:
cfg.domain_inference.mode = _as_nonempty_str(
domain_inference["mode"], "domain_inference.mode"
)
return cfg
if "hotkey" in data:
@ -170,6 +247,12 @@ def _as_bool(value: Any, field_name: str) -> bool:
return value
def _as_int(value: Any, field_name: str) -> int:
if isinstance(value, bool) or not isinstance(value, int):
raise ValueError(f"{field_name} must be an integer")
return value
def _as_recording_input(value: Any) -> str | int | None:
if value is None:
return None
@ -178,3 +261,93 @@ def _as_recording_input(value: Any) -> str | int | None:
if isinstance(value, (str, int)):
return value
raise ValueError("recording.input must be string, integer, or null")
def _as_replacements(value: Any) -> list[VocabularyReplacement]:
if not isinstance(value, list):
raise ValueError("vocabulary.replacements must be a list")
replacements: list[VocabularyReplacement] = []
for i, item in enumerate(value):
if not isinstance(item, dict):
raise ValueError(f"vocabulary.replacements[{i}] must be an object")
if "from" not in item:
raise ValueError(f"vocabulary.replacements[{i}].from is required")
if "to" not in item:
raise ValueError(f"vocabulary.replacements[{i}].to is required")
source = _as_nonempty_str(item["from"], f"vocabulary.replacements[{i}].from")
target = _as_nonempty_str(item["to"], f"vocabulary.replacements[{i}].to")
replacements.append(VocabularyReplacement(source=source, target=target))
return replacements
def _as_terms(value: Any) -> list[str]:
if not isinstance(value, list):
raise ValueError("vocabulary.terms must be a list")
terms: list[str] = []
for i, item in enumerate(value):
terms.append(_as_nonempty_str(item, f"vocabulary.terms[{i}]"))
return terms
def _validated_limit(value: int, field_name: str) -> int:
if isinstance(value, bool) or not isinstance(value, int):
raise ValueError(f"{field_name} must be an integer")
if value <= 0:
raise ValueError(f"{field_name} must be positive")
if value > 5000:
raise ValueError(f"{field_name} cannot exceed 5000")
return value
def _validate_replacements(value: list[VocabularyReplacement]) -> list[VocabularyReplacement]:
deduped: list[VocabularyReplacement] = []
seen: dict[str, str] = {}
for i, item in enumerate(value):
source = item.source.strip()
target = item.target.strip()
if not source:
raise ValueError(f"vocabulary.replacements[{i}].from cannot be empty")
if not target:
raise ValueError(f"vocabulary.replacements[{i}].to cannot be empty")
if source == target:
raise ValueError(f"vocabulary.replacements[{i}] cannot map a term to itself")
if "\n" in source or "\n" in target:
raise ValueError(f"vocabulary.replacements[{i}] cannot contain newlines")
if any(ch in source for ch in WILDCARD_CHARS):
raise ValueError(
f"vocabulary.replacements[{i}].from cannot contain wildcard characters"
)
source_key = _normalize_key(source)
target_key = _normalize_key(target)
prev_target = seen.get(source_key)
if prev_target is None:
seen[source_key] = target
deduped.append(VocabularyReplacement(source=source, target=target))
continue
if _normalize_key(prev_target) != target_key:
raise ValueError(f"vocabulary.replacements has conflicting entries for '{source}'")
return deduped
def _validate_terms(value: list[str]) -> list[str]:
deduped: list[str] = []
seen: set[str] = set()
for i, term in enumerate(value):
cleaned = term.strip()
if not cleaned:
raise ValueError(f"vocabulary.terms[{i}] cannot be empty")
if "\n" in cleaned:
raise ValueError(f"vocabulary.terms[{i}] cannot contain newlines")
if any(ch in cleaned for ch in WILDCARD_CHARS):
raise ValueError(f"vocabulary.terms[{i}] cannot contain wildcard characters")
key = _normalize_key(cleaned)
if key in seen:
continue
seen.add(key)
deduped.append(cleaned)
return deduped
def _normalize_key(value: str) -> str:
return " ".join(value.casefold().split())